Offer 驾到,掘友接招!我正在参与2022春招打卡活动,点击查看活动详情。
题目链接:洛谷 P2505
题目大意:给定单向图,求每条边被多少条不同的最短路径经过。
题目分析:
为了解决这个问题,我们来介绍一个概念,叫做:最短路图。
在此处引入这个新的概念主要在于叙述更加方便。
最短路图的概念如下:
我们以 为起点对图 做一次最短路算法,如果存在 的子图 满足:
的任意一条边都在某一条最短路径上,且不在 的任意一条边都不在任意一条最短路径上,则将 称为点 的 最短路图。
简言之,一个点 的最短路图,就是从这个点做单源最短路径后,所有最短路径组成的新图(显然,这个新图是原图的子图)。
基于最短路图的概念,我们可以考虑枚举每个点作为起点,跑一趟最短路。随后讨论点 的最短路图 上的边对于答案的贡献。
对于 上的边 ,设从 到 的最短路径有 条,经过 且以 为出边的最短路径有 条。
显然,根据乘法原理,在图 中,边 对答案的贡献为 。
于是我们考虑求解 。
由于在最短路图上,不存在环(这一点根据最短路图的定义很容易发现),于是我们首先在该点的最短路图上进行拓扑排序。
对于 上的边 ,显然有递推式 ,且有初始化 。
然后我们考虑拓扑排序的逆序,有递推式 ,且有初始化 。
将上面结合起来,就可以进行求解。
参考代码:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll N = 1505;
const ll M = 10005;
const ll Inf = 1e18;
const ll Mod = 1e9 + 7;
ll ans[M];
ll inc, from[M], to[M], edg[M], nxt[M], head[N];
ll n, m, in[N], dis[N], vis[M], cnt1[N], cnt2[N];
ll add(ll x, ll y) {
x += y;
return x >= Mod ? x - Mod : x;
}
ll mul(ll x, ll y) {
return x * y % Mod;
}
void insert(ll x, ll y, ll z) {
++inc;
from[inc] = x;
to[inc] = y;
edg[inc] = z;
nxt[inc] = head[x];
head[x] = inc;
}
void spfa(ll s) {
for (ll i = 1; i <= n; ++i) {
dis[i] = Inf, vis[i] = 0;
}
queue<ll> q;
dis[s] = 0;
vis[s] = 1;
q.push(s);
while (q.size()) {
ll x = q.front();
q.pop();
vis[x] = 0;
for (ll i = head[x]; i; i = nxt[i]) {
ll y = to[i], z = edg[i];
if (dis[y] > dis[x] + z) {
dis[y] = dis[x] + z;
if (!vis[y]) {
vis[y] = 1;
q.push(y);
}
}
}
}
}
void dfs1(ll x) {
vis[x] = 1;
for (ll i = head[x]; i; i = nxt[i]) {
ll y = to[i], z = edg[i];
if (dis[y] == dis[x] + z) {
in[y]++;
if (!vis[y]) {
dfs1(y);
}
}
}
}
void dfs2(ll x) {
for (ll i = head[x]; i; i = nxt[i]) {
ll y = to[i], z = edg[i];
if (dis[y] == dis[x] + z) {
vis[i] = 1;
cnt1[y] = add(cnt1[y], cnt1[x]);
if (--in[y] == 0) {
dfs2(y);
}
}
}
}
void dfs3(ll x) {
cnt2[x] = 1;
for (ll i = head[x]; i; i = nxt[i]) {
ll y = to[i], z = edg[i];
if (dis[y] == dis[x] + z) {
if (!cnt2[y]) {
dfs3(y);
}
cnt2[x] = add(cnt2[x], cnt2[y]);
}
}
}
int main() {
scanf("%lld%lld", &n, &m);
for (ll i = 1; i <= m; i++) {
ll x, y, z;
scanf("%lld%lld%lld", &x, &y, &z);
insert(x, y, z);
}
for (ll i = 1; i <= n; i++) {
spfa(i);
memset(vis, 0, sizeof(vis));
dfs1(i);
memset(vis, 0, sizeof(vis));
memset(cnt1, 0, sizeof(cnt1));
memset(cnt2, 0, sizeof(cnt2));
cnt1[i] = 1;
dfs2(i);
dfs3(i);
for (ll j = 1; j <= m; j++) {
if (vis[j]) {
ans[j] = add(ans[j], mul(cnt1[from[j]], cnt2[to[j]]));
}
}
}
for (ll i = 1; i <= m; i++) {
printf("%lld\n", ans[i]);
}
return 0;
}