gin的middleware机制以及源码剖析

1,577 阅读2分钟

一、gin的middleware使用

func main() {
	router := gin.New()
	router.Use(mid1)
	router.Use(mid2)
	router.GET("/", func(c *gin.Context) {
		fmt.Println("--- main ---")
		c.String(http.StatusOK, "Hello Middleware")
	})
	router.Run(":8080")
}
func mid1(c *gin.Context) {
	fmt.Println("hello 1")
	c.Next()
	fmt.Println("bye 1")
}
func mid2(c *gin.Context) {
	fmt.Println("hello 2")
	c.Next()
	fmt.Println("bye 2")
}

打开浏览器访问 http://localhost:8080/

控制台输出:
hello 1
hello 2
--- main ---
bye 2
bye 1

二、Next()和Abort()

  • Next() 暂停当前中间件剩余的代码,直接运行后面的中间件以及相应的处理函数, 当相应的handle函数处理完毕,再从后往前依次执行中间件
  • Abort() 剩余的中间件以及handle不会执行,但是当前中间件中剩余的代码会继续运行,如果剩余的代码没有必要执行,直接在Abort()之后return即可

三、源码剖析

Context结构体:把中间件和handle都合并到了HandlersChain

type Context struct {
	...
	handlers HandlersChain
	index    int8
}
type HandlersChain []HandlerFunc
type HandlerFunc func(*Context)

Next() : 注意这里是c.index 而不是每次for循环都重新定义的变量, 也就是说先去调用剩余的处理函数,然后再执行该中间件中处于c.Next()下面的代码

func (c *Context) Next() {
	c.index++
	for c.index < int8(len(c.handlers)) {
		c.handlers[c.index](c)
		c.index++
	}
}

Abort():

const abortIndex int8 = math.MaxInt8 / 2
func (c *Context) Abort() {
	c.index = abortIndex
}

为什么设置成abortIndex? 而不是更大的数? 因为在合并中间件和handle到c.HandlersChain中的时候就设置了最大数量是abortIndex, 也就是说len(c.handlers)肯定是小于abortIndex的, 这样就达到了终止for循环的目的。

func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
	finalSize := len(group.Handlers) + len(handlers)
	if finalSize >= int(abortIndex) {
		panic("too many handlers")
	}
}

四、简单模仿一下

import ...

const abortIndex int8 = math.MaxInt8 / 2

type HandlerFunc func(*Context)

type Context struct {
	index    int8
	handlers []HandlerFunc
}

func (c *Context) execHandlers() {
	for c.index < int8(len(c.handlers)) {
		c.handlers[c.index](c)
		c.index++
	}
}

func (c *Context) Next() {
	c.index++
	for c.index < int8(len(c.handlers)) {
		c.handlers[c.index](c)
		c.index++
	}
}

func (c *Context) Abort() {
	c.index = abortIndex
}

func (c *Context) Use(handler HandlerFunc) {
	c.handlers = append(c.handlers, handler)
}

func main() {
	var c Context
	c.Use(middle1)
	c.Use(middle2)
	c.Use(func(context *Context) {
		fmt.Println("Hello World")
	})
	c.execHandlers()
}

func middle1(c *Context) {

	fmt.Println("hello")
	c.Next()
	//c.Abort()
	fmt.Println("bye bye")
}

func middle2(c *Context) {
	fmt.Println("hello2")
	c.Next()
	fmt.Println("bye bye2")
}

控制台输出:

hello
hello2
Hello World
bye bye2
bye bye