SpringBoot+Redis Lua限流最佳实践

3,242 阅读11分钟

这是我参与8月更文挑战的第26天,活动详情查看:8月更文挑战

常见的限流算法

计数器算法

计数器算法采用计数器实现限流有点简单粗暴,一般我们会限制一秒钟的能够通过的请求数,比如限流qps为100,算法的实现思路就是从第一个请求进来开始计时,在接下去的1s内,每来一个请求,就把计数加1,如果累加的数字达到了100,那么后续的请求就会被全部拒绝。等到1s结束后,把计数恢复成0,重新开始计数。具体的实现可以是这样的:对于每次服务调用,可以通过AtomicLong#incrementAndGet()方法来给计数器加1并返回最新值,通过这个最新值和阈值进行比较。这种实现方式,相信大家都知道有一个弊端:如果我在单位时间1s内的前10ms,已经通过了100个请求,那后面的990ms,只能眼巴巴的把请求拒绝,我们把这种现象称为“突刺现象”

漏桶算法

漏桶算法为了消除”突刺现象”,可以采用漏桶算法实现限流,漏桶算法这个名字就很形象,算法内部有一个容器,类似生活用到的漏斗,当请求进来时,相当于水倒入漏斗,然后从下端小口慢慢匀速的流出。不管上面流量多大,下面流出的速度始终保持不变。不管服务调用方多么不稳定,通过漏桶算法进行限流,每10毫秒处理一次请求。因为处理的速度是固定的,请求进来的速度是未知的,可能突然进来很多请求,没来得及处理的请求就先放在桶里,既然是个桶,肯定是有容量上限,如果桶满了,那么新进来的请求就丢弃。

image.png 在算法实现方面,可以准备一个队列,用来保存请求,另外通过一个线程池(ScheduledExecutorService)来定期从队列中获取请求并执行,可以一次性获取多个并发执行。这种算法,在使用过后也存在弊端:无法应对短时间的突发流量。

令牌桶算法

从某种意义上讲,令牌桶算法是对漏桶算法的一种改进,桶算法能够限制请求调用的速率,而令牌桶算法能够在限制调用的平均速率的同时还允许一定程度的突发调用。在令牌桶算法中,存在一个桶,用来存放固定数量的令牌。算法中存在一种机制,以一定的速率往桶中放令牌。每次请求调用需要先获取令牌,只有拿到令牌,才有机会继续执行,否则选择选择等待可用的令牌、或者直接拒绝。放令牌这个动作是持续不断的进行,如果桶中令牌数达到上限,就丢弃令牌,所以就存在这种情况,桶中一直有大量的可用令牌,这时进来的请求就可以直接拿到令牌执行,比如设置qps为100,那么限流器初始化完成一秒后,桶中就已经有100个令牌了,这时服务还没完全启动好,等启动完成对外提供服务时,该限流器可以抵挡瞬时的100个请求。所以,只有桶中没有令牌时,请求才会进行等待,最后相当于以一定的速率执行。

image.png

实现思路: 可以准备一个队列,用来保存令牌,另外通过一个线程池定期生成令牌放到队列中,每来一个请求,就从队列中获取一个令牌,并继续执行。

基于redis-lua实现令牌桶限流算法解读

-- 令牌桶在redis中的key值
local tokens_key = KEYS[1]
-- 该令牌桶上一次刷新的时间对应的key的值
local timestamp_key = KEYS[2]
-- 令牌单位时间填充速率
local rate = tonumber(ARGV[1])
-- 令牌桶容量
local capacity = tonumber(ARGV[2])
-- 当前时间
local now = tonumber(ARGV[3])
-- 请求需要的令牌数
local requested = tonumber(ARGV[4])
-- 令牌桶容量/令牌填充速率=令牌桶填满所需的时间
local fill_time = capacity/rate
-- 令牌过期时间 填充时间*2
local ttl = math.floor(fill_time*2)
-- 获取上一次令牌桶剩余的令牌数
local last_tokens = tonumber(redis.call("get", tokens_key))
-- 如果没有获取到,可能是令牌桶是新的,之前不存在该令牌桶,或者该令牌桶已经好久没有使用
-- 过期了,这里需要对令牌桶进行初始化,初始情况,令牌桶是满的
if last_tokens == nil then
  last_tokens = capacity
end
-- 获取上一次刷新的时间,如果没有,或者已经过期,那么初始化为0
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
-- 计算上一次刷新时间和本次刷新时间的时间差
local delta = math.max(0, now-last_refreshed)
-- delta*rate = 这个时间差可以填充的令牌数,
-- 令牌桶中先存在的令牌数 = 填充令牌数+令牌桶中原有的令牌数
-- 以为令牌桶有容量,所以如果计算的值大于令牌桶容量,那么以令牌容容量为准
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
-- 判断令牌桶中的令牌数是都满本次请求需要的令牌数,如果不满足,说明被限流了
local allowed = filled_tokens >= requested

-- 这里声明了两个变量,一个是新的令牌数,一个是是否被限流,0代表限流,1代表没有线路
local new_tokens = filled_tokens
local allowed_num = 0
-- 如果没有被限流,即,filled_tokens >= requested,
-- 新的令牌数=刚刚计算好的令牌桶中存在的令牌数减掉本次需要使用的令牌数
-- 并设置限流结果为未限流
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end
-- 存储本次操作后,令牌桶中的令牌数以及本次刷新时间
if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end
-- 返回是否被限流标志以及令牌桶剩余令牌数
return { allowed_num, new_tokens }
  1. 这个KEYS[1],KEYS[2]ARGV[1],ARGV[2]... 表示调用该lua脚本时,传入的变量列表。
  2. KEYS[i] 表示调用lua脚本传过来的变量KEYS[i]作为一个key,从redis中获取具体的值
  3. ARGV[i] 表示调用lua脚本时传过来的变量ARGV[i]

举个例子吧, 我们在调用脚本的时候,传入了两组参数,一组是KEYS,一组是ARGV,这两组参数假设是
KEYS : [demo1,demo2]
ARGV: [3,3,11,1]
那么,
KEYS[1]等于redis.get(demo1)
ARGV[1]等于3

SpringBoot调用RedisLua

引入依赖

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>

资源目录新建scripts文件夹,将lua脚本放进去

image.png

local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end

-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }

构造redis-script对象

springboot将每个lua脚本抽象为一个RedisScript对象,该类提供了两个方法,一个是设置lua脚本的io流,还有一个是直接将lua脚本以字符串的形式设置,这里用io流的形式。
该对象的泛型是lua脚本的返回值,我们的脚本返回的是两个long类型,所以使用List来接收。

    @Bean(name = "rateLimitRedisScript")
    public RedisScript<List<Long>> rateLimitRedisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript<>();
//        redisScript.setScriptText();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/redis_rate_limit.lua")));
        redisScript.setResultType(List.class);
        return redisScript;
    }

设置redis序列化规则

@Bean
    @ConditionalOnMissingBean(StringRedisTemplate.class)
    public StringRedisTemplate stringRedisTemplate(RedisConnectionFactory redisConnectionFactory) throws UnknownHostException {
        StringRedisTemplate template = new StringRedisTemplate();
        template.setConnectionFactory(redisConnectionFactory);
        return template;
    }

    @Bean
    @ConditionalOnMissingBean(RedisTemplate.class)
    public RedisTemplate<String, Object> redisTemplate(
            RedisConnectionFactory redisConnectionFactory)
            throws UnknownHostException {

        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<Object>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);

        RedisTemplate<String, Object> template = new RedisTemplate<String, Object>();
        template.setConnectionFactory(redisConnectionFactory);
        template.setKeySerializer(jackson2JsonRedisSerializer);
        template.setValueSerializer(jackson2JsonRedisSerializer);
        template.setHashKeySerializer(jackson2JsonRedisSerializer);
        template.setHashValueSerializer(jackson2JsonRedisSerializer);
        template.afterPropertiesSet();
        return template;
    }

调用lua脚本

    @Resource
    private RedisScript<List<Long>> rateLimitRedisScript;

    @Resource
    private StringRedisTemplate stringRedisTemplate;

    @GetMapping
    public List<Long> userToken() {
        // 设置lua脚本的ARGV的值
        List<String> scriptArgs = Arrays.asList(
                1 + "",
                3 + "",
                (Instant.now().toEpochMilli()) + "",
                "1");
        // 设置lua脚本的KEYS值
        List<String> keys = getKeys("test");
        return stringRedisTemplate.execute(rateLimitRedisScript,keys, scriptArgs.toArray());
    }

    private List<String> getKeys(String id) {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        String prefix = "request_rate_limiter.{" + id;

        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

调用成功后会返回两个数,一个是是否成功标志,0代表限流,1代表未限流,还有一个是令牌桶中剩余的令牌数

[
  1,
  2
]

注意

  1. 这里在拼接key的时候,对id使用了大括号{}进行了包裹,这是因为lua脚本执行成功的前提条件是所用使用到的redis健值必须在一个hash槽中,使用大括号对key进行包裹后,redis在对key进行hash时,指挥hash大括号内部的字符,这样就可以保证lua脚本中的使用的key-value在同一个槽内。这样就确保了cluster模式下正常执行redis-lua脚本,但是需要注意的是,这里大括号内包裹的内容不能是不变的,如果是不变的话,会有大量的key-value被分配到同一个槽里,导致hash倾斜,key-value分布不均匀。

  2. 这里使用的不是RedisTemplate,而是使用的StringRedisTemplate执行lua脚本的,使用RedisTemplate执行lua脚本的时候,会报错。

AOP+RedisLua对接口进行限流

image.png

每次请求,获取令牌桶中的令牌,如果令牌获取成功,代表没有被限流,可以正常访问,如果获取失败代表被限流,访问失败,这时会抛出一个RateLimitException结束。

最终效果

  1. 我打算结合springboot的手动装配,制作一个限流的工具,最终可以被封装成一个jar包,其他项目需要,直接引入就可以,不用重复开发。

  2. 具体的用法是这样的

    1. 在配置类上标注@EnableRedisRateLimit注解,激活限流工具

    2. 在需要限流的接口上标注@RateLimit注解,并根据具体的场景设置限流规则

 @RateLimit(replenishRate = 3,burstCapacity = 300)
 @GetMapping("test-limit")
 public Result<Void> testLimit(){
     return Result.buildSuccess();
 }

核心代码介绍

  1. @RateLimit 为了方便拓展,使得使用不同的场景,这里通过实现KeyResolver接口来指定具体的限流维度

  2. 这里说一下limitProperties的作用,我们可以默认使用注解中的参数指定配置信息,但是为了方便拓展,这里提供了limitProperties,如果指定了limitProperties,那么会以limitProperties的配置为准。

  3. 上篇文章介绍的限流lua脚本只能针对秒为时间单位进行限流,我这里对它的lua脚本做了一个小小的改变,使得可以支持秒,分钟,小时,天 为时间单位的限流。

  4. 限流注解

@Documented
@Inherited
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {

    /**
     * 限流维度,默认使用uri进行限流
     *
     * @return uri
     */
    Class<? extends KeyResolver> keyResolver() default UriKeyResolver.class;

    /**
     * 限流配置,如果实现了该接口,默认以这个为准
     *
     * @return limitProp
     */
    Class<? extends LimitProperties> limitProperties() default DefaultLimitProperties.class;

    /**
     * 令牌桶每秒填充平均速率
     *
     * @return replenishRate
     */
    int replenishRate() default 1;

    /**
     * 令牌桶总容量
     *
     * @return burstCapacity
     */
    int burstCapacity() default 3;

    /**
     * 限流时间维度,默认为秒
     * 支持秒,分钟,小时,天
     * 即,
     * {@link TimeUnit#SECONDS},
     * {@link TimeUnit#MINUTES},
     * {@link TimeUnit#HOURS},
     * {@link TimeUnit#DAYS}
     *
     * @return TimeUnit
     * @since 1.0.2
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;
}
  1. 限流配置limitProperties
public interface LimitProperties {
    /**
     * 令牌桶每秒填充平均速率
     *
     * @return replenishRate
     */
    int replenishRate();

    /**
     * 令牌桶总容量
     *
     * @return burstCapacity
     */
    int burstCapacity();

    /**
     * 限流时间维度,默认为秒
     * 支持秒,分钟,小时,天
     * 即,
     * {@link TimeUnit#SECONDS},
     * {@link TimeUnit#MINUTES},
     * {@link TimeUnit#HOURS},
     * {@link TimeUnit#DAYS}
     *
     * @return TimeUnit
     * @since 1.0.2
     */
    TimeUnit timeUnit();
}
  1. 限流lua脚本
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local time_unit = tonumber(ARGV[5])
-- 填满令牌桶所需要的时间
local fill_time = capacity/rate
local ttl = math.floor((fill_time*time_unit)*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
    last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
    last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
    new_tokens = filled_tokens - requested
    allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)

return { allowed_num, new_tokens }
  1. 核心AOP
@Slf4j
@Aspect
public class RateLimitInterceptor implements ApplicationContextAware {

    @Resource
    private RedisTemplate<String, Object> stringRedisTemplate;

    @Resource
    private RedisScript<List<Long>> rateLimitRedisScript;

    private ApplicationContext applicationContext;

    @Around("execution(public * *(..)) && @annotation(org.ywb.aoplimiter.anns.RateLimit)")
    public Object interceptor(ProceedingJoinPoint pjp) throws Throwable {
        MethodSignature signature = (MethodSignature) pjp.getSignature();
        Method method = signature.getMethod();
        RateLimit rateLimit = method.getAnnotation(RateLimit.class);
        // 断言不会被限流
        assertNonLimit(rateLimit, pjp);
        return pjp.proceed();
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    public void assertNonLimit(RateLimit rateLimit, ProceedingJoinPoint pjp) {
        Class<? extends KeyResolver> keyResolverClazz = rateLimit.keyResolver();
        KeyResolver keyResolver = applicationContext.getBean(keyResolverClazz);
        String resolve = keyResolver.resolve(HttpContentHelper.getCurrentRequest(), pjp);
        List<String> keys = getKeys(resolve);

        LimitProperties limitProperties = getLimitProperties(rateLimit);

        // 根据限流时间维度计算时间
        long timeLong = getCurrentTimeLong(limitProperties.timeUnit());

        // The arguments to the LUA script. time() returns unixtime in seconds.
        List<String> scriptArgs = Arrays.asList(limitProperties.replenishRate() + "",
                limitProperties.burstCapacity() + "", (Instant.now().toEpochMilli() / timeLong) + "", "1", timeLong + "");
        // 第一个参数是是否被限流,第二个参数是剩余令牌数
        List<Long> rateLimitResponse = this.stringRedisTemplate.execute(this.rateLimitRedisScript, keys, scriptArgs.toArray());
        Assert.notNull(rateLimitResponse, "redis execute redis lua limit failed.");
        Long isAllowed = rateLimitResponse.get(0);
        Long newTokens = rateLimitResponse.get(1);
        log.info("rate limit key [{}] result: isAllowed [{}] new tokens [{}].", resolve, isAllowed, newTokens);
        if (isAllowed <= 0) {
            throw new RateLimitException(resolve);
        }
    }

    private LimitProperties getLimitProperties(RateLimit rateLimit) {
        Class<? extends LimitProperties> aClass = rateLimit.limitProperties();
        if (aClass == DefaultLimitProperties.class) {
            // 选取注解中的配置
            return new DefaultLimitProperties(rateLimit.replenishRate(), rateLimit.burstCapacity(), rateLimit.timeUnit());
        }
        // 优先使用用户自己的配置类
        return applicationContext.getBean(aClass);
    }

    private long getCurrentTimeLong(TimeUnit timeUnit) {
        switch (timeUnit) {
            case SECONDS:
                return 1;
            case MINUTES:
                return 60;
            case HOURS:
                return 60 * 60;
            case DAYS:
                return 60 * 60 * 24;
            default:
                throw new IllegalArgumentException("timeUnit:" + timeUnit + " not support");
        }
    }

    private List<String> getKeys(String id) {
        // use `{}` around keys to use Redis Key hash tags
        // this allows for using redis cluster

        // Make a unique key per user.
        String prefix = "request_rate_limiter.{" + id;

        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }
}

核心代码就是这么多,完整的源码我已上传至github,

github.com/xiao-ren-wu…

如果感觉对您有帮助的话,请帮忙点个star,谢谢啦~