手写LRU算法(Go&Java)

70 阅读3分钟

算法介绍

LRU算法,即Least Recently Used算法,最近最少使用算法。主要应用于缓存淘汰的场景,用于控制缓存的容量上限。即当缓存的容量达到阈值时,都取出最久远一次使用的元素进行淘汰。

算法实现要点

LRU算法的关键在于维护数据结构中各元素的使用时序序列,即以“最近一次使用时间”为排序规则对数据结构中所有的元素进行排序并维护一个排序列表,称之为时序队列,时序队列算法规则如下:

  1. 每次淘汰时,都将时序队列队尾(或对头)的元素进行淘汰
  2. 每次有元素进行查询或者变更操作时,都将元素节点从时序列表里取出(不存在于列表中则初始化一个元素节点),将元素节点重新放入时序队列队头(或队尾)

同时,为了保证时序队列操作的时间复杂度和空间复杂度都维持在O(1),需要用一个哈希列表记录每个时序队列元素节点的位置。该位置通常定义为LRU算法所归属的缓存数据结构的Key(例如缓存以int类型数字为key,则哈希列表对于元素节点的记录方式为key = int,value = itemNode) 维护排序列表通常也会使用双向链表来进行,保证排序列表操作的便捷性。

算法示例

设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构,实现 LRUCache 类:

  • LRUCache(int capacity):以正整数作为容量 capacity 初始化 LRU 缓存
  • int get(int key):如果关键字 key 存在于缓存中,则返回关键字的值,否则返回 -1 。
  • void put(int key, int value):如果关键字 key 已经存在,则变更其数据值 value;如果不存在,则向缓存中插入该组 key-value 。如果插入操作导致关键字数量超过 capacity则应该逐出最久未使用的关键字。

函数 get 和 put 必须以 O(1) 的平均时间复杂度运行。

Go实现

type LRUCache struct {
    capacity  int
    size      int
    cache     map[int]*Node
    dummyHead *Node
    tail      *Node
}

type Node struct {
    key   int
    value int
    pre   *Node
    next  *Node
}

func Constructor(capacity int) LRUCache {
    d := &Node{
        key:   -1,
        value: -1,
        pre:   nil,
        next:  nil,
    }

    return LRUCache{
        capacity:  capacity,
        size:      0,
        cache:     make(map[int]*Node),
        dummyHead: d,
        tail:      d,
    }
}

func (this *LRUCache) remove() *Node {
    dummyHead := this.dummyHead
    removeNode := dummyHead.next

    if removeNode != nil {

        // 如果要删除的节点是tail节点,先将tail节点前移,避免tail节点丢失
        if this.tail == removeNode {
            this.tail = removeNode.pre
        }

        next := removeNode.next
        dummyHead.next = next
        if next != nil {
            next.pre = dummyHead
        }

    }

    return removeNode
}

func (this *LRUCache) adjust(node *Node) {
    var pre, next *Node

    // 如果要移动的节点本来就是tail节点,则不进行任何动作
    if node == this.tail {
        return
    }

    // 节点从链表中拆除
    if node.pre != nil {
        pre = node.pre
        node.pre = nil
    }

    if node.next != nil {
        next = node.next
        node.next = nil
    }

    if pre != nil {
        pre.next = next
        if next != nil {
            next.pre = pre
        }
    }

    // 节点append到列表结尾
    this.tail.next = node
    node.pre = this.tail
    this.tail = node

}

func (this *LRUCache) Get(key int) int {
    node, exist := this.cache[key]
    if exist {
        this.adjust(node)
        return node.value
    } else {
        return -1
    }
}

func (this *LRUCache) Put(key int, value int) {
    node, exist := this.cache[key]

    if exist {
        // 如果Put的节点key Node已存在,则直接取用,修改value即可
        node.value = value
    } else {
        node = &Node{
            key:   key,
            value: value,
        }
        this.cache[key] = node
        this.size++
    }

    this.adjust(node)

    // 超出阈值,执行LRU淘汰
    if this.size > this.capacity {

        removeNode := this.remove()
        if removeNode != nil {
            delete(this.cache, removeNode.key)
            this.size--
        }

    }
}

Java实现

public class LRUCache {

    private final List list;
    private final Map<Integer, ListNode> cache = new HashMap<>();

    public LRUCache(int capacity) {
        list = new List(capacity);
    }

    public int get(int key) {
        if (cache.containsKey(key)) {
            ListNode node = cache.get(key);
            list.adjust(node);
            return node.getVal();
        }
        return -1;
    }

    public void put(int key, int value) {
        if (cache.containsKey(key)) {
            ListNode cur = cache.get(key);
            cur.setVal(value);
            list.adjust(cur);
        }
        else {
            ListNode node = new ListNode(value, key);
            cache.put(key, node);
            list.add(node);

            ListNode removed = list.remove();
            if (removed != null) {
                cache.remove(removed.getKey());
            }
        }
    }

    public static class List {
        int capacity;
        int size;
        ListNode head = new ListNode(-1);
        ListNode tail = head;

        public List(int capacity) {
            this.capacity = capacity;
            this.size = 0;
        }

        public void add(ListNode node) {
            tail.next = node;
            node.pre = tail;
            tail = tail.next;
            size++;
        }

        public void adjust(ListNode node) {
            ListNode preNode = node.pre;
            preNode.next = node.next;
            if (node.next != null)
                node.next.pre = preNode;
            if (tail == node)
                tail = preNode;
            size--;

            add(node);
        }

        public ListNode remove() {
            if (size > capacity) {
                ListNode removed = head.next;
                head.next = head.next.next;
                head.next.pre = head;
                size--;
                return removed;
            }

            return null;
        }
    }

    public static class ListNode {
        int val;
        int key;
        ListNode next;
        ListNode pre;

        public ListNode(int val, int key) {
            this.val = val;
            this.key = key;
        }

        public int getVal() {
            return val;
        }

        public void setVal(int val) {
            this.val = val;
        }

        public int getKey() {
            return key;
        }

        public void setKey(int key) {
            this.key = key;
        }

        ListNode(int x) {
            val = x;
            next = null;
        }
    }
}