【图论】「堆优化Dijkstra」+「动态规划」解决首尾节点的受限路径数

58 阅读2分钟

题目描述

LeetCode上的1786. 从第一个节点出发到最后一个节点的受限路径数,难度:中等

现有一个加权无向连通图。给你一个正整数 n ,表示图中有 n 个节点,并按从 1n 给节点编号;另给你一个数组 edges ,其中每个 edges[i] = [ui, vi, weighti] 表示存在一条位于节点 uivi 之间的边,这条边的权重为 weighti

从节点 start 出发到节点 end 的路径是一个形如 [z0, z1, z2, ..., zk] 的节点序列,满足 z0 = startzk = end 且在所有符合 0 <= i <= k-1 的节点 zizi+1 之间存在一条边。

路径的距离定义为这条路径上所有边的权重总和。用 distanceToLastNode(x) 表示节点 nx 之间路径的最短距离。受限路径 为满足 distanceToLastNode(zi) > distanceToLastNode(zi+1) 的一条路径,其中 0 <= i <= k-1

返回从节点 1 出发到节点 n受限路径数 。由于数字可能很大,请返回对 109 + 7 取余 的结果。

算法示例

image-20230316110552635.png

输入:n = 5, edges = [[1,2,3],[1,3,3],[2,3,1],[1,4,2],[5,2,2],[3,5,1],[5,4,10]]
输出:3
解释:每个圆包含黑色的节点编号和蓝色的 distanceToLastNode 值。三条受限路径分别是:
1) 1 --> 2 --> 5
2) 1 --> 2 --> 3 --> 5
3) 1 --> 3 --> 5

image-20230316110630082.png

输入:n = 7, edges = [[1,3,1],[4,1,2],[7,3,4],[2,5,3],[5,6,1],[6,7,2],[7,5,3],[2,6,4]]
输出:1
解释:每个圆包含黑色的节点编号和蓝色的 distanceToLastNode 值。唯一一条受限路径是:1 --> 3 --> 7

算法思想

  • 利用堆优化的Dijkstra得到每个点到达结尾的「最短路」
  • 使用动态规划球的「起点」到「结尾」的受限路径数量

定义f[i]为从第i个点到第n个点的受限路径数量,f[0]就是我们的答案,而f[n - 1]=1作为起始条件。

我们从n点逆着出发,要求前点a和下一个点b,节点b必须要满足最短路径距离比上一个点a要远,为了保证b都是由到达n节点距离小于它的节点转移过来,需要将最短路距离按从小到大进行排序,

不失一般性,当我们要求 f[i]的时候,其实找的所有满足「与点 i 相连,且最短路比点 i 要小的点 j」,符合条件的点 j 有很多个,将所有的 f[j] 累加即是 f[i]。

class Solution {
    int MOD = 1000000007;
    public int countRestrictedPaths(int n, int[][] edges) {
        // 建图
        Map<Integer, Map<Integer, Integer>> graph = new HashMap<>();
        for(int[] edge : edges) {
            int u = edge[0] - 1, v = edge[1] - 1, weight = edge[2];
            Map<Integer, Integer> map_u = graph.getOrDefault(u, new HashMap<>());
            map_u.put(v, weight);
            graph.put(u, map_u);
            
            Map<Integer, Integer> map_v = graph.getOrDefault(v, new HashMap<>());
            map_v.put(u, weight);
            graph.put(v, map_v);
        }
        PriorityQueue<int[]> queue = new PriorityQueue<>((a,b)->(a[1] - b[1]));
        boolean[] visited = new boolean[n];
        // 各个点到节点n的距离
        int[] dist = new int[n];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[n - 1] = 0;
        queue.add(new int[] {n- 1, 0});
        while(!queue.isEmpty()) {
            int[] u_dis = queue.poll();
            int u = u_dis[0];
            if(visited[u]) continue;
            visited[u] = true;
            if(graph.containsKey(u)) {
                Map<Integer, Integer> map = graph.get(u);
                for(int v : map.keySet()) {
                    if(visited[v]) continue;
                    int v_dis = map.get(v);
                    int dv = dist[u] + v_dis;
                    if(dv < dist[v]) {
                        dist[v] = dv;
                    }
                    queue.add(new int[] {v, dist[v]});
                }
            }
            
        }
        // 接下来开始dp
        int[][] arr = new int[n][2];
        for(int i = 0; i < n; i++) {
            arr[i] = new int[] {i, dist[i]};
        }
        Arrays.sort(arr, (a, b) -> (a[1] - b[1]));
        // 定义f[i],表示节点i到n-1的受限数量
        int[] f = new int[n];
        f[n - 1] = 1;
        // i为0,拿到的第一个节点应该是第n个节点
        for(int i = 0; i < n; i++) {
            int idx = arr[i][0], dis = arr[i][1];
            Map<Integer, Integer> map = graph.get(idx);
        	if(map == null) continue;
            for(int v : map.keySet()) {
                if(dis > dist[v]) {
                    f[idx] += f[v];
                    f[idx] %= MOD;
                }
            }
            if (idx == 0) break;
        }
        return f[0];
    }
}