Java 过期缓存实现 | ConcurrentHashMap

4,219 阅读1分钟

image.png

最近写一个个人peoject时要用到一个可定时清理值的容器,就自己简单实现了一个(虽然有Guava等开源实现实现但懒得引入了)

特征

  • 类似 Map,但是 put 时可以传入一个毫秒数,表示多长时间后过期,键值对自动清除
  • 实现上简单起见,key 固定 String 类型;底层容器 ConcurrentHashMap

参考

CODE

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;

/**
 * 模仿 Redis 缓存过期策略 的缓存实现
 * 
 * <p> 有两种策略: 被动和主动
 * <p>
 * 参考 <a href="https://redis.io/commands/expire" >redis.io/commands/expire</a>
 */
public class ExpireCache<T> {

    private ConcurrentMap<String,DelayValue<T> > map = new ConcurrentHashMap<>();

    private static final ScheduledExecutorService executor = Executors.newScheduledThreadPool(7);

    public static final long DEFAULT_PERIOD_SECOND = 30;

    /**
     * 创建默认新过期缓存
     * 
     * @see #ExpireCache(long, TimeUnit)
     */
    public ExpireCache() {
        this(DEFAULT_PERIOD_SECOND, TimeUnit.SECONDS);
    }

    /**
     * 创建新过期缓存
     * @param period 没两次扫描相隔时间
     * @param timeUnit 时间单位
     */
    public ExpireCache(long period, TimeUnit timeUnit) {
        // 创建缓存时,启动定时扫描周期任务
        executor.scheduleAtFixedRate(this::scanClean, period, period, timeUnit);
    }

    /**
     * 放入一个不过期 k v
     * @param key
     * @param value
     */
    public void put(String key, T value){
        map.put(key, new DelayValue<>(value));
    }

    /**
     * 放入一个可过期 k v
     * @param key 键
     * @param value 值
     * @param delay 过期时间 单位毫秒
     */
    public void put(String key, T value, long delay){
        map.put(key, new DelayValue<>(value, delay));
    }

    /**
     * 取值的适合判断值是否过期
     * @param key
     * @return 值或null(如果key不存在或过期的话)
     */
    public T get(String key){
        DelayValue<T> v = map.get(key);
        if(Objects.isNull(v)){
            return null;
        }
        if(expired(v)){
            map.remove(key);
            return null;
        }
        return v.data;
    }


    /**
     * 随机扫描20个可过期 key 判断删除
     */
    private void scanClean(){

        List<String> canExpiredKeys = new ArrayList<>();
        for (Entry<String, DelayValue<T>> entry : map.entrySet()) {
            if(canExpired(entry.getValue())){
                canExpiredKeys.add(entry.getKey());
            }
        }

        Collections.shuffle(canExpiredKeys);
        for (int i = 0, end = Math.min(canExpiredKeys.size(), 20); i < end; i++) {
            String key = canExpiredKeys.get(i);
            DelayValue<T> v = map.get(key);
            if (Objects.nonNull(v) && expired(v)) {
                map.remove(key);
            }
        }
    }

    private boolean canExpired(DelayValue<T> v){
        Objects.requireNonNull(v);
        return v.delay != DelayValue.FOREVER_FLAG;
    }

    private boolean expired(DelayValue<T> v){
        Objects.requireNonNull(v);
        return v.delay == DelayValue.FOREVER_FLAG ? false : 
            (System.currentTimeMillis() > v.delay + v.timestamp);
    }


    private static class DelayValue<T> {

        static final long FOREVER_FLAG = -1;

        final T data;
        final long timestamp = System.currentTimeMillis();
        final long delay;  // -1表示永不过期

        DelayValue(T data){
            this.data = data;
            this.delay = FOREVER_FLAG;
        }

        DelayValue(T data, long delay){
            this.data = data;
            this.delay = delay;
        }

        @Override
        public String toString() {
            return "DelayValue [data=" + data + ", delay=" + delay + ", timestamp=" + timestamp + "]";
        }        
    }   

}

UNIT TEST

    @Test
    @SuppressWarnings("unchecked")
    public void testExpire() throws NoSuchFieldException, SecurityException, IllegalArgumentException, IllegalAccessException{
        
        ExpireCache<String> cache = new ExpireCache<>(5, TimeUnit.SECONDS);
        
        cache.put("test", "test", 1500);
        assertEquals("test", cache.get("test"));

        try {
            TimeUnit.MILLISECONDS.sleep(1600);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        assertNull(cache.get("test"));


        cache.put("test2", "test2", 1500);
        assertEquals("test2", cache.get("test2"));

        try {
            TimeUnit.MILLISECONDS.sleep(5500);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        // 用反射拿到私有变量 map
        Field mapField = ExpireCache.class.getDeclaredField("map");
        mapField.setAccessible(true);
        ConcurrentMap<String, Object>  map =  (ConcurrentMap<String, Object>) mapField.get(cache);

        assertNull(map.get("test2"));
        assertNull(cache.get("test2"));
    }

image.png


OLD/NOGOOD CODE

/**
 * 自己实现的定期清除并调用close方法 缓存
 * <p>
 * 主要用来定时清理{@link JedisPool} 调用 close 方法
 */
public class ExpireCache2<V extends Closeable> {
    
    private final ConcurrentMap<String, V> map = new ConcurrentHashMap<>();

    private static final ScheduledExecutorService scheduler = new ScheduledThreadPoolExecutor(11);

    /**
     * 放入新值,在 delay 秒后调用 close 方法并删除 key 
     * @param key 键
     * @param value closeable对象
     * @param delay delay seconds 
     */
    public void put(String key, V value, long delay){
        map.put(key, value);
        addClearTimer(key, value, delay);
    }


    /** 返回对应值或 null */
    public V get(String key){
        return map.get(key);
    }

    private void addClearTimer(String key, V value, long delay){
        scheduler.schedule(() -> {
            try {
                value.close();
            } catch (IOException e) {}
            map.remove(key);
        }, delay, TimeUnit.SECONDS);
    }

}

Has a better idea? Comment it pls