【Redis】如何使用 Redis 实现一个通用的限频器

2,485 阅读18分钟

前言

在项目中我们经常遇到需要进行限频的地方,比如短信验证码每分钟只能发送一次,每天只能发送10次。我们可以使用 Redis 来实现。

名词解释

固定时间点刷新

固定时间点刷新是,我们设置比如每天晚上24点或者早上6点刷新次数,这样到达时间点,用户的可请求次数就会刷新到最大上限。当然也不一定非得是一天刷新一次,我们使用 cron 表达式可以设置23点48分5秒时刷新,每小时刷新一次,每5分钟刷新一次等复杂要求。下图是每天可以请求6次,每天早上6点刷新的示意图。

固定延迟刷新

固定延迟刷新是,在用户第一次请求开始那个时间点开始计算,延迟设定的时间后一次性刷新次数到最大上限。如下图每5分钟可以请求2次,19.57第一次请求,因此在20.02时候刷新可请求次数为上限2,20.05再次请求,因此20.10刷新可请求次数为上限2。

范围刷新

范围刷新与固定延迟刷新类似,但可请求次数并非一次性刷新,而是逐渐刷新。如下图刷新时间5分钟,最大可请求次数为2。如图19.58用户发送请求,因此会在20.03可请求次数增加1次,20.00发送请求,在20.05可请求次数增加1次,20.04发送请求,在20.09可请求次数增加1。 若可请求次数为1,那么固定延迟刷新和范围刷新效果相同。

Redis 实现限频器原理

参数

  • key 需要限频键
  • frequency 频率
  • expireAt 过期时间戳
  • currentTime 当前时间戳

固定时间点刷新 Redis 实现原理

使用 string 结构。

比如每天早上6点刷新,我们可以用下一个早上6点的时间戳作为 expireAt。首先判断 Redis 该 key 的值(使用 GET 命令)是不是大于等于 frequency(key 不存在当成0),若是则拒绝,否则使用 INCR 增加 key 的值,并设置 key 在 expireAt 过期(使用 EXPIREAT 命令)。

固定延迟刷新 Redis 实现原理

使用 string 结构。

比如每5分钟可请求2次,我们可以让 expireAt = 当前时间 + 5分钟。首先判断 Redis 该 key 的值(使用 GET 命令)是不是大于等于 frequency(key 不存在当成0),若是则拒绝,否则使用 INCR 增加 key 的值,如果第一步的 key 不存在,则设置 key 在 expireAt 过期(使用 EXPIREAT 命令)。

这里与固定时间点刷新只有最后一步才不同,对于固定时间点刷新来说,判不判断 key 存在都无所谓,因为每次的 expireAt 都是相同的,除非到达下一个时间段。但是固定延迟刷新则必须不存在才设置 key 的过期时间,因为每次的 expireAt 都是当前时间 + 延迟时间。

范围刷新限频 Redis 实现原理

使用 list 结构。

我们在 list 里面存放 expireAt,然后把 list 的长度作为已经被使用的请求次数。这时候我们必须让 list 里面过期的 expireAt 被删除,因此每次的操作流程如下。

我们首先获取 list 的第一个元素(使用 LINDEX 命令),如果该元素的值小于 currentTime,则删除它(使用 LPOP 命令)。再判断 list 的长度(使用 LLEN 命令)是否大于等于 frequency,若是则拒绝,否则使用 RPUSH 把 expireAt 加入队列右边,并设置 key 在 expireAt 过期(使用 EXPIREAT 命令)。

可能会有的疑问

如果 list 有多个值小于 currentTime,为什么不一次性删除?

因为我们一次最多 RPUSH 一次,因此只需要一个空位即可。如果一次性删除 list 所有过期元素,在最坏情况下,list 里会有 frequency - 1 个元素过期,这样会导致这次请求突然时间复杂度从变成 O(frequency),而如果只删一个的话,每次请求的时间复杂度都是 O(1)。

该文章主要类的结构图

FrequencyLimiter 是核心接口,DefaultFrequencyLimiter是具体的实现类。FrequencyLimitAspect 使用 FrequencyLimiter 的实现类去实现注解使用限频。FixedPointRefreshFrequencyLimitFixedDelayRefreshFrequencyLimitRangeRefreshFrequencyLimit注解分别标识三种刷新类型的限频器。 结构图

限频器的 Redis Lua 脚本实现

为什么使用 Redis Lua 实现

如果使用 Java 代码配合 RedisTemplate 或者 Jedis 的方式实现,在并发下可能会使得用户获取到比 frequency 多的请求次数,且需要多次的 Redis 网络请求,因此使用 Redis Lua 实现保证原子性同时避免多次网络请求。

名词解释

我们这里把请求被允许时增加的值称为 token,其实也就是 string 里面的那个自增的整数值或者是 list 的长度。

固定时间点刷新和固定延迟刷新 Lua 脚本

这两个刷新只是 expireAt 的计算方式不同而已,因此我们使用一个脚本实现。

--[[
KEYS[1] 需要限频的 key
ARGV[1] 频率
ARGV[2] 过期时间戳
--]]

    -- 获取 key 对应的 token 数量
    local tokenNumbers = redis.call('GET', KEYS[1])

    -- 如果对应 key 存在,且数量大于等于 frequency,直接 return
    if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[1]) then
        return false
    end

    -- 增加 token 数量,设置过期时间
    redis.call('INCR', KEYS[1])
    -- 如果原来的并不存在对应 key,需要设置过期时间
    if not tokenNumbers then
        redis.call('PEXPIREAT', KEYS[1], ARGV[2])
    end
    return true

范围刷新 Lua 脚本

--[[
KEYS[1] 需要限频的 key
ARGV[1] 频率
ARGV[2] 过期时间戳
ARGV[3] 当前时间戳
--]]

    -- 删除 tokens 里面第一个过期的 token
    local expireTime = redis.call('LINDEX', KEYS[1], 0)
    if expireTime and (tonumber(expireTime) <= tonumber(ARGV[3])) then
        redis.call('LPOP', KEYS[1])
    end

    -- 如果 token 的数量大于等于 frequency,直接 return
    if redis.call('LLEN', KEYS[1]) >= frequency) then
        return false
    end

    -- 把 expireAt 作为 value 添加到对于 key 的 list 里,并设置 list 的过期时间
    redis.call('RPUSH', KEYS[1], ARGV[2])
    redis.call('PEXPIREAT', KEYS[1], ARGV[2])
    return true

整合成一个 Lua 脚本

这里可以使用一个 Lua 脚本操作3种类型的限频。因此需要添加多一个限频类型的参数。

--[[
KEYS[1] 需要限频的 key
ARGV[1] 频率
ARGV[2] 过期时间戳
ARGV[3] 限频类型
ARGV[4] 当前时间戳
--]]

-- 若频率为0直接拒绝请求
local frequency = tonumber(ARGV[1])
if frequency == 0 then
    break
end

-- 范围刷新
if ARGV[3] == 'RANGE_REFRESH' then
    -- 删除 tokens 里面第一个过期的 token
    local expireTime = redis.call('LINDEX', KEYS[1], 0)
    if expireTime and (tonumber(expireTime) <= tonumber(ARGV[4])) then
        redis.call('LPOP', KEYS[1])
    end

    -- 如果 token 的数量大于等于 frequency,直接 return
    if redis.call('LLEN', KEYS[1]) >= frequency) then
        return false
    end

    -- 把 expireAt 作为 value 添加到对于 key 的 list 里,并设置 list 的过期时间
    redis.call('RPUSH', KEYS[1], ARGV[2])
    redis.call('PEXPIREAT', KEYS[1], ARGV[2])
-- 固定时间点刷新和固定延迟刷新
else
    -- 获取 key 对应的 token 数量
    local tokenNumbers = redis.call('GET', KEYS[1])

    -- 如果对应 key 存在,且数量大于等于 frequency,直接 return
    if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[1]) then
        return false
    end

    -- 增加 token 数量,设置过期时间
    redis.call('INCR', KEYS[1])
    -- 如果原来的并不存在对应 key,需要设置过期时间
    if not tokenNumbers then
        redis.call('PEXPIREAT', KEYS[1], ARGV[2])
    end
end

return true

在 Spring Boot 中使用脚本

定义脚本的 Bean

    /**
     * 限频 Lua 脚本 classpath 路径
     */
    private static final String FREQUENCY_LIMIT_LUA_SCRIPT_CLASS_PATH = "/redis/lua/FrequencyLimit.lua";
    
    /**
     * 限频 Redis 脚本 Bean
     */
    @Bean("frequencyLimitRedisScript")
    public RedisScript<Boolean> frequencyLimitRedisScript() {
        DefaultRedisScript<Boolean> frequencyLimitRedisScript = new DefaultRedisScript<>();
        frequencyLimitRedisScript.setResultType(Boolean.class);
        frequencyLimitRedisScript.setScriptSource(new ResourceScriptSource(
                new ClassPathResource(FREQUENCY_LIMIT_LUA_SCRIPT_CLASS_PATH)));
        return frequencyLimitRedisScript;
    }

使用 RedisTemplate 执行脚本

    @Override
    public Boolean isAllowed(FrequencyLimitType frequencyLimitType, String key, long frequency, long expireAt, long currentTime) {
        return redisTemplate.execute(frequencyLimitRedisScript, Collections.singletonList(key), 
        	frequency.toString(), String.valueOf(expireAt), frequencyLimitType.name(), String.valueOf(currentTime));
    }

多次限频下会出现的问题

如果使用多次限频器进行限频,比如一分钟2次(范围刷新),每天10次(固定时间点刷新),这时会出现一个问题:如果第一个限制通过,第二个限制没有通过,按照上面的实现,第一个限制不会被自动释放(list 里面就会多出一个元素) 。这里第一个限制只有1分钟还好,用户只需要等待1分钟。但是如果这个限制是1个小时,那么用户必须一个小时后才可能再次尝试获取第二个限制。

因此,我们需要在多次限制的情况下,如果一个限制没有通过,需要把前面已经通过的限制给释放掉。

解决办法

我们可以把前面每次通过的限频都记录起来,然后如果某个限频没有通过,则还原已经通过的限频。

Lua 脚本实现

这里与前面的脚本类似,只是现在需要一次性传入需要获取的全部限频请求,然后循环处理,使用 tokenMap 记录已经通过的限频的 key,在有一个限频失败时需要把 tokenMap 里面的 key 都还原。

同时返回值不再是 true 和 false,因为当失败时我们需要告知调用方具体哪个限频操作失败,因此返回失败的限频操作的下标(从0开始的),而成功时返回-1。

--[[
KEYS[i] 需要限频的 key
ARGV[(i - 1) * 3 + 1] 频率
ARGV[(i - 1) * 3 + 2] 过期时间戳
ARGV[(i - 1) * 3 + 3] 限频类型
ARGV[#KEYS * 3 + 1] 当前时间戳
--]]

-- 记录已经获取成功的 key,值是失败时需要进行的操作
local tokenMap = {}
local currentTime = tonumber(ARGV[#KEYS * 3 + 1])

-- 最后一个获取成功的 key 下标
local lastIndex = 0

-- 循环处理每个限频请求
for i = 1, #KEYS do
    -- 若频率为0直接拒绝请求
    local frequency = tonumber(ARGV[(i - 1) * 3 + 1])
    if frequency == 0 then
        break
    end

    -- 范围刷新
    if ARGV[(i - 1) * 3 + 3] == 'RANGE_REFRESH' then
        -- 删除 tokens 里面第一个过期的 token
        local expireTime = redis.call('LINDEX', KEYS[i], 0)
        if expireTime and (tonumber(expireTime) <= currentTime) then
            redis.call('LPOP', KEYS[i])
        end

        -- 如果 token 的数量大于等于 frequency,直接 break
        if redis.call('LLEN', KEYS[i]) >= frequency then
            break
        end

        -- 把 expireAt 作为 value 添加到对于 key 的 list 里,并设置 list 的过期时间
        local expireAt = ARGV[(i - 1) * 3 + 2]
        redis.call('RPUSH', KEYS[i], expireAt)
        redis.call('PEXPIREAT', KEYS[i], expireAt)
        tokenMap[KEYS[i]] = 'RPOP'

    -- 固定时间点刷新和固定延迟刷新
    else
        -- 获取 key 对应的 token 数量
        local tokenNumbers = redis.call('GET', KEYS[i])

        -- 如果对应 key 存在,且数量大于等于 frequency,直接 break
        if tokenNumbers and tonumber(tokenNumbers) >= frequency then
            break
        end

        -- 增加 token 数量,设置过期时间
        redis.call('INCR', KEYS[i])
        -- 如果原来的并不存在对应 key,需要设置过期时间
        if not tokenNumbers then
            redis.call('PEXPIREAT', KEYS[i], ARGV[(i - 1) * 3 + 2])
        end
        tokenMap[KEYS[i]] = 'INCRBY'
    end
    lastIndex = i
end

-- 判断是否所有限频都成功,若失败需要释放已经成功的 token
if lastIndex < #KEYS then
    for key, token in pairs(tokenMap) do
        if token == 'RPOP' then
            redis.call('RPOP', key)
        else
            redis.call('INCRBY', key, -1)
        end
    end
    return lastIndex
end

return -1

扩展1:封装成限频器

直接使用 lua 脚本还是比较麻烦,需要配合 RedisTemplate 等,因此我们封装成限频器。

实现代码

这里实现了一个可以同时多个请求多个限频的方法和一个只能请求一个限频的方法,同时我们会给每个 key 添加对应类型限频的前缀,既可以防止相同 key 在不同限频类型下冲突,也可以防止与项目中其他 Redis key 冲突。

/**
 * 默认的限频器
 *
 * @author xhsf
 * @create 2020/12/18 15:41
 */
public class DefaultFrequencyLimiter implements FrequencyLimiter {

    /**
     * 固定时间点刷新的限频 Redis Key 前缀
     */
    private static final String FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX =
            "frequency-limit:fixed-point-refresh:";

    /**
     * 固定延迟刷新的限频 Redis Key 前缀
     */
    private static final String FIXED_DELAY_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX =
            "frequency-limit:fixed-delay-refresh:";

    /**
     * 范围刷新的限频 Redis Key 前缀
     */
    private static final String RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX =
            "frequency-limit:range-fresh:";

    /**
     * 被允许时的 Lua 脚本返回值
     */
    private static final int LUA_RETURN_VALUE_WHEN_ALLOWED = -1;

    /**
     * 可同时获取多个 key 的限频 Lua 脚本
     */
    private static final String REPEATABLE_FREQUENCY_LIMIT_LUA =
    "--[[\n" +
            "KEYS[i] 需要限频的 key\n" +
            "ARGV[(i - 1) * 3 + 1] 频率\n" +
            "ARGV[(i - 1) * 3 + 2] 过期时间戳\n" +
            "ARGV[(i - 1) * 3 + 3] 限频类型\n" +
            "ARGV[#KEYS * 3 + 1] 当前时间戳\n" +
            "--]]\n" +
            "\n" +
            "-- 记录已经获取成功的 key,值是失败时需要进行的操作\n" +
            "local tokenMap = {}\n" +
            "local currentTime = tonumber(ARGV[#KEYS * 3 + 1])\n" +
            "\n" +
            "-- 最后一个获取成功的 key 下标\n" +
            "local lastIndex = 0\n" +
            "\n" +
            "-- 循环处理每个限频请求\n" +
            "for i = 1, #KEYS do\n" +
            "    -- 若频率为0直接拒绝请求\n" +
            "    local frequency = tonumber(ARGV[(i - 1) * 3 + 1])\n" +
            "    if frequency == 0 then\n" +
            "        break\n" +
            "    end\n" +
            "\n" +
            "    -- 范围刷新\n" +
            "    if ARGV[(i - 1) * 3 + 3] == 'RANGE_REFRESH' then\n" +
            "        -- 删除 tokens 里面第一个过期的 token\n" +
            "        local expireTime = redis.call('LINDEX', KEYS[i], 0)\n" +
            "        if expireTime and (tonumber(expireTime) <= currentTime) then\n" +
            "            redis.call('LPOP', KEYS[i])\n" +
            "        end\n" +
            "\n" +
            "        -- 如果 token 的数量大于等于 frequency,直接 break\n" +
            "        if redis.call('LLEN', KEYS[i]) >= frequency then\n" +
            "            break\n" +
            "        end\n" +
            "\n" +
            "        -- 把 expireAt 作为 value 添加到对于 key 的 list 里,并设置 list 的过期时间\n" +
            "        local expireAt = ARGV[(i - 1) * 3 + 2]\n" +
            "        redis.call('RPUSH', KEYS[i], expireAt)\n" +
            "        redis.call('PEXPIREAT', KEYS[i], expireAt)\n" +
            "        tokenMap[KEYS[i]] = 'RPOP'\n" +
            "\n" +
            "    -- 固定时间点刷新和固定延迟刷新\n" +
            "    else\n" +
            "        -- 获取 key 对应的 token 数量\n" +
            "        local tokenNumbers = redis.call('GET', KEYS[i])\n" +
            "\n" +
            "        -- 如果对应 key 存在,且数量大于等于 frequency,直接 break\n" +
            "        if tokenNumbers and tonumber(tokenNumbers) >= frequency then\n" +
            "            break\n" +
            "        end\n" +
            "\n" +
            "        -- 增加 token 数量,设置过期时间\n" +
            "        redis.call('INCR', KEYS[i])\n" +
            "        -- 如果原来的并不存在对应 key,需要设置过期时间\n" +
            "        if not tokenNumbers then\n" +
            "            redis.call('PEXPIREAT', KEYS[i], ARGV[(i - 1) * 3 + 2])\n" +
            "        end\n" +
            "        tokenMap[KEYS[i]] = 'INCRBY'\n" +
            "    end\n" +
            "    lastIndex = i\n" +
            "end\n" +
            "\n" +
            "-- 判断是否所有限频都成功,若失败需要释放已经成功的 token\n" +
            "if lastIndex < #KEYS then\n" +
            "    for key, token in pairs(tokenMap) do\n" +
            "        if token == 'RPOP' then\n" +
            "            redis.call('RPOP', key)\n" +
            "        else\n" +
            "            redis.call('INCRBY', key, -1)\n" +
            "        end\n" +
            "    end\n" +
            "    return lastIndex\n" +
            "end\n" +
            "\n" +
            "return -1";

    /**
     * StringRedisTemplate
     */
    private final StringRedisTemplate stringRedisTemplate;

    /**
     * 可获取多个 key 的限频脚本
     */
    private final RedisScript<Long> repeatableFrequencyLimitRedisScript;

    public DefaultFrequencyLimiter(StringRedisTemplate stringRedisTemplate) {
        this.stringRedisTemplate = stringRedisTemplate;
        // 可获取多个 key 的限频脚本
        DefaultRedisScript<Long> repeatableFrequencyLimitRedisScript = new DefaultRedisScript<>();
        repeatableFrequencyLimitRedisScript.setResultType(Long.class);
        repeatableFrequencyLimitRedisScript.setScriptText(REPEATABLE_FREQUENCY_LIMIT_LUA);
        this.repeatableFrequencyLimitRedisScript = repeatableFrequencyLimitRedisScript;
    }

    /**
     * 查询一个键是否被允许操作
     *
     * @param frequencyLimitType 限频类型
     * @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
     * @param frequency 频率
     * @param expireAt 过期时间戳
     * @param currentTime 当前时间戳
     * @return 是否允许
     */
    @Override
    public boolean isAllowed(FrequencyLimitType frequencyLimitType, String key, long frequency, long expireAt,
                             long currentTime) {
        String[] args = {String.valueOf(frequency), String.valueOf(expireAt), frequencyLimitType.name(),
                String.valueOf(currentTime)};
        return isAllowed(Collections.singletonList(key), args) == LUA_RETURN_VALUE_WHEN_ALLOWED;
    }

    /**
     * 查询多个键是否被允许操作
     *
     * 只要其中一个不被允许,就会失败,并释放已经获取的 tokens
     *
     * @param keys 需要限频的键
     * @param args 参数列表,格式为 [frequency1, expireAt1, frequencyLimitType1,
     *             frequency2, expireAt2, frequencyLimitType2, ..., frequencyN, expireAtN, frequencyLimitTypeN,
     *             currentTime]
     * @return 是否允许,-1表示允许,其他表示获取失败时的下标
     */
    public int isAllowed(List<String> keys, String[] args) {
        for (int i = 0; i < keys.size(); i++) {
            String frequencyLimitType = args[i * 3 + 2];
            if (frequencyLimitType.equals(FrequencyLimitType.FIXED_POINT_REFRESH.name())) {
                keys.set(i, FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + keys.get(i));
            } else if (frequencyLimitType.equals(FrequencyLimitType.RANGE_REFRESH.name())){
                keys.set(i, RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + keys.get(i));
            } else {
                keys.set(i, FIXED_DELAY_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + keys.get(i));
            }
        }
        return stringRedisTemplate.execute(repeatableFrequencyLimitRedisScript, keys, args).intValue();
    }

}

接口 FrequencyLimiter

/**
 * 限频器
 *
 * @author xhsf
 * @create 2020/12/18 15:41
 */
public interface FrequencyLimiter {

    /**
     * 查询一个键是否被允许操作
     *
     * @param frequencyLimitType 限频类型
     * @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
     * @param frequency 频率
     * @param expireAt 过期时间戳
     * @param currentTime 当前时间戳
     * @return 是否允许
     */
     boolean isAllowed(FrequencyLimitType frequencyLimitType, String key, long frequency, long expireAt,
                       long currentTime);

    /**
     * 查询多个键是否被允许操作
     *
     * 只要其中一个不被允许,就会失败,并释放已经获取的 tokens
     *
     * @param keys 需要限频的键
     * @param args 参数列表,格式为 [frequency1, expireAt1, frequencyLimitType1,
     *             frequency2, expireAt2, frequencyLimitType2, ..., frequencyN, expireAtN, frequencyLimitTypeN,
     *             currentTime]
     * @return 是否允许,-1表示允许,其他表示获取失败时的下标
     */
    int isAllowed(List<String> keys, String[] args);

}

限频类型 FrequencyLimitType

/**
 * 描述:限频类型
 *
 * @author xhsf
 * @create 2020/12/18 21:10
 */
public enum FrequencyLimitType {

    /**
     * 固定时间点刷新
     */
    FIXED_POINT_REFRESH,

    /**
     * 固定延迟刷新
     */
    FIXED_DELAY_REFRESH,

    /**
     * 范围刷新
     */
    RANGE_REFRESH

}

扩展2:封装成注解形式

使用限频器还是挺麻烦的,特别是使用 isAllowed(List<String> keys, String[] args) 方法的时候,因此需要进一步封装成注解。

使用注解可以降低对业务代码的入侵,提高易用性。这里 key 实现了 EL + 占位符表达式;同时支持多个限频类型的多个注解;固定时间点刷新注解使用了 cron 表达式表示,使用时只需要编写 cron 表达式即可;errorMessage 也可以使用 EL 表达;满足大部分需求。

推荐使用注解形式。

EL 表达式及解析器

可以参考文章SpEL你感兴趣的实现原理浅析spring-expression

如表达式=sms:auth-code:#{#user.phone},参数=(UserDTO user),其中user = {name: xxx, phone: 1333333333}。则 sms:auth-code:#{#user.phone} 解析结果为 sms:auth-code:1333333333

cron 表达式

cron 表达式可以参考文章cron表达式详解根据CronSequenceGenerator计算cron表达式的时间

0 0 6 * * *,表示每天早上6点;33 15 23 * * *,表示每天23点15分33秒;0 0/5 * * * *,表示每个小时的第0, 5, 10, 15, ..., 55分钟。

占位符表达式

其实就是 Java 的 MessageFormat.format(key, parameters)

MessageFormat.format("sms:auth-code:{0}:{1}", "fixed-refresh", "#{#phone}") -> sms:auth-code:fixed-refresh:#{#phone}

三个限频类型对应的注解

固定时间点刷新注解

/**
 * 描述: 固定时间点刷新限频注解,由 {@link FrequencyLimitAspect} 实现
 *
 * @author xhsf
 * @create 2020-12-18 21:16
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(FixedPointRefreshFrequencyLimit.List.class)
public @interface FixedPointRefreshFrequencyLimit {

    /**
     * 限频的 key,支持 EL 表达式,如#{#user.phone}
     * @see #parameters() 配合该参数,支持占位符填充
     *      如 value = "user:{0}:phone", parameters="#{#user.id}" 会转换成 value = "user:#{#user.id}:phone"
     */
    String key();

    /**
     * 填充到占位符的参数
     */
    String[] parameters() default {};

    /**
     * cron 表达式
     * 1.cron一共有7位,但是最后一位是年(1970-2099),可以留空,所以我们可以写6位,按顺序依次为
     *
     * 秒(0~59)
     * 分钟(0~59)
     * 小时(0~23)
     * 天(月)(0~31,但是你需要考虑你月的天数)
     * 月(0~11)
     * 星期(1~7 1=SUN,MON,TUE,WED,THU,FRI,SAT)
     *
     * cron的一些特殊符号
     * (*)星号:
     * 可以理解为每的意思,每秒,每分,每天,每月,每年
     * (?)问号:
     * 问号只能出现在日期和星期这两个位置,表示这个位置的值不确定,每天3点执行,所以第六位星期的位置,我们是不需要关注的,
     * 就是不确定的值。同时:日期和星期是两个相互排斥的元素,通过问号来表明不指定值。比如,1月10日,比如是星期1,
     * 如果在星期的位置是另指定星期二,就前后冲突矛盾了。
     * (-)减号:
     * 表达一个范围,如在小时字段中使用“10-12”,则表示从10到12点,即10,11,12
     * (,)逗号:
     * 表达一个列表值,如在星期字段中使用“1,2,4”,则表示星期一,星期二,星期四
     * (/)斜杠:如:x/y,x是开始值,y是步长,比如在第一位(秒) 0/15就是,从0秒开始,每15秒,最后就是0,15,30,45,60
     * 另: *\/y,等同于0/y
     *
     * @see org.springframework.scheduling.support.CronSequenceGenerator
     */
    String cron();

    /**
     * 频率
     */
    long frequency();

    /**
     * 当获取 token 失败时的错误信息,支持 EL 表达式
     */
    String errorMessage() default "Too many request.";

    /**
     * 描述: 限频注解数组
     *
     * @author xhsf
     * @create 2020-12-18 21:16
     */
    @Target({ElementType.METHOD})
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    @interface List {

        /**
         * 限频的注解数组
         */
        FixedPointRefreshFrequencyLimit[] value();

    }
}

固定延迟刷新注解

/**
 * 描述: 限频注解数组
 *
 * @author xhsf
 * @create 2020-12-18 21:16
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface FrequencyLimits {

    /**
     * 限频的注解数组
     */
    FrequencyLimit[] value();

}

范围刷新注解

/**
 * 描述: 范围限频注解,由 {@link FrequencyLimitAspect} 实现
 *
 * @author xhsf
 * @create 2020-12-18 21:16
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(RangeRefreshFrequencyLimit.List.class)
public @interface RangeRefreshFrequencyLimit {

    /**
     * 限频的 key,支持 EL 表达式,如#{#user.phone}
     * @see #parameters() 配合该参数,支持占位符填充
     *      如 value = "user:{0}:phone", parameters="#{#user.id}" 会转换成 value = "user:#{#user.id}:phone"
     */
    String key();

    /**
     * 填充到占位符的参数
     */
    String[] parameters() default {};

    /**
     * 频率
     */
    long frequency();

    /**
     * 刷新时间
     */
    long refreshTime();

    /**
     * 时间单位,默认为秒
     */
    TimeUnit timeUnit() default TimeUnit.SECONDS;

    /**
     * 当获取 token 失败时的错误信息,支持 EL 表达式
     */
    String errorMessage() default "Too many request.";

    /**
     * 描述: 限频注解数组
     *
     * @author xhsf
     * @create 2020-12-18 21:16
     */
    @Target({ElementType.METHOD})
    @Retention(RetentionPolicy.RUNTIME)
    @Documented
    @interface List {
        /**
         * 限频的注解数组
         */
        RangeRefreshFrequencyLimit[] value();
    }

}

FrequencyLimitAspect 切面实现类

这里使用 FrequencyLimiter 去实现,错误时返回项目里统一的 Result 对象(可以根据具体需求抛出异常或者封装其他对象)。

/**
 * 描述:限频切面,配合 {@link FixedDelayRefreshFrequencyLimit}、{@link FixedPointRefreshFrequencyLimit}、
 *                  {@link RangeRefreshFrequencyLimit} 可以便捷的使用限频
 *
 * @author xhsf
 * @create 2020/12/18 21:10
 */
@Aspect
public class FrequencyLimitAspect {

    /**
     * 限频器
     */
    private final FrequencyLimiter frequencyLimiter;

    public FrequencyLimitAspect(FrequencyLimiter frequencyLimiter) {
        this.frequencyLimiter = frequencyLimiter;
    }

    /**
     * 限制请求频率
     *
     * @errorCode TooManyRequests: 请求太频繁
     *
     * @param joinPoint ProceedingJoinPoint
     * @return Object
     */
    @Around("@annotation(FixedDelayRefreshFrequencyLimit.List) || @annotation(FixedDelayRefreshFrequencyLimit) " +
            "|| @annotation(FixedPointRefreshFrequencyLimit.List) || @annotation(FixedPointRefreshFrequencyLimit) " +
            "|| @annotation(RangeRefreshFrequencyLimit.List) || @annotation(RangeRefreshFrequencyLimit)")
    public Object isAllowed(ProceedingJoinPoint joinPoint) throws Throwable {
        List<Annotation> frequencyLimits = getFrequencyLimits(joinPoint);
        int notAllowIndex = isAllowed(joinPoint, frequencyLimits);
        if (notAllowIndex != -1) {
            String errorMessage = getErrorMessage(joinPoint, frequencyLimits.get(notAllowIndex));
            return Result.fail(ErrorCodeEnum.TOO_MANY_REQUESTS, errorMessage);
        }

        // 执行业务逻辑
        return joinPoint.proceed();
    }

    /**
     * 是否允许
     *
     * @param joinPoint ProceedingJoinPoint
     * @param frequencyLimits List<Annotation>
     * @return 是否允许,-1表示允许,其他表示获取失败时的下标
     */
    private int isAllowed(ProceedingJoinPoint joinPoint, List<Annotation> frequencyLimits) {
        List<String> keys = new ArrayList<>(frequencyLimits.size());
        String[] args = new String[frequencyLimits.size() * 3 + 1];
        Date now = new Date();
        long currentTime = now.getTime();
        args[frequencyLimits.size() * 3] = String.valueOf(currentTime);
        for (int i = 0; i < frequencyLimits.size(); i++) {
            // 范围刷新限频
            if (frequencyLimits.get(i) instanceof RangeRefreshFrequencyLimit) {
                RangeRefreshFrequencyLimit rangeRefreshFrequencyLimit =
                        (RangeRefreshFrequencyLimit) frequencyLimits.get(i);
                keys.add(getKey(joinPoint, rangeRefreshFrequencyLimit.key(), rangeRefreshFrequencyLimit.parameters()));
                args[i * 3] = String.valueOf(rangeRefreshFrequencyLimit.frequency());
                args[i * 3 + 1] = String.valueOf(FrequencyLimiterUtils.getExpireAt(
                        rangeRefreshFrequencyLimit.refreshTime(), rangeRefreshFrequencyLimit.timeUnit(), currentTime));
                args[i * 3 + 2] = FrequencyLimitType.RANGE_REFRESH.name();
            }
            // 固定时间点刷新限频
            else if (frequencyLimits.get(i) instanceof FixedPointRefreshFrequencyLimit) {
                FixedPointRefreshFrequencyLimit fixedPointRefreshFrequencyLimit =
                        (FixedPointRefreshFrequencyLimit) frequencyLimits.get(i);
                keys.add(getKey(joinPoint, fixedPointRefreshFrequencyLimit.key(),
                        fixedPointRefreshFrequencyLimit.parameters()));
                args[i * 3] = String.valueOf(fixedPointRefreshFrequencyLimit.frequency());
                args[i * 3 + 1] = String.valueOf(FrequencyLimiterUtils.getExpireAt(
                        fixedPointRefreshFrequencyLimit.cron(), now));
                args[i * 3 + 2] = FrequencyLimitType.FIXED_POINT_REFRESH.name();
            }
            // 固定延迟刷新限频
            else {
                FixedDelayRefreshFrequencyLimit fixedDelayRefreshFrequencyLimit =
                        (FixedDelayRefreshFrequencyLimit) frequencyLimits.get(i);
                keys.add(getKey(joinPoint, fixedDelayRefreshFrequencyLimit.key(),
                        fixedDelayRefreshFrequencyLimit.parameters()));
                args[i * 3] = String.valueOf(fixedDelayRefreshFrequencyLimit.frequency());
                args[i * 3 + 1] = String.valueOf(FrequencyLimiterUtils.getExpireAt(
                        fixedDelayRefreshFrequencyLimit.refreshTime(), fixedDelayRefreshFrequencyLimit.timeUnit(),
                        currentTime));
                args[i * 3 + 2] = FrequencyLimitType.FIXED_DELAY_REFRESH.name();
            }
        }

        // 判断是否允许
        return frequencyLimiter.isAllowed(keys, args);
    }

    /**
     * 获取限频注解列表
     *
     * @param joinPoint ProceedingJoinPoint
     * @return List<Annotation>
     */
    private List<Annotation> getFrequencyLimits(ProceedingJoinPoint joinPoint) {
        Method method;
        try {
            MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
            method = joinPoint.getTarget()
                    .getClass()
                    .getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
            Annotation[] rangeRefreshFrequencyLimits =
                    method.getAnnotationsByType(RangeRefreshFrequencyLimit.class);
            Annotation[] fixedPointRefreshFrequencyLimits =
                    method.getAnnotationsByType(FixedPointRefreshFrequencyLimit.class);
            Annotation[] fixedDelayRefreshFrequencyLimits =
                    method.getAnnotationsByType(FixedDelayRefreshFrequencyLimit.class);
            List<Annotation> frequencyLimits = new ArrayList<>(rangeRefreshFrequencyLimits.length
                    + fixedPointRefreshFrequencyLimits.length + fixedDelayRefreshFrequencyLimits.length);
            frequencyLimits.addAll(Arrays.asList(rangeRefreshFrequencyLimits));
            frequencyLimits.addAll(Arrays.asList(fixedPointRefreshFrequencyLimits));
            frequencyLimits.addAll(Arrays.asList(fixedDelayRefreshFrequencyLimits));
            return frequencyLimits;
        } catch (NoSuchMethodException ignored) {
        }
        return new ArrayList<>();
    }

    /**
     * 获取错误信息
     *
     * @param joinPoint ProceedingJoinPoint
     * @param frequencyLimit frequencyLimit
     * @return 错误信息
     */
    private String getErrorMessage(ProceedingJoinPoint joinPoint, Annotation frequencyLimit) {
        String errorMessageExpression;
        if (frequencyLimit instanceof RangeRefreshFrequencyLimit) {
            errorMessageExpression = ((RangeRefreshFrequencyLimit) frequencyLimit).errorMessage();
        } else if (frequencyLimit instanceof FixedPointRefreshFrequencyLimit) {
            errorMessageExpression = ((FixedPointRefreshFrequencyLimit) frequencyLimit).errorMessage();
        } else {
            errorMessageExpression = ((FixedDelayRefreshFrequencyLimit) frequencyLimit).errorMessage();
        }
        return SpELUtils.getExpressionValue(errorMessageExpression, joinPoint);
    }

    /**
     * 获取限频的键
     *
     * @param joinPoint ProceedingJoinPoint
     * @param key Key
     * @param parameters parameters
     * @return 限频键
     */
    private String getKey(ProceedingJoinPoint joinPoint, String key, Object[] parameters) {
        // 构造键表达式
        String keyExpression = key;
        // 若有参数则需要填充参数
        if (parameters.length > 0) {
            keyExpression = MessageFormat.format(key, parameters);
        }
        // 获取键
        return SpELUtils.getExpressionValue(keyExpression, joinPoint);
    }

}

FrequencyLimiterUtils

/**
 * 描述:限频器工具类
 *
 * @author xhsf
 * @create 2020/12/22 19:09
 */
public class FrequencyLimiterUtils {

    /**
     * 通过 cron 表达式获取过期时间
     *
     * @param cron cron 表达式
     * @param now 当前时间
     * @return 过期时间
     */
    public static long getExpireAt(String cron, Date now) {
        return CronUtils.next(cron, now).getTime();
    }

    /**
     * 通过 time、timeUnit、currentTime 获取过期时间
     *
     * @param refreshTime 刷新时间
     * @param timeUnit    时间单位
     * @param currentTime 当前时间
     * @return 过期时间
     */
    public static long getExpireAt(long refreshTime, TimeUnit timeUnit, long currentTime) {
        return TimeoutUtils.toMillis(refreshTime, timeUnit) + currentTime;
    }

}

CronUtils

/**
 * 描述:Cron 表达式工具类
 *
 * @author xhsf
 * @create 2020/12/22 19:19
 */
public class CronUtils {

    /**
     * cron 表达式的 CronSequenceGenerator 的缓存
     */
    private static final Map<String, CronSequenceGenerator> cronSequenceGeneratorMap = new HashMap<>();

    /**
     * 获取下一个 cron 表达式的时间戳
     *
     * @param cron cron 表达式
     * @param now 当前时间
     * @return 下一个时间
     */
    public static Date next(String cron, Date now) {
        CronSequenceGenerator cronSequenceGenerator = cronSequenceGeneratorMap.getOrDefault(
                cron, cronSequenceGeneratorMap.put(cron, new CronSequenceGenerator(cron)));
        return cronSequenceGenerator.next(now);
    }

}

SpELUtils

/**
 * 描述:Spring Expression Language 工具类
 *
 * @author xhsf
 * @create 2020/12/22 19:14
 */
public class SpELUtils {

    /**
     * EL 表达式解析器
     */
    private static final ExpressionParser expressionParser = new SpelExpressionParser();

    /**
     * 获取表达式的值
     *
     * @param expression 表达式
     * @param joinPoint ProceedingJoinPoint
     * @return value
     */
    public static String getExpressionValue(String expression, ProceedingJoinPoint joinPoint) {
        // 获得方法参数的 Map
        String[] parameterNames = ((CodeSignature) joinPoint.getSignature()).getParameterNames();
        Object[] parameterValues = joinPoint.getArgs();
        Map<String, Object> parameterMap = new HashMap<>();
        for (int i = 0; i < parameterNames.length; i++) {
            parameterMap.put(parameterNames[i], parameterValues[i]);
        }

        // 解析 EL 表达式
        return getExpressionValue(expression, parameterMap);
    }

    /**
     * 获取 EL 表达式的值
     *
     * @param elExpression EL 表达式
     * @param parameterMap 参数名-值 Map
     * @return 表达式的值
     */
    public static String getExpressionValue(String elExpression, Map<String, Object> parameterMap) {
        Expression expression = expressionParser.parseExpression(elExpression, new TemplateParserContext());
        EvaluationContext context = new StandardEvaluationContext();
        for (Map.Entry<String, Object> entry : parameterMap.entrySet()) {
            context.setVariable(entry.getKey(), entry.getValue());
        }
        return expression.getValue(context, String.class);
    }

}

使用注解示例

    // 最多2次,刷新时间60秒
    @RangeRefreshFrequencyLimit(key = "email:auth-code:#{#createAndSendEmailAuthCodePO.email}",
            frequency = 2, refreshTime = 60)
    // 60秒最多2次
    @FixedDelayRefreshFrequencyLimit(key = "email:auth-code:#{#createAndSendEmailAuthCodePO.email}",
            frequency = 2, refreshTime = 60)
    // 最多10次,每小时的第[0, 5, 10, 15, ..., 55]刷新次数
    @FixedPointRefreshFrequencyLimit(key = "email:auth-code:#{#createAndSendEmailAuthCodePO.email}",
            frequency = 10, cron = "0 0/5 * * * *")
    public Result<Void> createAndSendEmailAuthCode(CreateAndSendEmailAuthCodePO createAndSendEmailAuthCodePO) {
    // 业务代码
    }

记得注册切面

    /**
     * 限频器
     *
     * @param stringRedisTemplate StringRedisTemplate
     * @return FrequencyLimiter
     */
    @Bean
    public FrequencyLimiter frequencyLimiter(StringRedisTemplate stringRedisTemplate) {
        return new DefaultFrequencyLimiter(stringRedisTemplate);
    }

    /**
     * 限频切面
     *
     * @param frequencyLimiter FrequencyLimiter
     * @return FrequencyLimitAspect
     */
    @Bean
    public FrequencyLimitAspect frequencyLimitAspect(FrequencyLimiter frequencyLimiter) {
        return new FrequencyLimitAspect(frequencyLimiter);
    }

扩展3:请求到达顺序问题

解决的思路,避免请求带时间戳参数。

范围刷新限频

问题描述

如果有相同 key 的请求A和请求B,这时候请求A的 expireAt 是 12:23:01,请求B的 expireAt 是 12:23:03,但是因为网络延迟等原因,请求B提前到达,这时候 key 的 expireAt 是 12:23:03,但是当请求A到达后, key 的 expireAt 被修改为 12:23:01。到达 12:23:01 时,虽然请求B 的 token 还不能刷新,但是因为 key 已经过期了,这时候请求B 的 token 也会被清除,也就是 key 对应的 list 已经被删除了。原本用户在 12:23:01 的可请求次数应该是1,但是因为请求B 的 token 被清除,可请求次数变为2。

解决办法

不再直接传入 expireAt,而是传入 timeout(refreshTime),然后在 Lua 里使用 TIME 命令获取当前时间,加上 timeout 作为 expireAt。

固定延迟刷新限频

问题描述

如果 key 不存在,对应 key 的请求A在11:59:50发送,expireAt 是 12:00:00,但是因为网络延迟等原因,请求 A 在 12:00:01 才到达,这时候相当于这次限频不被限制。因为 key 在执行完 PEXPIREAT 后马上过期了,可请求次数又恢复到最大。

解决办法

不再直接传入 expireAt,而是传入 timeout(delay),然后在 Lua 里使用 PEXPIRE 命令 + timeout去设置过期时间。

固定时间点刷新限频

问题描述

同固定延迟刷新限频。

解决办法

如果可以在 Lua 脚本里解析 cron 表达式,那么这个问题是可以解决的。但是目前还没找到 Lua 里解析 cron 比较好的库,因此我们去判断 expireAt 的值,如果小于当前时间 TIME,直接拒绝请求。

这里会导致新的问题,如果时间点太小,比如 0/1 * * * * *(一秒一个刷新时间点) 可能会因为 expireAt 跟不上网络传输速度而一直失败,所以最好不要设置太小间隔的时间。如果需要很小的刷新间隔,可用使用固定延迟刷新或者范围刷新代替。

解决后的 Lua 脚本

这里还需要改动 Aspect 里面范围刷新和固定延迟刷新关于时间的计算方式,还有不再需要传入当前时间参数。

--[[
KEYS[i] 需要限频的 key
ARGV[(i - 1) * 3 + 1] 频率
ARGV[(i - 1) * 3 + 2] expireAt or refreshTime or delayTime
ARGV[(i - 1) * 3 + 3] 限频类型
--]]

-- 由于调用了时间函数,因此需要调用此函数,让 Redis 只复制写命令,避免主从不一致
redis.replicate_commands()

-- 记录已经获取成功的 key,值是失败时需要进行的操作
local tokenMap = {}

-- 获取当前时间
local now = redis.call('TIME')
local currentTime = tonumber(now[1]) * 1000 + math.ceil(tonumber(now[2]) / 1000)

-- 最后一个获取成功的 key 下标
local lastIndex = 0

-- 循环处理每个限频请求
for i = 1, #KEYS do
    -- 若频率为0直接拒绝请求
    local frequency = tonumber(ARGV[(i - 1) * 3 + 1])
    if frequency == 0 then
        break
    end

    -- 范围刷新
    if ARGV[(i - 1) * 3 + 3] == 'RANGE_REFRESH' then
        -- 删除 tokens 里面第一个过期的 token
        local expireTime = redis.call('LINDEX', KEYS[i], 0)
        if expireTime and (tonumber(expireTime) <= currentTime) then
            redis.call('LPOP', KEYS[i])
        end

        -- 如果 token 的数量大于等于 frequency,直接 break
        if redis.call('LLEN', KEYS[i]) >= frequency then
            break
        end

        -- 把 expireAt 作为 value 添加到对于 key 的 list 里,并设置 list 的过期时间
        local expireAt = currentTime + tonumber(ARGV[(i - 1) * 3 + 2])
        redis.call('RPUSH', KEYS[i], expireAt)
        redis.call('PEXPIREAT', KEYS[i], expireAt)
        tokenMap[KEYS[i]] = 'RPOP'

    -- 固定延迟刷新
    elseif ARGV[(i - 1) * 3 + 3] == 'FIXED_DELAY_REFRESH' then
        -- 获取 key 对应的 token 数量
        local tokenNumbers = redis.call('GET', KEYS[i])

        -- 如果对应 key 存在,且数量大于等于 frequency,直接 break
        if tokenNumbers and tonumber(tokenNumbers) >= frequency then
            break
        end

        -- 增加 token 数量,设置过期时间
        redis.call('INCR', KEYS[i])
        -- 如果原来的并不存在对应 key,需要设置过期时间
        if not tokenNumbers then
            redis.call('PEXPIRE', KEYS[i], ARGV[(i - 1) * 3 + 2])
        end
        tokenMap[KEYS[i]] = 'INCRBY'

    -- 固定时间点刷新
    else
        -- expireAt 小于当前时间直接拒绝
        local expireAt = tonumber(ARGV[(i - 1) * 3 + 2])
        if expireAt <= currentTime then
            break
        end

        -- 获取 key 对应的 token 数量
        local tokenNumbers = redis.call('GET', KEYS[i])

        -- 如果对应 key 存在,且数量大于等于 frequency,直接 break
        if tokenNumbers and tonumber(tokenNumbers) >= frequency then
            break
        end

        -- 增加 token 数量,设置过期时间
        redis.call('INCR', KEYS[i])
        -- 如果原来的并不存在对应 key,需要设置过期时间
        if not tokenNumbers then
            redis.call('PEXPIREAT', KEYS[i], ARGV[(i - 1) * 3 + 2])
        end
        tokenMap[KEYS[i]] = 'INCRBY'
    end
    lastIndex = i
end

-- 判断是否所有限频都成功,若失败需要释放已经成功的 token
if lastIndex < #KEYS then
    for key, token in pairs(tokenMap) do
        if token == 'RPOP' then
            redis.call('RPOP', key)
        else
            redis.call('INCRBY', key, -1)
        end
    end
    return lastIndex
end

return -1

时间复杂度分析

范围刷新

时间复杂度O(1),使用了 LINDEXLPOPLLENRPUSHPEXPIREAT 命令。其中 LINDEX 由于每次只请求一次获取 list 第一个元素的操作,因此复杂度还是 O(1)。

固定时间点刷新和固定延迟刷新

时间复杂度为O(1),只使用了 GETINCRPEXPIREAT