手写LRU缓存淘汰算法

252 阅读5分钟

本文是字节飞书一面真题,以下针对 LRU 的基本实现和进阶提问做解析。

LRU

LRU 是一个缓存淘汰算法,全称是 Least recently used,最近最少使用,也就是最近没被使用的缓存,会被淘汰。

分析

核心关注点:

  1. 需要实现一个缓存,用来存数据,KV结构是最优选择。
  2. 需要淘汰最近没有使用的缓存,所以需要对K按「访问顺序」进行排序。

实现

算法的核心逻辑是在写入和查询缓存的时候,记录缓存的访问记录,从而在容量满的时候,将最近没被使用的缓存淘汰掉。

手动实现

实现起来可以用 LinkedList 记录缓存的访问记录,用 HashMap 记录缓存。

public class Main {

    /**
     * 手写LRU缓存
     *
     * LRU 是一个最近最少使用的缓存淘汰算法
     */
    public static void main(String[] args) {
        LRU<String> lru = new LRU<>(3);
        lru.put("a", "a");
        lru.put("b", "b");
        lru.put("c", "c");
        System.out.println(lru.get("b"));
        System.out.println(lru.get("c"));
        lru.put("d", "d");
        System.out.println(lru.get("a"));
    }

    static class LRU<T> {
        private final int capacity; // 容量
        private final LinkedList<String> cacheKey; // 记录访问顺序
        private final Map<String, T> cache; // 存数据

        public LRU(int capacity) {
            this.capacity = capacity;
            cache = new HashMap<>(capacity);
            cacheKey = new LinkedList<>();
        }

        public synchronized void put(String key, T value) {
            // 最近访问放队头
            if (cache.containsKey(key)) {
                cacheKey.remove(key);
            }
            cacheKey.addFirst(key);
            cache.put(key, value);

            // 淘汰队尾
            if (cacheKey.size() >= capacity) {
                String last = cacheKey.pollLast();
                cache.remove(last);
            }
        }

        public synchronized T get(String key) {
            if (cache.containsKey(key)) {
                // 最近访问放队头
                cacheKey.remove(key);
                cacheKey.addFirst(key);

                return cache.get(key);
            }
            return null;
        }
    }
}

以上代码实现了一个基础的同步 LRU 缓存。

复杂度分析

在跟面试官沟通过实现原理之后,被提问实现的复杂度如何。

数据使用 Map 进行存储,存取复杂度都是 O(1)。

访问顺序使用 LinkedList 实现,在调整顺序时需要先移除旧元素或者淘汰队头最近没访问的元素,后添加新元素到队尾。移除旧元素时需要遍历找到该元素,遍历的复杂度是 O(n),淘汰队头元素的复杂度是 O(1)。添加新元素,因为 LinkedList 是双向链表,所以直接在队尾追加即可,复杂度 O(1),所以调整顺序的平均复杂度为 O(n)。

这里还提问了单双向链表的区别。单双向链表的遍历复杂度一致,主要区别在于插入和删除。

所以上述代码实现的 LRU 缓存可以看成是一个复杂度为 O(n) 的实现,还有继续优化的空间。

复杂度优化

分析完复杂度,面试官问如何降低复杂度,提升性能?

上述实现中,复杂度较高的原因是因为移除旧元素的遍历比较,那么可以通过手动实现双向链表,并在 Map 中存储链表节点的方式来将复杂度降低到 O(1)。

代码如下:

public class Main {

    /**
     * 手写LRU缓存 - 优化版本
     *
     * LRU 是一个最近最少使用的缓存淘汰算法
     * 使用双向链表 + HashMap 实现 O(1) 时间复杂度
     */
    public static void main(String[] args) {
        LRU<String> lru = new LRU<>(3);
        lru.put("a", "a");
        lru.put("b", "b");
        lru.put("c", "c");
        System.out.println(lru.get("b")); // 应该返回 "b"
        System.out.println(lru.get("c")); // 应该返回 "c"
        lru.put("d", "d"); // 此时应该淘汰 "a"
        System.out.println(lru.get("a")); // 应该返回 null
        System.out.println("当前缓存内容:");
        lru.printCache();
    }

    static class LRU<T> {
        private final int capacity; // 容量
        private final Map<String, Node<T>> cache; // 存储键值对
        private final Node<T> head; // 虚拟头节点
        private final Node<T> tail; // 虚拟尾节点
        private int size; // 当前大小

        public LRU(int capacity) {
            this.capacity = capacity;
            this.cache = new HashMap<>(capacity);
            this.size = 0;
            
            // 初始化双向链表,使用虚拟头尾节点简化操作
            this.head = new Node<>(null, null, null, null);
            this.tail = new Node<>(null, null, null, null);
            this.head.next = this.tail;
            this.tail.prev = this.head;
        }

        public synchronized void put(String key, T value) {
            // 如果key已存在,先删除旧节点
            if (cache.containsKey(key)) {
                removeNode(cache.get(key));
                size--;
            }
            
            // 创建新节点并添加到头部
            Node<T> newNode = new Node<>(key, value, head, head.next);
            head.next.prev = newNode;
            head.next = newNode;
            cache.put(key, newNode);
            size++;
            
            // 如果超出容量,删除尾部节点
            if (size > capacity) {
                Node<T> lastNode = tail.prev;
                removeNode(lastNode);
                cache.remove(lastNode.key);
                size--;
            }
        }

        public synchronized T get(String key) {
            Node<T> node = cache.get(key);
            if (node == null) {
                return null;
            }
            
            // 将节点移动到头部(最近使用)
            moveToHead(node);
            return node.value;
        }
        
        // 删除指定节点
        private void removeNode(Node<T> node) {
            node.prev.next = node.next;
            node.next.prev = node.prev;
        }
        
        // 将节点移动到头部
        private void moveToHead(Node<T> node) {
            // 先从当前位置删除
            removeNode(node);
            // 添加到头部
            node.next = head.next;
            node.prev = head;
            head.next.prev = node;
            head.next = node;
        }
        
        // 打印缓存内容(用于调试)
        public void printCache() {
            Node<T> current = head.next;
            System.out.print("缓存内容: ");
            while (current != tail) {
                System.out.print(current.key + "=" + current.value + " ");
                current = current.next;
            }
            System.out.println();
        }
        
        // 双向链表节点
        static class Node<T> {
            String key;
            T value;
            Node<T> prev;
            Node<T> next;
            
            public Node(String key, T value, Node<T> prev, Node<T> next) {
                this.key = key;
                this.value = value;
                this.prev = prev;
                this.next = next;
            }
        }
    }
}

使用 LinkedHashMap 实现

还有一种更简单的实现方式,只需要继承 LinkedHashMap,并重写 removeEldestEntry 方法。

原理是 LinkedHashMap 内部维护了一个双向链表,在设置了 accessOrder 为 true 的时候,每次访问都会动态调整元素的位置,将最新的节点放到尾部,这样就保持了头部节点始终是最老的元素(最近最少访问的元素)。

并且在 LinkedHashMap 的源码中,每次插入新元素都会判断是否要删除最老的元素:

// LinkedHashMap 源码,新节点插入回调
void afterNodeInsertion(boolean evict) { // possibly remove eldest
    LinkedHashMap.Entry<K,V> first;
    if (evict && (first = head) != null && removeEldestEntry(first)) {
        // removeEldestEntry 返回true会移除最老的节点
        K key = first.key;
        removeNode(hash(key), key, null, false, true);
    }
}

所以使用 LinkedHashMap 可以快速的实现一个 LRU 缓存,代码如下:

public class Main {

    /**
* 手写LRU缓存
*
* LRU 是一个最近最少使用的缓存淘汰算法
*/
public static void main(String[] args) {
        LRU<String> lru = new LRU<>(3);
        lru.put("a", "a");
        lru.put("b", "b");
        lru.put("c", "c");
        System.out.println(lru.get("b"));
        System.out.println(lru.get("c"));
        lru.put("d", "d");
        System.out.println(lru.get("a"));
    }

    static class LRU<T> extends LinkedHashMap<String, T> {
        private final int capacity; // 容量

        public LRU(int capacity) {
            super(capacity, 0.75f, true);
            this.capacity = capacity;
        }

        @Override
        public boolean removeEldestEntry(Map.Entry<String, T> eldest) {
            // 超出容量,返回删除最老元素
            return this.size() > capacity;
        }
    }
}