分组锁

132 阅读3分钟

分组锁

思路

  1. 使用 map 存储每个组的锁对象
  2. 提供相应的方法进行并发控制
  3. 防止出现数据不再使用后造成的内存浪费, 需要有删除机制

分组锁对象

package cn.moquan.tools.lock;

import java.util.Map;
import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 分组锁
 * <br />
 *
 * @author :<b> moquan </b><br />
 */
public class GroupLock<T> {

    private static final AtomicLong THREAD_COUNT = new AtomicLong(0);

    private static final Integer CORE_THREAD_COUNT = 1000;

    private static final ScheduledExecutorService DEFAULT_SCHEDULED_THREAD_POOL = new ScheduledThreadPoolExecutor(
            CORE_THREAD_COUNT,
            runnable -> {
                Thread thread = new Thread(runnable);
                thread.setName("GroupLock-DefaultScheduledThreadPool-Thread-" + THREAD_COUNT.incrementAndGet());
                return thread;
            }
    );

    private final Map<T, GroupLockContext> map = new ConcurrentHashMap<>();

    public void lock(T key) throws Exception {
        keyCheck(key);
        GroupLockContext groupLockContext = map.computeIfAbsent(key, keyItem -> GroupLockContext.init());
        groupLockContext.lock();
    }

    public void unlock(T key) {
        keyCheck(key);
        GroupLockContext groupLockContext = map.getOrDefault(key, null);
        if (Objects.nonNull(groupLockContext)) {
            groupLockContext.unlock();
        }
    }

    public void remove(T key) {
        keyCheck(key);
        map.remove(key);
    }

    public void delayRemove(T key, long delayMilliSeconds, ScheduledThreadPoolExecutor threadPool) {
        delayRemove(key, delayMilliSeconds, TimeUnit.MILLISECONDS, threadPool);
    }

    public void delayRemove(T key, long delayMilliSeconds) {
        delayRemove(key, delayMilliSeconds, TimeUnit.MILLISECONDS, DEFAULT_SCHEDULED_THREAD_POOL);
    }

    public void delayRemove(T key, long time, TimeUnit unit) {
        delayRemove(key, time, unit, DEFAULT_SCHEDULED_THREAD_POOL);
    }

    public void delayRemove(T key, long time, TimeUnit unit, ScheduledExecutorService threadPool) {
        keyCheck(key);
        ScheduledFuture<?> future = threadPool.schedule(() -> remove(key), time, unit);
        GroupLockContext groupLockContext = map.getOrDefault(key, null);
        if (Objects.isNull(groupLockContext)) {
            // 数据已经被删除, 取消任务
            future.cancel(true);
        } else {
            // 数据存在, 替换旧延时删除
            groupLockContext.setDelayRemoveFuture(future);
        }
    }

    public void unlockAndRemove(T key) {
        keyCheck(key);
        unlock(key);
        remove(key);
    }

    public ReentrantLock getLock(T key) {
        keyCheck(key);
        GroupLockContext groupLockContext = map.getOrDefault(key, null);
        return Objects.isNull(groupLockContext) ? null : groupLockContext.getLock();
    }

    public void keyCheck(T key) {
        Objects.requireNonNull(key, "分组锁: key 参数不可为 null, 请检查");
    }

}

分组锁环境对象

package cn.moquan.tools.lock;

import java.util.Objects;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.locks.ReentrantLock;

/**
 * describe
 * <br />
 *
 * @author :<b> moquan </b><br />
 */
public class GroupLockContext {

    private ReentrantLock lock;

    private ScheduledFuture<?> delayRemoveFuture;

    public GroupLockContext() {
    }

    public GroupLockContext(ReentrantLock lock) {
        this.lock = lock;
    }

    public static GroupLockContext init() {
        return new GroupLockContext(new ReentrantLock());
    }

    /**
     * synchronized 的目的为:
     * 防止方法初始化时, 同时出现两个线程同时执行 Objects.nonNull(delayRemoveFuture) 方法得到结果为 false
     * 从而造成某个 future 未能被取消而导致的延时失败
     *
     * @param future 延时删除的对象
     */
    public synchronized void setDelayRemoveFuture(ScheduledFuture<?> future){

        // 检查任务是否已经完成
        if (Objects.nonNull(delayRemoveFuture) && !delayRemoveFuture.isDone()) {
            // 未完成则取消
            delayRemoveFuture.cancel(true);
        }

        // 替换为新的延时删除
        delayRemoveFuture = future;
    }

    public void lock(){
        lock.lock();
    }

    public void unlock(){
        lock.unlock();
    }

    public ReentrantLock getLock() {
        return lock;
    }

}

测试类

package cn.moquan.tools.lock;

import org.junit.jupiter.api.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

/**
 * describe
 * <br />
 *
 * @author :<b> moquan </b><br />
 */
class GroupLockTest {

    private static final Logger log = LoggerFactory.getLogger(GroupLockTest.class);

    public GroupLock<String> groupLock = new GroupLock<>();

    @BeforeEach
    void setUp() {
    }

    @AfterEach
    void tearDown() {
    }

    @Test
    @DisplayName("AB锁测试")
    @Timeout(1)
    void abLockTest() {

        CountDownLatch countDownLatch = new CountDownLatch(2);

        Runnable massageRunA = () -> {
            try {
                groupLock.lock("a");
                for (int i = 0; i < 5; i++) {
                    TimeUnit.MILLISECONDS.sleep(100);
                    log.info("A sleep for " + (i + 1) + " milliseconds");
                }
                log.info("massageRunA");
                countDownLatch.countDown();
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                groupLock.unlock("a");
            }
        };

        Runnable massageRunB = () -> {
            try {
                groupLock.lock("b");
                for (int i = 0; i < 5; i++) {
                    TimeUnit.MILLISECONDS.sleep(100);
                    log.info("B sleep for " + (i + 1) + " milliseconds");
                }
                log.info("massageRunB");
                countDownLatch.countDown();
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                groupLock.unlock("b");
            }
        };

        new Thread(massageRunA).start();
        new Thread(massageRunB).start();

        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    @Test
    @DisplayName("AA锁测试")
    void aaLockTest() {

        CountDownLatch countDownLatch = new CountDownLatch(2);

        Runnable massageRunA1 = () -> {
            try {
                groupLock.lock("a");
                for (int i = 0; i < 5; i++) {
                    TimeUnit.MILLISECONDS.sleep(100);
                    log.info("A1 sleep for " + (i + 1) + " milliseconds");
                }
                log.info("massageRunA1");
                countDownLatch.countDown();
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                groupLock.unlock("a");
            }
        };

        Runnable massageRunA2 = () -> {
            try {
                groupLock.lock("a");
                for (int i = 0; i < 5; i++) {
                    TimeUnit.MILLISECONDS.sleep(100);
                    log.info("A2 sleep for " + (i + 1) + " milliseconds");
                }
                log.info("massageRunA2");
                countDownLatch.countDown();
            } catch (Exception e) {
                throw new RuntimeException(e);
            } finally {
                groupLock.unlock("a");
            }
        };

        new Thread(massageRunA1).start();
        new Thread(massageRunA2).start();

        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    @Test
    @DisplayName("锁删除测试")
    void lockRemoveTest() {

        String lockKey = "a";

        try {
            groupLock.lock(lockKey);
            log.info("use a group lock");
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            groupLock.unlock(lockKey);
        }

        assertNotNull(groupLock.getLock(lockKey));

        groupLock.remove(lockKey);
        assertNull(groupLock.getLock(lockKey));

    }

    @Test
    @DisplayName("锁延时删除测试")
    void lockDelayRemoveTest() {

        String lockKey = "a";

        try {
            groupLock.lock(lockKey);
            log.info("use [ a ] group lock");
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            groupLock.unlock(lockKey);
        }

        assertNotNull(groupLock.getLock(lockKey));

        groupLock.delayRemove(lockKey, 900);
        assertNotNull(groupLock.getLock(lockKey));

        try {
            TimeUnit.SECONDS.sleep(1);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        assertNull(groupLock.getLock(lockKey));

    }

    @Test
    @DisplayName("锁更新延时删除测试")
    void lockUpdateDelayRemoveTest() {

        String lockKey = "a";

        try {
            groupLock.lock(lockKey);
            log.info("use [ a ] group lock");
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            groupLock.unlock(lockKey);
        }

        assertNotNull(groupLock.getLock(lockKey));

        try {

            long timeMillis = System.currentTimeMillis();

            for (int i = 0; i < 5; i++) {
                groupLock.delayRemove(lockKey, 100);
                TimeUnit.MILLISECONDS.sleep(50);
                assertNotNull(groupLock.getLock(lockKey));

                log.info("第 [ " + (i + 1) + " ] 次延时, 已延时 [ " + (System.currentTimeMillis() - timeMillis) + " ] 毫秒");
            }

            TimeUnit.SECONDS.sleep(1);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }

        assertNull(groupLock.getLock(lockKey));

    }

}