【数据结构】索引堆

226 阅读2分钟

1. 介绍

  • 索引堆是对堆的优化:在堆中,构建堆、插入、删除过程都需要大量的交换操作。在普通堆的实现中,进行交换操作是直接交换真实元素数组中的两个元素。而索引堆交换的是这两个元素的索引,而不是直接交换元素,就有以下优势:

    1. 减小交换操作的消耗,尤其是对于元素交换需要很多资源的对象来说,比如大字符串
    2. 可以根据原位置找到元素,即便这个元素已经换了位置。比如说我们这些元素表示的是一个一个的系统任务,初始的时候,数值的索引是系统进程的 ID 号,可是将数组构建成普通堆之后,这些索引和系统任务之间就会失去关联
  • 在堆的基础上用一个索引数组来存储数据元素的位置,即索引堆里面包含两个数组

    public class IndexHeap<T> {
        // 真实元素数组
        private T[] data;
        // 索引数组
        // 普通的堆:data[0]就是堆中的最大/最小元素
        // 索引堆:data[index[0]]才是最大/最小元素
        private int[] index;
    }
    

2. 代码实现

/**
 * 索引堆
 */
public class IndexHeap<T> {
    // 真实元素数组
    private T[] data;
    // 索引数组
    // 普通的堆:data[0]就是堆中的最大/最小元素
    // 索引堆:data[index[0]]才是最大/最小元素
    private int[] index;
    // 元素的比较器,自己决定"大"的定义
    private Comparator<T> comparator;
    // 索引堆的总容量
    private int capacity;
    // 元素数量
    private int size = 0;

    public IndexHeap(Comparator<T> comparator) {
        this(10, comparator);
    }

    public IndexHeap(int capacity, Comparator<T> comparator) {
        if (capacity <= 0 || comparator == null) {
            throw new IllegalArgumentException("Illegal argument for constructor!");
        }
        this.capacity = capacity;
        this.comparator = comparator;
        this.data = (T[]) new Object[capacity];
        this.index = new int[capacity];
        Arrays.fill(this.index, -1);
    }

    public boolean isEmpty() {
        return size == 0;
    }

    public int size() {
        return size;
    }

    // 往索引i的位置添加元素
    public void offer(int i, T t) {
        if (i < 0 || i >= capacity || size + 1 > capacity) {
            throw new IllegalArgumentException("Illegal operate");
        }
        data[i] = t;
        index[size] = i;
        siftUp(size);
        size++;
    }

    // 替换索引i位置的元素
    public void replace(int i, T t) {
        if (i < 0 || i >= capacity) {
            throw new IllegalArgumentException("Illegal operate");
        }
        if (data[i] == null) {
            offer(i, t);
            return;
        }
        data[i] = t;
        for (int k = 0; k < size; k++) {
            if (index[k] == i) {
                siftUp(k);
                siftDown(k);
                return;
            }
        }
    }

    // 获取索引i上的元素
    public T indexOf(int i) {
        if (i < 0 || i >= capacity) {
            throw new IllegalArgumentException("Illegal operate");
        }
        return data[i];
    }

    public T peek() {
        return size == 0 ? null : data[index[0]];
    }

    public int peekIndex() {
        return size == 0 ? -1 : index[0];
    }

    public T poll() {
        if (size == 0) {
            return null;
        }
        T ret = data[index[0]];

        pollOperate();

        return ret;
    }

    public int pollIndex() {
        if (size == 0) {
            return -1;
        }
        int ret = index[0];

        pollOperate();

        return ret;
    }

    private void pollOperate() {
        size--;
        swapIndex(0, size);
        siftDown(0);
        data[index[size]] = null;
        index[size] = -1;
    }

    // 上浮操作
    private void siftUp(int k) {
        while (k > 0 && comparator.compare(data[index[k]], data[index[getParentIndex(k)]]) > 0) {
            swapIndex(k, getParentIndex(k));
            k = getParentIndex(k);
        }
    }

    // 下沉操作
    private void siftDown(int k) {
        int max;
        while (getLeftChildIndex(k) < size) {
            max = getLeftChildIndex(k);
            if (max + 1 < size && comparator.compare(data[index[max + 1]], data[index[max]]) > 0) {
                max++;
            }
            // 此时max指向左右孩子中较大的元素
            if (comparator.compare(data[index[max]], data[index[k]]) > 0) {
                swapIndex(max, k);
                k = max;
            } else {
                break;
            }
        }
    }

    private void swapIndex(int i, int j) {
        int temp = index[i];
        index[i] = index[j];
        index[j] = temp;
    }

    private int getParentIndex(int i) {
        return (i - 1) / 2;
    }

    private int getLeftChildIndex(int i) {
        return i * 2 + 1;
    }
}