基于滑动窗口的限流实现

394 阅读6分钟

在复杂不可控的网络环境中,我们的http API一般都是完全暴露在网络中的,只要遵守定义的api调用规范,都是能够对接口发起请求的。虽然我们可以做很多的安全措施去保护api不能被随意的调用,但是这并不能完全杜绝这种现象的发生,况且,攻击者也有可能是一个正常的用户。所以我们必须对接口做一些流量限制。防止过多的请求占用过多的服务资源导致其他资源不可用,当然,这主要是为了防止恶意刷接口。

限流的实现方式有很多种,包括:计数器、滑动窗口、令牌桶、漏桶。这里我们主要用一种的简单且有效的方式来实现限流-滑动窗口。

本次实现使用spring的面向切面编程来使得限流具有可移植性,然后借助Redis的zset数据结构来实现滑动窗口算法,不使用lua脚本来保证一致性(运维不会允许)。

直接上代码:

  1. 首先定义一个限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {

    /**
     * 限流的key前缀
     */
    String key() default Constants.RATE_LIMIT_KEY;

    /**
     * 限流时间(ms)
     */
    int time() default 10000;

    /**
     * 限流请求数
     */
    int count() default 2000;

    /**
     * 限流类型: 全局限流、ip限流(默认)、用户限流
     */
    LimitType limitType() default LimitType.IP;

}

限流的实现:

运用Redis的zset这个数据结构:zset是一个有序set容器,特点就是有序且唯一。其中,有序是每个元素的权重来决定的,其也支持分段扫描操作。基于以上特性,做如下设计:

  • key设计:将不同的限流类型做不同的key,比如ip限流就可以将ip作为key中的关键字,用户限流可以用token作为key中的关键字,全局限流就把全限定路径作为关键字
  • 将当前请求的时间戳作为zset中一个元素的权重,也就是决定当前请求所处的位置。
  • 请求进来的时候查询zset数据结构中当前时间戳以前的限流时间内的请求个数(使用Redis#rangeByScore)。这样就查出了以当前时间为结尾的这个时间段内的请求数量,用于判定是否需要限流。后续做出相应的回应即可。
@Component
@Aspect
public class RateLimiterAspect {
    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate redisTemplate;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint joinPoint, RateLimiter rateLimiter) throws BusinessException {
        //限制key前缀
        final String key = rateLimiter.key();
        //时间
        final long time = rateLimiter.time();
        //次数
        final int count = rateLimiter.count();
        final String combineKey = getCombineKey(rateLimiter, joinPoint);
        limitFlow(combineKey, time, count);
    }

    /**
     * 限流实现
     * 运用redis的zset结构,以唯一请求标识为key,随机字符串为value,当前时间戳位权重
     * 算法:
     * 拼接出当前请求的唯一combineKey,然后查看redis中以这个combineKey为key的并且权重是在当前(时间-time)的数量,
     * 如果超过预设值说明超过流量限制,抛出提示性异常;否则则写入该请求
     * 支持ip限流,用户名限流;用redis中的key区别;
     * @param key 存放在Redis的唯一标识,全局限流默认是自定义key加上类名+方法名;IP限流则添加了ip,用户限流则添加了用户名
     * @param time 规定时间范围
     * @param count 规定时间内的请求数量
     * @throws BusinessException 自定义异常类,用于抛出限流异常
     */
    private void limitFlow(String key, long time, long count) throws BusinessException {
        long cur = System.currentTimeMillis();
        if(redisTemplate.hasKey(key)) {
            final Integer requestCount = redisTemplate.opsForZSet().rangeByScore(key, cur - time, cur).size();
            if(null != requestCount && requestCount >= count) {
                final String ip = IPUtil.getIpAddr(((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest());
                log.error(AddressUtils.getRealAddressByIP(ip) + " ip地址: " + ip +"请求速度过快,请稍后重试");
                throw new BusinessException(EmBusinessError.RATELIMITOE_ERROR);
            }
        }
        redisTemplate.opsForZSet().add(key,  UUID.randomUUID().toString(), cur);
        final Long aLong = redisTemplate.opsForZSet().removeRangeByScore(key, 0, cur - time);
    }

    private String getCombineKey(RateLimiter rateLimiter, JoinPoint joinPoint) {
        final StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        //根据IP限流
        if (rateLimiter.limitType() == LimitType.IP) {
            // key后面跟上ip
            stringBuffer.append(IPUtil.getIpAddr(((ServletRequestAttributes)RequestContextHolder.getRequestAttributes()).getRequest())).append("-");
        }
        if(rateLimiter.limitType() == LimitType.USER) {
            // 跟上用户名
            UserDetails userDetails = (UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal();
            stringBuffer.append(userDetails.getUsername()).append("-");
        }
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        final Method method = signature.getMethod();
        final Class<?> declaringClass = method.getDeclaringClass();
        // 后面跟上类名,方法名
        stringBuffer.append(declaringClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}

该种实现方式简单且有效,能够实现简单场景的限流。可能到这会有疑问:为什么不用计数器?计数器的操作似乎更加简单,且方便。计数器可以实现限流,但是简单的incr实现的限流方式是无效的。上面的这种实现方式是基于滑动窗口,简单的incr可不是滑动窗口,我的理解是:它更像是一个一维坐标轴上的翻转,简单来说只能实现固定区间的限流。大多数场景下,该种实现方式限流效果几乎为0。

以上方式其实也还可以做更多的优化:比如扩展不仅仅只支持一种限流方式,比如我要求全局限流10s内不超过3000的访问量,单个用户不能超过100的访问量这种特殊场景。这都需要根据实际需求进行扩展。

限流扩展:

除了以上的窗口算法,还有一类就是桶算法较为复杂,包括:漏桶算法,令牌桶算法。

漏桶算法:类似一个队列模型。请求不断的被加入桶中,而服务的处理速率是固定的,请求就会逐渐堆积,直到漏桶溢出,溢出的部分就是我们拒绝服务的部分。这种算法的一个好处是,消费的速率恒定,能够很好的保护自己,能够对流量整形,但是在面对突发流量的时候不能很好的响应。

令牌桶算法:类似一个生产者消费者模型。生产者负责生产令牌,消费者负责消费令牌。当请求进来之后,要想被系统处理就必须去桶里获取一个令牌。获取到令牌之后就能被及时处理。需要考虑的是生产令牌的速率不能过慢,也不能过快。这种算法的好处是能够处理突发的流量,但是要考虑到线程问题以及初次启动令牌放入不够及时的问题。

 final class Counter {
        volatile long timePeriodId;
        AtomicInteger requests = new AtomicInteger(0);
        public int addRequest(long currPeriodId) {
            if (currPeriodId != timePeriodId) {
                synchronized (this) {
                    if (currPeriodId != timePeriodId) {
                        timePeriodId = currPeriodId;
                        requests.set(0);
                    }
                }
            }
            // increment and return if we have gone above the limit
            return requests.incrementAndGet();
        }
        public synchronized long getTimePeriodId() {
            return timePeriodId;
        }
}

public boolean requestIncoming(Request request, long timeout) {
        if (!matcher.apply(request)) {
            return true;
        }

        boolean retval = true;
        long now = System.currentTimeMillis();
        long currPeriodId = now / timeInterval;
        String userKey = keyGenerator.getUserKey(request);

        // grab/generate the counter
        Counter counter = counters.get(userKey);
        if (counter == null) {
            userKey = canonicalizer.unique(userKey);
            synchronized (userKey) {
                counter = counters.get(userKey);
                if (counter == null) {
                    counter = new Counter();
                    counters.put(userKey, counter);
                }
            }
        }

        // update the counters
        int requests = counter.addRequest(currPeriodId);
        int residual = maxRequests - requests;

        if (residual < 0) {
            if (delay <= 0) {
                throw new HttpErrorCodeException(
                        429,
                        "Too many requests requests in the current time period, check X-Rate-Limit HTTP response headers");
            } else if (delay > timeout) {
                // no point in waiting
                return false;
            } else {
                if (LOGGER.isLoggable(Level.FINE)) {
                    LOGGER.fine(this + ", delaying current request");
                }
                try {
                    Thread.sleep(delay);
                } catch (InterruptedException e) {
                    LOGGER.log(Level.WARNING, this + ", the delay was abruptly interrupted", e);
                }
            }
        }
}