持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第1天,点击查看活动详情
前言
作为一个后端开发人员,在我们日常的项目开发中访问速率的限制是必不可少的。今天分享下我用的一种方案。
实现思路
想要限制指定ip,首先要获取该ip。然后将该ip记录到Redis中,同时记录该ip的访问时间。该ip每次访问就对对应的计数加一,当达到设置的阈值时,用当前时间减去该ip的访问时间。用该时间与阈值时间对比,小于阈值时间则说明访问频率过高,需要限制;大于则说明访问频率符合要求,此时将访问时间置为当前时间,并重置访问次数。
思路理清楚了,接下来开始实践,明白过程只是需要demo的直接文末自取。
正文
按照思路,我们需要先获取用户ip,目前来说我们的应用都会经过nginx的代理,这会导致我们通过 request.getRemoteAddr()得到的不是用户ip而是nginx的ip,所以我们需要在nginx配置文件的location里加上这段配置(如果没用到代理的话就不需要这一步):
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
这些请求头的具体含义其它博客有更详细、专业的解读,我这里就不解释了。 这个处理好之后就可以下一步了。
首先准备一个过滤器,我这里采用的是通过注解@WebFilter实现的,参考代码如下:
@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
}
@Override
public void destroy() {
}
}
简单解释下注解参数:
- filterName:当前过滤器的名字
- urlPatterns:过滤器的匹配路径,/*代表匹配所有
- initParams:初始化参数,定义一个参数可以在过滤器里使用的初始化参数。我这里定义了一个过滤白名单,即名单里的路径不被过滤
过滤器准备好之后,在init里处理初始化参数,方便后续在处理逻辑里使用:
@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {
private List<String> noFilterUrls;
@Override
public void init(FilterConfig filterConfig) {
// 从过滤器配置中获取initParams参数
String noFilterUrl = filterConfig.getInitParameter("noFilterUrl");
// 将排除的URL放入成员变量noFilterUrls中
if (StringUtils.isNotBlank(noFilterUrl)) {
noFilterUrls = new ArrayList<>(Arrays.asList(noFilterUrl.split(",")));
}
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
}
@Override
public void destroy() {
}
}
之后就可以在doFilter里实现自己的判断以及限制逻辑了
首先定义需要的变量
@Resource
private RedisTemplate<String, Integer> redisTemplate;
public static final String frequencyKey = "frequency:control";
/**
* normal:正常访问
* black:黑名单
* time:最近一次访问时间
*/
public static final String normal = "normal:";
public static final String black = "black:";
public static final String time = "time:";
/**
* 单位时间内最大访问数:30 次
*/
private static final Integer MAX_COUNT = 20;
/**
* 单位时间:1 s
*/
private static final Integer UNIT_TIME = 1 * 1000;
/**
* 限制时长:1 hour
*/
private static final Long REJECT_TIME = 1 * 60 * 60 * 1000L;
在处理逻辑中,首先判断当前路径是否在白名单中:是,则直接放行;否,进行下一步处理。 处理未放行的请求:
- 获取用户ip,前面提到获取用户要通过
X-Forwarded-For获取。对应Java代码:request.getHeader("x-forwarded-for") - 判断当前ip是否已被封禁,否:进行下一步
- 判断当前ip是否是本轮第一次访问,是:重置时间和次数;否:下一步处理
- 判断当前ip的访问次数是否达到阈值
- 是:进一步判断时间是否到达阈值
- 大于阈值:限制ip,重置时间和访问次数
- 小于阈值:重置时间和访问次数
- 否:访问次数+1,放行 清楚每一步之后,我们来看最终代码
- 是:进一步判断时间是否到达阈值
@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {
private List<String> noFilterUrls;
@Resource
private RedisTemplate<String, Integer> redisTemplate;
public static final String frequencyKey = "frequency:control";
/**
* normal:正常访问
* black:黑名单
* time:最近一次访问时间
*/
public static final String normal = "normal:";
public static final String black = "black:";
public static final String time = "time:";
/**
* 单位时间内最大访问数:30 次
*/
private static final Integer MAX_COUNT = 20;
/**
* 单位时间:1 s
*/
private static final Integer UNIT_TIME = 1 * 1000;
/**
* 限制时长:1 hour
*/
private static final Long REJECT_TIME = 1 * 60 * 60 * 1000L;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
// 从过滤器配置中获取initParams参数
String noFilterUrl = filterConfig.getInitParameter("noFilterUrl");
// 将排除的URL放入成员变量noFilterUrls中
if (StringUtils.isNotBlank(noFilterUrl)) {
noFilterUrls = new ArrayList<>(Arrays.asList(noFilterUrl.split(",")));
}
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
// 若请求中包含noFilterUrls中的片段则直接跳过过滤器进入下一步请求中
HttpServletRequest request = (HttpServletRequest) servletRequest;
String url = request.getRequestURI();
Boolean flag = false;
if (!CollectionUtils.isEmpty(noFilterUrls)) {
for (String noFilterUrl : noFilterUrls) {
if (url.contains(noFilterUrl)) {
flag = true;
break;
}
}
}
if (!flag) {
//过滤请求响应逻辑
String ip = null;
if (request.getHeader("x-forwarded-for") == null) {
log.info("from RemoteAddr");
ip = request.getRemoteAddr();
}else {
log.info("from x-forwarded-for");
ip = request.getHeader("x-forwarded-for");
}
String key = frequencyKey + ip;
HashOperations<String, String, Object> hashOps = redisTemplate.opsForHash();
//过滤黑名单
if (redisTemplate.hasKey(frequencyKey + black + ip)) {
log.error("ip访问过于频繁,已被限制=>" + ip + " 倒计时" + redisTemplate.getExpire(frequencyKey + black + ip));
return;
}
//判断ip是否首次访问
if (hashOps.hasKey(frequencyKey, normal + ip)) {
//判断最大访问次数
Integer maxCount = Integer.valueOf(hashOps.get(frequencyKey, normal + ip).toString());
log.info("ip:" + ip + " 访问" + maxCount + "次");
if (maxCount > MAX_COUNT) {
// 获取从0达到上限次数所用时间
Long maxTime = Long.valueOf(hashOps.get(frequencyKey, time + ip).toString());
if (System.currentTimeMillis() - maxTime < UNIT_TIME) {
log.error("ip访问过于频繁,已被限制=>" + ip + " 倒计时" + REJECT_TIME);
redisTemplate.opsForValue().set(frequencyKey + black + ip, 1, REJECT_TIME, TimeUnit.MILLISECONDS);
String str[] = {normal + ip, time + ip};
hashOps.delete(frequencyKey, str);
return;
}
initVisitsIP(ip);
}
} else {
initVisitsIP(ip);
}
hashOps.increment(frequencyKey, normal + ip, 1);
filterChain.doFilter(servletRequest, servletResponse);
} else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
/**
* 初始化访问ip
*
* @param ip
*/
private void initVisitsIP(String ip) {
redisTemplate.opsForHash().put(frequencyKey, normal + ip, 0);
redisTemplate.opsForHash().put(frequencyKey, time + ip, String.valueOf(System.currentTimeMillis()));
}
@Override
public void destroy() {
}
}
看起来稍微有点复杂,但其实弄清每一步,自己试着写出来之后会发现并没有想象中那么困难。 本文到此就结束了,有问题欢迎评论区交流。