Redisson限流器源码分析

221 阅读7分钟

Redisson限流器

限流器

基于Redis的分布式限流器RateLimiter可以用来在分布式环境下限制请求方的调用频率。既适用于不同Redisson实例下的多线程限流,也适用于相同Redisson实例下的多线程限流。

使用

  • 速率
  • 时间间隔
  • 类型:RateType.OVERALL所有实例共享、RateType.CLIENT单实例端共享

初始化

一、指定限流器名称:org.redisson.Redisson#getRateLimiter

    public RRateLimiter getRateLimiter(String name) {
        return new RedissonRateLimiter(commandExecutor, name);
    }

org.redisson.RedissonObject#RedissonObject(org.redisson.client.codec.Codec, org.redisson.command.CommandAsyncExecutor, java.lang.String)

    public RedissonObject(Codec codec, CommandAsyncExecutor commandExecutor, String name) {
        this.codec = codec;
        this.commandExecutor = commandExecutor;
        if (name == null) {
            throw new NullPointerException("name can't be null");
        }
        // 设置名称
        setName(name);
    }

如果没有自定义NameMapper,那么输入的name即为限流器名称。

二、初始化参数

org.redisson.api.RRateLimiter#trySetRate

boolean trySetRate(RateType mode, long rate, long rateInterval, RateIntervalUnit rateIntervalUnit)
  • RateType: RateType.OVERALL所有限流器实例共用,RateType.PER_CLIENT单实例端共享
  • rate:指定时间窗口生成令牌数
  • rateInterval:指定时间窗口
  • rateIntervalUnit:指定时间窗口单位

org.redisson.RedissonRateLimiter#trySetRateAsync

    public RFuture<Boolean> trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
        return commandExecutor.evalWriteNoRetryAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
                "redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);"
              + "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);"
              + "return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);",
                Collections.singletonList(getRawName()), rate, unit.toMillis(rateInterval), type.ordinal());
    }
​

限流器是一个hash结构,下面三行就是设置速率和模式:

redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);
redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);
return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);

设置过期时间

org.redisson.RedissonExpirable#expire(java.time.Duration)

  public boolean expire(Duration duration) {
        return get(expireAsync(duration));
    }

org.redisson.RedissonExpirable#expireAsync(java.time.Duration)

public RFuture<Boolean> expireAsync(Duration duration) {
    return expireAsync(duration.toMillis(), TimeUnit.MILLISECONDS, "", getRawName());
}

org.redisson.RedissonExpirable#expireAsync(long, java.util.concurrent.TimeUnit, java.lang.String, java.lang.String...)

  protected RFuture<Boolean> expireAsync(long timeToLive, TimeUnit timeUnit, String param, String... keys) {
        return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
                  "local result = 0;"
                + "for j = 1, #KEYS, 1 do "
                    + "local expireSet; "
                    + "if ARGV[2] ~= '' then "
                        + "expireSet = redis.call('pexpire', KEYS[j], ARGV[1], ARGV[2]); "
                    + "else "
                        + "expireSet = redis.call('pexpire', KEYS[j], ARGV[1]); "
                    + "end; "
                    + "if expireSet == 1 then "
                        + "result = expireSet;"
                    + "end; "
                + "end; "
                + "return result; ", Arrays.asList(keys), timeUnit.toMillis(timeToLive), param);
    }



		local result = 0;
                for j = 1, #KEYS, 1 do 
                   local expireSet; 
                    if ARGV[2] ~= '' then 
                       expireSet = redis.call('pexpire', KEYS[j], ARGV[1], ARGV[2]); 
                    else 
                        expireSet = redis.call('pexpire', KEYS[j], ARGV[1]); 
                    end; 
                    if expireSet == 1 then 
                        result = expireSet;
                    end; 
                end; 
        return result;

调用pexpire设置过期时间,如果成功则返回1

获取令牌

tryAcquire方式:

org.redisson.RedissonRateLimiter#tryAcquire(),这里默认取一个令牌

 public boolean tryAcquire() {
        return tryAcquire(1);
    }

org.redisson.RedissonRateLimiter#tryAcquireAsync(org.redisson.client.protocol.RedisCommand<T>, java.lang.Long)

    private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
        byte[] random = new byte[8];
        ThreadLocalRandom.current().nextBytes(random);

        return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, command,
                "local rate = redis.call('hget', KEYS[1], 'rate');"
              + "local interval = redis.call('hget', KEYS[1], 'interval');"
              + "local type = redis.call('hget', KEYS[1], 'type');"
              + "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
              
              + "local valueName = KEYS[2];"
              + "local permitsName = KEYS[4];"
              + "if type == '1' then "
                  + "valueName = KEYS[3];"
                  + "permitsName = KEYS[5];"
              + "end;"

              + "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); "

              + "local currentValue = redis.call('get', valueName); "
              + "local res;"
              + "if currentValue ~= false then "
                     + "local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
                     + "local released = 0; "
                     + "for i, v in ipairs(expiredValues) do "
                          + "local random, permits = struct.unpack('Bc0I', v);"
                          + "released = released + permits;"
                     + "end; "

                     + "if released > 0 then "
                          + "redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
                          + "if tonumber(currentValue) + released > tonumber(rate) then "
                               + "currentValue = tonumber(rate) - redis.call('zcard', permitsName); "
                          + "else "
                               + "currentValue = tonumber(currentValue) + released; "
                          + "end; "
                          + "redis.call('set', valueName, currentValue);"
                     + "end;"

                     + "if tonumber(currentValue) < tonumber(ARGV[1]) then "
                         + "local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); "
                         + "res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));"
                     + "else "
                         + "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
                         + "redis.call('decrby', valueName, ARGV[1]); "
                         + "res = nil; "
                     + "end; "
              + "else "
                     + "redis.call('set', valueName, rate); "
                     + "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
                     + "redis.call('decrby', valueName, ARGV[1]); "
                     + "res = nil; "
              + "end;"

              + "local ttl = redis.call('pttl', KEYS[1]); "
              + "if ttl > 0 then "
                  + "redis.call('pexpire', valueName, ttl); "
                  + "redis.call('pexpire', permitsName, ttl); "
              + "end; "
              + "return res;",
                Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()),
                value, System.currentTimeMillis(), random);
    }

第一步:判断限流器是否已经初始化了

local rate = redis.call('hget', KEYS[1], 'rate');
              local interval = redis.call('hget', KEYS[1], 'interval');
              local type = redis.call('hget', KEYS[1], 'type');
              assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')

第二步,通过判断是否为集群模式获取对应的valueName和permitName

local valueName = KEYS[2];
local permitsName = KEYS[4];
if type == '1' then 
   valueName = KEYS[3];
   permitsName = KEYS[5];
end;

第三步:判断rate是否比请求的令牌数大

assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate');

第四步:初始化value和permit值

local currentValue = redis.call('get', valueName);
if currentValue ~= false then
    // ...
    end
else 
   // 设置令牌总数
   redis.call('set', valueName, rate); 
   // 记录获取令牌的当前时间和请求令牌数
   redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3],ARGV[1])); 
   // 减去获取的令牌数
   redis.call('decrby', valueName, ARGV[1]); 
   res = nil; 
end;

struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3],ARGV[1])

数据打包,其中Bc0I B代表的数据标识是无符号字符,C0表示动态长度字符序列,I标识有符号long类型。这里使用了三种数据类型标识,那么相应的即是后面三个参数,最后结果为一个格式化的字符串,member会包含请求的令牌数。形如:

\x08`\xb5Q\x9bi=\x10\x03\x01\x00\x00\x00,其中x01表示本次获取的令牌数,如果修改为一次获取5个,那么结果形如:\x08\xab\xfb\xbeO|\x02\xef~\x05\x00\x00\x00,这个值将在后面释放令牌数时会用到。

第五步:再次获取令牌时,进入if currentValue ~= false逻辑

 if currentValue ~= false then 
    // 取出已经过期的数据(第二次请求时间戳 - 令牌生产的时间)
    local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); 
   	// 需要释放的令牌数
    local released = 0; 
    for i, v in ipairs(expiredValues) do 
        // 遍历table,并进行数据解包,random, permits是前面打包时的原始数据
        local random, permits = struct.unpack('Bc0I', v);
        released = released + permits;
     end; 

    if released > 0 then 
        // 将上一个周期的数据移除掉
       redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); 
       if tonumber(currentValue) + released > tonumber(rate) then 
          // 如果令牌桶中令牌数加上需要释放的令牌数大于了限流器周期能生产的令牌数,目的是为了currentValue不能大于rate?目的是为了修正rate被修改后遗留的问题?           
          currentValue = tonumber(rate) - redis.call( 'zcard', permitsName); 
       else 
          currentValue = tonumber(currentValue) + released; 
       end; 
       redis.call('set', valueName, currentValue);
    end;

	// 当前桶中令牌数 < 请求的令牌数
    if tonumber(currentValue) < tonumber(ARGV[1]) then 
       local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); 
      // +3可能是为了避免短时间内生产的令牌超过最大限制数;返回还有多长时间才能生产出足够的令牌
       res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));
    else 
       redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); 
       redis.call('decrby', valueName, ARGV[1]); 
       res = nil; 
     end; 

分成几块来看:

1、释放过期令牌

    // 取出已经过期的数据(第二次请求时间戳 - 令牌生产的时间)
    local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval);

例如 interval 为86400000,当前时间戳ARGV[2]为:1654687732000,permits中记录的时间戳为1654675087027。

第二次请求时间戳 - 令牌生产的时间=1654601332000,使用zrangebyscore找不到该范围的数据(0到1654601332000之间的),那么expiredValues是一个空列表,说明没有过期数据。如果有,则需要删除这部分数据。

2、计算还有多久生产出足够的令牌

3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));
// tonumber(ARGV[2]) - tonumber(firstValue[2]) 距离第一次请求已经过了多久

例如 interval 为86400000,当前时间戳ARGV[2]为:1654687732000,第一条数据时间戳为1654675087027,那么得出结果是

3+86400000-12644973=73755030 毫秒,这个计算有可能是负数。

带超时时间的tryAcquire

org.redisson.RedissonRateLimiter#tryAcquire(long, java.util.concurrent.TimeUnit),根据调用链会走到这里org.redisson.RedissonRateLimiter#tryAcquireAsync(long, long)

    private CompletableFuture<Boolean> tryAcquireAsync(long permits, long timeoutInMillis) {
        long s = System.currentTimeMillis();
        RFuture<Long> future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
        return future.thenCompose(delay -> {
            // delay--》tryAcquireAsync返回的时间
            if (delay == null) {
                // 说明获取令牌成功了
                return CompletableFuture.completedFuture(true);
            }
            // 如果是显示设置超时时间,那么timeoutInMillis为-1
            if (timeoutInMillis == -1) {
                CompletableFuture<Boolean> f = new CompletableFuture<>();
                // //延迟delay时间后重新获取令牌
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    CompletableFuture<Boolean> r = tryAcquireAsync(permits, timeoutInMillis);
                    // 成功后将结果保存到f
                    commandExecutor.transfer(r, f);
                }, delay, TimeUnit.MILLISECONDS);
                return f;
            }

            // 上一次获取令牌经过的时长
            long el = System.currentTimeMillis() - s;
            // 剩余能获取令牌的时长
            long remains = timeoutInMillis - el;
            if (remains <= 0) {
                return CompletableFuture.completedFuture(false);
            }

            CompletableFuture<Boolean> f = new CompletableFuture<>();
            if (remains < delay) {
                // 剩余能获取令牌的时间小于令牌生产的时间
                // 假设timeoutInMillis为3秒,经过了1秒,还需要5秒生产令牌,那么2秒后通知失败
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    f.complete(false);
                }, remains, TimeUnit.MILLISECONDS);
            } else {
                long start = System.currentTimeMillis();
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    // 再次检查时长
                    long elapsed = System.currentTimeMillis() - start;
                    if (remains <= elapsed) {
                        f.complete(false);
                        return;
                    }
                    // 再次获取令牌
                    CompletableFuture<Boolean> r = tryAcquireAsync(permits, remains - elapsed);
                    commandExecutor.transfer(r, f);
                }, delay, TimeUnit.MILLISECONDS);
            }
            return f;
        }).toCompletableFuture();
    }

其他

  • value表示有多少个可以获取的令牌,permit表示当前时间窗口内发起了多少请求
  • 判断是否可以请求,是根据score的范围来的,数据该范围数据不为空,则表示已经过了这个令牌的生产时间,可以进行请求

参考

**欢迎关注:吴编程