redis+lua实现分布式限流

442 阅读2分钟

1:安装好redis,创建一个普通maven项目,pom.xml文件如下,因为限流是一个通用的组件,所以这个项目就作为通用子项目

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
            <version>2.2.2.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
            <version>2.2.2.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.8</version>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>28.1-jre</version>
        </dependency>

2:编写redis的配置类

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;

@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate<String,String> redisTemplate(RedisConnectionFactory factory) {
        return new StringRedisTemplate(factory);
    }

    @Bean
    public DefaultRedisScript loadRedisScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript();
        //加载lua脚本
        redisScript.setLocation(new ClassPathResource("rateLimiter.lua"));
        //执行lua脚本返回的类型,返回true代表可以通行,false代表被限流了
        redisScript.setResultType(java.lang.Boolean.class);
        return redisScript;
    }
}

3:在resource目录下编写rateLimiter.lua脚本

-- 获取方法签名特征
local methodKey = KEYS[1]
redis.log(redis.LOG_DEBUG,"key is ",methodKey)
--调用脚本传入的限流大小
local limit = tonumber(ARGV[1])

--获取当前流量大小
local count = tonumber(redis.call('get',methodKey) or "0")

--是否超出限流阈值
if count + 1 > limit then
    -- 拒绝服务访问
    return false
else
    --没有被限流,设置当前访问的数量+1
    redis.call("INCRBY",methodKey,1)
    -- 设置过期时间
    redis.call("EXPIRE",methodKey,1)
    return true
end

4:编写一个服务提供入口

  import com.google.common.collect.Lists;
  import lombok.extern.slf4j.Slf4j;
  import org.springframework.beans.factory.annotation.Autowired;
  import org.springframework.data.redis.core.StringRedisTemplate;
  import org.springframework.data.redis.core.script.RedisScript;
  import org.springframework.stereotype.Service;

  @Service
  @Slf4j
  public class AccessLimiter {

      @Autowired
      private StringRedisTemplate stringRedisTemplate;

      @Autowired
      private RedisScript<Boolean> redisScript;

      public void limitAccess(String key,Integer limit) {
          //1:调用lua脚本
          boolean acquired = stringRedisTemplate.execute(
                  redisScript,
                  Lists.newArrayList(key),
                  limit.toString()
          );
          if(!acquired) {
              log.error("");
              throw new RuntimeException("被限流了");
          }
      }
  }
  

5:重新创建一个SpringBoot项目,pom.xml如下:

 <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <version>2.2.2.RELEASE</version>
        </dependency>
        <!-- 引入分布式限流的通用组件 -->
        <dependency>
            <groupId>com.coco</groupId>
            <artifactId>ratelimiter</artifactId>
            <version>1.0-SNAPSHOT</version>
        </dependency>
    </dependencies>
    

6:编写测试类

@RestController
@Slf4j
public class TestController {


    @Resource
    AccessLimiter accessLimiter;

    /**
     * ratelimiter-test:是存储在redis中的一个key,代表方法签名
     * 1:表示每秒只允许通过一个请求
     */
    @GetMapping("/test")
    public String test() {
        accessLimiter.limitAccess("ratelimiter-test",1);
        return "success";
    }
 }
 

7:但是这样需要在限流的每个方法中都写上 accessLimiter.limitAccess("ratelimiter-test",1);下面使用注解来实现

 @Target({ElementType.METHOD})
 @Retention(RetentionPolicy.RUNTIME)
 @Documented
 public @interface AccessLimiterAnnotation {

    String methodKey() default "";

    int limit() default 0;

 }
 

8:使用AOP实现统一拦截

  @Aspect
  @Component
  public class AccessLimiterAspect {

      @Autowired
      private StringRedisTemplate stringRedisTemplate;

      @Autowired
      private RedisScript<Boolean> redisScript;

      @Pointcut("@annotation(com.coco.AccessLimiterAnnotation)")
      public void cut() {

      }

      @Before("cut()")
      public void before(JoinPoint joinPoint) {
          //1:获得方法签名
          MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
          Method method = methodSignature.getMethod();
          AccessLimiterAnnotation limiterAnnotation = method.getAnnotation(AccessLimiterAnnotation.class);
          if(limiterAnnotation == null) {
              return;
          }
          String methodKey = limiterAnnotation.methodKey();
          Integer limit = limiterAnnotation.limit();
          if(StringUtils.isEmpty(methodKey)) {
              //todo 如果注解中methodKey为空,那么这里就需要自定义一个名称
          }
          //2:调用redis
          boolean acquired = stringRedisTemplate.execute(
                  redisScript,
                  Lists.newArrayList(methodKey),
                  limit.toString()
          );
          if(!acquired) {
              throw new RuntimeException("被限流了");
          }
      }
  }
  

9:测试

     @GetMapping("/anno")
     @AccessLimiterAnnotation(methodKey = "anno",limit = 1)
     public String anno() {
        return "success";
     }
     

10:虽然我这里是单机,但是可以想到限流原理是每个客户端都是从redis中获取流量。如果超出阈值就直接限流了