手写大根堆-Java实现

696 阅读2分钟

实现大根堆的前提是满足完全二叉树(没看过完全二叉树的可以先去查阅一下),大根堆的规则:父节点永远大于它的子节点,实现小根堆只需将大于小于符号改变即可

举例:如数组{0,1,2,3,4,5,6};

其中对于任意一个节点K(除了根节点)其父节点为(K-1)/2,子节点2K+1,2K+2;

最后一个非叶子节点为:(heapSize-2)/2

public static class MyMaxHeap {
    private int[] elem;
    private int heapSize = 0;

    public MyMaxHeap() {
        elem = new int[16];
    }

    /*
     * 建立大根堆
     * */
    public void buildHeap(int[] array){
        for(int i = 0; i < array.length; i++) {
            this.elem[i] = array[i];
            this.heapSize++;
        }
        buildMaxHeap(elem,heapSize);
    }

    public void push(int value) {
        if (isFull()) {
            //扩容
            System.out.printf("触发扩容");
            this.elem = Arrays.copyOf(this.elem, this.elem.length*2);
        }
        elem[heapSize] = value;
        heapInsert(heapSize++);
    }

    //已经是大根堆,只需要向上调整
    //如果收了N个数,时间复杂度为logN
    private void heapInsert(int child) {
        int parent = (child-1) / 2;
        while(parent >= 0) {
            if(this.elem[parent] < this.elem[child]) {
                int tmp = this.elem[parent];
                this.elem[parent] = this.elem[child];
                this.elem[child] = tmp;
                child = parent;
                parent = (child-1) / 2;
            } else {
                break;
            }
        }
    }

    // 返回最大值,并在大根堆中把最大值删掉
    // 剩下的数,依然保持大根堆组织
    public int pop() {
        if(isEmpty()) {
            throw new RuntimeException("heap is empty");
        }
        int ans = elem[0];
        swap(elem, 0, --heapSize);
        maxHeapify(elem, 0, heapSize);
        return ans;
    }

    //从最后一个非叶子节点开始,构建大根堆
    private void buildMaxHeap(int[] elem, int heapSize){
        int top = (heapSize -2) /2;
        for(int i = top; i>=0; i--){
            maxHeapify(elem,i,heapSize);
        }
    }

    // 堆结构的个关键操作:从某个位置开始往下调整,时间复杂度logN
    public void maxHeapify(int[] arr, int parent,int heapSize){
        int left = parent * 2 + 1;
        int right = parent * 2 + 2;
        int largest = parent;

        if(left < heapSize && arr[left] > arr[largest]){
            largest = left;
        }

        if(right < heapSize && arr[right] > arr[largest]){
            largest = right;
        }

        //如果最大值的指针不是父节点,则交换父节点和当前最大值指针指向的子节点。
        if(largest != parent){
            swap(arr,largest,parent);
            //由于交换了父节点和子节点,因此可能对子节点的子树造成影响,所以对子节点的子树进行调整。
            maxHeapify(arr,largest,heapSize);
        }
    }

    private void swap(int[] arr, int i, int j) {
        int t = arr[i];
        arr[i] = arr[j];
        arr[j] = t;
    }

    public boolean isEmpty() {
        return heapSize==0;
    }

    public boolean isFull() {
        return heapSize==elem.length;
    }

    //堆的输出
    public void Print(){
        for (int i = 0; i < heapSize; i++) {
            System.out.printf("%d ",elem[i]);
        }
    }


}

测试代码

public class MaxHeap {
    public static void main(String[] args) {
        MyMaxHeap heap= new MyMaxHeap();
        int flag = 1;
        while (flag == 1) {
            System.out.println("请输入你想完成的操作(请先创建堆):");
            System.out.println("创建堆:create");
            System.out.println("插入数据:push");
            System.out.println("删除数据:pop");
            System.out.println("输出堆:print");
            System.out.println("结束程序:over");
            Scanner sc = new Scanner(System.in);
            String comm = sc.next();
            switch (comm) {
                case "create":
                    build(heap);
                    break;
                case "push":
                    push(heap);
                    break;
                case "pop": {
                    System.out.printf("%d ",heap.pop());
                    break;
                }
                case "print": heap.Print();
                    break;
                case  "over": flag = 0;
            }
        }
    }

    public static void build(MyMaxHeap heap) {
        Scanner sc = new Scanner(System.in);
        int n;
        System.out.print("输入堆的大小和数据:");
        n = sc.nextInt();
        int[] a = new int[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
        }
        heap.buildHeap(a);
    }

    public static void push(MyMaxHeap heap) {
        try {
            Scanner sc = new Scanner(System.in);
            System.out.print("输入数据,多个用逗号隔开:");
            String n = sc.next();
            String[] split = n.split(",");
            for(int i=0; i< split.length;i++){
                heap.push(Integer.valueOf(split[i]));
            }
        } catch (NumberFormatException e) {
            e.printStackTrace();
        }
    }
}