手把手教你一个注解搞定指定API的用户限流

742 阅读8分钟

一个注解实现用户限流

最近在写RPC框架,在实现完了负载均衡,熔断后开始准备实现限流器,原本只是做到API级别的限流,刚好有人问我对用户限流的操作,我想了想差别也不大,那就做到对用户的限流粒度吧!

项目地址:easy_rpc: 个人探索rpc框架中~~~ (gitee.com) 我自己的一个RPC框架,目前包含了rpc调用,负责均衡,服务熔断,降级。如果要找限流器的代码,在domain模块的limit包下


话不多说,先上成果,使用是有点类似与sentinel的,但是也有点不一样

2023-08-24 14:45:20.565  INFO 12128 --- [nio-8088-exec-1] com.rpc.domain.limit.LimitProcess        : 请求通过,当前QPS:0.0
执行放行
2023-08-24 14:45:20.964  INFO 12128 --- [io-8088-exec-10] com.rpc.domain.limit.LimitProcess        : 请求通过,当前QPS:1.0
执行放行
2023-08-24 14:45:21.271  INFO 12128 --- [nio-8088-exec-2] com.rpc.domain.limit.LimitProcess        : 请求通过,当前QPS:2.0
执行放行
2023-08-24 14:45:21.561  INFO 12128 --- [nio-8088-exec-6] com.rpc.domain.limit.LimitProcess        : 请求通过,当前QPS:3.0
执行放行
2023-08-24 14:45:21.879  INFO 12128 --- [nio-8088-exec-3] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0
2023-08-24 14:45:22.165  INFO 12128 --- [nio-8088-exec-4] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0
2023-08-24 14:45:22.446  INFO 12128 --- [io-8088-exec-10] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0
2023-08-24 14:45:22.730  INFO 12128 --- [nio-8088-exec-2] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0
2023-08-24 14:45:23.048  INFO 12128 --- [nio-8088-exec-6] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0
2023-08-24 14:45:23.413  INFO 12128 --- [nio-8088-exec-3] com.rpc.domain.limit.LimitProcess        : 请求拒绝,当前QPS:4.0

然后这个是需要限流的API

  @RequestMapping("/order")
    @LimitingStrategy(strategyName = "test",limitKey = "id",QPS = 3)
    public void getOrder( Person p)  {
        System.out.println("执行放行");
    }
​

这里先解释下参数吧

  • strategyName 这个是用来表示一个限流名,多个API之间的限流名不允许一样。
  • limitKey,这个就是表示操作的key,也就是相同的key会被认为是同样的请求,如果你不写则是对这个API进行限流,比如我这里写了id,那么他会查找参数p的属性的id值,会对同一个id进行限流,这里也就是完成了对用户的限流,当然你可以写其他的,比如requestId,订单号等等。
  • QPS,也就是每秒的查询量,这个的秒数可以自定义的,比如我测试用的四秒,也就是四秒内最多三个请求通过(分粒度,可能是整个API,也可能是针对Key)

源码分析


首先限流算法是有很多种的,比如令牌桶算法,滑动窗口算法,我这里是借鉴的sentinel的滑动窗口算法,稍微比较复杂,我这里先贴一个视频讲解的www.bilibili.com/video/BV1My…

滑动窗口


*    B0     B1     B2      B3     B4      B5      B6
* |_____|______|_______|_______|_______|_______|_______||___
* 0     100    200     400     600     800     1000    1200  timestamp

如上图所示,这是一个底部是一个时间戳,假设单位为ms,也就是每个格子时间长度为100ms。

假设我们定义了一个时间窗口为200ms,QPS=3,也就是说200ms内最多通过三个请求。

整个滑动窗口实际上并不是每来一个请求就需要向右移动,而是将滑动窗口分成一个一个小块,每个小块被称为样本窗口SampleWindow(为什么要这么做,我认为一是为了性能,二是实现更加方便)

样本窗口实际为一个数组,你也可以认为是一个循环数组,因为取元素时进行了取模操作。每个样本窗口都有以下属性

  • startTimeInMs,窗口开始时间
  • intervalTime,窗口时间长度
  • SampleEntity,保存窗口内部的请求数据,如通过的数据,被拦截的数据

下面进入演示

假设样本窗口数量为2,也就是说每个样本窗口时间长为100ms(数量越多,精确度越高,后面会解释)

那么当一个请求进入时就如以下步骤:

  1. 根据请求到达的时间计算它应该在哪个样本窗口内,算法:index=(curTime / 100) % 样本窗口数量。如 在130ms到达的请求 index = (130/100) % 2 = 1 也就是它应该属于B1窗口。通过这个算法可以发现只要是在100-200之间的请求到达都会在s1样本窗口上。

    *    B0     B1     B2      B3     B4      B5      B6
    * |_____|______|_______|_______|_______|_______|_______||___
    * 0     100    200     400     600     800     1000    1200  timestamp
                ^
                s1
    
  2. 那么同理,200-400之间的请求会落在B2窗口上。那么你会问了,只有两个样本窗口,到达B3了怎么办呢?其实刷新之前的窗口就好了。

    *    B0     B1     B2      B3     B4      B5      B6
    * |_____|______|_______|_______|_______|_______|_______||___
    * 0     100    200     400     600     800     1000    1200  timestamp
                ^      ^        ^
                s1     s2       s3  
    
  3. 我们先计算这个请求会落在哪个窗口上,图上我们会发现应该是在B3。也就是开始时间为400ms的窗口。同时根据之前计算样本窗口下标计算方法可以得到index = 1 ,这是时候我们根据index来取样本窗口数组中的元素 ,然后再去比较B3窗口的开始时间s和取得的样本窗口起始时间,然后发现一个是400ms,另外一个是100ms,这时认为取出来的这个窗口是一个过去的窗口,于是我们重置窗口里面的数据,并更新开始时间为当前的400ms,最后添加一次请求,至此就是一个滑动窗口的滑动操作,你会发现其实每个样本窗口不一定都在时间滑动窗口里面,因为有可能那个样本窗口是一个过期了的样本窗口。

    *    B0     B1     B2      B3     B4      B5      B6
    * |_____|______|_______|_______|_______|_______|_______||___
    * 0     100    200     400     600     800     1000    1200  timestamp
                ^      ^       ^
             s1(被刷了) s2     s1  (新的)
    
  4. 如何计算一个请求到达的窗口的开始时间呢,算法: startTime = curTime - (curTime % 100ms)

由上图可以发现,一个请求到达的情况其实有两种:

  • 一种是计算出的时间窗口的起始时间就是取出来的样本窗口起始时间,那么直接获取样本窗口的SampleEntity,进行数据添加
  • 另一种是计算出的时间窗口的起始时间与取出来的样本窗口起始时间不一致,发现这个窗口是过期的数据,那么进行数据的跟新,然后添加一次请求

实际上你还需要考虑窗口没有被初始化,如果样本窗口为null,也进行初始化后放入值,这里贴一下核心的代码

滑动窗口实现类

// 滑动窗口实现类
public class SlidingWindow {
//    滑动窗口时间
    private final long windowIntervalInMs;
//    每个滑动窗口的样本窗口数量
    private final int sampleWindowAmount;
//    每个样本窗口的时间
    private final long sampleWindowIntervalInMs;
//    样本窗口数组
    private final AtomicReferenceArray<SampleWindow> sampleWindowAtomicReferenceArray;
​
    private final ReentrantLock updateLock = new ReentrantLock();
}

样本窗口类

public class SampleWindow {
    private long startTimeInMs;
    private final long intervalTime;
//    实际保持当前窗口的数据
    private SampleEntity sampleEntity;
    public SampleWindow reset(long startTime){
        startTimeInMs = startTime;
        sampleEntity.init();
        return this;
    }
​
}

样本窗口数据载体

@Data
public class SampleEntity {
//    键为limitKey
    private ConcurrentHashMap<Object,Integer> passMap;
//    键为limitKey
    private ConcurrentHashMap<Object,Integer> blockMap;
//    整个样本窗口内的API的通过数量
    private AtomicInteger pass;
//    整个样本窗口内的API的拒绝数量
    private AtomicInteger block;
}

这个方法也就是核心方法,根据当前时间获取样本窗口的方法

private SampleWindow getCurSampleWindow(){
    Long curSystemTime = System.currentTimeMillis();
    int curSampleWindowIndex = getCurSampleWindowIndex(curSystemTime);
​
    long curSampleWindowStartTime = getCurSampleWindowStartTime(curSystemTime);
​
    while (true){
//          若为null则初始化一个样本窗口
        if(sampleWindowAtomicReferenceArray.get(curSampleWindowIndex) == null){
            SampleWindow newSampleWindow = new SampleWindow(curSampleWindowStartTime,sampleWindowIntervalInMs,new SampleEntity());
            if (sampleWindowAtomicReferenceArray.compareAndSet(curSampleWindowIndex,null,newSampleWindow)) {
//                    log.info("当前创建一个样本窗口");
                return newSampleWindow;
            }else {
//                    CAS如果失败则礼让线程,重试
                Thread.yield();
            }
        }else {
            SampleWindow oldSampleWindow = sampleWindowAtomicReferenceArray.get(curSampleWindowIndex);
//                    当前为同一个样本窗,直接返回
            if(oldSampleWindow.getStartTimeInMs() == curSampleWindowStartTime){
//                    log.info("当前为同一个样本窗口");
                return oldSampleWindow;
            }else if(oldSampleWindow.getStartTimeInMs() < curSampleWindowStartTime){
                 if(updateLock.tryLock()){
                     try {
//                             log.info("当前跟新样本窗口,设置开始时间="+curSampleWindowStartTime);
                         return oldSampleWindow.reset(curSampleWindowStartTime);
                     }finally {
                         updateLock.unlock();
                     }
                 }else {
                     Thread.yield();
                 }
            }
//                一般不会发生以下情况(重置系统时间导致时间倒流)
            else if(oldSampleWindow.getStartTimeInMs() > curSampleWindowStartTime){
                return new SampleWindow(curSampleWindowStartTime,sampleWindowIntervalInMs,new SampleEntity());
            }
        }
    }
}
​
// 此外还有两个工具方法
//  根据当前时间获取样本窗口地址
private int getCurSampleWindowIndex(Long time){
    long temp = time / sampleWindowIntervalInMs;
    return (int)(temp % sampleWindowAmount);
}
​
//    获取样本窗口的开始时间
private Long getCurSampleWindowStartTime(Long time){
    return time - (time % windowIntervalInMs);
}

如果我这里写的不清楚,你可以翻看sentinel的LeapArray中的方法,基本一致。


当数据保存好后那么统计数据就变得非常简单了

也就是遍历每个窗口,查看窗口的起始时间距离现在的时间有没有超过我规定的时间窗口时长

//    判断当前样本窗口是否还在滑动窗口时间内
    private boolean isSampleWindowDeprecated(long time,SampleWindow sampleWindow){
        if(sampleWindow == null){
            return true;
        }
        return ((time - sampleWindow.getStartTimeInMs()) > windowIntervalInMs);
    }
//  获取有效窗口
    private List<SampleWindow> getValidSampleWindow(long time){
        List<SampleWindow> res = new ArrayList<>();
        int length = sampleWindowAtomicReferenceArray.length();
        for (int i = 0; i < length; i++) {
            SampleWindow sampleWindow = sampleWindowAtomicReferenceArray.get(i);
            if(!isSampleWindowDeprecated(time,sampleWindow)){
                res.add(sampleWindow);
            }
        }
        return res;
    }
// 这里就直接统计数据了    
       @Override
    public int getQPS() {
        long curTime = System.currentTimeMillis();
        int passCount = getPassCount(curTime);
        return passCount;
    }

那我是如何实现对用户级别的限流呢,其实也非常简单,我的SampleEntity 存储了一个Map,也就是同一个limitKey会进入同一个PassCount或者blockCount中,下面是AOP相关的操作,也就是如何从方法中获取limitKey的值,这里主要是一些反射的方法(降级还没写好,但是不难,根据class对象拿到bean,然后生成代理对象执行方法即可)

 @Around("annotationPointCut()")
    public Object annotationAround(ProceedingJoinPoint jp) {
        Method method = ((MethodSignature) jp.getSignature()).getMethod();
        LimitingStrategy limitingStrategy = method.getAnnotation(LimitingStrategy.class);
        // 把注解的值赋值到rule对象中
        LimitingRule rule = new LimitingRule(limitingStrategy);
        String limitKeyName = limitingStrategy.limitKey();
        Object[] args = jp.getArgs();
//        如果设置了limitKey则查找属性名
        if(!rule.getLimitKey().equals("")){
            for (Object arg : args) {
                for (Field declaredField : arg.getClass().getDeclaredFields()) {
                    if(declaredField.getName().equals(limitKeyName)){
                        try {
//                            设置可访问性
                            declaredField.setAccessible(true);
//                            拿到值
                            Object value = declaredField.get(arg);
                            rule.setLimitValue(value);
                            boolean pass = limitProcess.isPass(rule);
                            if(pass){
                                return jp.proceed();
                            }else {
                            // 这里执行降级处理
                                return null;
                            }
                        } catch (Throwable e ){
                            e.printStackTrace();
                        }
                    }
                }
            }
        }
        boolean pass = limitProcess.isPass(rule);
        if(pass){
            try {
                return jp.proceed();
            } catch (Throwable throwable) {
                throwable.printStackTrace();
            }
        }else {
            log.info("已拦截");
             // 这里执行降级处理
            return null;
        }
        return null;
    }

那么整个方法调用就是

limitProcess.isPass()-> limitStrategy.getQPS( ) -> getPassCount( ) ->

判断是否通过 获取当前QPS 获取通过数量

getValidSampleWindow() -> 返回有效窗口

获取有效的窗口

如果QPS 小于当前设置的QPS ,则添加一次通过请求数量

limitStrategy.incrPassCount()-> getCurSampleWindow() ->

增加通过数量 获取当前窗口

curSampleWindow.getSampleEntity().addPass()

添加一次通过请求

希望这个文章对你有帮助,你也可以为我的项目点个STAR。

easy_rpc: 个人探索rpc框架中~~~ (gitee.com)