题目描述
LeetCode上的1786. 从第一个节点出发到最后一个节点的受限路径数,难度:中等
现有一个加权无向连通图。给你一个正整数 n ,表示图中有 n 个节点,并按从 1 到 n 给节点编号;另给你一个数组 edges ,其中每个 edges[i] = [ui, vi, weighti] 表示存在一条位于节点 ui 和 vi 之间的边,这条边的权重为 weighti 。
从节点 start 出发到节点 end 的路径是一个形如 [z0, z1, z2, ..., zk] 的节点序列,满足 z0 = start 、zk = end 且在所有符合 0 <= i <= k-1 的节点 zi 和 zi+1 之间存在一条边。
路径的距离定义为这条路径上所有边的权重总和。用 distanceToLastNode(x) 表示节点 n 和 x 之间路径的最短距离。受限路径 为满足 distanceToLastNode(zi) > distanceToLastNode(zi+1) 的一条路径,其中 0 <= i <= k-1 。
返回从节点 1 出发到节点 n 的 受限路径数 。由于数字可能很大,请返回对 109 + 7 取余 的结果。
算法示例
输入:n = 5, edges = [[1,2,3],[1,3,3],[2,3,1],[1,4,2],[5,2,2],[3,5,1],[5,4,10]]
输出:3
解释:每个圆包含黑色的节点编号和蓝色的 distanceToLastNode 值。三条受限路径分别是:
1) 1 --> 2 --> 5
2) 1 --> 2 --> 3 --> 5
3) 1 --> 3 --> 5
输入:n = 7, edges = [[1,3,1],[4,1,2],[7,3,4],[2,5,3],[5,6,1],[6,7,2],[7,5,3],[2,6,4]]
输出:1
解释:每个圆包含黑色的节点编号和蓝色的 distanceToLastNode 值。唯一一条受限路径是:1 --> 3 --> 7 。
算法思想
- 利用堆优化的Dijkstra得到每个点到达结尾的「最短路」
- 使用动态规划球的「起点」到「结尾」的受限路径数量。
定义f[i]为从第i个点到第n个点的受限路径数量,f[0]就是我们的答案,而f[n - 1]=1作为起始条件。
我们从n点逆着出发,要求前点a和下一个点b,节点b必须要满足最短路径距离比上一个点a要远,为了保证b都是由到达n节点距离小于它的节点转移过来,需要将最短路距离按从小到大进行排序,
不失一般性,当我们要求 f[i]的时候,其实找的所有满足「与点 i 相连,且最短路比点 i 要小的点 j」,符合条件的点 j 有很多个,将所有的 f[j] 累加即是 f[i]。
class Solution {
int MOD = 1000000007;
public int countRestrictedPaths(int n, int[][] edges) {
// 建图
Map<Integer, Map<Integer, Integer>> graph = new HashMap<>();
for(int[] edge : edges) {
int u = edge[0] - 1, v = edge[1] - 1, weight = edge[2];
Map<Integer, Integer> map_u = graph.getOrDefault(u, new HashMap<>());
map_u.put(v, weight);
graph.put(u, map_u);
Map<Integer, Integer> map_v = graph.getOrDefault(v, new HashMap<>());
map_v.put(u, weight);
graph.put(v, map_v);
}
PriorityQueue<int[]> queue = new PriorityQueue<>((a,b)->(a[1] - b[1]));
boolean[] visited = new boolean[n];
// 各个点到节点n的距离
int[] dist = new int[n];
Arrays.fill(dist, Integer.MAX_VALUE);
dist[n - 1] = 0;
queue.add(new int[] {n- 1, 0});
while(!queue.isEmpty()) {
int[] u_dis = queue.poll();
int u = u_dis[0];
if(visited[u]) continue;
visited[u] = true;
if(graph.containsKey(u)) {
Map<Integer, Integer> map = graph.get(u);
for(int v : map.keySet()) {
if(visited[v]) continue;
int v_dis = map.get(v);
int dv = dist[u] + v_dis;
if(dv < dist[v]) {
dist[v] = dv;
}
queue.add(new int[] {v, dist[v]});
}
}
}
// 接下来开始dp
int[][] arr = new int[n][2];
for(int i = 0; i < n; i++) {
arr[i] = new int[] {i, dist[i]};
}
Arrays.sort(arr, (a, b) -> (a[1] - b[1]));
// 定义f[i],表示节点i到n-1的受限数量
int[] f = new int[n];
f[n - 1] = 1;
// i为0,拿到的第一个节点应该是第n个节点
for(int i = 0; i < n; i++) {
int idx = arr[i][0], dis = arr[i][1];
Map<Integer, Integer> map = graph.get(idx);
if(map == null) continue;
for(int v : map.keySet()) {
if(dis > dist[v]) {
f[idx] += f[v];
f[idx] %= MOD;
}
}
if (idx == 0) break;
}
return f[0];
}
}