- 最近在一个项目中用到了redis分布式锁,这边写简化了下redisson的分布式锁的逻辑,修复了redisson的一个并发bug.实现了一个可重入分布式锁
- 还有一些点可以优化
- 上锁以后再加上一层cache,再次加锁可以通过缓存走,就不需要与redis的网络交互了。
- 宕机时可以加上shutdownHook处理,尽快释放本进程加锁成功的锁。
等有时间再回来优化下。
import com.xkh.util.StringUtil;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
import io.netty.util.HashedWheelTimer;
import io.netty.util.Timeout;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* @author kaihui.xia
* @since 2021/3/15
*/
public class RedisDistributeLock {
private static final Logger LOGGER = LoggerFactory.getLogger(RedisDistributeLock.class);
private RedisCommands redisCommands;
//key前缀,一般取appName+业务流程标识
private String redisLockingKeyPrefix;
private final String clientId = UUID.randomUUID().toString();
private final HashedWheelTimer timer = new HashedWheelTimer(new DefaultThreadFactory("reids-watchdog-timer"), 50, TimeUnit.MILLISECONDS, 1024, false);
private static final ConcurrentHashMap<String, ExpirationEntry> LOCKED_CACHE = new ConcurrentHashMap<>();
private ExecutorService executor = new ThreadPoolExecutor(0,
Runtime.getRuntime().availableProcessors(),
60, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(),
new ThreadFactory() {
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final String namePrefix = "RedisDistributeLock-unLock-Pool";
@Override
public Thread newThread(Runnable r) {
return new Thread(Thread.currentThread().getThreadGroup(), r,
namePrefix + threadNumber.getAndIncrement(),
0);
}
}, new ThreadPoolExecutor.CallerRunsPolicy());
/**
* redis乐观锁默认超时时间 毫秒
*/
private static final long DEFAULT_LOCK_REDIS_EXPIRE_MILLISECONDS = 200L;
/**
* 默认分布式锁间隔时间 毫秒
*/
private static final long DEFAULT_RETRY_DELAY_MILLISECONDS = 100L;
public RedisDistributeLock(RedisCommands redisCommands, String redisLockingKeyPrefix) {
if (null == redisCommands || StringUtil.isBlank(redisLockingKeyPrefix)) {
throw new IllegalArgumentException("invalid params, redisLockingKeyPrefix: " + redisLockingKeyPrefix
+ " redisCommands: " + redisCommands);
}
this.redisCommands = redisCommands;
this.redisLockingKeyPrefix = redisLockingKeyPrefix;
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
executor.shutdown();
LOGGER.info("shut down successfully to recycle RedisDistributeLock-unLock-Pool");
}));
}
public <V> V handleInLock(String lockKey, Callable<V> action, Callable<LockFailActionResult<V>> lockFailAction, int retryCount) throws Exception {
return handleInLock(lockKey, action, lockFailAction, retryCount, DEFAULT_RETRY_DELAY_MILLISECONDS, DEFAULT_LOCK_REDIS_EXPIRE_MILLISECONDS);
}
public <V> V handleInLock(String lockKey, Callable<V> action, Callable<LockFailActionResult<V>> lockFailAction,
int retryCount, long delayMillSeconds, long lockExpireMilliSeconds) throws Exception {
if (StringUtil.isBlank(lockKey) || null == action || retryCount < 0 || delayMillSeconds < 0 || lockExpireMilliSeconds < 0) {
throw new IllegalArgumentException("illegal argument lockKey:" + lockKey
+ " action:" + action
+ " retryCount:" + retryCount
+ " delayMillSeconds" + delayMillSeconds
+ " lockExpireMilliSeconds" + lockExpireMilliSeconds);
}
if (null == lockFailAction) {
lockFailAction = () -> new LockFailActionResult<>(null);
}
RedisLock lock = obtainLockSubject(lockKey, lockExpireMilliSeconds);
while (retryCount-- > 0) {
if (lock.tryLock()) {
try {
return action.call();
} finally {
try {
lock.unLock();
} catch (IllegalStateException e) {
//如果解锁的时候redis中没有了该key(原因: 1.key超时,2.key因为redis故障重启或者人为删除),这两种其实都有可能导致资源被并发访问
LOGGER.error("no value in redis, redisLockingKeyPrefix:{},lock:{} e:{}", redisLockingKeyPrefix, lock.toString(), e.getMessage());
}
}
}
LOGGER.warn("try lock failed, redisLockingKeyPrefix:{}, lock:{}", redisLockingKeyPrefix, lock);
LockFailActionResult<V> lockFailActionResult = lockFailAction.call();
switch (lockFailActionResult.getLock_control_enum()) {
case BREAK:
return lockFailActionResult.getResult();
case CONTINUE:
default:
}
try {
Thread.sleep(delayMillSeconds);
} catch (InterruptedException e) {
LOGGER.warn("try lock interrupted, redisLockingKeyPrefix:{}, lock:{}, e:{}", redisLockingKeyPrefix, lock, e.getMessage());
}
}
return null;
}
/**
* 获取本次tryLock操作唯一锁对象
* @return
*/
private RedisLock obtainLockSubject(String lockKey, long lockExpireMilliSeconds) {
return new RedisLock(lockKey, lockExpireMilliSeconds);
}
public static class ExpirationEntry {
private final Map<Long, Integer> threadIds = new LinkedHashMap<>();
private volatile Timeout timeout;
public ExpirationEntry() {
super();
}
public synchronized void addThreadId(long threadId) {
Integer counter = threadIds.get(threadId);
if (counter == null) {
counter = 1;
} else {
counter++;
}
threadIds.put(threadId, counter);
}
public synchronized boolean hasNoThreads() {
return threadIds.isEmpty();
}
public synchronized Long getFirstThreadId() {
if (threadIds.isEmpty()) {
return null;
}
return threadIds.keySet().iterator().next();
}
public synchronized void removeThreadId(long threadId) {
Integer counter = threadIds.get(threadId);
if (counter == null) {
return;
}
counter--;
if (counter == 0) {
threadIds.remove(threadId);
} else {
threadIds.put(threadId, counter);
}
}
public void setTimeout(Timeout timeout) {
this.timeout = timeout;
}
public Timeout getTimeout() {
return timeout;
}
}
private final class RedisLock {
/**
* 真实redis锁对应的key redisLockingKeyPrefix + ":" + lockKey
*/
private final String lockKey;
private volatile Long lockExpireMilliSeconds;
private RedisLock(String path, long lockExpireMilliSeconds) {
this.lockKey = constructLockKey(path);
this.lockExpireMilliSeconds = lockExpireMilliSeconds;
}
private String constructLockKey(String path) {
return RedisDistributeLock.this.redisLockingKeyPrefix + ":" + path;
}
public boolean tryLock() {
CompletableFuture<Long> future = tryLock(Thread.currentThread().getId(), -1, TimeUnit.MILLISECONDS);
try {
return null == future.get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
private CompletableFuture<Long> tryLock(long threadId, long leaseTime, TimeUnit unit) {
CompletableFuture<Long> ttlRemainFuture;
if (-1 != leaseTime) {
ttlRemainFuture = tryInnerLock(threadId, leaseTime, unit);
} else {
ttlRemainFuture = tryInnerLock(threadId, lockExpireMilliSeconds, TimeUnit.MILLISECONDS);
}
ttlRemainFuture.whenComplete((ttl, e) -> {
if (null != e) {
return;
}
if (null == ttl) {
if (-1 == leaseTime) {
ExpirationEntry expirationEntry = new ExpirationEntry();
ExpirationEntry old = LOCKED_CACHE.putIfAbsent(lockKey, expirationEntry);
if (null != old) {
old.addThreadId(threadId);
} else {
expirationEntry.addThreadId(threadId);
renewExpiration();
}
}
}
});
return ttlRemainFuture;
}
private void renewExpiration() {
ExpirationEntry expirationEntry = LOCKED_CACHE.get(lockKey);
if (null == expirationEntry) {
return;
}
Timeout task = timer.newTimeout(timeout -> {
ExpirationEntry entry = LOCKED_CACHE.get(lockKey);
if (null == entry) {
return;
}
Long tid = entry.getFirstThreadId();
if (null == tid) {
return;
}
CompletableFuture.supplyAsync(() ->
(boolean) redisCommands.eval("if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then " +
"redis.call('pexpire', KEYS[1], ARGV[1]); " +
"return 1; " +
"end; " +
"return 0;", ScriptOutputType.BOOLEAN, lockExpireMilliSeconds, getLockName(tid)))
.whenComplete((res, ex) -> {
if (null != ex) {
LOGGER.error("can't update lock: " + lockKey + " expiration", ex);
LOCKED_CACHE.remove(lockKey);
return;
}
if (res) {
renewExpiration();
}
});
}, lockExpireMilliSeconds / 3, TimeUnit.MILLISECONDS);
expirationEntry.setTimeout(task);
}
private CompletableFuture<Long> tryInnerLock(long threadId, long leaseTime, TimeUnit unit) {
CompletableFuture<Long> evalFuture = CompletableFuture.supplyAsync(() -> {
String res = (String) redisCommands.eval("if (redis.call('exists', KEYS[1]) == 0) then " +
"redis.call('hincrby', KEYS[1], ARGV[2], 1); " +
"redis.call('pexpire', KEYS[1], ARGV[1]); " +
"return nil; " +
"end; " +
"if (redis.call('hexists', KEYS[1], ARGV[2]) == 1) then " +
"redis.call('hincrby', KEYS[1], ARGV[2], 1); " +
"redis.call('pexpire', KEYS[1], ARGV[1]); " +
"return nil; " +
"end; " +
"return redis.call('pttl', KEYS[1]);",
ScriptOutputType.VALUE, new String[]{lockKey}, unit.toMillis(leaseTime), getLockName(threadId));
if (null != res) {
return Long.parseLong(res);
} else {
return null;
}
});
return evalFuture;
}
private String getLockName(long threadId) {
return clientId + ":" + threadId;
}
public void unLock() {
long threadId = Thread.currentThread().getId();
unLock(threadId);
if (!getLockName(threadId).equals(RedisDistributeLock.this.redisCommands.get(this.lockKey))) {
throw new IllegalStateException("Lock was released in the store due to expiration. " +
"The integrity of data protected by this lock may have been compromised.");
}
if (Thread.currentThread().isInterrupted()) {
RedisDistributeLock.this.executor.execute(() -> unLock(threadId));
} else {
unLock(threadId);
}
}
private CompletableFuture<Boolean> unLock(Long threadId) {
return CompletableFuture.supplyAsync(() -> (Boolean) redisCommands.eval(
"if (redis.call('hexists', KEYS[1], ARGV[2]) == 0) then " +
"return nil;" +
"end; " +
"local counter = redis.call('hincrby', KEYS[1], ARGV[2], -1); " +
"if (counter > 0) then " +
"redis.call('pexpire', KEYS[1], ARGV[1]); " +
"return 0; " +
"else " +
"redis.call('del', KEYS[1]); " +
"return 1; " +
"end; " +
"return nil;", ScriptOutputType.BOOLEAN, new String[]{lockKey}, lockExpireMilliSeconds, getLockName(threadId)))
.whenComplete((res, e) -> {
ExpirationEntry task = LOCKED_CACHE.get(lockKey);
if (task == null) {
return;
}
if (threadId != null) {
task.removeThreadId(threadId);
}
if (threadId == null || task.hasNoThreads()) {
Timeout timeout = task.getTimeout();
if (timeout != null) {
timeout.cancel();
}
LOCKED_CACHE.remove(lockKey);
}
if (null != e) {
return;
}
if (null == res) {
throw new RuntimeException("attempt to unlock lock, not locked by current thread by node id: "
+ lockKey + " thread-id: " + threadId);
}
});
}
@Override
public String toString() {
return "RedisLock{" +
"lockKey='" + lockKey + '\'' +
", lockExpireMilliSeconds=" + lockExpireMilliSeconds +
'}';
}
}
public enum LOCK_CONTROL_ENUM {
/**
* 控制是否再次获取锁
*/
CONTINUE, BREAK;
}
public static class LockFailActionResult<T> {
private T result;
private LOCK_CONTROL_ENUM lock_control_enum = LOCK_CONTROL_ENUM.CONTINUE;
public LockFailActionResult(LOCK_CONTROL_ENUM lock_control_enum, T result) {
this.lock_control_enum = lock_control_enum;
this.result = result;
}
public LockFailActionResult(T result) {
this.result = result;
}
public T getResult() {
return result;
}
public void setResult(T result) {
this.result = result;
}
public LOCK_CONTROL_ENUM getLock_control_enum() {
return lock_control_enum;
}
}
}