官方文档中间件 docs.nestjs.com/middleware
使用
中间件可以使用函数实现,也可以使用类实现,不同的是,当使用类实现的时候,需要 实现 NestMiddleware接口,使用函数实现的时候,不支持依赖注入。
当使用中间件的时候,必须在 Module 中使用,并且这个模块需要实现 NestModule 接口,需要提供 configure函数
定义一个中间件 test.middleware.ts
import { Injectable, NestMiddleware } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
@Injectable()
export class TestMiddleware implements NestMiddleware {
use(req: Request, res: Response, next: NextFunction) {
console.log('TestMiddleware');
next();
}
}
然后在 AppModule 中使用
import { TestMiddleware } from './test.middleware.ts'
export class AppModule implements NestModule{
configure(consumer: MiddlewareConsumer) {
consumer.apply(TestMiddleware).forRoutes('middleware')
}
}
当我们访问 middleware 的时候,就可以打印 TestMiddleware, 注意 forRoutes 还可以传入一个对象,标志路由和方法
forRoutes({ path: 'config', method: RequestMethod.GET })
当我们使用中间件的时候,希望排除某些路由,则可以通过 exclude()来进行实现
consumer
.apply(TestMiddleware)
.exclude(
{ path: 'middleware', method: RequestMethod.POST },
)
.forRoutes(middleware);
实现
基本实现
我们先实现 NestMiddleware 接口
import { Request, Response, NextFunction } from 'express';
export interface NestMiddleware {
use(req: Request, res: Response, next: NextFunction): void
}
然后实现 NestModule 和 MiddlewareConsumer
import { Request, Response, NextFunction } from 'express';
export interface NestMiddleware {
use(req: Request, res: Response, next: NextFunction): void
}
export enum RequestMethod {
GET = 'GET',
POST = 'POST',
PUT = 'PUT',
DELETE = 'DELETE',
PATCH = 'PATCH',
ALL = 'ALL',
OPTIONS = 'OPTIONS',
HEAD = 'HEAD',
}
export interface MiddlewareConsumer{
apply(...middleware: (Function | any)[]):this,
forRoutes (...routes: (string | { path: string; method: RequestMethod } | Function)[]): this
exclude(...routes: (string | { path: string; method: RequestMethod } | Function)[]): this;
}
export interface NestModule {
configure(consumer: MiddlewareConsumer): void
}
修改 nest-application.ts
我们在调用 NestFactory.create(AppModule) 的时候,实际上就是在 new NestApplication(module) 然后将 AppModule 传入,所以我们初始化 initMiddlewares 的时候,其实consumer 就是 AppModule
然后实现他的 apply 方法, forRoutes 方法,以及 exclude 方法
private readonly middlewares = []
constructor(module:any){
this.module = module
this.postMiddleware()
// 初始化 middlewares
this.initMiddlewares();
}
private initMiddlewares() {
this.module.prototype.configure?.(this);
}
apply(...middleware: (Function | any)[]): this {
this.middlewares.push(...middleware);
return this;
}
forRoutes(...routes: any[]): this {
for (const route of routes) {
for (const middleware of this.middlewares) {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
const middlewareInstance = new middleware();
middlewareInstance.use(req, res, next);
} else {
next();
}
});
}
}
return this;
}
private normalizeRouteInfo(route) {
let routePath = '';
let routeMethod = RequestMethod.ALL;
if (typeof route === 'string') {
routePath = route;
} else if ('path' in route) {
routePath = route.path;
routeMethod = route.method ?? RequestMethod.ALL;
}
routePath = path.posix.join('/', routePath);
return { routePath, routeMethod };
}
这个时候,我们访问 http://localhost:3000/middleware 发现控制台打印了 TestMiddleware
依赖注入
Nest 中间件完全支持依赖注入。与提供程序和控制器一样,它们能够注入同一模块中可用的依赖项。与往常一样,这是通过 完成的constructor。 我们注入一个 LoggerService
@Injectable()
export class LoggerService {
log(message) {
console.log('log', message)
}
}
import { Injectable, NestMiddleware } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
import { LoggerService } from './logger.service';
@Injectable()
export class TestMiddleware implements NestMiddleware {
constructor(
private readonly loggerService: LoggerService
) {
}
use(req: Request, res: Response, next: NextFunction) {
console.log(this.loggerService.log('log - TestMiddleware'))
next();
}
}
当依赖注入的时候,就需要解析 中间件的 dependencies
修改 forRoutes 方法,让执行中间件的时候,可以将以来注入进去
forRoutes(...routes: any[]): this {
for (const route of routes) {
for (const middleware of this.middlewares) {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
// 定义模块 middleware 的元数据
defineModule(this.module, [middleware])
// 解析依赖
const dependencies = this.resolveDependencies(middleware);
const middlewareInstance = new middleware(...dependencies);
middlewareInstance.use(req, res, next);
} else {
next();
}
});
}
}
return this;
}
排除路由
然后我们实现 exculde, 将要排除的路由记录下来, 然后还是修改我们的 forRoutes 方法, 当是路由是在 excludedRoutes 中的时候,直接调用 next() 方法
// middleware 排除的路由的
private readonly excludedRoutes = []
exclude(...routes: any[]): this {
this.excludedRoutes.push(...routes);
return this;
}
private isExcluded(reqPath: string, method: RequestMethod): boolean {
return this.excludedRoutes.some(route => {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
return routePath === reqPath && (routeMethod === RequestMethod.ALL || routeMethod === method);
});
}
forRoutes(...routes: any[]): this {
for (const route of routes) {
for (const middleware of this.middlewares) {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
// 排除路由
if (this.isExcluded(req.originalUrl, req.method)) {
return next();
}
if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
defineModule(this.module, middleware)
const dependencies = this.resolveDependencies(middleware);
const middlewareInstance = new middleware(...dependencies);
middlewareInstance.use(req, res, next);
} else {
next();
}
});
}
}
return this;
}
函数中间件
我们来创建一个 单纯的函数中间件,
export function TestFunctionMiddleware(req: Request, res: Response, next: NextFunction) {
console.log('function middleware')
next()
}
函数中间件不能进行依赖注入, 当判断是函数的时候,直接执行中间件,通过判断有咩有 use 方法 来进行区分。修改我们的 forRoutes
forRoutes(...routes: any[]): this {
for (const route of routes) {
for (const middleware of this.middlewares) {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
if (this.isExcluded(req.originalUrl, req.method)) {
return next();
}
if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
if (typeof middleware === 'function' && middleware.prototype && 'use' in middleware.prototype) {
defineModule(this.module, middleware)
const dependencies = this.resolveDependencies(middleware);
const middlewareInstance = new middleware(...dependencies);
middlewareInstance.use(req, res, next);
} else if (typeof middleware === 'function') {
middleware(req, res, next);
} else {
next();
}
} else {
next()
}
});
}
}
return this;
}
export class AppModule implements NestModule{
configure(consumer: MiddlewareConsumer) {
consumer
.apply(TestMiddleware)
.apply(TestFunctionMiddleware)
.forRoutes('middleware')
// .exclude({ path: '/middleware', method: RequestMethod.GET })
}
}
当我们访问 /middleware 的时候,我们会发现打印
function middleware
log logTestMiddleware
全局中间件
全局中间件,访问任何路由都可以生效 可以直接使用 app.use 来进行使用,注意,全局中间件无法传入参数,如果想要依赖注入,可以使用 类方法的形式实现,路由可以设置为 *
// 导入 NestFactory 模块,用于创建 Nest 应用的实例
import { NestFactory } from '@nestjs/core';
// 导入根模块
import { AppModule } from './app.module.ts';
import session from 'express-session'
import { TestFunctionMiddleware } from './test.middleware.ts';
// 定义一个异步函数,用来创建 Nest 实例并启动应用
async function bootstrap() {
const app = await NestFactory.create(AppModule)
app.use(session({
secret: 'secret-key', // 加密会话的秘钥
resave: false, // 每次请求结束后是否强制保存会话,即使它没有改变
saveUninitialized:false, // 是否保存未初始化的会话
cookie: {
maxAge: 1000 * 60 * 60 * 24
} // 定义会话cookie 配置,设置 cookie 最大存活时间的一天
}))
app.use(TestFunctionMiddleware)
// 监听3000端口启动http服务器
await app.listen(3000)
}
bootstrap()
完结 撒花
nest-application.ts 所有代码
import { defineModule, RequestMethod } from '@nestjs/common';
import express, { Express, Request as ExpressRequest, Response as ExpressResponse, NextFunction } from 'express';
import path from 'path'
export class NestApplication {
private readonly app:Express = express()
// 在此处保存全部的 providers. key 是 providers 的 token, value 是provider 实例或者本身
private readonly providersInstance = new Map()
// 记录全局可用的提供者token 集合
private readonly globalProviders = new Set()
// 记录某个模块里有哪些 providers token
private readonly moduleProviders = new Map()
// 记录所有的 middleware
private readonly middlewares = []
// middleware 排除的路由的
private readonly excludedRoutes = []
constructor(module:any){
this.module = module
this.postMiddleware()
// 为了兼容异步动态模块,要在 listen 中调用
// this.initProviders()
// 初始化 middlewares
this.initMiddlewares();
}
private initMiddlewares() {
this.module.prototype.configure?.(this);
}
private isExcluded(reqPath: string, method: RequestMethod): boolean {
return this.excludedRoutes.some(route => {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
return routePath === reqPath && (routeMethod === RequestMethod.ALL || routeMethod === method);
});
}
apply(...middleware: (Function | any)[]): this {
this.middlewares.push(...middleware);
return this;
}
forRoutes(...routes: any[]): this {
for (const route of routes) {
for (const middleware of this.middlewares) {
const { routePath, routeMethod } = this.normalizeRouteInfo(route);
this.app.use(routePath, (req: ExpressRequest, res: ExpressResponse, next: NextFunction) => {
if (this.isExcluded(req.originalUrl, req.method)) {
return next();
}
if ((routeMethod === RequestMethod.ALL || routeMethod === req.method)) {
if (typeof middleware === 'function' && middleware.prototype && 'use' in middleware.prototype) {
defineModule(this.module, middleware)
const dependencies = this.resolveDependencies(middleware);
const middlewareInstance = new middleware(...dependencies);
middlewareInstance.use(req, res, next);
} else if (typeof middleware === 'function') {
middleware(req, res, next);
} else {
next();
}
} else {
next()
}
});
}
}
return this;
}
exclude(...routes: any[]): this {
this.excludedRoutes.push(...routes);
return this;
}
private normalizeRouteInfo(route) {
let routePath = '';
let routeMethod = RequestMethod.ALL;
if (typeof route === 'string') {
routePath = route;
} else if ('path' in route) {
routePath = route.path;
routeMethod = route.method ?? RequestMethod.ALL;
}
routePath = path.posix.join('/', routePath);
return { routePath, routeMethod };
}
private registerProvidersFromModule(module, ...parentModules) {
// 判断是不是全局模块
const global = Reflect.getMetadata('globalModule', module)
// 拿到导入模块的提供者数据
const importedProviders = Reflect.getMetadata('providers', module) ?? []
// 有的导入模块只导出了一部分,没有全部导出,需要 exports 进行过滤
const exports = Reflect.getMetadata('exports', module) ?? []
// 遍历 exports
for (const exportToken of exports) {
// exports 中可能还有模块
if (this.isModule(exportToken)) {
// 递归
this.registerProvidersFromModule(exportToken, module, ...parentModules)
} else {
// 不是模块的话
const provider = importedProviders.find(provider => provider === exportToken || provider.provide === exportToken);
if (provider) {
[module, ...parentModules].forEach((itemModule) => {
this.addProvider(provider, itemModule, global);
})
}
}
}
}
async initProviders() {
// 获取模块导入的元数据
const imports = Reflect.getMetadata('imports', this.module) ?? [];
// 遍历所有导入的模块
for (const importModule of imports) {
// 兼容异步模块
let importedModule = importModule;
if (importModule instanceof Promise) {
importedModule = await importModule
}
// 如果导入的模块有 module 属性,说明是动态模块
if ('module' in importedModule) {
const { module, providers, exports } = importedModule;
defineModule(this.module,providers);
Reflect.defineMetadata('providers', [...(module.providers || []), ...providers], module);
Reflect.defineMetadata('exports', [...(module.exports || []), ...exports], module);
this.registerProvidersFromModule(module, this.module);
} else {
this.registerProvidersFromModule(importedModule, this.module)
}
// 拿到自己的
const providers = Reflect.getMetadata('providers', this.module) ?? [];
for (const provider of providers) {
this.addProvider(provider, this.module)
}
}
}
private isModule(exportToken) {
return exportToken && exportToken instanceof Function && Reflect.getMetadata("isModule", exportToken)
}
addProvider(provider, module, global = false) {
const providers = global ? this.globalProviders : (this.moduleProviders.get(module) || new Set());
if (!global && !this.moduleProviders.has(module)) {
this.moduleProviders.set(module, providers);
}
// 如果实例池里已经有了,则直接进行写入
const injectToken = provider.provide ?? provider;
//如果实例池里已经有此token对应的实例了
if (this.providersInstance.has(injectToken)) {
//则直接把此token放入到providers这个集合直接返回
if(!providers.has(injectToken)){
providers.add(injectToken);
return
}
}
if(provider.provide && provider.useClass) {
const Clazz = provider.useClass;
const dependencies = this.resolveDependencies(Clazz)
const classInstance = new Clazz(...dependencies)
this.providersInstance.set(provider.provide, classInstance)
providers.add(provider.provide || provider)
}else if (provider.provide && provider.useValue) {
// 提供的是实例化后的类
this.providersInstance.set(provider.provide, provider.useValue)
providers.add(provider.provide)
} else if (provider.provide && provider.useFactory) {
if (this.providersInstance.has(provider.provide)) return;
const inject = provider.inject ?? []
const injectedValues = inject.map(item => {
// return this.providersInstance.get(item) ?? item
return this.getProviderByToken(item, module)
})
this.providersInstance.set(provider.provide, provider.useFactory(...injectedValues))
providers.add(provider.provide)
}
else {
// 只提供了一个类,值是类的实例
const dependencies = this.resolveDependencies(provider)
const classInstance = new provider(...dependencies);
this.providersInstance.set(provider, classInstance)
providers.add(provider)
}
}
// post参数 的中间件
private readonly postMiddleware = () => {
// 用来把json格式的请求体对象放在req.body中
this.app.use(express.json());
// 把form表单格式的请求体对象放在req.body上
this.app.use(express.urlencoded({
extended: true
}))
// 自定义装饰器示例
this.app.use((req,res, next) => {
req.user = {name: 'admin', role: 'admin'}
next()
})
}
use(middleware) {
this.app.use(middleware)
}
private readonly module:any
private resolveParams(instance:any, methodName:string, req:ExpressRequest, res:ExpressResponse, next:NextFunction) {
// 获取参数的元数据
const paramsMetaData = Reflect.getMetadata(`params`, instance, methodName) ?? []
return paramsMetaData.sort((a,b) =>a.parameterIndex - b.parameterIndex ).map((paramsMetaData) => {
const { key,data, factory } = paramsMetaData;
const ctx = {
swithToHttp() {
return {
getRequest: () => req,
getResponse: () => res,
getNext: () => next
}
}
}
switch (key) {
case "Request":
case "Req":
return req;
case "Response":
case "Res":
return res;
case "Query":
return data?req.query[data]:req.query;
case "Session":
return req.session;
case "Headers":
return data ? req.headers[data]: req.headers;
case "Ip":
return req.ip;
case "Param":
return data ? req.params[data]: req.params;
case "Body":
return data ? req.body[data]: req.body;
case "DecoratorFactory":
return factory(data, ctx)
default:
return null;
}
})
}
private getProviderByToken=(injectedToken, module)=>{
if (this.moduleProviders.get(module)?.has(injectedToken) || this.globalProviders.has(injectedToken)) {
return this.providersInstance.get(injectedToken);
} else {
return null;
}
}
private resolveDependencies(Controller) {
// 取得通过 @inject 方式注入的token
const injectedTokens = Reflect.getMetadata('injectTokens', Controller)
// 取得构造函数的参数类型 方式注入的token
const constructorParams = Reflect.getMetadata('design:paramtypes', Controller) ?? []
return constructorParams.map((param, index) => {
//找到控制器属于哪个 module
const module = Reflect.getMetadata('nestModule', Controller)
// 把每个param 中token 默认换成对应的provider的值
return this.getProviderByToken(injectedTokens?.[index] ?? param, module);
})
}
async init() {
// 取出模块 里所有的控制器
const controllers = Reflect.getMetadata('controllers', this.module) || [];
for (const Controller of controllers) {
// 解析出控制器的依赖
const dependencies = this.resolveDependencies(Controller);
const controller = new Controller(...dependencies)
// 获取控制器前缀
const prefix = Reflect.getMetadata('prefix', Controller) || '/';
// 解析路由
const controllerPrototype = Controller.prototype
// 遍历类的原型上的方法名
for (const methodName of Object.getOwnPropertyNames(controllerPrototype)) {
// 获取原型上的方法 getHello
const method = controllerPrototype[methodName]
// 获取 getHello 上的Get装饰器的元数据
const httpMethod = Reflect.getMetadata('method', method);
const pathMetadata = Reflect.getMetadata('path', method);
if (!httpMethod) continue;
const routePath = path.posix.join('/', prefix, pathMetadata)
// 当请求对应的路由的时候,由对应的函数进行处理
this.app[httpMethod.toLowerCase()](routePath, (req:ExpressRequest, res: ExpressResponse, next:NextFunction) => {
const args = this.resolveParams(controller, methodName, req, res, next)
// 执行路由处理函数获取返回值
const result = method.call(controller, ...args)
// 判断 controller的 methodName 方法有没有使用 response ,使用了则不发响应
const responseMetadata = this.getResponseMetaData(controller, methodName)
if (!responseMetadata || (responseMetadata?.data?.passthrough)) {
// 把返回值序列化返回客户端
res.send(result)
}
})
}
}
}
private getResponseMetaData(controller, methodName) {
const paramsMetaData = Reflect.getMetadata(`params`, controller, methodName) ?? []
return paramsMetaData.find(param => param.key === 'Response' || param.key === 'Res')
}
async listen(port) {
await this.initProviders()
await this.init()
this.app.listen(port, () => {
})
}
}