All files / src/modules/auth origin-validation.middleware.ts

68.96% Statements 20/29
66.66% Branches 10/15
100% Functions 4/4
65.38% Lines 17/26

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 6212x 12x   12x 12x     12x       24x 24x   24x   24x     24x 24x           24x 24x                 162x 162x                                       162x   162x    
import { Injectable, NestMiddleware, ForbiddenException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { Request, Response, NextFunction } from 'express';
import { Counter, register } from 'prom-client';
import { LoggerService } from '@app/common/services/logger.service';
 
@Injectable()
export class OriginValidationMiddleware implements NestMiddleware {
  private readonly allowed: string[];
  private readonly counter: Counter<string>;
 
  constructor(private readonly config: ConfigService, private readonly logger: LoggerService) {
    const raw = this.config.get<string>('FRONTEND_URL') || '';
    // support comma-separated list for multiple frontends, normalise to scheme+host (no slash)
    this.allowed = raw
      .split(',')
      .map(v => v.trim())
      .filter(Boolean)
      .map(v => {
        try {
          return new URL(v).origin.toLowerCase();
        } catch {
          return v.toLowerCase();
        }
      });
 
    const existing = register.getSingleMetric('auth_origin_validation_total') as Counter<string> | undefined;
    this.counter = existing ?? new Counter({
      name: 'auth_origin_validation_total',
      help: 'Auth origin validation outcomes',
      labelNames: ['status'] as const,
      registers: [register],
    });
  }
 
  use(req: Request, _res: Response, next: NextFunction) {
    const originHeader = req.headers.origin;
    Iif (originHeader) {
      let origin: string;
      try {
        origin = new URL(originHeader).origin.toLowerCase();
      } catch {
        origin = originHeader.toLowerCase();
      }
 
      Iif (!this.allowed.includes(origin)) {
        this.counter.inc({ status: 'fail' });
        this.logger.warn('Rejected request due to invalid origin', {
          origin: originHeader,
          ip: req.ip,
          path: req.originalUrl,
        });
        throw new ForbiddenException('Invalid request origin');
      }
      this.counter.inc({ status: 'pass' });
    } else {
      // No Origin header – could be same-site or CSRF attempt; count separately
      this.counter.inc({ status: 'no_header' } as any);
    }
    next();
  }
}