【算法修炼】图论算法四(SPFA最短路径算法实战)

138 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

先贴贴模板:

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        // 先建图
        List<int[]>[] graph = new LinkedList[n + 1];
//         也可以写成:
//        List<Integer>[] graph = new LinkedList[n];
        // List数组记得对每一个List进行初始化,才能使用
        // 注意节点从1开始,数组大小要开成:n + 1
        for (int i = 1; i <= n; i++) {
            graph[i] = new LinkedList<>();
        }
        for (int[] time : times) {
            int from = time[0];
            int to = time[1];
            int weight = time[2];
            // from -> List<(to, weight)>
            // 邻接表存储图结构,同时存储权重信息weight
            graph[from].add(new int[]{to, weight});
        }
        // 记录开始节点到任一节点的最短路径
        int[] distTo = new int[graph.length];
        Arrays.fill(distTo, Integer.MAX_VALUE);
        // 记录是否入队
        boolean[] vis = new boolean[graph.length];
        // 统计当前节点的遍历次数,用于判断负环
        int[] nums = new int[graph.length];
        // 初始条件
        distTo[k] = 0;
        vis[k] = true;
        // 只有一个点,不含边
        nums[k] = 0;
        // 是否为负环
        boolean flag = false;
        // SPFA开始,k为起点
        Queue<Integer> queue = new LinkedList<>();
        queue.offer(k);
        while (!queue.isEmpty()) {
            int curId = queue.poll();
            // 出队,与BFS题目的vis数组区分开
            vis[curId] = false;
            // 遍历与该节点相邻的节点
            for (int[] next : graph[curId]) {
                int nextId = next[0];
                int weight = next[1];
                // 如果当前的更新距离更小才能更新
                if (distTo[nextId] > distTo[curId] + weight) {
                    // 更新距离
                    distTo[nextId] = distTo[curId] + weight;
                    // 当前节点的最短路径包含的边数 + 1
                    nums[nextId] = nums[curId] + 1;
                    if (nums[nextId] == n) {
                        // 是负环
                        flag = true;
                        break;
                    }
                    // 如果队列中没有,就入队
                    if (vis[nextId] == false) {
                        vis[nextId] = true;
                        queue.offer(nextId);
                    }
                }
            }
            // 是负环
            if (flag) {
                break;
            }
        }
        // 是负环了
        if (flag) {
            return -1;
        }
        int res = 0;
        for (int i = 1; i < graph.length; i++) {
            if (distTo[i] == Integer.MAX_VALUE) {
                return -1;
            }
            res = Math.max(res, distTo[i]);
        }
        return res;
    }
}

注意上面的vis数组位置和值,我们将节点拿出来时,vis置为false,是因为vis记录的是当前节点是否在队列中,而BFS搜索中,vis记录的是当前节点是否访问,不是它是不是在队列中,两者不能混淆!

最小体力消耗路径(中等)

在这里插入图片描述 把矩阵中的一个个元素,转换为图中的点,注意最短路径的定义:这条路径上高度差绝对值的最大值!

class node {
    int x, y;
    node (int x, int y) {
        this.x = x;
        this.y = y;
    }
}
class Solution {
    // 方向数组
    int[] x = new int[] {0,0,1,-1};
    int[] y = new int[] {1,-1,0,0};
    int m, n;
    public int minimumEffortPath(int[][] heights) {
        m = heights.length;
        n = heights[0].length;
        Queue<node> queue = new LinkedList<>();
        // 起点入队
        queue.add(new node(0, 0));
        // 记录最小体力消耗
        // 从(0,0) -> (i,j)
        int[][] minTo = new int[m][n];
        for (int i = 0; i < m; i++) {
            Arrays.fill(minTo[i], Integer.MAX_VALUE);
        }
        minTo[0][0] = 0;
        // 记录当前节点是否入队
        boolean[][] vis = new boolean[m][n];
        vis[0][0] = true;
        while (!queue.isEmpty()) {
            node tmp = queue.poll();
            vis[tmp.x][tmp.y] = false;
            for (int[] next : makeGraph(heights, tmp.x, tmp.y)) {
                int xx = next[0];
                int yy = next[1];
                // 计算起点(0,0)到当前点的最小消耗:取决于路径上的最大高度差
                // 这里类似于DP数组的更新
                int curMin = Math.max(minTo[tmp.x][tmp.y], Math.abs(heights[tmp.x][tmp.y] - heights[xx][yy]));
                // 更新 dp table
                if (minTo[xx][yy] > curMin) {
                    minTo[xx][yy] = curMin;
                    if (vis[xx][yy] == false) {
                        vis[xx][yy] = true;
                        queue.offer(new node(xx, yy));
                    }
                }
            }
        }
        // 返回结果
        return minTo[m - 1][n - 1];
    }
    List<int[]> makeGraph(int[][] heights, int xx, int yy) {
        List<int[]> graph = new LinkedList<>();
        for (int i = 0; i < 4; i++) {
            int tmpx = xx + x[i];
            int tmpy = yy + y[i];
            if (tmpx < 0 || tmpx >= m || tmpy < 0 || tmpy >= n) {
                continue;
            }
            graph.add(new int[]{tmpx, tmpy});
        }
        return graph;
    }
}

这道题用Dijkstra + 优先队列,可以很快解出来,这里还是用的SPFA。其实Dijkstra代码和SPFA差不多,Dijkstra需要用优先队列,不需要使用vis数组,只需要判断如果能够更新,才入队,否则就不入队,Dijkstra由于使用了优先队列,可以在队列的遍历的过程中提前返回答案(因为优先队列保证了第一次到达终点的答案一定是最小答案,也正是因为用了优先队列,Dijkstra才能那么快)。

SPFA和Dijkstra的过程都有点像DP算法,特别是在更新最值,存储最值时,和dp table一样。

下面给出Dijkstra代码,就是在SPFA基础上改动的:

class Solution {
    int m, n;
    int[] xx = new int[] {1,-1,0,0};
    int[] yy = new int[] {0,0,1,-1};
    public int minimumEffortPath(int[][] heights) {
        m = heights.length;
        n = heights[0].length;
        // 答案数组
        int[][] minTo = new int[m][n];
        for (int i = 0; i < m; i++) {
            Arrays.fill(minTo[i], Integer.MAX_VALUE);
        }
        minTo[0][0] = 0;
        class node {
            int x, y, minDist;
            node (int x, int y, int minDist) {
                this.x = x;
                this.y = y;
                this.minDist = minDist;
            }
        }
        Queue<node> queue = new PriorityQueue<>(new Comparator<node>() {
            @Override
            public int compare(node o1, node o2) {
                return o1.minDist - o2.minDist;
            }
        });
        queue.offer(new node(0,0, 0));
        while (!queue.isEmpty()) {
            node tmp = queue.poll();
            // Dijkstra可以提前结束
            if (tmp.x == m - 1 && tmp.y == n - 1) {
                return tmp.minDist;
            }
            if (tmp.minDist < minTo[tmp.x][tmp.y]) {
                continue;
            }
            for (int[] next : makeGraph(heights, tmp.x, tmp.y)) {
                int tx = next[0];
                int ty = next[1];
                int curMin = Math.max(minTo[tmp.x][tmp.y], Math.abs(heights[tmp.x][tmp.y] - heights[tx][ty]));
                if (curMin < minTo[tx][ty]) {
                    minTo[tx][ty] = curMin;
                    queue.offer(new node(tx, ty, minTo[tx][ty]));
                }
            }
        }
        return minTo[m - 1][n - 1];
    }
    List<int[]> makeGraph(int[][] heights, int x, int y) {
        List<int[]> graph = new LinkedList<>();
        for (int i = 0; i < 4; i++) {
            int tmpx = x + xx[i];
            int tmpy = y + yy[i];
            if (tmpx < 0 || tmpx >= m || tmpy < 0 || tmpy >= n) continue;
            graph.add(new int[] {tmpx, tmpy});
        }
        return graph;
    }
}

概率最大的路径(中等)

在这里插入图片描述 之前一直在说最小路径、有向图,这道题,无向图,求最大值。无向图简单,就是把一条边存两次,求最大值也简单,改变下if判断条件,就Ok啦~。

鸡汤来咯~

class Solution {
    public double maxProbability(int n, int[][] edges, double[] succProb, int start, int end) {
        // 建图,节点下标从0开始
        List<double[]>[] graph = new LinkedList[n];
        for (int i = 0; i < n; i++) {
            graph[i] = new LinkedList<>();
        }
        for (int i = 0; i < edges.length; i++) {
            // 无向图嘛,换个理解方式:双向图~
            graph[edges[i][0]].add(new double[] {edges[i][1], succProb[i]});
            graph[edges[i][1]].add(new double[] {edges[i][0], succProb[i]});
        }
        // 记录结果的数组
        double[] maxProb = new double[n];
        // 找最大值,要初始化为最小值
        Arrays.fill(maxProb, Double.MIN_VALUE);
        // 起点的概率应该是1,不然后续乘权值的时候会报错
        maxProb[start] = 1;
        // 记录是否入队
        boolean[] vis = new boolean[n];
        vis[start] = true;
        Queue<Integer> queue = new LinkedList<>();
        queue.offer(start);
        while (!queue.isEmpty()) {
            int curId = queue.poll();
            vis[curId] = false;
            for (double[] next : graph[curId]) {
                int to = (int)next[0];
                double prob = next[1];
                if (maxProb[to] < maxProb[curId] * prob) {
                    maxProb[to] = maxProb[curId] * prob;
                    if (vis[to] == false) {
                        vis[to] = true;
                        queue.offer(to);
                    }
                }
            }
        }
        // 有可能无法到达
        return maxProb[end] == Double.MIN_VALUE ? 0.00000 : maxProb[end];
    }
}

3488、最短路径(上交机试题)

在这里插入图片描述 这题骚在路径长度的把握,因为后面的一条道路的长度一定大于前面的路径总和,所以需要用并查集 在保证起点到其它城市都连通的情况下,道路长度最短,在此基础上再求最短路径,妙哉妙哉!

import java.util.*;
import java.io.*;

class UF {
    // 连通分量个数
    int count;
    // 记录每棵树
    int[] parent;
    // 记录每棵树的大小
    int[] size;
    // 初始化
    UF(int n) {
        this.count = n;
        this.parent = new int[n];
        this.size = new int[n];
        for (int i = 0; i < n; i++) {
            parent[i] = i;
            size[i] = 1;
        }
    }
    // find
    public int find(int x) {
        // 路径压缩
        while (x != parent[x]) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }
    // connected
    public boolean connected(int p, int q) {
        return find(p) == find(q);
    }
    // 联通
    public void union(int p, int q) {
        int rootP = find(p);
        int rootQ = find(q);
        // 已经联通
        if (rootP == rootQ) return;
        // 小树接在大树下
        if (size[rootP] > size[rootQ]) {
            parent[rootQ] = rootP;
            size[rootP] += size[rootQ];
        } else {
            parent[rootP] = rootQ;
            size[rootQ] += size[rootP];
        }
        // 连通分量--
        count--;
    }
    // 查询
    public int count() {
        return count;
    }
}
public class Main {
    static BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
    static BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out));
    static int MOD = 100000;
    public static void main(String[] args) throws IOException {
        String[] input = reader.readLine().trim().split(" ");
        int n = Integer.parseInt(input[0]);
        int m = Integer.parseInt(input[1]);
        // n个城市下标从0 - n-1
        List<long[]>[] graph = new LinkedList[n];
        for (int i = 0; i < n; i++) {
            graph[i] = new LinkedList<>();
        }
        UF uf = new UF(n);  // n个城市下标从0 - n-1
        for (int i = 0; i < m; i++) {
            input = reader.readLine().trim().split(" ");
            int u = Integer.parseInt(input[0]);
            int v = Integer.parseInt(input[1]);
            // 如果已经联通,那就不用再联通了,后面的道路长度一定大于前面的道路总和
            if (uf.connected(u, v)) continue;
            // 联通
            uf.union(u, v);
            // 算路径长
            long wei = quickPow(2, i);
            // 双向边
            graph[u].add(new long[] {v, wei});
            graph[v].add(new long[] {u, wei});
        }
        // 记录节点路径长
        long[] distTo = new long[n];
        Arrays.fill(distTo, Long.MAX_VALUE);
        distTo[0] = 0;  // 起点=0
        // 是否入队
        boolean[] vis = new boolean[n];
        vis[0] = true;
        Queue<Integer> queue = new LinkedList<>();
        queue.offer(0);
        while (!queue.isEmpty()) {
            int cur = queue.poll();
            vis[cur] = false;  // 出队
            for (long[] next: graph[cur]) {
                int nextId = (int)next[0];
                long wei = next[1];
                if (distTo[nextId] > distTo[cur] + wei) {
                    distTo[nextId] = distTo[cur] + wei;
                    if (vis[nextId] == false) {
                        vis[nextId] = true;
                        queue.offer(nextId);
                    }
                }
            }
        }
        for (int i = 1; i < n; i++) {
            if (distTo[i] == Long.MAX_VALUE) {
                writer.write(-1 + "\n");
            } else {
                // 在最后打印结果时才取模
                writer.write(distTo[i] % MOD + "\n");
            }
        }
        writer.flush();
    }
    // 快速幂
    // 底数、幂
    static long quickPow(long num, long n) {
        if (n == 0) return 1;
        else if (n % 2 == 1) {
            // 奇数
            return quickPow(num, n - 1) * num % MOD;
        } else {
            // 偶数
            long tmp = quickPow(num, n / 2) % MOD;
            return tmp * tmp % MOD;
        }
    }
}

图论算法的学习到这里就差不多结束了!一定要记得多刷题,多背模板!