Golang + Gin + tollbooth 请求限速

126 阅读2分钟

Golang + Gin + tollbooth 请求限速

环境

Golang1.19
github.com/gin-gonic/gin v1.9.1
github.com/didip/tollbooth v4.0.2+incompatible

代码

  1. 创建限速器
// 创建过期配置
expiraOption := &limiter.ExpirableOptions{
	// 默认过期时间设置为 1 秒
	DefaultExpirationTTL: 1 * time.Second,
}
// 创建限速器
// 设置每个TTL内的最大访问数为 100
nLimiter := tollbooth.NewLimiter(100, expiraOption)
// 设置限速器请求超过配置时的提示信息
nLimiter.SetMessage("calm down...")
  1. 创建 Gin 接口处理器
func limiterHandler(lmt *limiter.Limiter) gin.HandlerFunc {
	return func(c *gin.Context) {
		httpError := tollbooth.LimitByRequest(lmt, c.Writer, c.Request)
		if httpError != nil {
			c.Data(httpError.StatusCode, lmt.GetMessageContentType(), []byte(httpError.Message))
			c.Abort()
		} else {
			c.Next()
		}
	}
}
  1. 配置 Gin 接口
func ping(r *gin.Engine) {
	r.GET("/ping", limiterHandler(nLimiter), func(c *gin.Context) {
		c.JSON(http.StatusOK, "pong")
	})
}

源码分析

  1. Gin 接收请求后,将首先经过 limiterHandler 处理。

  2. 当其通过 tollbooth.LimitByRequest() 处理时,将出发请求限速配置。

    // 根据限速配置和 request 对象创建一个密钥片
    sliceKeys := BuildKeys(lmt, r)
    
    // 遍历密钥片并检查其中一个是否有错误
    for _, keys := range sliceKeys {
    	httpError := LimitByKeys(lmt, keys)
    	if httpError != nil {
    		return httpError
    	}
    }
    
    ...
    
    // LimitByKeys跟踪由管道分隔的键发出的请求数量。超过限制时返回HTTPError。
    func LimitByKeys(lmt *limiter.Limiter, keys []string) *errors.HTTPError {
    	if lmt.LimitReached(strings.Join(keys, "|")) {
    		return &errors.HTTPError{Message: lmt.GetMessage(), StatusCode: lmt.GetStatusCode()}
    	}
    
    	return nil
    }
    
    ...
    
    // lmt.LimitReached() 方法内容
    // LimitReached返回一个bool,指示由密钥标识的Bucket是否用完了令牌。
    func (l *Limiter) LimitReached(key string) bool {
    	ttl := l.GetTokenBucketExpirationTTL()
    
    	if ttl <= 0 {
    		ttl = l.generalExpirableOptions.DefaultExpirationTTL
    	}
    
    	return l.limitReachedWithTokenBucketTTL(key, ttl)
    }
    
    ...
    
    // l.limitReachedWithTokenBucketTTL() 方法内容
    func (l *Limiter) limitReachedWithTokenBucketTTL(key string, tokenBucketTTL time.Duration) bool {
    	lmtMax := l.GetMax()
    	lmtBurst := l.GetBurst()
    	l.Lock()
    	defer l.Unlock()
    	
    	// 检查指定key是否存在,若不存在则生成一个key用于记录本次请求时间
    	if _, found := l.tokenBuckets.Get(key); !found {
    		l.tokenBuckets.Set(
    			key,
    			rate.NewLimiter(rate.Limit(lmtMax), lmtBurst),
    			tokenBucketTTL,
    		)
    	}
    
    	// 判断当前key是否过期
    	expiringMap, found := l.tokenBuckets.Get(key)
    	if !found {
    		return false
    	}
    
    	// key有效时判断是否允许本次请求
    	return !expiringMap.(*rate.Limiter).Allow()
    }
    
    ...
    
    // 检查事件是否现在可能发生
    func (lim *Limiter) Allow() bool {
    	return lim.AllowN(time.Now(), 1)
    }
    
    // 检查时间t是否可能发生n个事件,即判断指定时间内是否可发生指定个请求
    func (lim *Limiter) AllowN(t time.Time, n int) bool {
    	return lim.reserveN(t, n, 0).ok
    }
    
    // maxFutureReserve指定允许的最大预订等待持续时间
    func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration) Reservation {
    	lim.mu.Lock()
    	defer lim.mu.Unlock()
    	
    	// 系统最大值与0的判定
    	if lim.limit == Inf {
    		return Reservation{
    			ok:        true,
    			lim:       lim,
    			tokens:    n,
    			timeToAct: t,
    		}
    	} else if lim.limit == 0 {
    		var ok bool
    		if lim.burst >= n {
    			ok = true
    			lim.burst -= n
    		}
    		return Reservation{
    			ok:        ok,
    			lim:       lim,
    			tokens:    lim.burst,
    			timeToAct: t,
    		}
    	}
    
    	// 计算并返回由于时间的推移而产生的lim的更新状态
    	t, tokens := lim.advance(t)
    
    	// Calculate the remaining number of tokens resulting from the request.
    	tokens -= float64(n)
    
    	// Calculate the wait duration
    	var waitDuration time.Duration
    	if tokens < 0 {
    		waitDuration = lim.limit.durationFromTokens(-tokens)
    	}
    
    	// Decide result
    	ok := n <= lim.burst && waitDuration <= maxFutureReserve
    
    	// Prepare reservation
    	r := Reservation{
    		ok:    ok,
    		lim:   lim,
    		limit: lim.limit,
    	}
    	if ok {
    		r.tokens = n
    		r.timeToAct = t.Add(waitDuration)
    
    		// Update state
    		lim.last = t
    		lim.tokens = tokens
    		lim.lastEvent = r.timeToAct
    	}
    
    	return r
    }