自定义注解AOP限流

479 阅读2分钟

描述

根据AOP的环绕通知,给需要限流的接口打上注解,即被AOP所拦截。配合redis+lua脚本使其操作具有原子性,保证并发操作。每访问一次接口,redis的incr计数器增加一次,并设置过期时间。在有效的时间内,超出设置的最大访问值,就会被限制接口访问,从而达到接口限流的目的。

限流类型枚举

/**
 * 限流枚举,通过枚举可设置key限流或者ip限流
 *
 * @author: 苦瓜不苦
 * @date: 2021/8/4 20:35
 **/
public enum LimitType {
    /**
     * 自定义key
     */
    KEY,
    /**
     * 请求ip
     */
    IP,
}

自定义限流注解

/**
 * 限流注解
 */
@Documented
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {

    // 名字
    String name() default "限流";

    // 唯一标识key
    String key() default "";


    // 时间范围,单位(秒)
    int expire() default 5;

    // 最大访问次数
    int count() default 60;

    // 限流类型
    LimitType limitType() default LimitType.KEY;


}

获取IP工具类


/**
 * @author 苦瓜不苦
 * @date 2022/12/13 16:53
 **/
public class IpUtil {

    /**
     * 获取Http Servlet请求
     *
     * @return
     */
    public static HttpServletRequest request() {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        assert attributes != null;
        return attributes.getRequest();
    }


    /**
     * 获取请求IP地址
     *
     * @return
     */
    public static String getIpAddress() {
        HttpServletRequest request = request();
        return getIpAddress(request);
    }


    /**
     * 获取请求IP地址
     *
     * @param request
     * @return
     */
    public static String getIpAddress(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip;
    }
    

}


AOP切面实现限流

/**
 * AOP限流切面实现
 *
 * @author: 苦瓜不苦
 * @date: 2021/8/4 20:58
 **/
@Aspect
@Component
public class LimitConfig {


    @Autowired
    private RedisTemplate<String, Serializable> redisTemplate;


    /**
     * 环绕通知
     *
     * @param point
     * @return
     */
    @Around("execution(public * *(..)) && @annotation(top.citycode.annotation.Limit)")
    public Object interceptor(ProceedingJoinPoint point) {
        // 获取方法上面的注解
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Limit limit = method.getAnnotation(Limit.class);

        // 判断限流的类型
        String key = method.getName();
        switch (limit.limitType()) {
            case IP:
                key = IPUtil.getIpAddress();
                break;
            case KEY:
                key = limit.key();
                break;
            default:
                break;
        }

        try {
            // key值
            List<String> keyList = new ArrayList<>();
            keyList.add(key);

            // 编写lua脚本
            String luaScript = this.buildLuaScript();
            RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
            // 执行lua脚本
            Number count = redisTemplate.execute(redisScript, keyList, limit.count(), limit.expire());
            if (count != null && count.intValue() <= limit.count()) {
                // 执行方法
                return point.proceed();
            } else {
                // 拦截,抛出异常
                throw new RuntimeException("You have been dragged into the blacklist");
            }
        } catch (Throwable throwable) {
            throwable.printStackTrace();
            throw new RuntimeException("server exception");
        }

    }


    /**
     * 编写 redis lua限流脚本
     *
     * @return
     */
    public String buildLuaScript() {
        StringBuilder lua = new StringBuilder();
        lua.append("local c");
        lua.append("\nc = redis.call('get',KEYS[1])");
        // 调用不超过最大值,则直接返回
        lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
        lua.append("\nreturn c;");
        lua.append("\nend");
        // 执行计算器自加
        lua.append("\nc = redis.call('incr',KEYS[1])");
        lua.append("\nif tonumber(c) == 1 then");
        // 从第一次调用开始限流,设置对应键值的过期
        lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");
        lua.append("\nend");
        lua.append("\nreturn c;");

        return lua.toString();
    }


}

代码测试

/**
 * @author: 苦瓜不苦
 * @date: 2021/8/4 21:50
 **/
@RestController
@RequestMapping("/test")
public class TestController {

    /**
     * 限流:60s内请求三次
     *
     * @return
     */
    @Limit(key = "user", expire = 60, count = 3, limitType = LimitType.KEY)
    @GetMapping("/user")
    public Map<String, Object> user() {
        Map<String, Object> map = new HashMap<>();
        map.put("姓名", "李四");
        map.put("年龄", 18);
        map.put("性别", "男");
        return map;
    }


}

扩展: 限流工具类

适用于单体服务的滑动窗口限流算法、固定窗口限流算法

public class LimitUtil {

    private static final Map<String, List<Long>> FIXED_MAP = new ConcurrentHashMap<>();
    private static final Map<String, List<Long>> SLIDE_MAP = new ConcurrentHashMap<>();


    /**
     * 滑动窗口限流算法
     *
     * @param key   唯一键值
     * @param count 限流次数
     * @param time  时间窗口,单位: ms
     * @return 是否限流
     */
    public static synchronized boolean slideLimit(String key, int count, long time) {
        // 获取当前key的限流队列
        List<Long> list = SLIDE_MAP.computeIfAbsent(key, linkedList -> new ArrayList<>());
        // 获取当前时间戳
        long millis = System.currentTimeMillis();
        // 判断队列是否满
        if (list.size() < count) {
            // 允许通过,并记录当前时间戳到队列
            list.add(millis);
            return true;
        }
        // 队列已满,达到限流次数,获取最早的时间戳
        Long beforeMillis = list.get(0);
        // 用当前时间戳减去最早的时间戳
        if (millis - beforeMillis <= time) {
            // 若小于等于time,表示在time时间内,通过次数大于count,不允许通过
            return false;
        } else {
            // 若大于time,表示在time时间内,通过次数小于等于count,允许通过
            // 删除最早添加的时间戳,并将当前时间错添加
            list.remove(0);
            list.add(millis);
            return true;
        }
    }


    /**
     * 固定窗口限流算法
     *
     * @param key   唯一键值
     * @param count 限流次数
     * @param time  时间窗口,单位: ms
     * @return 是否限流
     */
    public static synchronized boolean fixedLimit(String key, int count, long time) {
        // 获取当前时间戳
        long millis = System.currentTimeMillis();
        // 获取当前key的限流队列
        List<Long> list = FIXED_MAP.computeIfAbsent(key, linkedList -> new ArrayList<>());
        // 判断计数器是否超过总数量
        if (list.size() < count) {
            list.add(millis);
            return true;
        }
        // 计数已满,判断是否在窗口内
        Long beforeMillis = list.get(0);
        if (millis - beforeMillis <= time) {
            // 若小于等于time,表示在time时间内,通过次数大于count,不允许通过
            return false;
        } else {
            // 若大于time,表示在time时间内,通过次数小于等于count,允许通过
            // 计数器清零,重新计算
            list.clear();
            list.add(millis);
            return true;
        }

    }


}