716. Max Stack

8 阅读1分钟

image.png

TreeMap + DLL

  • A DoubleLinkedList to simulate the stack and allow O(1) removals.
  • A TreeMap to keep track of values in sorted order, allowing efficient max tracking
  • use TreeMap<Integer, List<Node>> instead of just TreeMap<Integer, Node> because multiple nodes in the stack can have the same value
class MaxStack {
    class Node {
        int val;
        Node pre;
        Node next;
        Node(int val) {
            this.val = val;
        }
    }

    TreeMap<Integer, List<Node>> map;
    Node head; // dummy head
    Node tail; // dymmy tail

    public MaxStack() {
        map = new TreeMap<>();
        head = new Node(0);
        tail = new Node(0);
        head.next = tail;
        tail.pre = head;
    }

    public void addNode(Node node) {
        node.pre = tail.pre;
        node.next = tail;
        node.pre.next = node;
        tail.pre = node;
    }

    public void removeNode(Node node) {
        node.pre.next = node.next;
        node.next.pre = node.pre;
        node.pre = null;
        node.next = null;
    }
    
    // O(log n) due to TreeMap insert
    public void push(int x) {
        Node node = new Node(x);
        addNode(node);
        if (!map.containsKey(x)) {
            map.put(x, new ArrayList<>());
        }
        map.get(x).add(node);
    }
    
    // O(1) for removal + O(log n) if we remove the key from map
    public int pop() {
        Node last = tail.pre;
        removeNode(last);
        List<Node> list = map.get(last.val); 
        map.get(last.val).remove(list.size() - 1); // remove from list
        if (map.get(last.val).size() == 0) { // empty list, remove from map
            map.remove(last.val);
        }
        return last.val;
    }
    
    //: O(1)
    public int top() {
        Node last = tail.pre;
        return last.val;
    }
    
    // O(log n) for lastKey()
    public int peekMax() {
        return map.lastKey();
    }
    
    // only pops top one if there are many same max val
    //  O(log n)
    public int popMax() {
        int max = map.lastKey(); // map是aesc排序
        List<Node> list = map.get(max);
        Node toRemove = list.get(list.size() - 1);
        removeNode(toRemove);
        map.get(max).remove(toRemove);
        if (map.get(max).size() == 0) {
            map.remove(max);
        }

        return max;
    }
}