一文了解Nest中controller的注册

241 阅读6分钟

1.注册controller的本质

Nest的注解是依赖Reflect Metadata实现的

// packages/common/decorators/modules/module.decorator.ts
export function Module(metadata: ModuleMetadata): ClassDecorator {
  const propsKeys = Object.keys(metadata);
  validateModuleKeys(propsKeys);

  return (target: Function) => {
    for (const property in metadata) {
      if (metadata.hasOwnProperty(property)) {
        Reflect.defineMetadata(property, (metadata as any)[property], target);
      }
    }
  };
}

在Nest中注册controller实际上是为Module添加元数据

@Module({
    // ...
    controllers: [AController, BController]
})
export class AppModule {}

// 获取
Reflect.getMetadata('controllers', AppModule);  // [AController, BController]

这里是一个标准的Contoller

@Controller('api')
export class ApiController {
    @Post('create')
    async create(@Body() dto: EmailCreateDto) {
        return {};
    }
    
    @Get('get')
    async create() {
        return {};
    }
}

转换成底层实现

// Nest通过分析元数据的方式可以转化为以下代码(伪)
express.post('/api/create', ApiController.create);
express.get('/api/get', ApiController.get);

底层实现是为express(或是fastify)注册路由,参考Express routing

进一步了解Nest是如何实现的

2.收集Controller

在NestFactory.creare中实例化了httpServer也就是express实例

回顾一下Nestjs启动方法

async function bootstrap() {
  const app = await NestFactory.create(AppModule);
  const configService = app.get(ConfigService);
  app.useGlobalPipes(new ValidationPipe({ transform: true }));
  await app.listen(configService.get<number>('app.port') ?? 3000);
}
bootstrap();

进入NestFactory.create方法

// packages/core/nest-factory.ts
public async create<T extends INestApplication = INestApplication>(
    moduleCls: any,  // AppModule
    serverOrOptions?: AbstractHttpAdapter | NestApplicationOptions,
    options?: NestApplicationOptions,
  ): Promise<T> {
    // 创建HttpServer
    const [httpServer, appOptions] = this.isHttpServer(serverOrOptions)
      ? [serverOrOptions, options]
          : [this.createHttpAdapter(), serverOrOptions];
     // 关键容器
     const container = new NestContainer(applicationConfig);
     // ...
     await this.initialize(
      moduleCls,
      container,
      graphInspector,
      applicationConfig,
      appOptions,
      httpServer,
    );
    // 关键:创建NestApplication实例,容器也会传入进去
    const instance = new NestApplication(
      container,
      httpServer,
      applicationConfig,
      graphInspector,
      appOptions,
    );
    // 下面两句都是代理 - 可以理解为返回了NestApplication
    const target = this.createNestInstance(instance);
    return this.createAdapterProxy<T>(target, httpServer);
}

private createHttpAdapter<T = any>(httpServer?: T): AbstractHttpAdapter {
    const { ExpressAdapter } = loadAdapter(
      '@nestjs/platform-express',
      'HTTP',
      () => require('@nestjs/platform-express'),
    );
    return new ExpressAdapter(httpServer);
  }
  
private async initialize(
    module: any,  // AppModule
    container: NestContainer,
    graphInspector: GraphInspector,
    config = new ApplicationConfig(),
    options: NestApplicationContextOptions = {},
    httpServer: HttpServer = null,
  ) {
      // ...
    const metadataScanner = new MetadataScanner();
    const dependenciesScanner = new DependenciesScanner(
      container,
      metadataScanner,
      graphInspector,
      config,
    );
    container.setHttpAdapter(httpServer);
    // ...
    // 关键入口:对AppModule进行依赖扫描
    await dependenciesScanner.scan(module);
 }

从create -> initialize -> dependenciesScanner.scan 开始AppModule进行依赖分析
步骤:收集AppModule下所有的子Module,再逐个读取子Module注册的Controller将其保存在container中

// packages/core/scanner.ts
public async scan(
    module: Type<any>,
    options?: { overrides?: ModuleOverride[] },
  ) {
    // 注册CoreModule 不用管
    await this.registerCoreModule(options?.overrides);
    // 递归扫描所有Module
    await this.scanForModules({
      moduleDefinition: module,
      overrides: options?.overrides,
    });
    // 扫描所有Module的依赖 
    await this.scanModulesForDependencies();
    
    // ...
  }
  
 public async scanForModules({
    moduleDefinition,
    lazy,
    scope = [],
    ctxRegistry = [],
    overrides = [],
  }: ModulesScanParameters): Promise<Module[]> {
     // 关键入口:在Container中插入或者覆盖Module
    const { moduleRef: moduleInstance, inserted: moduleInserted } =
      (await this.insertOrOverrideModule(moduleDefinition, overrides, scope)) ??
      {};

    // 读取moduleDefinition元数据
    const modules = !this.isDynamicModule(
      moduleDefinition as Type<any> | DynamicModule,
    )
      ? this.reflectMetadata(
          MODULE_METADATA.IMPORTS,
          moduleDefinition as Type<any>,
        )
      : [
          ...this.reflectMetadata(
            MODULE_METADATA.IMPORTS,
            (moduleDefinition as DynamicModule).module,
          ),
          ...((moduleDefinition as DynamicModule).imports || []),
        ];

    // 递归子Module
    for (const [index, innerModule] of modules.entries()) {
      const moduleRefs = await this.scanForModules({
        moduleDefinition: innerModule,
        scope: [].concat(scope, moduleDefinition),
        ctxRegistry,
        overrides,
        lazy,
      });
    }
    // ...
  }
  
    // 插入或者覆盖Module
   private insertOrOverrideModule(
    moduleDefinition: ModuleDefinition,
    overrides: ModuleOverride[],
    scope: Type<unknown>[],
  ): Promise<
    | {
        moduleRef: Module;
        inserted: boolean;
      }
    | undefined
  > {
    // ...
    return this.insertModule(moduleDefinition, scope);
  }
  
  // 插入Module到container
   public async insertModule(
    moduleDefinition: any,
    scope: Type<unknown>[],
  ): Promise<
    | {
        moduleRef: Module;
        inserted: boolean;
      }
    | undefined
  > {
    return this.container.addModule(moduleToAdd, scope);
  }
  
 // 遍历Module的收集各种依赖包括Controller
 public async scanModulesForDependencies(
    modules: Map<string, Module> = this.container.getModules(),  // 获取收集完成的所有Module
  ) {
    // 遍历Module
    for (const [token, { metatype }] of modules) {
      await this.reflectImports(metatype, token, metatype.name);
      this.reflectProviders(metatype, token);
      this.reflectControllers(metatype, token);
      this.reflectExports(metatype, token);
    }
  }
  
  // 收集单个Module的Controller
  public reflectControllers(module: Type<any>, token: string) {
    const controllers = [
      ...this.reflectMetadata(MODULE_METADATA.CONTROLLERS, module),
      ...this.container.getDynamicMetadataByToken(
        token,
        MODULE_METADATA.CONTROLLERS as 'controllers',
      ),
    ];
    controllers.forEach(item => {
      this.insertController(item, token);
      this.reflectDynamicMetadata(item, token);
    });
  }
  
  // 将Controller保存到container中
  public insertController(controller: Type<Controller>, token: string) {
    this.container.addController(controller, token);
  }

至此container中收集了所有Module与Controller

3.遍历Controller为HttpServer注册路由

在nest-factory中create方法创建了NestApplication实例
让我们回顾一下Nest的启动方法

async function bootstrap() {
  const app = await NestFactory.create(AppModule);
  const configService = app.get(ConfigService);
  app.useGlobalPipes(new ValidationPipe({ transform: true }));
  await app.listen(configService.get<number>('app.port') ?? 3000);
}
bootstrap();

进入app即NestApplication看listen方法

//  packages/core/nest-application.ts
export class NestApplication {
    constructor(
    container: NestContainer,  // 关键容器(已知容器中收集了所有Controller类)
    private readonly httpAdapter: HttpServer,
    private readonly config: ApplicationConfig,
    private readonly graphInspector: GraphInspector,
    appOptions: NestApplicationOptions = {},
  ) {
    super(container, appOptions);

    // ...
    this.routesResolver = new RoutesResolver(
      this.container,  // 关键容器
      this.config,
      this.injector,
      this.graphInspector,
    );
  }
  
  // 监听端口方法
  public async listen(port: number | string, ...args: any[]): Promise<any> {
      !this.isInitialized && (await this.init());
      // ...
  }
  
  // 初始化
  public async init(): Promise<this> {
    // ...
    await this.registerRouter();
    // ...
    return this;
  }
  
  // 注册路由
  public async registerRouter() {
    await this.registerMiddleware(this.httpAdapter);

    const prefix = this.config.getGlobalPrefix();
    const basePath = addLeadingSlash(prefix);
    // resove方法为关键入口
    this.routesResolver.resolve(this.httpAdapter, basePath);
  }
}

进入RoutesResolver

export class RoutesResolver implements Resolver {
    constructor(
        private readonly container: NestContainer,
        private readonly applicationConfig: ApplicationConfig,
        private readonly injector: Injector,
        graphInspector: GraphInspector,
      ) {
        // ...
        this.routerExplorer = new RouterExplorer(
          metadataScanner,
          this.container,
          this.injector,
          this.routerProxy,
          this.routerExceptionsFilter,
          this.applicationConfig,
          this.routePathFactory,
          graphInspector,
        );
    }
    
  public resolve<T extends HttpServer>(
    applicationRef: T,
    globalPrefix: string,
  ) {
    const modules = this.container.getModules();
    // 遍历container中的所有controllers
    modules.forEach(({ controllers, metatype }, moduleName) => {
      const modulePath = this.getModulePathMetadata(metatype);
      this.registerRouters(
        controllers,
        moduleName,
        globalPrefix,
        modulePath,
        applicationRef,
      );
    });
  }
  
  public registerRouters(
    routes: Map<string | symbol | Function, InstanceWrapper<Controller>>,
    moduleName: string,
    globalPrefix: string,
    modulePath: string,
    applicationRef: HttpServer,
  ) {
    routes.forEach(instanceWrapper => {
      const { metatype } = instanceWrapper;

      // 获取@Controller注解元数据
      const host = this.getHostMetadata(metatype);
      // @Controller可以注册多个路径,拼接/开头
      const routerPaths = this.routerExplorer.extractRouterPath(
        metatype as Type<any>,
      );
      const controllerVersion = this.getVersionMetadata(metatype);
      const controllerName = metatype.name;
      // 遍历@Controller注册的请求路径
      routerPaths.forEach(path => {
        // ...
        const versioningOptions = this.applicationConfig.getVersioning();
        const routePathMetadata: RoutePathMetadata = {
          ctrlPath: path,
          modulePath,
          globalPrefix,
          controllerVersion,
          versioningOptions,
        };
        // 关键入口
        this.routerExplorer.explore(
          instanceWrapper,
          moduleName,
          applicationRef,
          host,
          routePathMetadata,
        );
      });
    });
  }
}

进入RouterExplorer

export class RouterExplorer {
    constructor(
        metadataScanner: MetadataScanner,
        private readonly container: NestContainer,
        private readonly injector: Injector,
        private readonly routerProxy: RouterProxy,
        private readonly exceptionsFilter: ExceptionsFilter,
        config: ApplicationConfig,
        private readonly routePathFactory: RoutePathFactory,
        private readonly graphInspector: GraphInspector,
      ) {
        this.pathsExplorer = new PathsExplorer(metadataScanner);

        const routeParamsFactory = new RouteParamsFactory();
        const pipesContextCreator = new PipesContextCreator(container, config);
        const pipesConsumer = new PipesConsumer();
        const guardsContextCreator = new GuardsContextCreator(container, config);
        const guardsConsumer = new GuardsConsumer();
        const interceptorsContextCreator = new InterceptorsContextCreator(
          container,
          config,
        );
        const interceptorsConsumer = new InterceptorsConsumer();

        this.executionContextCreator = new RouterExecutionContext(
          routeParamsFactory,
          pipesContextCreator,
          pipesConsumer,
          guardsContextCreator,
          guardsConsumer,
          interceptorsContextCreator,
          interceptorsConsumer,
          container.getHttpAdapterRef(),
        );
      }
      
      public explore<T extends HttpServer = any>(
        instanceWrapper: InstanceWrapper,
        moduleKey: string,
        applicationRef: T,
        host: string | RegExp | Array<string | RegExp>,
        routePathMetadata: RoutePathMetadata,
      ) {
        const { instance } = instanceWrapper;  // 拿到Controller实例
        const routerPaths = this.pathsExplorer.scanForPaths(instance);
        // scanForPaths实现逻辑如下:
        // 遍历实例所有的原型属性 获取所有方法类型属性
        // 便利读取方法属性中的原数据 例如@Post('/get')
        // 会写入PATH_METADATA = '/get' METHOD_METADATA = RequestMethod.POST
        // 返回一个RouteDefinition数组
        // export interface RouteDefinition {
        //   path: string[];  Controller下请求方法对象路径
        //   requestMethod: RequestMethod;  路由方法类型 Post、Get...
        //   targetCallback: RouterProxyCallback;  目标回调(原型下的方法属性)
        //   methodName: string;  方法名称
        //   version?: VersionValue;  版本
        // }
        //  这样一来就获取了Controller下所有请求方法描述
        this.applyPathsToRouterProxy(
          applicationRef,
          routerPaths,
          instanceWrapper,
          moduleKey,
          routePathMetadata,
          host,
        );
      }
      
      public applyPathsToRouterProxy<T extends HttpServer>(
        router: T,
        routeDefinitions: RouteDefinition[],
        instanceWrapper: InstanceWrapper,
        moduleKey: string,
        routePathMetadata: RoutePathMetadata,
        host: string | RegExp | Array<string | RegExp>,
      ) {
        (routeDefinitions || []).forEach(routeDefinition => {
          const { version: methodVersion } = routeDefinition;
          routePathMetadata.methodVersion = methodVersion;

          this.applyCallbackToRouter(
            router,
            routeDefinition,
            instanceWrapper,
            moduleKey,
            routePathMetadata,
            host,
          );
        });
      }
      
  private applyCallbackToRouter<T extends HttpServer>(
    router: T,
    routeDefinition: RouteDefinition,
    instanceWrapper: InstanceWrapper,
    moduleKey: string,
    routePathMetadata: RoutePathMetadata,
    host: string | RegExp | Array<string | RegExp>,
  ) {
    const {
      path: paths,
      requestMethod,
      targetCallback,
      methodName,
    } = routeDefinition;

    const { instance } = instanceWrapper;
    // 获取到HttpSever注册请求的方法
    // express.get 或 express.post
    const routerMethodRef = this.routerMethodFactory
      .get(router, requestMethod)
      .bind(router);

    //  ...
    // 生成请求方法闭包代理
    // Nest的中间件、守卫、管道、异常过滤器和拦截器都在此处实现
    const proxy = this.createCallbackProxy(
          instance,
          targetCallback,
          methodName,
          moduleKey,
          requestMethod,
        );
    
    // host过滤满足Controller注册的host路径才会去匹配请求方法路径
    let routeHandler = this.applyHostFilter(host, proxy);

     paths.forEach(path => {
      routePathMetadata.methodPath = path;
      // 拼接Cotroller host path 与 方法请求path 生成出最终注册path
      const pathsToRegister = this.routePathFactory.create(
        routePathMetadata,
        requestMethod,
      );
      pathsToRegister.forEach(path => {
        this.copyMetadataToCallback(targetCallback, routeHandler);
        // 最终实现目的为HttpServer注册请求路径
        // 等同 e.g. express.get('/api/get', routeHandler) 
        routerMethodRef(path, routeHandler);
      });
    });
  }
}

至此controller路由注册完成, 本文的目的已经达成

4. (彩蛋) 一个隐藏巨大能量的模块 DiscoveryModule

在container中可以发现所有controllerRef都被添加到了DiscoverableMetaHostCollection这个静态类中

// packages/core/injector/container.ts
public addController(controller: Type<any>, token: string) {
    if (!this.modules.has(token)) {
      throw new UnknownModuleException();
    }
    // 从HostContainer中获取子模块
    const moduleRef = this.modules.get(token);  // 以AppModule为例token为`AppModule`
    moduleRef.addController(controller);  // 往子模块中添加controller

    const controllerRef = moduleRef.controllers.get(controller);
    // 静态类存储controllerRef
    DiscoverableMetaHostCollection.inspectController(
      this.modules,  // HostContainer
      controllerRef, // InstanceWrapper
    );
  }

DiscoverableMetaHostCollection静态类中存储了所有的controllerRef和providerRef

// packages/core/discovery/discoverable-meta-host-collection.ts
export class DiscoverableMetaHostCollection {
  private static readonly controllersByMetaKey = new WeakMap<
    ModulesContainer,
    Map<string, Set<InstanceWrapper>>
  >();
  
  public static inspectController(
    hostContainerRef: ModulesContainer,
    instanceWrapper: InstanceWrapper,
  ) {
    return this.inspectInstanceWrapper(
      hostContainerRef,
      instanceWrapper,
      this.controllersByMetaKey,
    );
  }
  
  private static inspectInstanceWrapper(
    hostContainerRef: ModulesContainer,
    instanceWrapper: InstanceWrapper,
    wrapperByMetaKeyMap: WeakMap<
      ModulesContainer,
      Map<string, Set<InstanceWrapper>>
    >,
  ) {
    // ...
    let collection: Map<string, Set<InstanceWrapper>>;
    if (wrapperByMetaKeyMap.has(hostContainerRef)) {
      collection = wrapperByMetaKeyMap.get(hostContainerRef);
    }
    // ...
    this.insertByMetaKey(metaKey, instanceWrapper, collection);
  }
  
  public static insertByMetaKey(
    metaKey: string,
    instanceWrapper: InstanceWrapper,
    collection: Map<string, Set<InstanceWrapper>>,
  ) {
    if (collection.has(metaKey)) {
      const wrappers = collection.get(metaKey);
      wrappers.add(instanceWrapper);
    } else {
      const wrappers = new Set<InstanceWrapper>();
      wrappers.add(instanceWrapper);
      collection.set(metaKey, wrappers);
    }
  }
}

DiscoveryService访问静态类DiscoverableMetaHostCollection

@Injectable()
export class DiscoveryService {
  constructor(private readonly modulesContainer: ModulesContainer) {}

  public getProviders(
    options: DiscoveryOptions = {},
    modules: Module[] = this.getModules(options),
  ): InstanceWrapper[] {
    if ('metadataKey' in options) {
      const providers = DiscoverableMetaHostCollection.getProvidersByMetaKey(
        this.modulesContainer,
        options.metadataKey,
      );
      return Array.from(providers);
    }

    const providers = modules.map(item => [...item.providers.values()]);
    return flatten(providers);
  }

  public getControllers(
    options: DiscoveryOptions = {},
    modules: Module[] = this.getModules(options),
  ): InstanceWrapper[] {
    if ('metadataKey' in options) {
      const controllers =
        DiscoverableMetaHostCollection.getControllersByMetaKey(
          this.modulesContainer,
          options.metadataKey,
        );
      return Array.from(controllers);
    }

    const controllers = modules.map(item => [...item.controllers.values()]);
    return flatten(controllers);
  }
 }

因此可以利用DiscoveryService获取全局注册的controllerRef与providerRef,在开发Module库时会有奇效
使用方法:

import { DiscoveryModule } from '@nestjs/core';

@Module({
  // ...
  imports: [DiscoveryModule],
  controllers: [AppController],
  providers: [AppService],
})
export class AppModule {}
import { DiscoveryService } from '@nestjs/core';

@Injectable()
export class AppService {
  constructor(private discoveryService: DiscoveryService) {
      const allControllerRef = this.discoveryService.getControllers();
      const allProviderRef = this.discoveryService.getProviders();
  }
}