Spring + Redis + Lua 脚本 + AOP 实现接口限流

692 阅读2分钟

Spring + Redis + Lua 脚本 + AOP 实现接口限流

在系统开发中,我们常常会遇见一些脑子不太好的人,为了防止他们刷尽我们的系统资源,有必要对这些人进行ip限制。

1.定义注解

/**
 * 自定义操作日志记录注解
 *
 * @author ht
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    /**
     * 限流key
     */
    String key() default "mb:rate_limit:";
​
    /**
     * 限流时间,单位秒
     */
    int time() default 60;
​
    /**
     * 限流次数
     */
    int count() default 100;
​
    /**
     * 限流类型
     */
    LimitType limitType() default LimitType.DEFAULT;
}

2.写Lua脚本

为什么要写?由于lua脚本的原子性,会一次性执行完,不被打断

我们在resources/lua下新增一个limit.lua文件

//获取传入注解中的 key count time
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
//从redis中用key获取 ,call方法则是获得执行结果
local current = redis.call('get', key)
//如果能获取到,就返回当前的执行结果
if current and tonumber(current) > count then
    return tonumber(current)
end
//将结果自增
current = redis.call('incr', key)
//如果结果为1 就设置过期时间
if tonumber(current) == 1 then
    redis.call('expire', key, time)
end
return tonumber(current)

3.Spring定义bean,引入lua脚本

路径大家自己修改即可

​
    @Bean
    public DefaultRedisScript<Long> limitScript() {
        DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
        redisScript.setResultType(Long.class);
        return redisScript;
    }

4.序列化Redis

 @Bean
    @SuppressWarnings("all")
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) throws UnknownHostException {
        // 自定义 String Object
        RedisTemplate<String, Object> template = new RedisTemplate();
        template.setConnectionFactory(redisConnectionFactory);
​
        // Json 序列化配置
        Jackson2JsonRedisSerializer<Object> objectJackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<Object>(Object.class);
        // ObjectMapper 转译
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        objectJackson2JsonRedisSerializer.setObjectMapper(objectMapper);
​
        // String 的序列化
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
​
        // key 采用String的序列化方式
        template.setKeySerializer(stringRedisSerializer);
        // hash 的key也采用 String 的序列化方式
        template.setHashKeySerializer(stringRedisSerializer);
        // value 序列化方式采用 jackson
        template.setValueSerializer(objectJackson2JsonRedisSerializer);
        // hash 的 value 采用 jackson
        template.setHashValueSerializer(objectJackson2JsonRedisSerializer);
        template.afterPropertiesSet();
​
        return template;
    }

5.定义Aspect

@Aspect
@Component
@Slf4j
public class RateLimiterAspect {
​
    @Resource
    private RedisTemplate<Object, Object> redisTemplate;
​
    @Autowired
    private RedisScript<Long> limitScript;
​
    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();
​
        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (number == null || number.intValue() > count) {
                throw new RateLimiterException(ErrorCode.FORBIDDEN_ERROR, "访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
    }
​
    //将ip,方法名,方法类型拼接成key
    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(IpUtils.getIpAddr(ServletUtils.getRequestAttributes().getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}
​

关于如何获取Ip,方法名,方法类型可以看我的上一篇文章

6.定义全局异常拦截器接收

@RestControllerAdvice
@Slf4j
public class GlobalExceptionHandler {
    @ExceptionHandler(RateLimiterException.class)
    public BaseResponse<?> rateLimiterExceptionHandler(RuntimeException e) {
        log.error("访问过于频繁,请稍后再试");
        return ResultUtils.error(ErrorCode.SYSTEM_ERROR, "你再刷试试");
    }
}

7.总结

我们知道限流的手段有很多,比如令牌桶算法,滑动窗口算法

感兴趣的同学可以看这篇文章

juejin.cn/post/714543…

当然,如果结合异步日志记录,效果会更佳噢~

可以看看我上篇文章juejin.cn/post/720172…