NestJS从使用到源码实现 - 中间件

145 阅读8分钟

官方文档中间件 docs.nestjs.com/middleware

使用

中间件可以使用函数实现,也可以使用类实现,不同的是,当使用类实现的时候,需要 实现 NestMiddleware接口,使用函数实现的时候,不支持依赖注入。

当使用中间件的时候,必须在 Module 中使用,并且这个模块需要实现 NestModule 接口,需要提供 configure函数

定义一个中间件 test.middleware.ts

import { Injectable, NestMiddleware } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
@Injectable()
  export class TestMiddleware implements NestMiddleware {
    use(req: Request, res: Response, next: NextFunction) {
      console.log('TestMiddleware');
      next();
    }
  }

然后在 AppModule 中使用

import { TestMiddleware } from './test.middleware.ts'

export class AppModule implements NestModule{
  configure(consumer: MiddlewareConsumer) {
      consumer.apply(TestMiddleware).forRoutes('middleware')
  }
}

当我们访问 middleware 的时候,就可以打印 TestMiddleware, 注意 forRoutes 还可以传入一个对象,标志路由和方法

forRoutes({ path: 'config', method: RequestMethod.GET })

当我们使用中间件的时候,希望排除某些路由,则可以通过 exclude()来进行实现

consumer
  .apply(TestMiddleware)
  .exclude(
    { path: 'middleware', method: RequestMethod.POST },
  )
  .forRoutes(middleware);

实现

基本实现

我们先实现 NestMiddleware 接口

import { Request, Response, NextFunction } from 'express';
export interface NestMiddleware {
 use(req: Request, res: Response, next: NextFunction): void
}

然后实现 NestModuleMiddlewareConsumer

import { Request, Response, NextFunction } from 'express';
export interface NestMiddleware {
 use(req: Request, res: Response, next: NextFunction): void
}

export enum RequestMethod {
  GET = 'GET',
  POST = 'POST',
  PUT = 'PUT',
  DELETE = 'DELETE',
  PATCH = 'PATCH',
  ALL = 'ALL',
  OPTIONS = 'OPTIONS',
  HEAD = 'HEAD',
}

export interface MiddlewareConsumer{
  apply(...middleware: (Function | any)[]):this,
  forRoutes (...routes: (string | { path: string; method: RequestMethod } | Function)[]): this
  exclude(...routes: (string | { path: string; method: RequestMethod } | Function)[]): this;
}

export interface NestModule {
  configure(consumer: MiddlewareConsumer): void
}

修改 nest-application.ts

我们在调用 NestFactory.create(AppModule) 的时候,实际上就是在 new NestApplication(module) 然后将 AppModule 传入,所以我们初始化 initMiddlewares 的时候,其实consumer 就是 AppModule

然后实现他的 apply 方法, forRoutes 方法,以及 exclude 方法

private readonly middlewares = []

constructor(module:any){
  this.module = module
  this.postMiddleware()
  // 初始化 middlewares 
  this.initMiddlewares();
}

private initMiddlewares() {
  this.module.prototype.configure?.(this);
}

apply(...middleware: (Function | any)[]): this {
      this.middlewares.push(...middleware);
        return this;
  }

forRoutes(...routes: any[]): this {
  for (const route of routes) {
      for (const middleware of this.middlewares) {
          const { routePath, routeMethod } = this.normalizeRouteInfo(route);
          this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
              if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
                  const middlewareInstance = new middleware();
                  middlewareInstance.use(req, res, next);
              } else {
                  next();
              }
          });
      }
  }
  return this;
}

private normalizeRouteInfo(route) {
    let routePath = '';
    let routeMethod = RequestMethod.ALL;
    if (typeof route === 'string') {
        routePath = route;
    } else if ('path' in route) {
        routePath = route.path;
        routeMethod = route.method ?? RequestMethod.ALL;
    }
    routePath = path.posix.join('/', routePath);
    return { routePath, routeMethod };
}

这个时候,我们访问 http://localhost:3000/middleware 发现控制台打印了 TestMiddleware

依赖注入

Nest 中间件完全支持依赖注入。与提供程序和控制器一样,它们能够注入同一模块中可用的依赖项。与往常一样,这是通过 完成的constructor。 我们注入一个 LoggerService

@Injectable()
export class LoggerService {

  log(message) {
    console.log('log', message)
  }
}
import { Injectable, NestMiddleware } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
import { LoggerService } from './logger.service';


@Injectable()
export class TestMiddleware implements NestMiddleware {

  constructor(
    private readonly loggerService: LoggerService
  ) {

  }

  use(req: Request, res: Response, next: NextFunction) {
    console.log(this.loggerService.log('log - TestMiddleware'))
    next();
  }
}

当依赖注入的时候,就需要解析 中间件的 dependencies

修改 forRoutes 方法,让执行中间件的时候,可以将以来注入进去

  forRoutes(...routes: any[]): this {
    for (const route of routes) {
        for (const middleware of this.middlewares) {
            const { routePath, routeMethod } = this.normalizeRouteInfo(route);
            this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
                if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
                    // 定义模块 middleware 的元数据
                    defineModule(this.module, [middleware])
                    // 解析依赖
                    const dependencies = this.resolveDependencies(middleware);
                    const middlewareInstance = new middleware(...dependencies);
                    middlewareInstance.use(req, res, next);
                } else {
                    next();
                }
            });
        }
    }
    return this;
  }

排除路由

然后我们实现 exculde, 将要排除的路由记录下来, 然后还是修改我们的 forRoutes 方法, 当是路由是在 excludedRoutes 中的时候,直接调用 next() 方法

  // middleware 排除的路由的 
  private readonly excludedRoutes = []

exclude(...routes: any[]): this {
     this.excludedRoutes.push(...routes);
     return this;
 }

 private isExcluded(reqPath: string, method: RequestMethod): boolean {
       return this.excludedRoutes.some(route => {
           const { routePath, routeMethod } = this.normalizeRouteInfo(route);
           return routePath === reqPath && (routeMethod === RequestMethod.ALL || routeMethod === method);
       });
  }


  forRoutes(...routes: any[]): this {
    for (const route of routes) {
        for (const middleware of this.middlewares) {
            const { routePath, routeMethod } = this.normalizeRouteInfo(route);
            this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
              // 排除路由
              if (this.isExcluded(req.originalUrl, req.method)) {
                              return next();
                    }
                if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
                    defineModule(this.module, middleware)
                    const dependencies = this.resolveDependencies(middleware);
                    const middlewareInstance = new middleware(...dependencies);
                    middlewareInstance.use(req, res, next);
                } else {
                    next();
                }
            });
        }
    }
    return this;
  }

函数中间件

我们来创建一个 单纯的函数中间件,

export function TestFunctionMiddleware(req: Request, res: Response, next: NextFunction) {
  console.log('function middleware')
  next()
}

函数中间件不能进行依赖注入, 当判断是函数的时候,直接执行中间件,通过判断有咩有 use 方法 来进行区分。修改我们的 forRoutes

 forRoutes(...routes: any[]): this {
    for (const route of routes) {
        for (const middleware of this.middlewares) {
            const { routePath, routeMethod } = this.normalizeRouteInfo(route);
            this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
              if (this.isExcluded(req.originalUrl, req.method)) {
                              return next();
                    }
                if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
                  if (typeof middleware === 'function' && middleware.prototype && 'use' in middleware.prototype) {
                    defineModule(this.module, middleware)
                    const dependencies = this.resolveDependencies(middleware);
                    const middlewareInstance = new middleware(...dependencies);
                    middlewareInstance.use(req, res, next);
                } else if (typeof middleware === 'function') {
                  middleware(req, res, next);
                } else {
                      next();
                  }
              }  else {
                next()
              }
            });
        }
    }
    return this;
  }
export class AppModule implements NestModule{
  configure(consumer: MiddlewareConsumer) {
      consumer
      .apply(TestMiddleware)
      .apply(TestFunctionMiddleware)
      .forRoutes('middleware')
      // .exclude({ path: '/middleware', method: RequestMethod.GET })
  }
}

当我们访问 /middleware 的时候,我们会发现打印

function middleware

log logTestMiddleware

全局中间件

全局中间件,访问任何路由都可以生效 可以直接使用 app.use 来进行使用,注意,全局中间件无法传入参数,如果想要依赖注入,可以使用 类方法的形式实现,路由可以设置为 *

// 导入 NestFactory 模块,用于创建 Nest 应用的实例
import { NestFactory } from '@nestjs/core';

// 导入根模块
import { AppModule } from './app.module.ts';
import session from 'express-session'
import { TestFunctionMiddleware } from './test.middleware.ts';

// 定义一个异步函数,用来创建 Nest 实例并启动应用
async function bootstrap() {

  const app = await NestFactory.create(AppModule)
  
  app.use(session({
    secret: 'secret-key', // 加密会话的秘钥
    resave: false, // 每次请求结束后是否强制保存会话,即使它没有改变
    saveUninitialized:false, // 是否保存未初始化的会话
    cookie: {
      maxAge: 1000 * 60 * 60 * 24
    } // 定义会话cookie 配置,设置 cookie 最大存活时间的一天
  }))

  app.use(TestFunctionMiddleware)
  // 监听3000端口启动http服务器
  await app.listen(3000)


}

bootstrap()

完结 撒花

nest-application.ts 所有代码


import { defineModule, RequestMethod } from '@nestjs/common';
import express, { Express, Request as ExpressRequest, Response as ExpressResponse, NextFunction } from 'express';
import path from 'path'

export class NestApplication {

  private readonly app:Express = express()

  // 在此处保存全部的 providers. key 是 providers 的 token, value 是provider 实例或者本身
  private readonly providersInstance = new Map()

  // 记录全局可用的提供者token 集合
  private readonly globalProviders = new Set()

  // 记录某个模块里有哪些 providers token
  private readonly moduleProviders = new Map()


  // 记录所有的 middleware
  private readonly middlewares = []

  // middleware 排除的路由的 
  private readonly excludedRoutes = []


  constructor(module:any){
    this.module = module
    this.postMiddleware()

    //  为了兼容异步动态模块,要在 listen 中调用
    // this.initProviders()

    // 初始化 middlewares 
    this.initMiddlewares();
  }

  private initMiddlewares() {
    this.module.prototype.configure?.(this);
  }

  private isExcluded(reqPath: string, method: RequestMethod): boolean {
    return this.excludedRoutes.some(route => {
        const { routePath, routeMethod } = this.normalizeRouteInfo(route);
        return routePath === reqPath && (routeMethod === RequestMethod.ALL || routeMethod === method);
    });
}

  apply(...middleware: (Function | any)[]): this {
      this.middlewares.push(...middleware);
        return this;
  }

  forRoutes(...routes: any[]): this {
    for (const route of routes) {
        for (const middleware of this.middlewares) {
            const { routePath, routeMethod } = this.normalizeRouteInfo(route);
            this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
              if (this.isExcluded(req.originalUrl, req.method)) {
                              return next();
                    }
                if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
                  if (typeof middleware === 'function' && middleware.prototype && 'use' in middleware.prototype) {
                    defineModule(this.module, middleware)
                    const dependencies = this.resolveDependencies(middleware);
                    const middlewareInstance = new middleware(...dependencies);
                    middlewareInstance.use(req, res, next);
                } else if (typeof middleware === 'function') {
                  middleware(req, res, next);
                } else {
                      next();
                  }
              }  else {
                next()
              }
            });
        }
    }
    return this;
  }

  exclude(...routes: any[]): this {
         this.excludedRoutes.push(...routes);
         return this;
     }

  private normalizeRouteInfo(route) {
      let routePath = '';
      let routeMethod = RequestMethod.ALL;
      if (typeof route === 'string') {
          routePath = route;
      } else if ('path' in route) {
          routePath = route.path;
          routeMethod = route.method ?? RequestMethod.ALL;
      }
      routePath = path.posix.join('/', routePath);
      return { routePath, routeMethod };
  }
 

  private registerProvidersFromModule(module, ...parentModules) {

    // 判断是不是全局模块
      const  global = Reflect.getMetadata('globalModule', module)


     // 拿到导入模块的提供者数据
     const importedProviders = Reflect.getMetadata('providers', module) ?? []
     // 有的导入模块只导出了一部分,没有全部导出,需要 exports 进行过滤
     const exports = Reflect.getMetadata('exports', module) ?? []
     // 遍历 exports 
     for (const exportToken of exports) {
      // exports 中可能还有模块
      if (this.isModule(exportToken)) {
        // 递归
        this.registerProvidersFromModule(exportToken, module, ...parentModules)
      } else {
        // 不是模块的话
        const provider = importedProviders.find(provider => provider === exportToken || provider.provide === exportToken);
              if (provider) {
                [module, ...parentModules].forEach((itemModule) => {
                  this.addProvider(provider, itemModule, global);
                })
            }
      }
     }
  }

  async initProviders() {
    // 获取模块导入的元数据 
    const imports = Reflect.getMetadata('imports', this.module) ?? [];

    

    // 遍历所有导入的模块
    for (const importModule of imports) {

    // 兼容异步模块
    let importedModule = importModule;
    if (importModule instanceof Promise) {
      importedModule = await importModule
    }


      // 如果导入的模块有 module 属性,说明是动态模块
      if ('module' in importedModule) {
          const { module, providers, exports } = importedModule;
          defineModule(this.module,providers);
          Reflect.defineMetadata('providers', [...(module.providers || []), ...providers], module);
          Reflect.defineMetadata('exports', [...(module.exports || []), ...exports], module);
          this.registerProvidersFromModule(module, this.module);
      } else {
            this.registerProvidersFromModule(importedModule, this.module)

      }


      // 拿到自己的
      const providers = Reflect.getMetadata('providers', this.module) ?? [];
      for (const provider of providers) {
        this.addProvider(provider, this.module)
      }

    }

  }

  private isModule(exportToken) {
    return exportToken && exportToken instanceof Function && Reflect.getMetadata("isModule", exportToken)
  }

  addProvider(provider, module, global = false) {
    const providers = global ? this.globalProviders :  (this.moduleProviders.get(module) || new Set());

    if (!global && !this.moduleProviders.has(module)) {
              this.moduleProviders.set(module, providers);
            }
   
 
    // 如果实例池里已经有了,则直接进行写入
     const injectToken = provider.provide ?? provider;
     //如果实例池里已经有此token对应的实例了
     if (this.providersInstance.has(injectToken)) {
         //则直接把此token放入到providers这个集合直接返回
         if(!providers.has(injectToken)){
             providers.add(injectToken);
             return
         }
     }

    if(provider.provide && provider.useClass) {
      const Clazz = provider.useClass;
      const dependencies = this.resolveDependencies(Clazz)
      const classInstance = new Clazz(...dependencies)
      this.providersInstance.set(provider.provide, classInstance)
      providers.add(provider.provide || provider)
    }else if (provider.provide && provider.useValue) {

      // 提供的是实例化后的类
      this.providersInstance.set(provider.provide, provider.useValue)
      providers.add(provider.provide)

    } else if (provider.provide && provider.useFactory) {
      if (this.providersInstance.has(provider.provide)) return;

      const inject = provider.inject ?? []
      const injectedValues = inject.map(item => {
      // return this.providersInstance.get(item) ?? item
      return this.getProviderByToken(item, module)
      })
      this.providersInstance.set(provider.provide, provider.useFactory(...injectedValues))
      providers.add(provider.provide)
    } 
    else {
      // 只提供了一个类,值是类的实例
      const dependencies = this.resolveDependencies(provider)
      const classInstance = new provider(...dependencies);
      this.providersInstance.set(provider, classInstance)
      providers.add(provider)

    }
  }

  // post参数 的中间件
  private readonly postMiddleware = () => {
    // 用来把json格式的请求体对象放在req.body中
    this.app.use(express.json());
    // 把form表单格式的请求体对象放在req.body上
    this.app.use(express.urlencoded({
      extended: true
    }))

    // 自定义装饰器示例
    this.app.use((req,res, next) => {
      req.user = {name: 'admin', role: 'admin'}
      next()
    })

  }


  use(middleware) {
    this.app.use(middleware)
  }

  private readonly module:any

  private resolveParams(instance:any, methodName:string, req:ExpressRequest, res:ExpressResponse, next:NextFunction) {
    // 获取参数的元数据
    const paramsMetaData = Reflect.getMetadata(`params`, instance, methodName) ?? []

    return paramsMetaData.sort((a,b) =>a.parameterIndex - b.parameterIndex ).map((paramsMetaData) => {
      const { key,data, factory } = paramsMetaData;

      const ctx = {
        swithToHttp() {
          return {
            getRequest: () => req,
            getResponse: () => res,
            getNext: () => next
          }
        }
      }

      switch (key) {
        case "Request":
        case "Req":
          return req;
        case "Response":
          case "Res":
            return res;
        case "Query":
          return data?req.query[data]:req.query;
        case "Session":
          return req.session;
        case "Headers":
          return data ? req.headers[data]: req.headers;
        case "Ip":
          return req.ip;
        case "Param":
          return data ? req.params[data]: req.params;
        case "Body":
          return data ? req.body[data]: req.body;

        case "DecoratorFactory":
            return factory(data, ctx)
        default:
            return null;
      }
    })

  }

 
  private getProviderByToken=(injectedToken, module)=>{


    if (this.moduleProviders.get(module)?.has(injectedToken) || this.globalProviders.has(injectedToken)) {
      return this.providersInstance.get(injectedToken);
  } else {
      return null;
  }
}

  private resolveDependencies(Controller) {
    // 取得通过 @inject 方式注入的token
   const injectedTokens =  Reflect.getMetadata('injectTokens', Controller)
   // 取得构造函数的参数类型 方式注入的token
   const constructorParams = Reflect.getMetadata('design:paramtypes', Controller) ?? []
   return constructorParams.map((param, index) => {
    //找到控制器属于哪个 module 
    const module = Reflect.getMetadata('nestModule', Controller)
    // 把每个param 中token 默认换成对应的provider的值
    return this.getProviderByToken(injectedTokens?.[index] ?? param, module);
   }) 
  }


  async init() {
    // 取出模块 里所有的控制器
    const controllers = Reflect.getMetadata('controllers', this.module) || [];

    for (const Controller of controllers) {

      // 解析出控制器的依赖
      const dependencies = this.resolveDependencies(Controller);


      const controller = new Controller(...dependencies)
      // 获取控制器前缀
      const prefix = Reflect.getMetadata('prefix', Controller) || '/';
      // 解析路由
      const controllerPrototype = Controller.prototype
      // 遍历类的原型上的方法名
      for (const methodName of Object.getOwnPropertyNames(controllerPrototype)) {
        // 获取原型上的方法 getHello
        const method = controllerPrototype[methodName]
        // 获取 getHello 上的Get装饰器的元数据
        const httpMethod = Reflect.getMetadata('method', method);
        const pathMetadata = Reflect.getMetadata('path', method);
        if (!httpMethod) continue;
        const routePath = path.posix.join('/', prefix, pathMetadata)
        // 当请求对应的路由的时候,由对应的函数进行处理
        this.app[httpMethod.toLowerCase()](routePath, (req:ExpressRequest, res: ExpressResponse, next:NextFunction) => {

          const args = this.resolveParams(controller, methodName, req, res, next)

          // 执行路由处理函数获取返回值
          const result = method.call(controller, ...args)

          // 判断 controller的 methodName 方法有没有使用 response ,使用了则不发响应
          const responseMetadata = this.getResponseMetaData(controller, methodName)
          if (!responseMetadata || (responseMetadata?.data?.passthrough)) {
              // 把返回值序列化返回客户端
              res.send(result)
          }
        })

      }

    }

  }

  private getResponseMetaData(controller, methodName) {

    const paramsMetaData = Reflect.getMetadata(`params`, controller, methodName) ?? []

    return paramsMetaData.find(param => param.key === 'Response' || param.key === 'Res')

  }

  async listen(port) {
    await this.initProviders()
    await this.init()
    this.app.listen(port, () => {

    })
  }

}