描述
根据AOP的环绕通知,给需要限流的接口打上注解,即被AOP所拦截。配合redis+lua脚本使其操作具有原子性,保证并发操作。每访问一次接口,redis的incr计数器增加一次,并设置过期时间。在有效的时间内,超出设置的最大访问值,就会被限制接口访问,从而达到接口限流的目的。
限流类型枚举
/**
* 限流枚举,通过枚举可设置key限流或者ip限流
*
* @author: 苦瓜不苦
* @date: 2021/8/4 20:35
**/
public enum LimitType {
/**
* 自定义key
*/
KEY,
/**
* 请求ip
*/
IP,
}
自定义限流注解
/**
* 限流注解
*/
@Documented
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface Limit {
// 名字
String name() default "限流";
// 唯一标识key
String key() default "";
// 时间范围,单位(秒)
int expire() default 5;
// 最大访问次数
int count() default 60;
// 限流类型
LimitType limitType() default LimitType.KEY;
}
获取IP工具类
/**
* @author 苦瓜不苦
* @date 2022/12/13 16:53
**/
public class IpUtil {
/**
* 获取Http Servlet请求
*
* @return
*/
public static HttpServletRequest request() {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
assert attributes != null;
return attributes.getRequest();
}
/**
* 获取请求IP地址
*
* @return
*/
public static String getIpAddress() {
HttpServletRequest request = request();
return getIpAddress(request);
}
/**
* 获取请求IP地址
*
* @param request
* @return
*/
public static String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip;
}
}
AOP切面实现限流
/**
* AOP限流切面实现
*
* @author: 苦瓜不苦
* @date: 2021/8/4 20:58
**/
@Aspect
@Component
public class LimitConfig {
@Autowired
private RedisTemplate<String, Serializable> redisTemplate;
/**
* 环绕通知
*
* @param point
* @return
*/
@Around("execution(public * *(..)) && @annotation(top.citycode.annotation.Limit)")
public Object interceptor(ProceedingJoinPoint point) {
// 获取方法上面的注解
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Limit limit = method.getAnnotation(Limit.class);
// 判断限流的类型
String key = method.getName();
switch (limit.limitType()) {
case IP:
key = IPUtil.getIpAddress();
break;
case KEY:
key = limit.key();
break;
default:
break;
}
try {
// key值
List<String> keyList = new ArrayList<>();
keyList.add(key);
// 编写lua脚本
String luaScript = this.buildLuaScript();
RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
// 执行lua脚本
Number count = redisTemplate.execute(redisScript, keyList, limit.count(), limit.expire());
if (count != null && count.intValue() <= limit.count()) {
// 执行方法
return point.proceed();
} else {
// 拦截,抛出异常
throw new RuntimeException("You have been dragged into the blacklist");
}
} catch (Throwable throwable) {
throwable.printStackTrace();
throw new RuntimeException("server exception");
}
}
/**
* 编写 redis lua限流脚本
*
* @return
*/
public String buildLuaScript() {
StringBuilder lua = new StringBuilder();
lua.append("local c");
lua.append("\nc = redis.call('get',KEYS[1])");
// 调用不超过最大值,则直接返回
lua.append("\nif c and tonumber(c) > tonumber(ARGV[1]) then");
lua.append("\nreturn c;");
lua.append("\nend");
// 执行计算器自加
lua.append("\nc = redis.call('incr',KEYS[1])");
lua.append("\nif tonumber(c) == 1 then");
// 从第一次调用开始限流,设置对应键值的过期
lua.append("\nredis.call('expire',KEYS[1],ARGV[2])");
lua.append("\nend");
lua.append("\nreturn c;");
return lua.toString();
}
}
代码测试
/**
* @author: 苦瓜不苦
* @date: 2021/8/4 21:50
**/
@RestController
@RequestMapping("/test")
public class TestController {
/**
* 限流:60s内请求三次
*
* @return
*/
@Limit(key = "user", expire = 60, count = 3, limitType = LimitType.KEY)
@GetMapping("/user")
public Map<String, Object> user() {
Map<String, Object> map = new HashMap<>();
map.put("姓名", "李四");
map.put("年龄", 18);
map.put("性别", "男");
return map;
}
}
扩展: 限流工具类
适用于单体服务的滑动窗口限流算法、固定窗口限流算法
public class LimitUtil {
private static final Map<String, List<Long>> FIXED_MAP = new ConcurrentHashMap<>();
private static final Map<String, List<Long>> SLIDE_MAP = new ConcurrentHashMap<>();
/**
* 滑动窗口限流算法
*
* @param key 唯一键值
* @param count 限流次数
* @param time 时间窗口,单位: ms
* @return 是否限流
*/
public static synchronized boolean slideLimit(String key, int count, long time) {
// 获取当前key的限流队列
List<Long> list = SLIDE_MAP.computeIfAbsent(key, linkedList -> new ArrayList<>());
// 获取当前时间戳
long millis = System.currentTimeMillis();
// 判断队列是否满
if (list.size() < count) {
// 允许通过,并记录当前时间戳到队列
list.add(millis);
return true;
}
// 队列已满,达到限流次数,获取最早的时间戳
Long beforeMillis = list.get(0);
// 用当前时间戳减去最早的时间戳
if (millis - beforeMillis <= time) {
// 若小于等于time,表示在time时间内,通过次数大于count,不允许通过
return false;
} else {
// 若大于time,表示在time时间内,通过次数小于等于count,允许通过
// 删除最早添加的时间戳,并将当前时间错添加
list.remove(0);
list.add(millis);
return true;
}
}
/**
* 固定窗口限流算法
*
* @param key 唯一键值
* @param count 限流次数
* @param time 时间窗口,单位: ms
* @return 是否限流
*/
public static synchronized boolean fixedLimit(String key, int count, long time) {
// 获取当前时间戳
long millis = System.currentTimeMillis();
// 获取当前key的限流队列
List<Long> list = FIXED_MAP.computeIfAbsent(key, linkedList -> new ArrayList<>());
// 判断计数器是否超过总数量
if (list.size() < count) {
list.add(millis);
return true;
}
// 计数已满,判断是否在窗口内
Long beforeMillis = list.get(0);
if (millis - beforeMillis <= time) {
// 若小于等于time,表示在time时间内,通过次数大于count,不允许通过
return false;
} else {
// 若大于time,表示在time时间内,通过次数小于等于count,允许通过
// 计数器清零,重新计算
list.clear();
list.add(millis);
return true;
}
}
}