中间件机制

21 阅读3分钟

NestJS 源码解析:中间件机制

深入 MiddlewareModule 和 MiddlewareBuilder,揭秘中间件的注册与执行。

中间件基础

NestJS 中间件与 Express 中间件兼容:

@Injectable()
export class LoggerMiddleware implements NestMiddleware {
  use(req: Request, res: Response, next: NextFunction) {
    console.log('Request...');
    next();
  }
}

中间件注册

在模块中配置中间件:

@Module({
  imports: [CatsModule],
})
export class AppModule implements NestModule {
  configure(consumer: MiddlewareConsumer) {
    consumer
      .apply(LoggerMiddleware, AuthMiddleware)
      .exclude({ path: 'cats', method: RequestMethod.GET })
      .forRoutes(CatsController);
  }
}

MiddlewareBuilder

构建中间件配置:

// packages/core/middleware/builder.ts
export class MiddlewareBuilder implements MiddlewareConsumer {
  private readonly middlewareCollection = new Set<MiddlewareConfiguration>();

  public apply(
    ...middleware: Array<Type<any> | Function | Array<Type<any> | Function>>
  ): MiddlewareConfigProxy {
    return new MiddlewareBuilder.ConfigProxy(
      this,
      middleware.flat(),
      this.routeInfoPathExtractor,
    );
  }

  public build(): MiddlewareConfiguration[] {
    return [...this.middlewareCollection];
  }

  // 内部代理类
  private static readonly ConfigProxy = class implements MiddlewareConfigProxy {
    private excludedRoutes: RouteInfo[] = [];

    constructor(
      private readonly builder: MiddlewareBuilder,
      private readonly middleware: Array<Type<any> | Function>,
      private routeInfoPathExtractor: RouteInfoPathExtractor,
    ) {}

    // 排除路由
    public exclude(...routes: Array<string | RouteInfo>): MiddlewareConfigProxy {
      this.excludedRoutes = [
        ...this.excludedRoutes,
        ...this.getRoutesFlatList(routes).reduce((excludedRoutes, route) => {
          for (const routePath of this.routeInfoPathExtractor.extractPathFrom(route)) {
            excludedRoutes.push({ ...route, path: routePath });
          }
          return excludedRoutes;
        }, [] as RouteInfo[]),
      ];
      return this;
    }

    // 应用到路由
    public forRoutes(
      ...routes: Array<string | Type<any> | RouteInfo>
    ): MiddlewareConsumer {
      const forRoutes = this.getRoutesFlatList(routes);

      // 添加到集合
      this.builder.middlewareCollection.add({
        middleware: this.middleware,
        forRoutes,
        excludedRoutes: this.excludedRoutes,
      });

      return this.builder;
    }
  };
}

MiddlewareModule

注册和解析中间件:

// packages/core/middleware/middleware-module.ts
export class MiddlewareModule {
  public async register(
    middlewareContainer: MiddlewareContainer,
    container: NestContainer,
    config: ApplicationConfig,
    injector: Injector,
    httpAdapter: HttpServer,
    graphInspector: GraphInspector,
  ) {
    const modules = container.getModules();

    // 遍历所有模块
    for (const [token, moduleRef] of modules) {
      const instance = moduleRef.instance;

      // 调用 configure 方法
      if (instance && isFunction(instance.configure)) {
        const middlewareBuilder = new MiddlewareBuilder(
          this.routesMapper,
          httpAdapter,
          this.routeInfoPathExtractor,
        );

        await instance.configure(middlewareBuilder);

        // 收集中间件配置
        const middlewareConfigs = middlewareBuilder.build();
        middlewareContainer.insertConfig(middlewareConfigs, token);
      }
    }

    // 解析中间件实例
    await this.resolver.resolveInstances(middlewareContainer, container);

    // 注册到 HTTP 适配器
    await this.registerMiddleware(middlewareContainer, httpAdapter);
  }

  // 注册中间件到路由
  private async registerMiddleware(
    middlewareContainer: MiddlewareContainer,
    httpAdapter: HttpServer,
  ) {
    const configs = middlewareContainer.getConfigurations();

    for (const [moduleKey, moduleConfigs] of configs) {
      for (const config of moduleConfigs) {
        await this.registerRouteMiddleware(
          middlewareContainer,
          config,
          moduleKey,
          httpAdapter,
        );
      }
    }
  }

  private async registerRouteMiddleware(
    middlewareContainer: MiddlewareContainer,
    config: MiddlewareConfiguration,
    moduleKey: string,
    httpAdapter: HttpServer,
  ) {
    const { forRoutes, middleware, excludedRoutes } = config;

    for (const route of forRoutes) {
      // 获取中间件实例
      const middlewareInstances = await this.resolveMiddleware(
        middlewareContainer,
        middleware,
        moduleKey,
      );

      // 绑定到路由
      for (const middlewareInstance of middlewareInstances) {
        this.bindHandler(
          middlewareInstance,
          route,
          httpAdapter,
          excludedRoutes,
        );
      }
    }
  }
}

MiddlewareResolver

解析中间件实例:

// packages/core/middleware/resolver.ts
export class MiddlewareResolver {
  public async resolveInstances(
    middlewareContainer: MiddlewareContainer,
    container: NestContainer,
  ) {
    const configs = middlewareContainer.getConfigurations();

    for (const [moduleKey, moduleConfigs] of configs) {
      const moduleRef = container.getModules().get(moduleKey);

      for (const config of moduleConfigs) {
        await this.resolveMiddlewareInstances(
          config.middleware,
          moduleRef,
          middlewareContainer,
        );
      }
    }
  }

  private async resolveMiddlewareInstances(
    middleware: Array<Type<any> | Function>,
    moduleRef: Module,
    middlewareContainer: MiddlewareContainer,
  ) {
    for (const metatype of middleware) {
      // 函数中间件直接使用
      if (!isClass(metatype)) {
        continue;
      }

      // 类中间件需要实例化
      const wrapper = moduleRef.middlewares.get(metatype);
      if (wrapper) {
        await this.injector.loadInstance(
          wrapper,
          moduleRef.middlewares,
          moduleRef,
        );
      }
    }
  }
}

中间件执行

绑定到 HTTP 适配器

// packages/core/middleware/middleware-module.ts
private bindHandler(
  middlewareInstance: NestMiddleware | Function,
  routeInfo: RouteInfo,
  httpAdapter: HttpServer,
  excludedRoutes: RouteInfo[],
) {
  const { path, method } = routeInfo;

  // 创建中间件处理函数
  const handler = isFunction(middlewareInstance)
    ? middlewareInstance
    : middlewareInstance.use.bind(middlewareInstance);

  // 包装排除逻辑
  const wrappedHandler = this.createExcludedMiddleware(
    handler,
    excludedRoutes,
  );

  // 注册到适配器
  const routerMethod = this.routerMethodFactory.get(httpAdapter, method);
  routerMethod.call(httpAdapter, path, wrappedHandler);
}

// 创建排除中间件
private createExcludedMiddleware(
  handler: Function,
  excludedRoutes: RouteInfo[],
) {
  if (!excludedRoutes.length) {
    return handler;
  }

  return (req: any, res: any, next: Function) => {
    const isExcluded = excludedRoutes.some(route =>
      this.isRouteExcluded(req, route),
    );

    if (isExcluded) {
      return next();
    }
    return handler(req, res, next);
  };
}

函数中间件

除了类中间件,还支持函数中间件:

// 函数中间件
export function logger(req: Request, res: Response, next: NextFunction) {
  console.log('Request...');
  next();
}

// 使用
consumer.apply(logger).forRoutes('*');

全局中间件

// main.ts
const app = await NestFactory.create(AppModule);
app.use(helmet());
app.use(compression());

全局中间件直接注册到 HTTP 适配器:

// packages/core/nest-application.ts
public use(...args: any[]): this {
  this.httpAdapter.use(...args);
  return this;
}

中间件执行顺序

1. 全局中间件 (app.use)
   ↓
2. 模块中间件 (按模块距离排序,距离小的先执行)
   ↓
3. 路由中间件 (按 forRoutes 顺序)
   ↓
4. Guards
   ↓
5. Interceptors
   ↓
6. Pipes
   ↓
7. Route Handler

源码中按模块距离排序:

// packages/core/middleware/middleware-module.ts
const entriesSortedByDistance = [...configs.entries()].sort(
  ([moduleA], [moduleB]) => {
    const moduleARef = this.container.getModuleByKey(moduleA)!;
    const moduleBRef = this.container.getModuleByKey(moduleB)!;
    const isModuleAGlobal = moduleARef.distance === Number.MAX_VALUE;
    const isModuleBGlobal = moduleBRef.distance === Number.MAX_VALUE;
    
    if (isModuleAGlobal && isModuleBGlobal) return 0;
    if (isModuleAGlobal) return -1;  // 全局模块优先
    if (isModuleBGlobal) return 1;
    
    return moduleARef.distance - moduleBRef.distance;
  },
);

路由通配符

consumer
  .apply(LoggerMiddleware)
  .forRoutes({ path: 'ab*cd', method: RequestMethod.ALL });

路径匹配使用 path-to-regexp

// packages/core/middleware/route-info-path-extractor.ts
public extractPathFrom(route: RouteInfo): string[] {
  const { path } = route;

  // 处理通配符
  if (path.includes('*')) {
    return this.extractPathsFromWildcard(path);
  }

  return [path];
}

中间件与守卫的区别

特性中间件守卫
执行时机路由匹配前路由匹配后
访问上下文req, res, nextExecutionContext
返回值voidboolean
依赖注入支持支持
适用场景通用处理权限控制

总结

NestJS 中间件机制的核心:

  1. MiddlewareBuilder:构建中间件配置
  2. MiddlewareModule:注册和解析中间件
  3. MiddlewareResolver:实例化类中间件
  4. 执行顺序:全局 → 模块 → 路由
  5. 路由匹配:支持通配符和排除规则
  6. 兼容性:与 Express 中间件完全兼容

下一篇我们将分析 AOP 实现(Guard、Pipe、Interceptor)。


📦 源码位置:packages/core/middleware/

下一篇:NestJS AOP 实现