更新2020-12-22
请看文章【Redis】如何使用 Redis 实现一个通用的限频器,里面的实现更加优雅,带图说明,里面限频的时间复杂度都是O(1)。
前言
在项目中我们经常遇到需要进行限频的地方,比如短信验证码每分钟只能发送一次,每天只能发送10次。我们可以使用 Redis 来实现。
基本实现原理
假设24小时可以请求6次。
固定时间点刷新是,我们设置比如晚上24点或者早上6点刷新次数,这样到达时间点,用户的可请求次数就会刷新成6次。
固定延迟刷新是,在用户第一次请求开始那个时间点开始计算,延迟24小时后一次性刷新次数为6。
范围刷新是,在用户第一次请求开始那个时间点开始计算,如果用户这24小时内请求了6次,请求时间分别为第【0、6、7、8、15、20】个小时,那么第20个小时时用户的可请求次数为0,用户这时停止发送请求。第24小时时次数刷新为1(因为距离当前24小时前发送了第一次请求),再过6小时(第30个小时)次数刷新为2(因为距离当前24小时前发送了第二次请求),再过1小时(第31个小时)刷新为3(因为距离当前24小时前发送了第三次请求)。
如果可请求次数为1,那么固定延迟刷新和范围刷新效果相同。
参数
- key 需要限频键
- frequency 频率
- time 限频时间
固定时间点刷新原理
这个比较简单,比如每天晚上24点刷新,我们可以计算当前时间到晚上24点还有多少时间作为 Redis key 的 timeout,然后判断 Redis 该 key 的值是不是大于等于 frequency(空值当成0),若是则拒绝,否则使用 INCR 增加 key 的值,并设置过期时间。
范围刷新限频原理
我们的实现原理是,当用户请求时,需要申请一个 token,这个 token 会有过期时间,只要 token 达到限制的数量,用户的请求就会被拒绝。
我们使用 Redis 的 Set 结构存放 token,当 token 数量小于限定数量可以发放 token。然而 Redis 不支持 Set 里的元素过期,而我们必须使 token 定时过期。因此我们每次把 token 存放进入 Set 时,会把 token 作为 key 添加到 Redis 的 String 结构里,并设置过期时间,这样只需要判断是否有存在与 Set 里面的 token 相同的 key 即可判断 token 是否过期,若过期则使用 SREM 命令清除即可。当然,这样的前提是必须保证 token 唯一,否则会出现错误的结果。
这里我们也会设置 Set 的过期时间,这样就不需要手动清除过期的 Set。
该文章主要类的结构图
FrequencyLimiter 是核心接口,下面的 FixedPointRefreshFrequencyLimiter、RangeRefreshFrequencyLimiter、RepeatableFrequencyLimiter 都是具体的实现类,FrequencyLimiterManager 只是把它们整合起来。FrequencyLimitAspect 使用 FrequencyLimiter 的实现类去实现注解使用限频。
范围刷新限频实现
实现方式1:Java + Redis
这里直接调用 Redis 命令实现,使用 UUID 保证 token 的唯一。但是该方法并不是原子操作,在并发下可能会出现拿到比 frequency 多的 token。
/**
* 查询一个键是否被允许操作
*
* 如短信验证码服务使用 isAllowed("15333333333", 10, 1, TimeUnit.DAYS); 表示一天只能发送10次短信验证码
* 如短信验证码服务使用 isAllowed("15333333333", 1, 1, TimeUnit.MINUTES); 表示一分钟只能发送1次短信验证码
*
* @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
* @param frequency 频率
* @param time 限频时间
* @param unit 时间单位
* @return 是否允许
*/
@Override
public Boolean isAllowed(String key, Long frequency, Long time, TimeUnit unit) {
// 获取 key 对应的 tokens
// smembers(key)
Set<String> tokens = redisTemplate.opsForSet().members(key);
// 删除 tokens 里面所有过期的 token
// srem(token)
int expiredTokensCount = 0;
for (String token : tokens) {
if (redisTemplate.opsForValue().get(token) == null) {
redisTemplate.opsForSet().remove(key, token);
expiredTokensCount++;
}
}
// 如果未过期的 token 数量大于等于 frequency,return false
int unexpiredTokensCount = tokens.size() - expiredTokensCount;
if (unexpiredTokensCount >= frequency) {
return false;
}
// 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间
// set(token, "", time, unit)
// sadd(key, token)
// expire(key, time, unit)
String token = UUID.randomUUID().toString();
redisTemplate.opsForValue().set(token, "", time, unit);
redisTemplate.opsForSet().add(key, token);
redisTemplate.expire(key, time, unit);
return true;
}
实现方式2:Lua 脚本
由于该功能实现只是使用到了 Redis 的命令,因此我们可以通过 Lua 脚本来保证原子性。这里只是把上面的代码使用 Lua 实现,UUID 则通过 Redis 的 INCR 命令实现。
--[[
KEYS[1] 需要限频的 key
ARGV[1] 频率
ARGV[2] 限频时间
--]]
-- 获取 key 对应的 tokens
local tokens = redis.call('SMEMBERS', KEYS[1])
-- 删除 tokens 里面所有过期的 token
local expiredTokensCount = 0
for i = 1, #tokens do
if not redis.call('GET', tokens[i]) then
redis.call('SREM', KEYS[1], tokens[i])
expiredTokensCount = expiredTokensCount + 1
end
end
-- 如果未过期的 token 数量大于等于 frequency,return false
local unexpiredTokensCount = #tokens - expiredTokensCount;
if unexpiredTokensCount >= tonumber(ARGV[1]) then
return false
end
-- 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间
local token = redis.call('INCR', 'frequency-limit:token:increment-id')
redis.call('SET', token, '', 'PX', ARGV[2])
redis.call('SADD', KEYS[1], token)
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return true
在 Spring Boot 中使用该脚本
定义脚本的 Bean
/**
* 限频 Lua 脚本 classpath 路径
*/
private static final String FREQUENCY_LIMIT_LUA_SCRIPT_CLASS_PATH = "/redis/lua/FrequencyLimit.lua";
/**
* 限频 Redis 脚本 Bean
*/
@Bean("frequencyLimitRedisScript")
public RedisScript<Boolean> frequencyLimitRedisScript() {
DefaultRedisScript<Boolean> frequencyLimitRedisScript = new DefaultRedisScript<>();
frequencyLimitRedisScript.setResultType(Boolean.class);
frequencyLimitRedisScript.setScriptSource(new ResourceScriptSource(
new ClassPathResource(FREQUENCY_LIMIT_LUA_SCRIPT_CLASS_PATH)));
return frequencyLimitRedisScript;
}
使用 RedisTemplate 执行脚本
@Override
public Boolean isAllowed(String key, Long frequency, Long time, TimeUnit unit) {
return redisTemplate.execute(frequencyLimitRedisScript, Collections.singletonList(key), frequency.toString(),
String.valueOf(TimeoutUtils.toMillis(time, unit)));
}
固定时间点刷新实现
实现比较简单,因为设置了过期时间,因此到达时间点时,Redis 会自动清除该 key,从而实现固定时间点刷新。
--[[
KEYS[1] 需要限频的 key
ARGV[1] 频率
ARGV[2] 过期时间
--]]
-- 若频率为0直接拒绝请求
if tonumber(ARGV[1]) == 0 then
return false
end
-- 获取 key 对应的 token 数量
local tokenNumbers = redis.call('GET', KEYS[1])
-- 如果对应 key 存在,且数量大于等于 frequency,直接返回 false
if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[1]) then
return false
end
-- 增加token数量,设置过期时间
redis.call('INCR', KEYS[1])
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return true
扩展1:封装成限频器
直接使用 lua 脚本还是比较麻烦,需要配合 RedisTemplate,而且固定时间点刷新限频还需要自己处理过期时间,因此我们把他们封装成限频器,方便使用。
范围刷新限频器
这里添加了前缀 RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX 防止与项目中其他 Redis key 冲突。
/**
* 范围刷新限频器
*
* @author xhsf
* @create 2020/12/18 15:41
*/
public class RangeRefreshFrequencyLimiter implements FrequencyLimiter{
/**
* 范围刷新的限频 Redis Key 前缀
*/
private static final String RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX = "frequency-limit:range-fresh:";
/**
* 范围刷新的限频 Lua 脚本
*/
private static final String RANGE_REFRESH_FREQUENCY_LIMIT_LUA =
"-- 获取 key 对应的 tokens\n" +
"local tokens = redis.call('SMEMBERS', KEYS[1])\n" +
"\n" +
"-- 删除 tokens 里面所有过期的 token\n" +
"local expiredTokensCount = 0\n" +
"for i = 1, #tokens do\n" +
" if not redis.call('GET', tokens[i]) then\n" +
" redis.call('SREM', KEYS[1], tokens[i])\n" +
" expiredTokensCount = expiredTokensCount + 1\n" +
" end\n" +
"end\n" +
"\n" +
"-- 如果未过期的 token 数量大于等于 frequency,return false\n" +
"local unexpiredTokensCount = #tokens - expiredTokensCount;\n" +
"if unexpiredTokensCount >= tonumber(ARGV[1]) then\n" +
" return false\n" +
"end\n" +
"\n" +
"-- 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间\n" +
"local token = '" + RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + "' .. redis.call('INCR', 'frequency-limit:token:increment-id')\n" +
"redis.call('SET', token, '', 'PX', ARGV[2])\n" +
"redis.call('SADD', KEYS[1], token)\n" +
"redis.call('PEXPIRE', KEYS[1], ARGV[2])\n" +
"return true";
/**
* StringRedisTemplate
*/
private final StringRedisTemplate stringRedisTemplate;
/**
* 范围刷新的限频脚本
*/
private final RedisScript<Boolean> rangeRefreshFrequencyLimitRedisScript;
public RangeRefreshFrequencyLimiter(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
// 范围刷新的限频脚本
DefaultRedisScript<Boolean> rangeRefreshFrequencyLimitRedisScript = new DefaultRedisScript<>();
rangeRefreshFrequencyLimitRedisScript.setResultType(Boolean.class);
rangeRefreshFrequencyLimitRedisScript.setScriptText(RANGE_REFRESH_FREQUENCY_LIMIT_LUA);
this.rangeRefreshFrequencyLimitRedisScript = rangeRefreshFrequencyLimitRedisScript;
}
/**
* 查询一个键是否被允许操作,范围刷新
*
* 如短信验证码服务使用 isAllowed("15333333333", 10, 1, TimeUnit.DAYS); 表示一天只能发送10次短信验证码
* 如短信验证码服务使用 isAllowed("15333333333", 1, 1, TimeUnit.MINUTES); 表示一分钟只能发送1次短信验证码
*
* @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
* @param frequency 频率
* @param time 限频时间
* @param unit 时间单位
* @return 是否允许
*/
public boolean isAllowed(String key, long frequency, long time, TimeUnit unit) {
key = RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + key;
return stringRedisTemplate.execute(rangeRefreshFrequencyLimitRedisScript, Collections.singletonList(key),
String.valueOf(frequency), String.valueOf(TimeoutUtils.toMillis(time, unit)));
}
}
固定时间点刷新限频器
这里添加了 Redis key 前缀同时把时间点通过 cron 表达式表示,使用时只需要编写 cron 表达式即可。
cron 表达式可以参考文章cron表达式详解和根据CronSequenceGenerator计算cron表达式的时间。
/**
* 固定时间点刷新限频器
*
* @author xhsf
* @create 2020/12/18 15:41
*/
public class FixedPointRefreshFrequencyLimiter implements FrequencyLimiter{
/**
* 固定时间点刷新的限频 Redis Key 前缀
*/
private static final String FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX =
"frequency-limit:fixed-point-refresh:";
/**
* 固定时间点刷新的限频 lua 脚本
*/
private static final String FIXED_POINT_REFRESH_FREQUENCY_LIMIT_LUA =
"-- 若频率为0直接拒绝请求\n" +
"if tonumber(ARGV[1]) == 0 then\n" +
" return false\n" +
"end\n" +
"\n" +
"-- 获取 key 对应的 token 数量\n" +
"local tokenNumbers = redis.call('GET', KEYS[1])\n" +
"\n" +
"-- 如果对应 key 存在,且数量大于等于 frequency,直接返回 false\n" +
"if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[1]) then\n" +
" return false\n" +
"end\n" +
"\n" +
"-- 增加token数量,设置过期时间\n" +
"redis.call('INCR', KEYS[1])\n" +
"redis.call('PEXPIRE', KEYS[1], ARGV[2])\n" +
"return true";
/**
* StringRedisTemplate
*/
private final StringRedisTemplate stringRedisTemplate;
/**
* cron 表达式的 CronSequenceGenerator 的缓存
*/
private final Map<String, CronSequenceGenerator> cronSequenceGeneratorMap = new HashMap<>();
/**
* 固定时间点刷新的限频脚本
*/
private final RedisScript<Boolean> fixedPointRefreshFrequencyLimitRedisScript;
public FixedPointRefreshFrequencyLimiter(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
// 固定时间点刷新的限频脚本
DefaultRedisScript<Boolean> fixedPointRefreshFrequencyLimitRedisScript = new DefaultRedisScript<>();
fixedPointRefreshFrequencyLimitRedisScript.setResultType(Boolean.class);
fixedPointRefreshFrequencyLimitRedisScript.setScriptText(FIXED_POINT_REFRESH_FREQUENCY_LIMIT_LUA);
this.fixedPointRefreshFrequencyLimitRedisScript = fixedPointRefreshFrequencyLimitRedisScript;
}
/**
* 查询一个键是否被允许操作,固定时间点刷新
*
* 如短信验证码服务使用 isAllowed("15333333333", 10, 1, TimeUnit.DAYS); 表示一天只能发送10次短信验证码
* 如短信验证码服务使用 isAllowed("15333333333", 1, 1, TimeUnit.MINUTES); 表示一分钟只能发送1次短信验证码
*
* @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
* @param frequency 频率
* @param cron cron表达式
* @return 是否允许
*/
public boolean isAllowed(String key, long frequency, String cron) {
CronSequenceGenerator cronSequenceGenerator = cronSequenceGeneratorMap.getOrDefault(
cron, cronSequenceGeneratorMap.put(cron, new CronSequenceGenerator(cron)));
Date now = new Date();
Date next = cronSequenceGenerator.next(now);
long timeout = next.getTime() - now.getTime();
key = FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + key;
return stringRedisTemplate.execute(fixedPointRefreshFrequencyLimitRedisScript, Collections.singletonList(key),
String.valueOf(frequency), String.valueOf(timeout));
}
}
使用 FrequencyLimiterManager 管理多个限频器
/**
* 限频器管理器
*
* @author xhsf
* @create 2020/12/18 15:41
*/
public class FrequencyLimiterManager implements FrequencyLimiter {
private final FixedPointRefreshFrequencyLimiter fixedPointRefreshFrequencyLimiter;
private final RangeRefreshFrequencyLimiter rangeRefreshFrequencyLimiter;
public FrequencyLimiterManager(StringRedisTemplate stringRedisTemplate) {
this.fixedPointRefreshFrequencyLimiter = new FixedPointRefreshFrequencyLimiter(stringRedisTemplate);
this.rangeRefreshFrequencyLimiter = new RangeRefreshFrequencyLimiter(stringRedisTemplate);
}
/**
* 查询一个键是否被允许操作,范围刷新
*
* 如短信验证码服务使用 isAllowed("15333333333", 10, 1, TimeUnit.DAYS); 表示一天只能发送10次短信验证码
* 如短信验证码服务使用 isAllowed("15333333333", 1, 1, TimeUnit.MINUTES); 表示一分钟只能发送1次短信验证码
*
* @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
* @param frequency 频率
* @param time 限频时间
* @param unit 时间单位
* @return 是否允许
*/
public boolean isAllowed(String key, long frequency, long time, TimeUnit unit) {
return rangeRefreshFrequencyLimiter.isAllowed(key, frequency, time, unit);
}
/**
* 查询一个键是否被允许操作,固定时间点刷新
*
* 如短信验证码服务使用 isAllowed("15333333333", 10, 1, TimeUnit.DAYS); 表示一天只能发送10次短信验证码
* 如短信验证码服务使用 isAllowed("15333333333", 1, 1, TimeUnit.MINUTES); 表示一分钟只能发送1次短信验证码
*
* @param key 需要限频的键,key的格式最好是 {服务名}:{具体业务名}:{唯一标识符},如 sms:auth-code:15333333333
* @param frequency 频率
* @param cron cron表达式
* @return 是否允许
*/
public boolean isAllowed(String key, long frequency, String cron) {
return fixedPointRefreshFrequencyLimiter.isAllowed(key, frequency, cron);
}
}
使用示例
// 每天只能发送10次,每天晚上00:00:00刷新
frequencyLimiter.isAllowed("sms:auth-code:15333333333", 10, "0 0 0 * * *");
// 一分钟只能发送一次
frequencyLimiter.isAllowed("sms:auth-code:15333333333", 1, 1, TimeUnit.MINUTES);
扩展2:封装成注解形式
使用注解可以降低对业务代码的入侵,提高易用性,这里实现了 EL 表达式,同时支持多个注解,可以满足大部分需求。
EL 表达式及解析器
可以参考文章SpEL你感兴趣的实现原理浅析spring-expression。
注解 FrequencyLimit
/**
* 描述: 限频注解,由 {@link FrequencyLimitAspect} 实现
*
* @author xhsf
* @create 2020-12-18 21:16
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(FrequencyLimits.class)
public @interface FrequencyLimit {
/**
* 限频的 key,支持 EL 表达式,如#{#user.phone}
* @see #parameters() 配合该参数,支持占位符填充
* 如 value = "user:{0}:phone", parameters="#{#user.id}" 会转换成 value = "user:#{#user.id}:phone"
*/
String value();
/**
* 填充到占位符的参数
*/
String[] parameters() default {};
/**
* cron 表达式
* 1.cron一共有7位,但是最后一位是年(1970-2099),可以留空,所以我们可以写6位,按顺序依次为
*
* 秒(0~59)
* 分钟(0~59)
* 小时(0~23)
* 天(月)(0~31,但是你需要考虑你月的天数)
* 月(0~11)
* 星期(1~7 1=SUN,MON,TUE,WED,THU,FRI,SAT)
*
* cron的一些特殊符号
* (*)星号:
* 可以理解为每的意思,每秒,每分,每天,每月,每年
* (?)问号:
* 问号只能出现在日期和星期这两个位置,表示这个位置的值不确定,每天3点执行,所以第六位星期的位置,我们是不需要关注的,
* 就是不确定的值。同时:日期和星期是两个相互排斥的元素,通过问号来表明不指定值。比如,1月10日,比如是星期1,
* 如果在星期的位置是另指定星期二,就前后冲突矛盾了。
* (-)减号:
* 表达一个范围,如在小时字段中使用“10-12”,则表示从10到12点,即10,11,12
* (,)逗号:
* 表达一个列表值,如在星期字段中使用“1,2,4”,则表示星期一,星期二,星期四
* (/)斜杠:如:x/y,x是开始值,y是步长,比如在第一位(秒) 0/15就是,从0秒开始,每15秒,最后就是0,15,30,45,60
* 另: *\/y,等同于0/y
*
* @see org.springframework.scheduling.support.CronSequenceGenerator
*/
String cron() default "";
/**
* 频率,默认0
*/
long frequency() default 0;
/**
* 限频时间,默认0
*/
long time() default 0;
/**
* 时间单位,默认为秒
*/
TimeUnit unit() default TimeUnit.SECONDS;
/**
* 当获取 token 失败时的错误信息,支持 EL 表达式
*/
String errorMessage() default "Too many request.";
}
注解 FrequencyLimits
/**
* 描述: 限频注解数组
*
* @author xhsf
* @create 2020-12-18 21:16
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface FrequencyLimits {
/**
* 限频的注解数组
*/
FrequencyLimit[] value();
}
FrequencyLimitAspect 切面实现类
/**
* 描述:限频切面,配合 {@link FrequencyLimit} 可以便捷的使用限频
*
* @author xhsf
* @create 2020/12/18 21:10
*/
@Aspect
public class FrequencyLimitAspect {
private final FrequencyLimiter frequencyLimiter;
/**
* EL 表达式解析器
*/
private static final ExpressionParser expressionParser = new SpelExpressionParser();
public FrequencyLimitAspect(FrequencyLimiter frequencyLimiter) {
this.frequencyLimiter = frequencyLimiter;
}
/**
* 限制请求频率
*
* @errorCode TooManyRequests: 请求太频繁
*
* @param joinPoint ProceedingJoinPoint
* @return Object
*/
@Around("@annotation(FrequencyLimits) || @annotation(FrequencyLimit)")
public Object handler(ProceedingJoinPoint joinPoint) throws Throwable {
FrequencyLimit[] frequencyLimits = getFrequencyLimits(joinPoint);
for (FrequencyLimit frequencyLimit : frequencyLimits) {
boolean isAllowed = isAllowed(joinPoint, frequencyLimit);
if (!isAllowed) {
String errorMessageExpression = frequencyLimit.errorMessage();
String errorMessage = getExpressionValue(errorMessageExpression, joinPoint);
return Result.fail(ErrorCodeEnum.TOO_MANY_REQUESTS, errorMessage);
}
}
// 执行业务逻辑
return joinPoint.proceed();
}
/**
* 是否允许
*
* @param joinPoint ProceedingJoinPoint
* @param frequencyLimit FrequencyLimit
* @return 是否允许
*/
private boolean isAllowed(ProceedingJoinPoint joinPoint, FrequencyLimit frequencyLimit) {
// 获取键
String key = getKey(joinPoint, frequencyLimit);
// 是否允许
if (frequencyLimit.cron().equals("")) {
return frequencyLimiter.isAllowed(
key, frequencyLimit.frequency(), frequencyLimit.time(),frequencyLimit.unit());
}
return frequencyLimiter.isAllowed(key, frequencyLimit.frequency(), frequencyLimit.cron());
}
/**
* 获取限频注解列表
*
* @param joinPoint ProceedingJoinPoint
* @return FrequencyLimit[]
*/
private FrequencyLimit[] getFrequencyLimits(ProceedingJoinPoint joinPoint) {
Method method;
try {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
method = joinPoint.getTarget()
.getClass()
.getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
return method.getAnnotationsByType(FrequencyLimit.class);
} catch (NoSuchMethodException ignored) {
}
return new FrequencyLimit[]{};
}
/**
* 获取限频的键
*
* @param joinPoint ProceedingJoinPoint
* @param frequencyLimit FrequencyLimit
* @return 限频键
*/
private String getKey(ProceedingJoinPoint joinPoint, FrequencyLimit frequencyLimit) {
// 获得键表达式模式
String keyExpressionPattern = frequencyLimit.value();
// 获取参数
Object[] parameters = frequencyLimit.parameters();
// 构造键表达式
String keyExpression = keyExpressionPattern;
// 若有参数则需要填充参数
if (parameters.length > 0) {
keyExpression = MessageFormat.format(keyExpressionPattern, parameters);
}
// 获取键
return getExpressionValue(keyExpression, joinPoint);
}
/**
* 获取表达式的值
*
* @param expression 表达式
* @param joinPoint ProceedingJoinPoint
* @return value
*/
private String getExpressionValue(String expression, ProceedingJoinPoint joinPoint) {
// 获得方法参数的 Map
String[] parameterNames = ((CodeSignature) joinPoint.getSignature()).getParameterNames();
Object[] parameterValues = joinPoint.getArgs();
Map<String, Object> parameterMap = new HashMap<>();
for (int i = 0; i < parameterNames.length; i++) {
parameterMap.put(parameterNames[i], parameterValues[i]);
}
// 解析 EL 表达式
return getExpressionValue(expression, parameterMap);
}
/**
* 获取 EL 表达式的值
*
* @param elExpression EL 表达式
* @param parameterMap 参数名-值 Map
* @return 表达式的值
*/
private String getExpressionValue(String elExpression, Map<String, Object> parameterMap) {
Expression expression = expressionParser.parseExpression(elExpression, new TemplateParserContext());
EvaluationContext context = new StandardEvaluationContext();
for (Map.Entry<String, Object> entry : parameterMap.entrySet()) {
context.setVariable(entry.getKey(), entry.getValue());
}
return expression.getValue(context, String.class);
}
}
使用注解示例
这里限制邮件验证码一分钟1次,每小时5次,每小时0分0秒时刷新。
@FrequencyLimit(value = "email:auth-code:{0}", parameters = "#{#createAndSendEmailAuthCodePO.email}", frequency = 1, time = 60)
@FrequencyLimit(value = "email:auth-code:{0}", parameters = "#{#createAndSendEmailAuthCodePO.email}", frequency = 5, cron = "0 0 0/1 * * ?")
@Override
public Result<Void> createAndSendEmailAuthCode(CreateAndSendEmailAuthCodePO createAndSendEmailAuthCodePO) {
// 业务代码
}
记得注册切面
/**
* 限频器
*
* @param stringRedisTemplate StringRedisTemplate
* @return FrequencyLimiter
*/
@Bean
public FrequencyLimiter frequencyLimiter(StringRedisTemplate stringRedisTemplate) {
return new FrequencyLimiter(stringRedisTemplate);
}
/**
* 限频切面
*
* @param frequencyLimiter FrequencyLimiter
* @return FrequencyLimitAspect
*/
@Bean
public FrequencyLimitAspect frequencyLimitAspect(FrequencyLimiter frequencyLimiter) {
return new FrequencyLimitAspect(frequencyLimiter);
}
扩展3:多次限制下会出现的问题
如果使用多次限频器进行限频,比如每分钟1次,一天10次,这时会出现一个问题:如果第一个限制通过,第二个限制没有通过,按照上面的实现,第一个限制不会被自动释放。这里第一个限制只有1分钟还好,用户只需要等待1分钟。但是如果这个限制是1个小时,那么用户必须一个小时后才可能再次尝试获取第二个限制。
因此,我们需要在多次限制的情况下,如果一个限制没有通过,需要把前面已经通过的限制给释放掉。
解决办法
我们可以把前面每次通过的限制都记录起来,然后如果某个限制没有通过,则释放掉已经通过的限制。
通过 Lua 脚本解决
虽然该 Lua 脚本比较复杂,但主要就是把前面两个 Lua 脚本整合起来,并添加失败情况下的清除代码。这里使用 tokenMap 记录已经被允许的 key: token 键值对,若出现不被允许的情况,则把 tokenMap 记录的 key: token 键值对释放掉。
--[[
KEYS[i] 需要限频的 key
ARGV[i] 频率
ARGV[#KEYS + i] 限频时间
ARGV[#KEYS * 2 + i] 限频类型
--]]
-- 记录已经获取成功的 key: token 键值对
local tokenMap = {}
-- 循环获取每个 token
for i = 1, #KEYS do
if ARGV[#KEYS * 2 + i] == 'FIXED_POINT_REFRESH' then
-- 若频率为0直接拒绝请求
if tonumber(ARGV[i]) == 0 then
break
end
-- 获取 key 对应的 token 数量
local tokenNumbers = redis.call('GET', KEYS[i])
-- 如果对应 key 存在,且数量大于等于 frequency,直接 break
if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[i]) then
break
end
-- 增加token数量,设置过期时间
redis.call('INCR', KEYS[i])
redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])
tokenMap[KEYS[i]] = ''
else
-- 获取 key 对应的 tokens
local tokens = redis.call('SMEMBERS', KEYS[i])
-- 删除 tokens 里面所有过期的 token
local expiredTokensCount = 0
for j = 1, #tokens do
if not redis.call('GET', tokens[j]) then
redis.call('SREM', KEYS[i], tokens[j])
expiredTokensCount = expiredTokensCount + 1
end
end
-- 如果未过期的 token 数量大于等于 frequency,直接 break
local unexpiredTokensCount = #tokens - expiredTokensCount
if unexpiredTokensCount >= tonumber(ARGV[i]) then
break
end
-- 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间
local token = redis.call('INCR', 'frequency-limit:token:increment-id')
redis.call('SET', token, '', 'PX', ARGV[#KEYS + i])
redis.call('SADD', KEYS[i], token)
redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])
tokenMap[KEYS[i]] = token
end
end
-- 获取 tokenMap 的大小
local tokenMapSize = 0
for key, token in pairs(tokenMap) do
tokenMapSize = tokenMapSize + 1
end
-- 判断是否获取所有 token 都成功,若失败需要释放已经获取的 token
if tokenMapSize < #KEYS then
for key, token in pairs(tokenMap) do
if token == '' then
redis.call('INCRBY', key, -1)
else
redis.call('SREM', key, token)
redis.call('DEL', token)
end
end
return false
end
return true
封装成限频器
上面的 Lua 脚本如果直接使用会特别麻烦,因此我们封装成限频器。同理,我们还是会在每个在 Redis 里生成的 key 前面添加前缀,防止冲突。
/**
* 可获取多个 key 的限频器
*
* @author xhsf
* @create 2020/12/18 15:41
*/
public class RepeatableFrequencyLimiter implements FrequencyLimiter {
/**
* 固定时间点刷新的限频 Redis Key 前缀
*/
private static final String FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX =
"frequency-limit:fixed-point-refresh:";
/**
* 范围刷新的限频 Redis Key 前缀
*/
private static final String RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX = "frequency-limit:range-fresh:";
/**
* 可同时获取多个 key 的限频 Lua 脚本
*/
private static final String REPEATABLE_FREQUENCY_LIMIT_LUA =
"--[[\n" +
"KEYS[i] 需要限频的 key\n" +
"ARGV[i] 频率\n" +
"ARGV[#KEYS + i] 限频时间\n" +
"ARGV[#KEYS * 2 + i] 限频类型\n" +
"--]]\n" +
"\n" +
"-- 记录已经获取成功的 key: token 键值对\n" +
"local tokenMap = {}\n" +
"\n" +
"-- 循环获取每个 token\n" +
"for i = 1, #KEYS do\n" +
" if ARGV[#KEYS * 2 + i] == 'FIXED_POINT_REFRESH' then\n" +
" -- 若频率为0直接拒绝请求\n" +
" if tonumber(ARGV[i]) == 0 then\n" +
" break\n" +
" end\n" +
"\n" +
" -- 获取 key 对应的 token 数量\n" +
" local tokenNumbers = redis.call('GET', KEYS[i])\n" +
"\n" +
" -- 如果对应 key 存在,且数量大于等于 frequency,直接 break\n" +
" if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[i]) then\n" +
" break\n" +
" end\n" +
"\n" +
" -- 增加token数量,设置过期时间\n" +
" redis.call('INCR', KEYS[i])\n" +
" redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])\n" +
" tokenMap[KEYS[i]] = ''\n" +
" else\n" +
" -- 获取 key 对应的 tokens\n" +
" local tokens = redis.call('SMEMBERS', KEYS[i])\n" +
"\n" +
" -- 删除 tokens 里面所有过期的 token\n" +
" local expiredTokensCount = 0\n" +
" for j = 1, #tokens do\n" +
" if not redis.call('GET', tokens[j]) then\n" +
" redis.call('SREM', KEYS[i], tokens[j])\n" +
" expiredTokensCount = expiredTokensCount + 1\n" +
" end\n" +
" end\n" +
"\n" +
" -- 如果未过期的 token 数量大于等于 frequency,直接 break\n" +
" local unexpiredTokensCount = #tokens - expiredTokensCount\n" +
" if unexpiredTokensCount >= tonumber(ARGV[i]) then\n" +
" break\n" +
" end\n" +
"\n" +
" -- 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间\n" +
" local token = '" + RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + "' .. redis.call('INCR', 'frequency-limit:token:increment-id')\n" +
" redis.call('SET', token, '', 'PX', ARGV[#KEYS + i])\n" +
" redis.call('SADD', KEYS[i], token)\n" +
" redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])\n" +
" tokenMap[KEYS[i]] = token\n" +
" end\n" +
"end\n" +
"\n" +
"-- 获取 tokenMap 的大小\n" +
"local tokenMapSize = 0\n" +
"for key, token in pairs(tokenMap) do\n" +
"\ttokenMapSize = tokenMapSize + 1\n" +
"end\n" +
"\n" +
"-- 判断是否获取所有 token 都成功,若失败需要释放已经获取的 token\n" +
"if tokenMapSize < #KEYS then\n" +
" for key, token in pairs(tokenMap) do\n" +
" if token == '' then\n" +
" redis.call('INCRBY', key, -1)\n" +
" else\n" +
" redis.call('SREM', key, token)\n" +
" redis.call('DEL', token)\n" +
" end\n" +
" end\n" +
" return false\n" +
"end\n" +
"\n" +
"return true";
/**
* StringRedisTemplate
*/
private final StringRedisTemplate stringRedisTemplate;
/**
* 可获取多个 key 的限频脚本
*/
private final RedisScript<Boolean> repeatableFrequencyLimitRedisScript;
public RepeatableFrequencyLimiter(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
// 可获取多个 key 的限频脚本
DefaultRedisScript<Boolean> repeatableFrequencyLimitRedisScript = new DefaultRedisScript<>();
repeatableFrequencyLimitRedisScript.setResultType(Boolean.class);
repeatableFrequencyLimitRedisScript.setScriptText(REPEATABLE_FREQUENCY_LIMIT_LUA);
this.repeatableFrequencyLimitRedisScript = repeatableFrequencyLimitRedisScript;
}
/**
* 查询多个键是否被允许操作
*
* 只要其中一个不被允许,就会失败,并释放已经获取的 tokens
*
* @param frequencyLimiterTypes 限频类型
* @param keys 需要限频的键
* @param frequencies 频率
* @param timeouts 过期时间
* @return 是否允许
*/
public boolean isAllowed(FrequencyLimiterType[] frequencyLimiterTypes, List<String> keys, long[] frequencies,
long[] timeouts) {
String[] args = new String[frequencyLimiterTypes.length * 3];
for (int i = 0; i < frequencies.length; i++) {
args[i] = String.valueOf(frequencies[i]);
}
for (int i = 0; i < timeouts.length; i++) {
args[frequencyLimiterTypes.length + i] = String.valueOf(timeouts[i]);
}
for (int i = 0; i < frequencyLimiterTypes.length; i++) {
args[frequencyLimiterTypes.length * 2 + i] = frequencyLimiterTypes[i].name();
}
for (int i = 0; i < keys.size(); i++) {
if (frequencyLimiterTypes[i] == FrequencyLimiterType.FIXED_POINT_REFRESH) {
keys.set(i, FIXED_POINT_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + timeouts[i] + ":" + keys.get(i));
} else {
keys.set(i, RANGE_REFRESH_FREQUENCY_LIMIT_REDIS_KEY_PREFIX + timeouts[i] + ":" + keys.get(i));
}
}
return stringRedisTemplate.execute(repeatableFrequencyLimitRedisScript, keys, args);
}
}
封装成切面
由于上面的限频器的参数都是面向 Redis Lua 脚本的,因此使用起来和前面两个限频器相比还是很繁琐,因此最好使用切面再次降低使用的复杂度。这里的切面实现与上面的类似,只是把 FrequencyLimit 数组的参数转换格式,封装成数组,然后调用限频器的 isAllowed() 方法。
/**
* 描述:限频切面,配合 {@link FrequencyLimit} 可以便捷的使用限频
*
* @author xhsf
* @create 2020/12/18 21:10
*/
@Aspect
public class FrequencyLimitAspect {
/**
* 限频器
*/
private final FrequencyLimiter frequencyLimiter;
/**
* EL 表达式解析器
*/
private static final ExpressionParser expressionParser = new SpelExpressionParser();
/**
* cron 表达式的 CronSequenceGenerator 的缓存
*/
private final Map<String, CronSequenceGenerator> cronSequenceGeneratorMap = new HashMap<>();
public FrequencyLimitAspect(FrequencyLimiter frequencyLimiter) {
this.frequencyLimiter = frequencyLimiter;
}
/**
* 限制请求频率
*
* @errorCode TooManyRequests: 请求太频繁
*
* @param joinPoint ProceedingJoinPoint
* @return Object
*/
@Around("@annotation(FrequencyLimits) || @annotation(FrequencyLimit)")
public Object handler(ProceedingJoinPoint joinPoint) throws Throwable {
FrequencyLimit[] frequencyLimits = getFrequencyLimits(joinPoint);
if (!isAllowed(joinPoint, frequencyLimits)) {
String errorMessageExpression = frequencyLimits[0].errorMessage();
String errorMessage = getExpressionValue(errorMessageExpression, joinPoint);
return Result.fail(ErrorCodeEnum.TOO_MANY_REQUESTS, errorMessage);
}
// 执行业务逻辑
return joinPoint.proceed();
}
/**
* 是否允许
*
* @param joinPoint ProceedingJoinPoint
* @param frequencyLimits FrequencyLimit[]
* @return 是否允许
*/
private boolean isAllowed(ProceedingJoinPoint joinPoint, FrequencyLimit[] frequencyLimits) {
FrequencyLimiterType[] frequencyLimiterTypes = new FrequencyLimiterType[frequencyLimits.length];
List<String> keys = new ArrayList<>(frequencyLimits.length);
long[] frequencies = new long[frequencyLimits.length];
long[] timeouts = new long[frequencyLimits.length];
for (int i = 0; i < frequencyLimits.length; i++) {
if (frequencyLimits[i].cron().equals("")) {
frequencyLimiterTypes[i] = FrequencyLimiterType.RANGE_REFRESH;
timeouts[i] = TimeoutUtils.toMillis(frequencyLimits[i].time(), frequencyLimits[i].unit());
} else {
frequencyLimiterTypes[i] = FrequencyLimiterType.FIXED_POINT_REFRESH;
String cron = frequencyLimits[i].cron();
CronSequenceGenerator cronSequenceGenerator = cronSequenceGeneratorMap.getOrDefault(
cron, cronSequenceGeneratorMap.put(cron, new CronSequenceGenerator(cron)));
Date now = new Date();
Date next = cronSequenceGenerator.next(now);
timeouts[i] = next.getTime() - now.getTime();
}
keys.add(getKey(joinPoint, frequencyLimits[i]));
frequencies[i] = frequencyLimits[i].frequency();
}
return frequencyLimiter.isAllowed(frequencyLimiterTypes, keys, frequencies, timeouts);
}
/**
* 获取限频注解列表
*
* @param joinPoint ProceedingJoinPoint
* @return FrequencyLimit[]
*/
private FrequencyLimit[] getFrequencyLimits(ProceedingJoinPoint joinPoint) {
Method method;
try {
MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
method = joinPoint.getTarget()
.getClass()
.getMethod(methodSignature.getName(), methodSignature.getParameterTypes());
return method.getAnnotationsByType(FrequencyLimit.class);
} catch (NoSuchMethodException ignored) {
}
return new FrequencyLimit[]{};
}
/**
* 获取限频的键
*
* @param joinPoint ProceedingJoinPoint
* @param frequencyLimit FrequencyLimit
* @return 限频键
*/
private String getKey(ProceedingJoinPoint joinPoint, FrequencyLimit frequencyLimit) {
// 获得键表达式模式
String keyExpressionPattern = frequencyLimit.value();
// 获取参数
Object[] parameters = frequencyLimit.parameters();
// 构造键表达式
String keyExpression = keyExpressionPattern;
// 若有参数则需要填充参数
if (parameters.length > 0) {
keyExpression = MessageFormat.format(keyExpressionPattern, parameters);
}
// 获取键
return getExpressionValue(keyExpression, joinPoint);
}
/**
* 获取表达式的值
*
* @param expression 表达式
* @param joinPoint ProceedingJoinPoint
* @return value
*/
private String getExpressionValue(String expression, ProceedingJoinPoint joinPoint) {
// 获得方法参数的 Map
String[] parameterNames = ((CodeSignature) joinPoint.getSignature()).getParameterNames();
Object[] parameterValues = joinPoint.getArgs();
Map<String, Object> parameterMap = new HashMap<>();
for (int i = 0; i < parameterNames.length; i++) {
parameterMap.put(parameterNames[i], parameterValues[i]);
}
// 解析 EL 表达式
return getExpressionValue(expression, parameterMap);
}
/**
* 获取 EL 表达式的值
*
* @param elExpression EL 表达式
* @param parameterMap 参数名-值 Map
* @return 表达式的值
*/
private String getExpressionValue(String elExpression, Map<String, Object> parameterMap) {
Expression expression = expressionParser.parseExpression(elExpression, new TemplateParserContext());
EvaluationContext context = new StandardEvaluationContext();
for (Map.Entry<String, Object> entry : parameterMap.entrySet()) {
context.setVariable(entry.getKey(), entry.getValue());
}
return expression.getValue(context, String.class);
}
}
使用示例
与前面完全相同。
@FrequencyLimit(value = "email:auth-code:{0}", parameters = "#{#createAndSendEmailAuthCodePO.email}", frequency = 1, time = 60)
@FrequencyLimit(value = "email:auth-code:{0}", parameters = "#{#createAndSendEmailAuthCodePO.email}", frequency = 5, cron = "0 0 0/1 * * ?")
@Override
public Result<Void> createAndSendEmailAuthCode(CreateAndSendEmailAuthCodePO createAndSendEmailAuthCodePO) {
// 业务代码
}
扩展4:返回失败的 key
上面的实现方式还有一个问题,就是不知道具体是获取哪个 key 的 token 时失败了。我们可以让 Lua 脚本返回失败时的下标,成功时返回 -1,Java 代码里就可以通过下标获取具体是哪个限频操作失败了。
Lua 代码实现
只是把返回值修改成了 tokenMapSize 和 -1。
--[[
KEYS[i] 需要限频的 key
ARGV[i] 频率
ARGV[#KEYS + i] 限频时间
ARGV[#KEYS * 2 + i] 限频类型
--]]
-- 记录已经获取成功的 key: token 键值对
local tokenMap = {}
-- 循环获取每个 token
for i = 1, #KEYS do
if ARGV[#KEYS * 2 + i] == 'FIXED_POINT_REFRESH' then
-- 若频率为0直接拒绝请求
if tonumber(ARGV[i]) == 0 then
break
end
-- 获取 key 对应的 token 数量
local tokenNumbers = redis.call('GET', KEYS[i])
-- 如果对应 key 存在,且数量大于等于 frequency,直接 break
if tokenNumbers and tonumber(tokenNumbers) >= tonumber(ARGV[i]) then
break
end
-- 增加token数量,设置过期时间
redis.call('INCR', KEYS[i])
redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])
tokenMap[KEYS[i]] = ''
else
-- 获取 key 对应的 tokens
local tokens = redis.call('SMEMBERS', KEYS[i])
-- 删除 tokens 里面所有过期的 token
local expiredTokensCount = 0
for j = 1, #tokens do
if not redis.call('GET', tokens[j]) then
redis.call('SREM', KEYS[i], tokens[j])
expiredTokensCount = expiredTokensCount + 1
end
end
-- 如果未过期的 token 数量大于等于 frequency,直接 break
local unexpiredTokensCount = #tokens - expiredTokensCount
if unexpiredTokensCount >= tonumber(ARGV[i]) then
break
end
-- 生成唯一的 token,添加 token 到 redis 里和 key 对应的 set 里,并设置过期时间
local token = redis.call('INCR', 'frequency-limit:token:increment-id')
redis.call('SET', token, '', 'PX', ARGV[#KEYS + i])
redis.call('SADD', KEYS[i], token)
redis.call('PEXPIRE', KEYS[i], ARGV[#KEYS + i])
tokenMap[KEYS[i]] = token
end
end
-- 获取 tokenMap 的大小
local tokenMapSize = 0
for key, token in pairs(tokenMap) do
tokenMapSize = tokenMapSize + 1
end
-- 判断是否获取所有 token 都成功,若失败需要释放已经获取的 token
if tokenMapSize < #KEYS then
for key, token in pairs(tokenMap) do
if token == '' then
redis.call('INCRBY', key, -1)
else
redis.call('SREM', key, token)
redis.call('DEL', token)
end
end
return tokenMapSize
end
return -1
时间复杂度分析
范围刷新限频器
由于每次都需要执行 SMEMBERS 命令,因此时间复杂度是 O(N),N 为 Set 里面的元素数量,最坏情况下 N = frequency。因此若 frequency 很大,该方案速度可能会比较慢,适合 frequency 不是特别大的情况。
固定时间点刷新限频器
时间复杂度为O(1),只使用了 GET, INCR, PEXPIRE。