SpringBoot+Redis限流

4 阅读2分钟

定义限流注解

@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RedisLimitAnnotation {

    /**
     * key
     */
    String key() default "";


    /**
     * 一定时间内最多访问次数
     */
    int count();

    /**
     * 设计时间,比如 5秒内只能访问count次
     */
    int period();
}

配置类

@Component
public class RedisConfiguration {

    @Bean
    public DefaultRedisScript<Number> redisLuaScript() {
        DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("limit.lua")));
        redisScript.setResultType(Number.class);
        return redisScript;
    }

    @Bean("redisTemplate")
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);

        //设置value的序列化方式为JSOn
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        //设置key的序列化方式为String
        redisTemplate.setKeySerializer(new StringRedisSerializer());

        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.afterPropertiesSet();

        return redisTemplate;
    }

}

AOP拦截

@Aspect
@Configuration
public class LimitRestAspect {

    private static final Logger logger = LoggerFactory.getLogger(LimitRestAspect.class);

    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    @Autowired
    private DefaultRedisScript<Number> redisluaScript;


    @Pointcut(value = "@annotation(cn.ccb.demo.redis.annoation.RedisLimitAnnotation)")
    public void rateLimit() {

    }

    @Around("rateLimit()")
    public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        RedisLimitAnnotation rateLimit = method.getAnnotation(RedisLimitAnnotation.class);

        if (rateLimit != null) {
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            String ipAddress = getIpAddr(request);
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(ipAddress).append("-")
                    .append(targetClass.getName()).append("- ")
                    .append(method.getName()).append("-")
                    .append(rateLimit.key());
            List<String> keys = Collections.singletonList(stringBuffer.toString());

            //调用lua脚本,获取返回结果,这里即为请求的次数
            Number number = redisTemplate.execute(
                    redisluaScript,
                    keys,
                    rateLimit.count(),
                    rateLimit.period()
            );
            if (number != null && number.intValue() != 0 && number.intValue() <= rateLimit.count()) {
                logger.info("限流时间段内访问了第:{} 次", number.toString());
                return joinPoint.proceed();
            }
        } else {
            return joinPoint.proceed(); // 没有限流注解
        }

        throw new RuntimeException("访问频率过快,被限流了");
    }

    /**
     * 获取请求的IP方法
     * @param request
     * @return
     */
    private static String getIpAddr(HttpServletRequest request) {
        String ipAddress = null;
        try {
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getRemoteAddr();
            }
            // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
            if (ipAddress != null && ipAddress.length() > 15) {
                if (ipAddress.indexOf(",") > 0) {
                    ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
                }
            }
        } catch (Exception e) {
            ipAddress = "";
        }
        return ipAddress;
    }

}

lua脚本

在resources目录下新建limit.lua

local key = "rate.limit:" .. KEYS[1]

local limit = tonumber(ARGV[1])

local current = tonumber(redis.call('get', key) or "0")

if current + 1 > limit then
  return 0
else
   -- 没有超阈值,将当前访问数量+1,并设置过期时间(可根据自己的业务情况调整)
   redis.call("INCRBY", key,"1")
   redis.call("expire", key,ARGV[2])
   return current + 1
end

测试访问

@RestController
public class RedisController {
    /**
     * 10秒内只能访问一次
     */
    @GetMapping("/redis/limit")
    @RedisLimitAnnotation(key = "queryFromRedis", period = 10, count = 1)
    public String queryFromRedis() {
        return "success";
    }

}