图的最短路计数问题——BFS和Dijkstra堆优化

147 阅读10分钟

一、题目内容

www.luogu.com.cn/problem/P11…

题目描述

给出一个 NN 个顶点 MM 条边的无向无权图,顶点编号为 1N1\sim N。问从顶点 11 开始,到其他每个点的最短路有几条。

输入格式

第一行包含 22 个正整数 N,MN,M,为图的顶点数与边数。

接下来 MM 行,每行 22 个正整数 x,yx,y,表示有一条连接顶点 xx 和顶点 yy 的边,请注意可能有自环与重边。

输出格式

NN 行,每行一个非负整数,第 ii 行输出从顶点 11 到顶点 ii 有多少条不同的最短路,由于答案有可能会很大,你只需要输出 ansmod100003 ans \bmod 100003 后的结果即可。如果无法到达顶点 ii 则输出 00

样例 #1

样例输入 #1

5 7
1 2
1 3
2 4
3 4
2 3
4 5
4 5

样例输出 #1

1
1
1
2
4

提示

1155 的最短路有 44 条,分别为 2212451\to 2\to 4\to 52213451\to 3\to 4\to 5(由于 454\to 5 的边有 22 条)。

对于 20%20\% 的数据,1N1001\le N \le 100
对于 60%60\% 的数据,1N1031\le N \le 10^3
对于 100%100\% 的数据,1N1061\le N\le10^61M2×1061\le M\le 2\times 10^6

二、BFS

2.1 思路

  • 初始化:

    • 创建一个数组 dist,用于记录从顶点1到各个顶点的最短距离,初始值设为无穷大。
    • 创建一个数组 count,用于记录从顶点1到各个顶点的最短路径数量,初始值设为0。
    • dist[1] 设置为0,count[1] 设置为1,表示从顶点1到自身的最短路径长度为0,路径数量为1。
  • BFS遍历:

    • 使用队列 q 存储待处理的节点,初始时将顶点1入队。

    • 从队列中取出一个节点 u,遍历其所有邻接节点 v

      • 如果 v 尚未被访问(即 dist[v] 为无穷大),则更新 dist[v]dist[u] + 1,并将 v 入队。
      • 如果 v 已经被访问,且通过 u 到达 v 的路径长度等于当前已知的最短路径长度(即 dist[u] + 1 == dist[v]),则将 count[v] 增加 count[u],表示从顶点1到 v 的最短路径数量增加了通过 u 到达的路径数量
  • 结果输出:

    • 对于每个顶点 i,输出 count[i] % 100003,如果 dist[i] 为无穷大,则输出0,表示无法到达该顶点。

2.2 代码

from collections import deque, defaultdict

MOD = 100003

def bfs_shortest_paths(N, adj):
    dist = [float('inf')] * (N + 1)
    count = [0] * (N + 1)
    dist[1] = 0
    count[1] = 1
    q = deque([1])
    
    while q:
        u = q.popleft()
        for v in adj[u]:
            if dist[v] == float('inf'):
                dist[v] = dist[u] + 1
                count[v] = count[u]
                q.append(v)
            elif dist[v] == dist[u] + 1:
                count[v] = (count[v] + count[u]) % MOD
    
    return count, dist

def main():
    N, M = map(int, input().split())
    adj = defaultdict(list)
    
    for _ in range(M):
        x, y = map(int, input().split())
        adj[x].append(y)
        adj[y].append(x)
    
    count, dist = bfs_shortest_paths(N, adj)
    
    for i in range(1, N + 1):
        if dist[i] == float('inf'):
            print(0)
        else:
            print(count[i] % MOD)

if __name__ == "__main__":
    main()

2.3 复杂度分析

时间复杂度:

BFS解法的时间复杂度为O(N + M),其中N是顶点数,M是边数。

注意:总时间复杂度O(N+M)O(N + M),而不是 O(N×M)O(N \times M)。这里的 O(N+M)O(N + M) 是因为:

  • 节点和边的处理是独立的,分别计数。
  • BFS 遍历的是每个节点及其邻接的边,每条边在图中最多出现一次,因此边的访问次数不会达到 O(N×M)O(N \times M)

空间复杂度:

  • 使用了邻接表 adj,其空间复杂度为O(N + M)。
  • 数组 distcount 的空间复杂度为O(N)。
  • 因此,整体空间复杂度为O(N + M)。

2.4 超时分析

在N为 10610^6,M为 21062*10^6 时,该算法的需要执行 51065 * 10^6 次(无向图,在上述题解中一条边存了2次),理论上不会超时。但实际最后一个用例耗费 1.2s 超时:

image.png

2.4.1 尝试转为CPP写法(依然超时):

考虑到相同算法用不同语言实现有执行速度差异,使用c++实现上述python代码:

#include <iostream>
#include <vector>
#include <queue>
#include <climits>
using namespace std;

const int MOD = 100003;

void bfs_shortest_paths(int N, vector<vector<int>>& adj, vector<int>& count, vector<int>& dist) {
    dist[1] = 0;
    count[1] = 1;
    
    // 使用队列进行 BFS
    queue<int> q;
    q.push(1);
    
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        
        // 遍历所有邻接节点
        for (int v : adj[u]) {
            if (dist[v] == INT_MAX) {  // 如果 v 还没有被访问过
                dist[v] = dist[u] + 1;
                count[v] = count[u];
                q.push(v);
            } else if (dist[v] == dist[u] + 1) {  // 如果是最短路径
                count[v] = (count[v] + count[u]) % MOD;
            }
        }
    }
}

int main() {
    int N, M;
    cin >> N >> M;

    vector<vector<int>> adj(N + 1);  // 邻接表
    vector<int> count(N + 1, 0);      // 路径数量数组
    vector<int> dist(N + 1, INT_MAX); // 最短路径数组
    
    // 读取图的边
    for (int i = 0; i < M; i++) {
        int x, y;
        cin >> x >> y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }

    // 调用 BFS 计算最短路径和路径数量
    bfs_shortest_paths(N, adj, count, dist);
    
    // 输出结果
    for (int i = 1; i <= N; i++) {
        if (dist[i] == INT_MAX) {
            cout << 0 << endl;
        } else {
            cout << count[i] % MOD << endl;
        }
    }

    return 0;
}

但实际用时还是1.2s,依然超时。

2.4.2 c++解法使用getchar()代替cin,优化输入,最终ac。

getchar()putchar() 是 C 语言中的低级输入输出函数,它们比 cincout 更高效,特别是在处理大量数据时。cincout 会进行更多的缓冲和格式化工作,而 getchar() 是直接从标准输入读取一个字符。

#include <iostream>
#include <vector>
#include <queue>
#include <climits>
using namespace std;

const int MOD = 100003;

// 自定义读取整数的函数
inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while ((ch < '0' || ch > '9') && ch != '-') {
        ch = getchar();
    }
    if (ch == '-') {
        f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return f * x;
}

void bfs_shortest_paths(int N, vector<vector<int>>& adj, vector<int>& count, vector<int>& dist) {
    dist[1] = 0;
    count[1] = 1;
    
    // 使用队列进行 BFS
    queue<int> q;
    q.push(1);
    
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        
        // 遍历所有邻接节点
        for (int v : adj[u]) {
            if (dist[v] == INT_MAX) {  // 如果 v 还没有被访问过
                dist[v] = dist[u] + 1;
                count[v] = count[u];
                q.push(v);
            } else if (dist[v] == dist[u] + 1) {  // 如果是最短路径
                count[v] = (count[v] + count[u]) % MOD;
            }
        }
    }
}

int main() {
    int N = read(), M = read(); // 读取节点数和边数

    vector<vector<int>> adj(N + 1);  // 邻接表
    vector<int> count(N + 1, 0);      // 路径数量数组
    vector<int> dist(N + 1, INT_MAX); // 最短路径数组
    
    // 读取图的边
    for (int i = 0; i < M; i++) {
        int x = read(), y = read(); // 读取每条边
        adj[x].push_back(y);
        adj[y].push_back(x);
    }

    // 调用 BFS 计算最短路径和路径数量
    bfs_shortest_paths(N, adj, count, dist);
    
    // 输出结果
    for (int i = 1; i <= N; i++) {
        if (dist[i] == INT_MAX) {
            printf("0\n");
        } else {
            printf("%d\n", count[i] % MOD);
        }
    }

    return 0;
}

耗时299ms:

image.png

三、Dijkstra 堆优化

Dijkstra 算法是用于求解 加权图(即图中的边有不同权重)中 单源最短路径 的经典算法。它的核心思想是贪心策略,依次选择当前已知的最短路径最小的节点,更新与该节点相连的所有邻接节点的最短路径值。

3.1 思路

朴素dijkstra的具体步骤:

  1. 所有点分为两个集合 STS 最开始只包括源点 s,剩余点都位于 TS 集合表示已经计算出最短路径的点集合,T 表示尚未计算出最短路径的点集合。
  2. 每次从集合 T 中选出一个与集合 S 距离最短的点v,将点v加入集合S。通过点v对集合T中的点进行松弛(“松弛”是指在最短路算法中通过不断更新路径估计值以逼近真实最短路径的一种操作。)
  3. 不断重复此步骤2,直至T集合中无法找出与集合S相邻的点。

该题是无权图,所以dijkstra不是最优解通过堆优化可以ac

核心要点:

3.1.1 链式前向星存图:

int head[N], to[M], nxt[M], d[N], ans[N];
bool p[N];

其中:

  • head:邻接表的头指针数组。

  • to:存储边的终点。

  • nxt:存储每个节点的下一条边的索引。

  • d:存储从源点到各节点的最短距离。

  • ans:存储从源点到各节点的最短路径数量。

  • p:标记节点是否已被处理。

3.1.2 堆优化

在 C++ 中,priority_queue<pair<int, int>> 默认使用最大堆(大顶堆),即优先级最高的元素位于队首。 pair 类型的元素比较规则是:首先比较 first 元素,如果相等,则比较 second 元素。 因此,priority_queue 会根据 pair 的第一个元素(即 first)进行排序,确保 first 最大的元素位于队首。 如果存在多个元素的 first 相等,则会根据 second 元素进行排序。

使用最大堆 q 表示未被松弛的集合T

priority_queue<pair<int, int>> q;

既然用到了最大堆,但是又想取最短距离,所以在push时,将first值取负

q.push(make_pair(-d[y], y));

然后每次取堆顶元素(距离集合S最近的点):

int u = q.front();
q.pop();

3.2 代码

#include <cstdio>
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
using namespace std;

inline int read() {
    char ch = getchar();
    int x = 0, f = 1;
    while ((ch > '9' || ch < '0') && ch != '-') ch = getchar();
    if (ch == '-') {
        f = -1;
        ch = getchar();
    }
    while ('0' <= ch && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

// 定义取模常数
int mod = 100003;

// 定义图的节点数、边数、当前节点、邻接节点、边的数量等变量
int n, m, x, y, tot = 0;
const int N = 1000005, M = 4000005;
int head[N], to[M], nxt[M], d[N], ans[N];
bool p[N];

// 定义优先队列,存储节点的最短距离和节点编号
priority_queue<pair<int, int>> q;

// 添加边到邻接表
void add(int x, int y) {
    to[++tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

int main() {
    // 读取节点数和边数
    n = read();
    m = read();

    // 读取每条边的信息,并构建邻接表
    for (int i = 1; i <= m; i++) {
        x = read();
        y = read();
        add(x, y);
        add(y, x);
    }

    // 初始化距离数组和标记数组
    for (int i = 1; i <= n; i++) {
        d[i] = 1e9; // 设置初始距离为无穷大
        p[i] = 0;   // 标记数组初始化为未访问
    }

    // 设置源点的距离为0,路径数量为1
    d[1] = 0;
    ans[1] = 1;

    // 将源点加入优先队列
    q.push(make_pair(0, 1));

    // Dijkstra 算法主体
    while (q.size()) {
        // 获取当前最短距离的节点
        x = q.top().second;
        q.pop();

        // 如果该节点已被访问,跳过
        if (p[x]) continue;

        // 标记该节点为已访问
        p[x] = 1;

        // 遍历当前节点的所有邻接节点
        for (int i = head[x]; i; i = nxt[i]) {
            y = to[i];

            // 如果发现更短的路径
            if (d[y] > d[x] + 1) {
                d[y] = d[x] + 1; // 更新最短距离
                ans[y] = ans[x]; // 更新路径数量
                q.push(make_pair(-d[y], y)); // 将更新后的节点加入优先队列
            } else if (d[y] == d[x] + 1) {
                // 如果发现相同长度的路径,累加路径数量
                ans[y] += ans[x];
                ans[y] %= mod; // 取模以防止溢出
            }
        }
    }

    // 输出每个节点的最短路径数量
    for (int i = 1; i <= n; i++)
        printf("%d\n", ans[i]);

    return 0;
}


耗时255ms:

image.png

3.3 复杂度分析

时间复杂度

1. 初始化:

  • 初始化距离数组 d 和标记数组 p 的时间复杂度为 O(n),其中 n 是节点数。

2. 读取边:

  • 读取每条边并构建链式前向星的时间复杂度为 O(m),其中 m 是边的数量。

3. Dijkstra 算法主体:

  • 在最坏情况下,每个节点会被加入优先队列一次,且每个节点的邻接节点会被遍历。对于每个节点的操作,优先队列的插入和删除操作的时间复杂度为 O(log n)。

  • 因此,Dijkstra 算法的时间复杂度为 O((n + m) log n),其中 n 是节点数,m 是边数。

空间复杂度

1. 链式前向星:

  • 使用链式前向星存储图的边,空间复杂度为 O(n + m),其中 n 是节点数,m 是边数。

2. 其他数组:

  • 数组 d、p 和 ans 的空间复杂度为 O(n)。