Redis实现简单的分布式锁及限流工具

583 阅读6分钟

基于SpringBoot + StringRedisTemplate实现分布式锁和限流工具

Redis实现分布式锁

当项目中存在高并发争夺同一资源时,容易发生资源冲突导致数据不一致的情况。一般对于读取、修改(判断)、写入(Read-Modify-Write)这种逻辑操作时,当存在并发操作,那么就有可能导致数据被修改错。这种情况下通常需要使用锁来保护共享资源,当在读取判断修改时,堵塞其它并发请求,以实现资源保护的目的。

在单机应用中我们使用编程语言提供的锁机制即可,例如Java的synchronize和重入锁。但在分布式应用中,单机的锁并不能保证多台机器中逻辑以正确顺序运行。为此需要应用第三方平台来锁住并发资源,而Redis正是符合锁工具的使用。

原理

使用Redis作为分布式锁的实现时,我们可以使用string类型,那么就有两种情况:

  • 当key锁值为0或不存在时,表示没有线程获取到锁;
  • 当key锁值为1时(或不为0),表示当前已有线程获取到锁了。

实现分布式锁需要保证的是在加锁和解锁时命令都需要保证原子性。

加锁

需要注意加锁命令的原子性,同时需要设置过期时间

加锁命令可通过 SET key value [EX seconds | PX milliseconds] [NX] 来实现。该命令有两个功能:

  • SET key value NX 命令是指当key不存在时,那么key就会被创建,并将值设置为value;如果key已存在,那么set nx命令将不做任何赋值操作。
  • SET key value NX PX 10000 命令与前面的相比,多了PX参数;该参数是设置key过期时间的。设置过期时间是为了防止已获取锁的线程执行中异常,导致最后没有释放锁,从而锁一直被持有的场景,那么设置过期时间,就可以让锁自动失效。

解锁

解锁需要原子性,同时需要判断解锁的客户端

解锁操作同样需要是原子性的操作。在解锁过程中需要防止其它客户端将自身客户端获取的锁给释放掉。为此我们在解锁时需要区分来自不同客户端的锁操作,我们可以在锁变量的值上入手,将值设置为唯一的,每个客户端加锁时设置的值都是唯一的,那么在解锁时即可通过判断该唯一值将锁给释放掉了。

释放锁使用DEL key命令删除即可。由于存在锁被误释放的风险,为此还需要区分不同客户端的锁操作,所以这里需要通过Lua脚本来实现解锁。

//释放锁 比较unique_value是否相等,避免误释放
if redis.call("get",KEYS[1]) == ARGV[1] then
    return redis.call("del",KEYS[1])
else
    return 0
end

实现

  • RedisUtils 封装Redis工具,分别是SET nx px命令和执行lua脚本的方法。
@Component
@Slf4j
public class RedisUtils {

    @Autowired
    private StringRedisTemplate redisTemplate;

    public Boolean setNx(String key,String value,int timeout) {
        return redisTemplate.opsForValue().setIfAbsent(key,value,timeout, TimeUnit.SECONDS);
    }

    /**
     * 执行脚本
     * @param redisScript
     * @param keys
     * @param args
     * @return
     */
    public Boolean execLimitLua(RedisScript<Long> redisScript, List<String> keys,String... args) {
        try {
            Long result = redisTemplate.execute(redisScript,keys,args);
            return result != null && result.intValue() == 1;
        }catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }
}
  • 锁服务:LockService
@Service
@Slf4j
public class LockService {

    @Autowired
    private RedisUtils redisUtils;

    private DefaultRedisScript<Long> redisScript;


    @PostConstruct
    public void init() {
        redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Long.class);
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("unLock.lua")));
    }

    @Value("${redis.lock.sleep-time}")
    private int sleepTime;

    @Value("${redis.lock.time-out}")
    private int timeOut;

    /**
     * 在指定时间范围内获取锁,堵塞
     * @param key
     * @param value
     * @param blockTime 获取锁的堵塞时间
     * @return
     * @throws InterruptedException
     */
    public boolean lock(String key,String value,int blockTime,AtomicInteger tryCount) throws InterruptedException {
        while(blockTime >= 0) {
            tryCount.incrementAndGet();
            if(redisUtils.setNx(key,value,timeOut)) {
                return true;
            }
            log.info("未获取锁,休眠后尝试再获取:value:" + value);
            blockTime = blockTime - sleepTime;
            Thread.sleep(sleepTime);
        }
        return false;
    }

    /**
     * 尝试获取锁,不堵塞
     * @param key
     * @param value
     * @return
     */
    public boolean tryLock(String key, String value) {
        if(redisUtils.setNx(key,value,timeOut)) {
            return true;
        }
        return false;
    }

    /**
     * 解锁
     * @param key
     * @param sign
     * @return
     */
    public boolean unLock(String key,String sign) {
        return redisUtils.execLimitLua(redisScript, Collections.singletonList(key),sign);
    }

}

  • unLock.lua
-- 加锁步骤:通过 SET key value NX PX 即可
-- 1. 获取锁
-- 2. 判断锁是否存在;
-- 3. 设置锁
-- 解锁步骤:
-- 1. 获取锁;
-- 2. 判断锁变量是否是本地变量;
-- 3. 释放锁,即删除key;
if redis.call("get",KEYS[1]) == ARGV[1] then
    return redis.call("del",KEYS[1])
else
    return 0
end

测试

添加10个任务来模拟任务并发

@RunWith(SpringRunner.class)
@SpringBootTest
@Slf4j
public class RedisLockTest {

    String key = "lock_test";
    int blockTime = 1000;    // 最多堵塞 1s

    @Autowired
    private LockService lockService;

    @Test
    public void testLock() throws InterruptedException {
        long startTime = System.currentTimeMillis();
        int clientNum = 10;
        CountDownLatch countDownLatch = new CountDownLatch(10);
        ExecutorService service = Executors.newFixedThreadPool(10);
        for(int i = 0 ;i < clientNum; i++) {
            TryLockTask task = new TryLockTask(lockService,countDownLatch);
            service.submit(task);
        }
        Thread.sleep(10);
        countDownLatch.await();
        System.out.println("任务提交完成;执行时长:" + (System.currentTimeMillis() - startTime) + "ms");
//        Thread.currentThread().join();  // 堵塞线程,观察日志输出
    }

    public class TryLockTask implements Runnable {

        LockService service;
        AtomicInteger tryCount;
        CountDownLatch countDownLatch;
        public TryLockTask(LockService service,CountDownLatch countDownLatch) {
            this.service = service;
            tryCount = new AtomicInteger();
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run() {
            try {
                long startTime = System.currentTimeMillis();
                String name = Thread.currentThread().getName();
                log.info("{} :准备获取锁",name);
                boolean hasLock = service.lock(key,name,blockTime,tryCount);
                if(!hasLock) {
                    log.info("{}:获取锁超时失败;获取次数为:{}" ,name,tryCount.intValue());
                    return;
                }

                Thread.sleep(100);    // 假装在工作
                log.info("{} : 干活中",name);

                boolean hasUnLock = service.unLock(key,name);
                if(!hasUnLock) {
                    log.info("{}: 解锁失败",name);
                    return;
                }
                countDownLatch.countDown();
                log.info("{}:干活完成;获取次数为:{};运行时长:{}",name,tryCount.intValue(),(System.currentTimeMillis() - startTime));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
}

结论

通过简单的SET NX命令以及Lua脚本即可实现简单的分布式锁,而这样简单的锁实现还是存在很明显的缺点的:

  • 锁不可被重入;
  • 锁过期后的续约问题;
  • 基于单节点Redis实现的分布式锁,不具有可靠性;

使用Redis官方提供的Reddission工具就可以解决上述问题。

Redis 实现限流

在Web应用中有挺多可以用到限流的场景,例如单位时间内超过一定登录次数时出现弹窗验证、消息推送时同一内容或同一类型的短信在一天内只允许发送n条、接口调用频率限制防止调用过频...等等。

如果是在单机应用中,使用map相关的数据结构来保存目标-调用次数这样的数据即可实现简单的限流。而在分布式应用中需要使用第三方平台,这里使用Redis实现限流算法。

原理

实现算法有很多种,如计数器算法、漏桶算法、令牌算法、滑动窗口算法等。不同算法的适应场景不太一样。这里以“消息推送时同一内容或同一类型的短信在一天内只允许发送n条”这种场景为例,来实现滑动窗口算法。

在消息推送中,为防止消息的重复投递导致用户的不满,需要对消息发送做限流处理,例如有这样的限流规则:

  • 一天内一个用户只能收到某个渠道的消息N次;
  • 相同内容且相同模板下短时间内(一小时)只能发送给一个用户。

这种在单位时间范围内限制某条消息只能发送N次,滑动窗口恰好满足这种特性,随着时间的变化,次数也会跟着变化。

那么怎么实现呢?

可以知道这种限流策略是跟时间相关,而时间是线性增长的;我们采用Redis来实现,在Redis中满足线性增长的数据结构,可以想到是用ZSET 来实现,通过score来保存线性增长的时间戳,那么想要获取单位时间内的次数就可以通过下面的命令来实现。

ZRANGEBYSCORE key time1 time2

该命令表示获取key在time1到time2时间内的所有元素。

那么可以使用ZREMRANGEBYSCORE key min max命令先移除不满足时间范围内的元素,其中min为0,max为当前时间减去限流的时间,即

ZREMRANGEBYSCORE key 0 curTime-limitTime

移除元素后再统计ZSET中的元素,那么通过判断元素个数是否超过阈值来决定是否限流。

  • 当zcard返回nil或值小于阈值,那么添加元素,设置score为当前时间戳,设置值为唯一的value,返回0表示还没达到阈值。
  • 当zcard返回的值大于阈值,那么返回1表示已超过阈值。

执行算法期间其命令必须是原子性的,所以需要利用Lua来实现上述的逻辑,实现如下:

  1. 通过zremrangeByScore移除已过期的数据;当前时间减去限流时间表示不符合窗口的时间,即过期前的时间。通过移除 [0,过期前的时间] 即表示移除已过期的数据。
  2. 利用zcard命令统计当前当前元素数量;
  3. 判断数值是否超过阈值;
    • 当zcard返回nil或值小于阈值,那么添加元素,设置score为当前时间戳,设置值为唯一的value,返回0表示还没达到阈值。
    • 当zcard返回的值大于阈值,那么返回1表示已超过阈值。
--KEYS[1]: 限流 key
--ARGV[1]: 限流窗口,单位秒
--ARGV[2]: 当前时间戳(作为score)
--ARGV[3]: 阈值
--ARGV[4]: score 对应的唯一value
-- 1. 移除已过期的数据
redis.call('zremrangeByScore', KEYS[1], 0, ARGV[2]-ARGV[1])
-- 2. 统计当前元素数量
local res = redis.call('zcard', KEYS[1])
-- 3. 是否超过阈值
if (res == nil) or (res < tonumber(ARGV[3])) then
    redis.call('zadd', KEYS[1], ARGV[2], ARGV[4])
    redis.call('expire', KEYS[1], ARGV[1])
    return 0
else
    return 1
end

实现

  • 限流服务实现:LimitService
@Service
public class LimitService {
    private static final String PRE = "limit";

    @Autowired
    private RedisUtils redisUtils;

    private DefaultRedisScript<Long> redisScript;

    @PostConstruct
    public void init() {
        redisScript = new DefaultRedisScript<>();
        redisScript.setResultType(Long.class);
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("limit.lua")));
    }

    /**
     * @param key       key
     * @param limit     限制次数
     * @param limitTime 限制时间,单位秒
     * @return true 表示超过阈值
     */
    public boolean limit(String key,int limit,int limitTime) {
        return limit(key, limit, limitTime,TimeUnit.SECONDS);
    }

    public boolean limit(String key, int limit, int limitTime, TimeUnit unit) {
        switch (unit) {
            case DAYS:
                limitTime = limitTime * 24 * 60 * 60;
                break;
            case HOURS:
                limitTime = limitTime * 60 * 60;
                break;
            case MINUTES:
                limitTime = limitTime * 60;
                break;
            case MILLISECONDS:
                limitTime = limitTime / 1000;
                break;    
        }
        long nowTime = System.currentTimeMillis() / 1000;
        String uuid = UUID.randomUUID().toString();
        return redisUtils.execLimitLua(redisScript, Collections.singletonList(key),String.valueOf(limitTime),String.valueOf(nowTime),String.valueOf(limit),uuid);
    }

    /**
     * key生成
     * @param biz 业务
     * @param unionId 唯一id
     * @return
     */
    public String generateKey(int biz,String unionId) {
        return PRE + "_" + biz + "_" + unionId;
    }

}
  • 滑动窗口算法Lua实现
--KEYS[1]: 限流 key
--ARGV[1]: 限流窗口,单位秒
--ARGV[2]: 当前时间戳(作为score)
--ARGV[3]: 阈值
--ARGV[4]: score 对应的唯一value
-- 1. 移除已过期的数据
redis.call('zremrangeByScore', KEYS[1], 0, ARGV[2]-ARGV[1])
-- 2. 统计当前元素数量
local res = redis.call('zcard', KEYS[1])
-- 3. 是否超过阈值
if (res == nil) or (res < tonumber(ARGV[3])) then
    redis.call('zadd', KEYS[1], ARGV[2], ARGV[4])
    redis.call('expire', KEYS[1], ARGV[1])
    return 0
else
    return 1
end

测试

@RunWith(SpringRunner.class)
@SpringBootTest
public class RedisLimitTest {

    int limit = 10;

    @Autowired
    private LimitService limitService;

    /**
     * 手机号 135***997 在 1小时内限制10次发送
     */
    @Test
    public void testLimitService() {
        String key = limitService.generateKey(1,"135*****997");
        for(int i=0;i<11;i++) {
            System.out.println("是否限流:" + limitService.limit(key,limit,10, TimeUnit.HOURS));
        }
    }

}

可以看到,第11次调用时发生限流:

jh4wqJ.png

优化

在使用限流服务时,需要关联使用上LimitService服务。这很显然地会将此逻辑与真正的业务逻辑相耦合。为此可以使用注解的方式来优化使用,通过注解+Aop的方式将限流逻辑与业务逻辑拆分开。

注解优化

假设场景是同一个ip在单位时间内限制访问N次

  • 定义注解,添加限制次数、限制时间、限制时间单位属性

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ControllerLimiter {

    /**
     * 限制次数
     * @return
     */
    int limit() default 1;

    /**
     * 限制时间
     */
    int limitTime() default 1;

    /**
     * 时间单位
     */
    TimeUnit unit() default TimeUnit.SECONDS;


}
  • 定义Aop,拦截有注解@ControllerLimiter的方法,进行限流;当超过阈值时,抛出异常。

ps:这里的异常实际上用自定义的异常比较好些,然后通过全局捕获该异常做特殊的输出返回。

@Aspect
@Component
@EnableAspectJAutoProxy(proxyTargetClass = true)
@Slf4j
public class ControllerLimiterAspect {

    private final static int BIZ = 1;   // Controller层

    @Autowired
    private LimitService limitService;

    @Pointcut("@annotation(com.example.springredisdemo.annotation.ControllerLimiter)")
    private void check() {}

    @Before("check()")
    public void before(JoinPoint joinPoint) throws Exception {
        MethodSignature methodSignature = (MethodSignature)joinPoint.getSignature();
        ControllerLimiter rateLimiter = methodSignature.getMethod().getAnnotation(ControllerLimiter.class);
        if(rateLimiter != null) {
            int limit = rateLimiter.limit();
            int limitTime = rateLimiter.limitTime();
            TimeUnit unit = rateLimiter.unit();

            ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            HttpServletRequest request =  attributes.getRequest();
            String ip = IpUtils.getRequestIp(request);

            String key = limitService.generateKey(BIZ,ip);
            if(limitService.limit(key,limit,limitTime,unit)) {
                log.info("method has been limited;ip:{},limit:{},limitTime:{}",ip,limit,limitTime);
                throw new RuntimeException("method has been limited");
            }
        }
    }
}
  • 定义Controller测试接口
@RestController
public class IndexController {

    /**
     * 同一个ip 在10s 内最多调用10次
     * @return
     */
    @ControllerLimiter(limit = 10,limitTime = 10)
    @RequestMapping("/index")
    public String index() {
        return "Hello World";
    }
    
}

访问该地址,当10内访问次数大于10次时就会返回RuntimeException异常。 jLRCut.png

总结

通过Lua实现滑动窗口算法来可以实现Redis的限流,此外还可以通过注解去优化限流服务的使用,达到解耦的作用。在该例子中,还存在几个可以优化的地方:

  • 直接抛异常体验不好;当超过阈值时,应当抛出自定义的异常,然后通过全局异常捕获该信息,来保证返回结果的正确性;
  • 限流实现算法单一;当滑动窗口不满足场景时应当可选择不同类型的方法实现。为此可以在注解中添加算法类型属性,在拦截中通过算法类型实例化出不同的算法实现。
  • 限流策略单一;这里仅通过ip去判断方法的调用量,而在实际业务中可能存在需要针对某个用户进行调用限制,或者售卖物品中某个商品的下单限制,那么就需要使用其它策略来进行业务判断。