基于Redis的分布式限流器(RateLimiter

2,743 阅读2分钟

参考: github.com/redisson/re…

基于Redis的分布式限流器(RateLimiter)可以用来在分布式环境下现在请求方的调用频率。既适用于不同Redisson实例下的多线程限流,也适用于相同Redisson实例下的多线程限流。该算法不保证公平性。

1. 限流切面

@Aspect
public class RateLimitAspect implements Ordered {

    private static final Logger ratelimitLog = LoggerFactory.getLogger("ratelimit");

    @Resource
    private RedissonClient redissonClient;

    @Before("@annotation(limitAnnotation)")
    public void limitBeforeExecute(JoinPoint joinPoint,RateLimit limitAnnotation) {
        int limit = limitAnnotation.limit();
        int perSeconds = limitAnnotation.perSeconds();

        String limitKey = buildLimitKey(limitAnnotation, joinPoint);
        if (limitKey == null || limitKey.trim().length() == 0) {
            return;
        }

        //获取限流器
        RRateLimiter rateLimiter = redissonClient.getRateLimiter(limitKey);

        // 最大流速 = 每perSeconds秒钟产生limit个令牌
        rateLimiter.trySetRate(RateType.OVERALL, limit, perSeconds, RateIntervalUnit.SECONDS);
        boolean tryAcquire = rateLimiter.tryAcquire();
        //rateLimiter.expire 要放在 rateLimiter.tryAcquire()后面才起作用
        rateLimiter.expire(perSeconds * 2L, TimeUnit.SECONDS);
        if (tryAcquire) {
            //获取到令牌结束方法
            return;
        }
        //不允许访问
        if (RateLimitTypeEnum.IP == limitAnnotation.type() && limitAnnotation.blockWhenExceed()) {
            // 若IP限流超过上限,且指定了要放入ip黑名单 TODO
            
            //String clientIp = 获取请求的ip;
            
            ratelimitLog.warn(">>>> {}超过限流{},被放入ip黑名单{}分钟", clientIp, limitKey, limitAnnotation.blockMinutes());
        }
        throw new RuntimeException(limitAnnotation.errmsg());


    }

    private String buildLimitKey(RateLimit limitAnnotation, JoinPoint joinPoint) {
        String limitKey = null;
        String key = limitAnnotation.key();
        RateLimitTypeEnum limitType = limitAnnotation.type();
        Object[] args = joinPoint.getArgs();
        switch (limitType) {
            case IP:
                String clientIp = 获取请求的ip;
                limitKey = CacheKeys.ratelimitKey(RateLimitTypeEnum.IP, key, clientIp);
                break;
            case UID:
                // 按用户登录id限流
                Integer memberId = 从缓存中获取登录的用户id;
                if (memberId != null) {
                    limitKey = ratelimitKey(RateLimitTypeEnum.UID, key, String.valueOf(memberId));
                } else {
                    ratelimitLog.warn(">> 未找到登录用户信息,限流失败: {}", joinPoint);
                }
                break;
            case POJO_FIELD:
                String[] fields = limitAnnotation.fields();
                if (fields.length == 0) {
                    ratelimitLog.warn(">> 未设置field,限流失败: {}", joinPoint);
                    break;
                }
                if (args == null || args.length == 0 || args[0] == null) {
                    ratelimitLog.warn(">> 未找到对象,限流失败: {}", joinPoint);
                    break;
                }
                StringBuilder buffer = new StringBuilder();
                for (String field : fields) {
                    String fieldValue = getPojoField(field, args[0]);
                    if (buffer.length() > 0) {
                        buffer.append(".");
                    }
                    buffer.append(fieldValue);
                }
                limitKey = ratelimitKey(RateLimitTypeEnum.POJO_FIELD, key, buffer.toString());
                break;
            case PARAM:
                int keyIndex = limitAnnotation.keyParamIndex();

                if (keyIndex < 0 || args == null || args.length < (keyIndex + 1) || args[keyIndex] == null) {
                    ratelimitLog.warn(">> 未找到参数或参数值为空,限流失败: {}, keyParamIndex={}", joinPoint, keyIndex);
                } else if (isValidKeyParamType(args[keyIndex])) {
                    limitKey = ratelimitKey(RateLimitTypeEnum.PARAM, key, String.valueOf(args[keyIndex]));
                } else {
                    ratelimitLog.warn(">> 设置的参数不是string/long/int/short/byte类型,限流失败: {}", joinPoint);
                }
                break;
            case KEY:
                limitKey = ratelimitKey(RateLimitTypeEnum.KEY, key,null);
                break;
            default:
                // nothing to do
        }
        return limitKey;
    }

    private boolean isValidKeyParamType(Object param) {
        return (param instanceof String) || (param instanceof Long) || (param instanceof Integer) || (param instanceof Short)
                || (param instanceof Byte);
    }

    private String getPojoField(String field, Object pojo) {
        try {
            Map<String, Object> map = JacksonUtils.pojo2map(pojo);
            Object value = map.get(field);
            return String.valueOf(value);
        } catch (Exception e) {
            return "null";
        }
    }
   
    public static String ratelimitKey(RateLimitTypeEnum type, String key, String param) {
        String result = "reate-limit";
        switch (type) {
            case IP:
                result += "ip.";
                break;
            case UID:
                result += "u.";
                break;
            case PARAM:
                result += "p.";
                break;
            case POJO_FIELD:
                result += "f.";
                break;
            case KEY:
                result += "k.";
                break;
            default:
                // nothing to do
        }
        result += key;
        if (StringUtils.isNotBlank(param)) {
            result += "." + param;
        }
        return result;
    }
}

2. 要限流的枚举类型

public enum RateLimitTypeEnum {
    /**
     * 针对每个IP进行限流
     */
    IP,
    /**
     * 针对每个用户的UID进行限流
     */
    UID,
    /**
     * 针对对象的某个属性值进行限流
     */
    POJO_FIELD,
    /**
     * 针对某个参数进行限流
     */
    PARAM,
    /**
     * 直接对指定的key进行限流
     */
    KEY
}

3.使用示例

@RestController
@RequestMapping("/app")
public class SmsApi {

    @ApiOperation("短信发送")
    @PostMapping("/sms")
    //@RateLimit(type = RateLimitTypeEnum.IP, key = "sms-send", perSeconds = 60, errmsg = "")
    @RateLimit(type = RateLimitTypeEnum.UID, key = "sms-send", limit = 1, perSeconds = 60, errmsg = "一分钟只能发送一次")
    public void sendRegister(String phone) {
       //短信发送业务 TODO
    }
}