Redis实战基础应用:SpringBoot+Redis实现对高频访问ip的限制

395 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第1天,点击查看活动详情

前言

作为一个后端开发人员,在我们日常的项目开发中访问速率的限制是必不可少的。今天分享下我用的一种方案。

实现思路

想要限制指定ip,首先要获取该ip。然后将该ip记录到Redis中,同时记录该ip的访问时间。该ip每次访问就对对应的计数加一,当达到设置的阈值时,用当前时间减去该ip的访问时间。用该时间与阈值时间对比,小于阈值时间则说明访问频率过高,需要限制;大于则说明访问频率符合要求,此时将访问时间置为当前时间,并重置访问次数。 思路理清楚了,接下来开始实践,明白过程只是需要demo的直接文末自取。

正文

按照思路,我们需要先获取用户ip,目前来说我们的应用都会经过nginx的代理,这会导致我们通过 request.getRemoteAddr()得到的不是用户ip而是nginx的ip,所以我们需要在nginx配置文件的location里加上这段配置(如果没用到代理的话就不需要这一步):

    proxy_set_header        X-Real-IP       $remote_addr;
    proxy_set_header        X-Forwarded-For $proxy_add_x_forwarded_for;

这些请求头的具体含义其它博客有更详细、专业的解读,我这里就不解释了。 这个处理好之后就可以下一步了。

首先准备一个过滤器,我这里采用的是通过注解@WebFilter实现的,参考代码如下:

@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
        initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) {
    }
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
    }
    @Override
    public void destroy() {

    }
}

简单解释下注解参数:

  • filterName:当前过滤器的名字
  • urlPatterns:过滤器的匹配路径,/*代表匹配所有
  • initParams:初始化参数,定义一个参数可以在过滤器里使用的初始化参数。我这里定义了一个过滤白名单,即名单里的路径不被过滤

过滤器准备好之后,在init里处理初始化参数,方便后续在处理逻辑里使用:

@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
        initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {
    private List<String> noFilterUrls;
    @Override
    public void init(FilterConfig filterConfig) {
        // 从过滤器配置中获取initParams参数
        String noFilterUrl = filterConfig.getInitParameter("noFilterUrl");
        // 将排除的URL放入成员变量noFilterUrls中
        if (StringUtils.isNotBlank(noFilterUrl)) {
            noFilterUrls = new ArrayList<>(Arrays.asList(noFilterUrl.split(",")));
        }
    }
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
    }
    @Override
    public void destroy() {

    }
}

之后就可以在doFilter里实现自己的判断以及限制逻辑了

首先定义需要的变量

@Resource
private RedisTemplate<String, Integer> redisTemplate;
public static final String frequencyKey = "frequency:control";
/**
 * normal:正常访问
 * black:黑名单
 * time:最近一次访问时间
 */
public static final String normal = "normal:";
public static final String black = "black:";
public static final String time = "time:";

/**
 * 单位时间内最大访问数:30 次
 */
private static final Integer MAX_COUNT = 20;
/**
 * 单位时间:1 s
 */
private static final Integer UNIT_TIME = 1 * 1000;
/**
 * 限制时长:1 hour
 */
private static final Long REJECT_TIME = 1 * 60 * 60 * 1000L;

在处理逻辑中,首先判断当前路径是否在白名单中:是,则直接放行;否,进行下一步处理。 处理未放行的请求:

  1. 获取用户ip,前面提到获取用户要通过X-Forwarded-For获取。对应Java代码:request.getHeader("x-forwarded-for")
  2. 判断当前ip是否已被封禁,否:进行下一步
  3. 判断当前ip是否是本轮第一次访问,是:重置时间和次数;否:下一步处理
  4. 判断当前ip的访问次数是否达到阈值
    1. 是:进一步判断时间是否到达阈值
      1. 大于阈值:限制ip,重置时间和访问次数
      2. 小于阈值:重置时间和访问次数
    2. 否:访问次数+1,放行 清楚每一步之后,我们来看最终代码
@WebFilter(filterName = "frequencyFilter", urlPatterns = "/*",
        initParams = @WebInitParam(name = "noFilterUrl", value = "/webjars,/v2/api-docs,/swagger-resources"))
@Slf4j
public class FrequencyFilter implements Filter {

    private List<String> noFilterUrls;

    @Resource
    private RedisTemplate<String, Integer> redisTemplate;
    public static final String frequencyKey = "frequency:control";
    /**
     * normal:正常访问
     * black:黑名单
     * time:最近一次访问时间
     */
    public static final String normal = "normal:";
    public static final String black = "black:";
    public static final String time = "time:";

    /**
     * 单位时间内最大访问数:30 次
     */
    private static final Integer MAX_COUNT = 20;
    /**
     * 单位时间:1 s
     */
    private static final Integer UNIT_TIME = 1 * 1000;
    /**
     * 限制时长:1 hour
     */
    private static final Long REJECT_TIME = 1 * 60 * 60 * 1000L;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        // 从过滤器配置中获取initParams参数
        String noFilterUrl = filterConfig.getInitParameter("noFilterUrl");
        // 将排除的URL放入成员变量noFilterUrls中
        if (StringUtils.isNotBlank(noFilterUrl)) {
            noFilterUrls = new ArrayList<>(Arrays.asList(noFilterUrl.split(",")));
        }
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
            throws IOException, ServletException {
        // 若请求中包含noFilterUrls中的片段则直接跳过过滤器进入下一步请求中
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        String url = request.getRequestURI();
        Boolean flag = false;
        if (!CollectionUtils.isEmpty(noFilterUrls)) {
            for (String noFilterUrl : noFilterUrls) {
                if (url.contains(noFilterUrl)) {
                    flag = true;
                    break;
                }
            }
        }
        if (!flag) {
            //过滤请求响应逻辑
            String ip = null;
            if (request.getHeader("x-forwarded-for") == null) {
                log.info("from RemoteAddr");
                ip = request.getRemoteAddr();
            }else {
                log.info("from x-forwarded-for");
                ip = request.getHeader("x-forwarded-for");
            }

            String key = frequencyKey + ip;
            HashOperations<String, String, Object> hashOps = redisTemplate.opsForHash();
            //过滤黑名单
            if (redisTemplate.hasKey(frequencyKey + black + ip)) {
                log.error("ip访问过于频繁,已被限制=>" + ip + " 倒计时" + redisTemplate.getExpire(frequencyKey + black + ip));
                return;
            }
            //判断ip是否首次访问
            if (hashOps.hasKey(frequencyKey, normal + ip)) {
                //判断最大访问次数
                Integer maxCount = Integer.valueOf(hashOps.get(frequencyKey, normal + ip).toString());
                log.info("ip:" + ip + " 访问" + maxCount + "次");
                if (maxCount > MAX_COUNT) {
                    // 获取从0达到上限次数所用时间
                    Long maxTime = Long.valueOf(hashOps.get(frequencyKey, time + ip).toString());
                    if (System.currentTimeMillis() - maxTime < UNIT_TIME) {
                        log.error("ip访问过于频繁,已被限制=>" + ip + " 倒计时" + REJECT_TIME);
                        redisTemplate.opsForValue().set(frequencyKey + black + ip, 1, REJECT_TIME, TimeUnit.MILLISECONDS);
                        String str[] = {normal + ip, time + ip};
                        hashOps.delete(frequencyKey, str);
                        return;
                    }
                    initVisitsIP(ip);
                }
            } else {
                initVisitsIP(ip);
            }
            hashOps.increment(frequencyKey, normal + ip, 1);
            filterChain.doFilter(servletRequest, servletResponse);
        } else {
            filterChain.doFilter(servletRequest, servletResponse);
        }

    }

    /**
     * 初始化访问ip
     *
     * @param ip
     */
    private void initVisitsIP(String ip) {
        redisTemplate.opsForHash().put(frequencyKey, normal + ip, 0);
        redisTemplate.opsForHash().put(frequencyKey, time + ip, String.valueOf(System.currentTimeMillis()));
    }

    @Override
    public void destroy() {

    }
}

看起来稍微有点复杂,但其实弄清每一步,自己试着写出来之后会发现并没有想象中那么困难。 本文到此就结束了,有问题欢迎评论区交流。