144 阅读3分钟

Snipaste_2023-03-07_13-10-29.png

  • 堆的结构是一棵完全二叉树

    • 上面一棵满二叉树
    • 最后一层的节点从左到右依次排布
  • 小根堆

    • 每个点都小于等于左右子节点
  • 存储:使用一维数组存储

    • 根节点:1
    • x的左子节点:2x
    • x的右子节点:2x+1
  • 基本操作

    • down(x):往下调整,有一个节点变大了,比左右子节点还大,那就与左右子节点中更小的那个交换,递归这个过程,直到调整完毕
    • up(x):往上调整,有一个节点变小了,小于它的根节点,那就与根节点交换,递归这个过程,直到调整完毕
  • 五种操作

    • 添加节点
      • heap[++size] = x; up(size);
    • 求集合中的最小值
      • heap[1];
    • 删除最小值
      • heap[1] = heap[size--]; down(1)
      • 注意当有其他东西维护heap[]下标时,要用heap_swap()实现交换
    • 删除任意一个元素
      • heap[k] = heap[size--]; down(k); up(k);
      • 注意当有其他东西维护heap[]下标时,要用heap_swap()实现交换
    • 修改任意一个元素
      • heap[k] = x; down(k); up(k);
  • 为什么建堆的时候使用for (int i = n / 2; i; i -- ) down(i);时间复杂度是O(n)的?

    • Snipaste_2023-03-06_17-26-50.png
    • 这里以4层为例,注意看,从倒数第二层(第三层)开始向上遍历,第三层(n/4)个节点都down()1次,第二层(n/8)个节点都down()2次,第一层(n/16)节点down()3次,求和(错位相减)后总数小于n,因此时间复杂度为O(n),比从头开始遍历插入O(nlogn)快一点

模板

  • C++
// h[N]存储堆中的值, h[1]是堆顶,x的左儿子是2x, 右儿子是2x + 1
// ph[k]存储第k个插入的点在堆中的位置
// hp[k]存储堆中下标是k的点是第几个插入的
int h[N], ph[N], hp[N], size;

// 交换两个点,及其映射关系
// 这里只有当题目中要求对第k个插入的元素进行操作时才定义这个函数
void heap_swap(int a, int b)
{
    swap(ph[hp[a]],ph[hp[b]]);
    swap(hp[a], hp[b]);
    swap(h[a], h[b]);
}

void down(int u)
{
    int t = u;  //t表示三个节点中最小值的下标
    if (u * 2 <= size && h[u * 2] < h[t]) t = u * 2;
    if (u * 2 + 1 <= size && h[u * 2 + 1] < h[t]) t = u * 2 + 1;
    if (u != t)
    {
        heap_swap(u, t); //普通情况用swap(u, t)就行
        down(t);
    }
}

void up(int u)
{
    while (u / 2 && h[u] < h[u / 2])
    {
        heap_swap(u, u / 2); //普通情况用swap(u, u / 2)就行
        u >>= 1;
    }
}

// O(n)建堆
for (int i = n / 2; i; i -- ) down(i);
  • Java
// h[N]存储堆中的值, h[1]是堆顶,x的左儿子是2x, 右儿子是2x + 1
// ph[k]存储第k个插入的点在堆中的位置
// hp[k]存储堆中下标是k的点是第几个插入的
public static int size;
public static int[] h = new int[N]; 
public static int[] ph = new int[N]; 
public static int[] hp = new int[N]; 

public static void heap_swap(int a, int b) {
    int t = ph[hp[a]];
    ph[hp[a]] = ph[hp[b]];
    ph[hp[b]] = t;
    
    t = hp[a];
    hp[a] = hp[b];
    hp[b] = t;
    
    t = h[a];
    h[a] = h[b];
    h[b] = t;
}

public static void down(int u) {
    int t = u;
    if (u * 2 <= size && h[u * 2] < h[t]) {
        t = u * 2;
    }
    if (u * 2 + 1 <= size && h[u *2 + 1] < h[t]) {
        t = u * 2 + 1;
    }
    if (u != t) {
        int tmp = h[u]; //特殊情况用heap_swap(u,t)
        h[u] = h[t];
        h[t] = tmp;
        
        down(t);
    }
}

public static void up(int u) {
    while (u / 2 != 0 && h[u] < h[u / 2]) {
        int tmp = h[u];  ////特殊情况用heap_swap(u, u / 2)
        h[u] = h[u / 2];
        h[u / 2] = tmp;
        
        u >>= 1;
    }
}

//O(n)建堆
for (int i = 2 / n; i != 0; i--) {
    down(i);
}

练习

01 堆排序

  • 题目

Snipaste_2023-03-10_23-49-38.png

  • 题解
import java.io.*;

public class Main {
    public static final int N = 100010;
    public static int m, n, size;
    public static int[] h = new int[N];

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter pw = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
        String[] str1 = br.readLine().split(" ");
        n = Integer.parseInt(str1[0]);
        m = Integer.parseInt(str1[1]);
        String[] str2 = br.readLine().split(" ");
        for (int i = 0; i < n; i++) {
            h[i + 1] = Integer.parseInt(str2[i]);
        }
        size = n;

        for (int i = size / 2; i != 0; i--) {
            down(i);
        }

        while (m-- > 0) {
            pw.print(h[1] + " ");
            h[1] = h[size--];
            down(1);
        }
        pw.close();
        br.close();
    }

    public static void down(int u) {
        int t = u;
        if (u * 2 <= size && h[u * 2] < h[t]) {
            t = u * 2;
        }
        if (u * 2 + 1 <= size && h[u * 2 + 1] < h[t]) {
            t = u * 2 + 1;
        }
        if (t != u) {
            int tmp = h[t];
            h[t] = h[u];
            h[u] = tmp;
            down(t);
        }
    }
}

02 模拟堆

  • 题目

Snipaste_2023-03-10_23-51-21.png

  • 题解
import java.io.*;

public class Main {
    public static final int N = 100010;
    public static int n, size, idx;
    public static int[] h = new int[N];
    public static int[] ph = new int[N];
    public static int[] hp = new int[N];

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter pw = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
        n = Integer.parseInt(br.readLine());

        while (n-- > 0) {
            String[] str1 = br.readLine().split(" ");
            if (str1[0].equals("I")) {
                h[++size] = Integer.parseInt(str1[1]);
                ph[++idx] = size;
                hp[size] = idx;
                up(size);
            } else if (str1[0].equals("PM")) {
                pw.println(h[1]);
            } else if (str1[0].equals("DM")) {
                heap_swap(1, size--);
                down(1);
            } else if (str1[0].equals("D")) {
                int k = ph[Integer.parseInt(str1[1])];
                heap_swap(k, size--);
                up(k);
                down(k);
            } else {
                int k = ph[Integer.parseInt(str1[1])];
                int x = Integer.parseInt(str1[2]);
                h[k] = x;
                up(k);
                down(k);
            }
        }
        pw.close();
        br.close();
    }

    public static void heap_swap(int a, int b) {
        int t = ph[hp[a]];
        ph[hp[a]] = ph[hp[b]];
        ph[hp[b]] = t;

        t = hp[a];
        hp[a] = hp[b];
        hp[b] = t;

        t = h[a];
        h[a] = h[b];
        h[b] = t;
    }

    public static void down(int u) {
        int t = u;
        if (u * 2 <= size && h[u * 2] < h[t]) {
            t = u * 2;
        }
        if (u * 2 + 1 <= size && h[u * 2 + 1] < h[t]) {
            t = u * 2 + 1;
        }
        if (t != u) {
            heap_swap(u, t);
            down(t);
        }
    }

    public static void up(int u) {
        while (u / 2 != 0 && h[u / 2] > h[u]) {
            heap_swap(u, u / 2);
            u >>= 1;
        }
    }
}