Redis分布式锁 基于lettuce

796 阅读3分钟
  • 最近在一个项目中用到了redis分布式锁,这边写简化了下redisson的分布式锁的逻辑,修复了redisson的一个并发bug.实现了一个可重入分布式锁
  • 还有一些点可以优化
  1. 上锁以后再加上一层cache,再次加锁可以通过缓存走,就不需要与redis的网络交互了。
  2. 宕机时可以加上shutdownHook处理,尽快释放本进程加锁成功的锁。

等有时间再回来优化下。


import com.xkh.util.StringUtil;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author kaihui.xia
 * @since 2021/3/15
 */
public class RedisDistributeLock {
    private static final Logger LOGGER = LoggerFactory.getLogger(RedisDistributeLock.class);

    private RedisCommands redisCommands;

    //key前缀,一般取appName+业务流程标识
    private String redisLockingKeyPrefix;

    private final String clientId = UUID.randomUUID().toString();

    private final HashedWheelTimer timer = new HashedWheelTimer(new DefaultThreadFactory("reids-watchdog-timer"), 50, TimeUnit.MILLISECONDS, 1024, false);

    private static final ConcurrentHashMap<String, ExpirationEntry> LOCKED_CACHE = new ConcurrentHashMap<>();

    private ExecutorService executor = new ThreadPoolExecutor(0,
            Runtime.getRuntime().availableProcessors(),
            60, TimeUnit.SECONDS,
            new LinkedBlockingQueue<>(),
            new ThreadFactory() {
                private final AtomicInteger threadNumber = new AtomicInteger(1);
                private final String namePrefix = "RedisDistributeLock-unLock-Pool";
                @Override
                public Thread newThread(Runnable r) {
                    return new Thread(Thread.currentThread().getThreadGroup(), r,
                            namePrefix + threadNumber.getAndIncrement(),
                            0);
                }
            }, new ThreadPoolExecutor.CallerRunsPolicy());

    /**
     * redis乐观锁默认超时时间 毫秒
     */
    private static final long DEFAULT_LOCK_REDIS_EXPIRE_MILLISECONDS = 200L;

    /**
     * 默认分布式锁间隔时间 毫秒
     */
    private static final long DEFAULT_RETRY_DELAY_MILLISECONDS = 100L;

    public RedisDistributeLock(RedisCommands redisCommands, String redisLockingKeyPrefix) {
        if (null == redisCommands || StringUtil.isBlank(redisLockingKeyPrefix)) {
            throw new IllegalArgumentException("invalid params, redisLockingKeyPrefix: " + redisLockingKeyPrefix
                                                                + " redisCommands: " + redisCommands);
        }
        this.redisCommands = redisCommands;
        this.redisLockingKeyPrefix = redisLockingKeyPrefix;
        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
            executor.shutdown();
            LOGGER.info("shut down successfully to recycle RedisDistributeLock-unLock-Pool");
        }));
    }

    public <V> V handleInLock(String lockKey, Callable<V> action, Callable<LockFailActionResult<V>> lockFailAction, int retryCount) throws Exception {
        return handleInLock(lockKey, action, lockFailAction, retryCount, DEFAULT_RETRY_DELAY_MILLISECONDS, DEFAULT_LOCK_REDIS_EXPIRE_MILLISECONDS);
    }

    public <V> V handleInLock(String lockKey, Callable<V> action, Callable<LockFailActionResult<V>> lockFailAction,
                              int retryCount, long delayMillSeconds, long lockExpireMilliSeconds) throws Exception {
        if (StringUtil.isBlank(lockKey) || null == action || retryCount < 0 || delayMillSeconds < 0 || lockExpireMilliSeconds < 0) {
            throw new IllegalArgumentException("illegal argument lockKey:" + lockKey
                                                                + " action:" + action
                                                                + " retryCount:" + retryCount
                                                                + " delayMillSeconds" + delayMillSeconds
                                                                + " lockExpireMilliSeconds" + lockExpireMilliSeconds);
        }
        if (null == lockFailAction) {
            lockFailAction = () -> new LockFailActionResult<>(null);
        }

        RedisLock lock = obtainLockSubject(lockKey, lockExpireMilliSeconds);
        while (retryCount-- > 0) {
            if (lock.tryLock()) {
                try {
                    return action.call();
                } finally {
                    try {
                        lock.unLock();
                    } catch (IllegalStateException e) {
                        //如果解锁的时候redis中没有了该key(原因: 1.key超时,2.key因为redis故障重启或者人为删除),这两种其实都有可能导致资源被并发访问
                        LOGGER.error("no  value in redis, redisLockingKeyPrefix:{},lock:{} e:{}", redisLockingKeyPrefix, lock.toString(), e.getMessage());
                    }
                }
            }
            LOGGER.warn("try lock failed, redisLockingKeyPrefix:{}, lock:{}", redisLockingKeyPrefix, lock);
            LockFailActionResult<V> lockFailActionResult = lockFailAction.call();
            switch (lockFailActionResult.getLock_control_enum()) {
                case BREAK:
                    return lockFailActionResult.getResult();
                case CONTINUE:
                default:
            }
            try {
                Thread.sleep(delayMillSeconds);
            } catch (InterruptedException e) {
                LOGGER.warn("try lock interrupted, redisLockingKeyPrefix:{}, lock:{}, e:{}", redisLockingKeyPrefix, lock, e.getMessage());
            }
        }

        return null;
    }

    /**
     * 获取本次tryLock操作唯一锁对象
     * @return
     */
    private RedisLock obtainLockSubject(String lockKey, long lockExpireMilliSeconds) {
        return new RedisLock(lockKey, lockExpireMilliSeconds);
    }

    public static class ExpirationEntry {
        private final Map<Long, Integer> threadIds = new LinkedHashMap<>();
        private volatile Timeout timeout;

        public ExpirationEntry() {
            super();
        }

        public synchronized void addThreadId(long threadId) {
            Integer counter = threadIds.get(threadId);
            if (counter == null) {
                counter = 1;
            } else {
                counter++;
            }
            threadIds.put(threadId, counter);
        }
        public synchronized boolean hasNoThreads() {
            return threadIds.isEmpty();
        }
        public synchronized Long getFirstThreadId() {
            if (threadIds.isEmpty()) {
                return null;
            }
            return threadIds.keySet().iterator().next();
        }
        public synchronized void removeThreadId(long threadId) {
            Integer counter = threadIds.get(threadId);
            if (counter == null) {
                return;
            }
            counter--;
            if (counter == 0) {
                threadIds.remove(threadId);
            } else {
                threadIds.put(threadId, counter);
            }
        }


        public void setTimeout(Timeout timeout) {
            this.timeout = timeout;
        }
        public Timeout getTimeout() {
            return timeout;
        }
    }

    private final class RedisLock {
        /**
         * 真实redis锁对应的key redisLockingKeyPrefix + ":" + lockKey
         */
        private final String lockKey;

        private volatile Long lockExpireMilliSeconds;

        private RedisLock(String path, long lockExpireMilliSeconds) {
            this.lockKey = constructLockKey(path);
            this.lockExpireMilliSeconds = lockExpireMilliSeconds;
        }

        private String constructLockKey(String path) {
            return RedisDistributeLock.this.redisLockingKeyPrefix + ":" + path;
        }

        public boolean tryLock() {

            CompletableFuture<Long> future = tryLock(Thread.currentThread().getId(), -1, TimeUnit.MILLISECONDS);

            try {
                return null == future.get();
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            } catch (ExecutionException e) {
                throw new RuntimeException(e);
            }
        }

        private CompletableFuture<Long> tryLock(long threadId, long leaseTime, TimeUnit unit) {
            CompletableFuture<Long> ttlRemainFuture;
            if (-1 != leaseTime) {
                ttlRemainFuture = tryInnerLock(threadId, leaseTime, unit);
            } else {
                ttlRemainFuture = tryInnerLock(threadId, lockExpireMilliSeconds, TimeUnit.MILLISECONDS);
            }

            ttlRemainFuture.whenComplete((ttl, e) -> {
                if (null != e) {
                    return;
                }

                if (null == ttl) {
                    if (-1 == leaseTime) {
                        ExpirationEntry expirationEntry = new ExpirationEntry();
                        ExpirationEntry old = LOCKED_CACHE.putIfAbsent(lockKey, expirationEntry);
                        if (null != old) {
                            old.addThreadId(threadId);
                        } else {
                            expirationEntry.addThreadId(threadId);
                            renewExpiration();
                        }
                    }
                }
            });
            return ttlRemainFuture;
        }

        private void renewExpiration() {
            ExpirationEntry expirationEntry = LOCKED_CACHE.get(lockKey);
            if (null == expirationEntry) {
                return;
            }
            Timeout task = timer.newTimeout(timeout -> {
                ExpirationEntry entry = LOCKED_CACHE.get(lockKey);
                if (null == entry) {
                    return;
                }

                Long tid = entry.getFirstThreadId();
                if (null == tid) {
                    return;
                }

                CompletableFuture.supplyAsync(() ->
                        (boolean) redisCommands.eval("if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then " +
                                       "redis.call('pexpire', KEYS[1], ARGV[1]); " +
                                       "return 1; " +
                                       "end; " +
                                       "return 0;", ScriptOutputType.BOOLEAN, lockExpireMilliSeconds, getLockName(tid)))
                                .whenComplete((res, ex) -> {
                                    if (null != ex) {
                                        LOGGER.error("can't update lock: " + lockKey + " expiration", ex);
                                        LOCKED_CACHE.remove(lockKey);
                                        return;
                                    }

                                    if (res) {
                                        renewExpiration();
                                    }
                                });

            }, lockExpireMilliSeconds / 3, TimeUnit.MILLISECONDS);
            expirationEntry.setTimeout(task);
        }

        private CompletableFuture<Long> tryInnerLock(long threadId, long leaseTime, TimeUnit unit) {
            CompletableFuture<Long> evalFuture = CompletableFuture.supplyAsync(() -> {
                String res = (String) redisCommands.eval("if (redis.call('exists', KEYS[1]) == 0) then " +
                                "redis.call('hincrby', KEYS[1], ARGV[2], 1); " +
                                "redis.call('pexpire', KEYS[1], ARGV[1]); " +
                                "return nil; " +
                                "end; " +
                                "if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then " +
                                "redis.call('hincrby', KEYS[1], ARGV[2], 1); " +
                                "redis.call('pexpire', KEYS[1], ARGV[1]); " +
                                "return nil; " +
                                "end; " +
                                "return redis.call('pttl', KEYS[1]);",
                                ScriptOutputType.VALUE, new String[]{lockKey}, unit.toMillis(leaseTime), getLockName(threadId));
                if (null != res) {
                    return Long.parseLong(res);
                } else {
                    return null;
                }
            });
            return evalFuture;
        }

        private String getLockName(long threadId) {
            return clientId + ":" + threadId;
        }

        public void unLock() {

            long threadId = Thread.currentThread().getId();
            unLock(threadId);

            if (!getLockName(threadId).equals(RedisDistributeLock.this.redisCommands.get(this.lockKey))) {
                throw new IllegalStateException("Lock was released in the store due to expiration. " +
                        "The integrity of data protected by this lock may have been compromised.");
            }

            if (Thread.currentThread().isInterrupted()) {
                RedisDistributeLock.this.executor.execute(() -> unLock(threadId));
            } else {
                unLock(threadId);
            }
        }

        private CompletableFuture<Boolean> unLock(Long threadId) {
            return CompletableFuture.supplyAsync(() -> (Boolean) redisCommands.eval(
                    "if (redis.call('hexists', KEYS[1], ARGV[2]) == 0) then " +
                    "return nil;" +
                    "end; " +
                    "local counter = redis.call('hincrby', KEYS[1], ARGV[2], -1); " +
                    "if (counter > 0) then " +
                    "redis.call('pexpire', KEYS[1], ARGV[1]); " +
                    "return 0; " +
                    "else " +
                    "redis.call('del', KEYS[1]); " +
                    "return 1; " +
                    "end; " +
                    "return nil;", ScriptOutputType.BOOLEAN, new String[]{lockKey}, lockExpireMilliSeconds, getLockName(threadId)))
            .whenComplete((res, e) -> {
                ExpirationEntry task = LOCKED_CACHE.get(lockKey);
                if (task == null) {
                    return;
                }

                if (threadId != null) {
                    task.removeThreadId(threadId);
                }

                if (threadId == null || task.hasNoThreads()) {
                    Timeout timeout = task.getTimeout();
                    if (timeout != null) {
                        timeout.cancel();
                    }
                    LOCKED_CACHE.remove(lockKey);
                }

                if (null != e) {
                    return;
                }

                if (null == res) {
                    throw new RuntimeException("attempt to unlock lock, not locked by current thread by node id: "
                            + lockKey + " thread-id: " + threadId);
                }
            });
        }

        @Override
        public String toString() {
            return "RedisLock{" +
                    "lockKey='" + lockKey + '\'' +
                    ", lockExpireMilliSeconds=" + lockExpireMilliSeconds +
                    '}';
        }
    }



    public enum LOCK_CONTROL_ENUM {
        /**
         * 控制是否再次获取锁
         */
        CONTINUE, BREAK;
    }

    public static class LockFailActionResult<T> {
        private T result;

        private LOCK_CONTROL_ENUM lock_control_enum = LOCK_CONTROL_ENUM.CONTINUE;

        public LockFailActionResult(LOCK_CONTROL_ENUM lock_control_enum, T result) {
            this.lock_control_enum = lock_control_enum;
            this.result = result;
        }

        public LockFailActionResult(T result) {
            this.result = result;
        }

        public T getResult() {
            return result;
        }

        public void setResult(T result) {
            this.result = result;
        }

        public LOCK_CONTROL_ENUM getLock_control_enum() {
            return lock_control_enum;
        }
    }
}