聊聊MCMF算法

1,460 阅读4分钟

工作后不管我们研究一个领域,还是学习一门知识,不会无缘无故只是为了学习而学习,总是会先带着问题出发,为了解决问题去研究或者学习新的领域。

MCMF的完整语义是minimum cost maximum flow,翻译成中文就是最小费用最大流算法,这是一个经典的运筹学算法。作为一名工程领域的软件工程师,为什么要解决运筹学领域的算法问题?

  • 遇到了实际的业务问题,需要解决
  • 好奇心驱动

问题背景

在出行行业,人们上下班高峰期经常会遇到想要打车,但是周边却没有司机;而有些地方没有人打车,但是周边却有足够的司机在空等着;从而造成想要打车的人打不到车,想要接单的司机又接不到单。如何解决这种空间维度,供需不平衡的问题?

提前对空闲的司机进行调度,把司机调到有人打车的地方,但是缺车的地方很多,一个司机只能被调到一个地方。每个城市很多地方周边都缺车,也有很多空闲的司机等着被调度,那如何更合理的把所有空闲的司机都调到缺车的地方?

问题分析

从工程的视角来分析,哪里缺车就把空闲的车往哪里调不就行了,哪还有这么多讲究。其实不然,这不是一个贪心算法问题,需要全局的视角来分析:

  • 10个地方都缺车,但是只有6个空闲司机,这6个空闲司机去哪最合适呢?
    • 你可能会想哪里最近去哪里,这可不一定哦,因为你去了最近的地方,可能导致另一个司机要去很远的地方。
  • 6个地方缺车,但是有10个空闲司机,那选择哪6个空闲司机去呢?
    • 你可能会想,反正司机多,随便选6个就可以,其实问题跟上面一样

每个地方缺车的数量不一样,每个司机被调过去的路程不一样,总而言之每个司机被调到各个地方花费的代价是不同的。

经过分析那最终要求解的问题就是:用最小的花费代价,满足各个地方最多的缺车需求,正好和MCMF算法要解决的问题领域一致。

MCMF算法

MCMF算法主要解决网络流的问题,网络流如下:

image.png

S: 网络流的源点(起点)
T: 网络流的汇点(终点)
1,2,3:网络流的中间节点
边上面的黑色字体表示流量,括号的中的红色字体表示费用
从图中可以看到最终到T的最大流量只能是3,因为和T连接的两条边的流量,都是最小流量。
S -> T路径可以是:

  • S -> 3 -> T + S -> 3 -> 2 -> T (1)
  • S -> 3 -> T + S -> 1 -> 2 -> T (2)

(1)和(2)都到了最大流3,
(1) 最小费用:(3+4)*2+(3+1+1)*1=19
(2) 最小费用:(3+4)*2+(1+2+1)*1=18

从而可以得到这个网络流的最小费用:18,最大流:3.

实现代码

可以参考:(1条消息) 最小费用最大流问题与算法实现(Bellman-Ford、SPFA、Dijkstra)_WhiteJunior的博客-CSDN博客_最小费用最大流问题

以下是我的实现代码,Java版本的

package mcf;

import lombok.Builder;
import lombok.Getter;

import java.util.Arrays;
import java.util.LinkedList;
import java.util.Queue;

/**
 * 类描述:
 *
 * @ClassName McmfAlg2
 * @Description TODO
 * @Author ***
 * @Date 2022/12/20 9:44
 * @Version 1.0
 */
public class McmfAlg2 {

    private static final int MAN = 5050;
    private final int start; // 源节点
    private final int end; // 汇节点
    @Getter
    private int maxFlow, minCost; // 最大流,最小费用
    private final boolean[] visited = new boolean[MAN]; // BFS 记录节点是否访问过
    private final int[] cost = new int[MAN]; // 记录费用
    private final int[] pre = new int[MAN]; // 记录指向to的 from节点
    private final int[] last = new int[MAN]; // 记录最后指向to的 是 哪条边
    private final int[] flow = new int[MAN]; // 流量记录
    private final int[] head = new int[MAN]; // 记录 from 节点有多少不同的指向的 边
    private final Edge[] edges = new Edge[MAN]; // 记录 各个边
    private int numEdge; // 边计数
    private final Queue<Integer> queue = new LinkedList<>();

    public McmfAlg2(int start, int end) {
        this.numEdge = -1;
        Arrays.fill(this.head, -1);
        this.start = start;
        this.end = end;
    }

    @Builder
    static class Edge {
        private final int from;// 开始节点
        private final int to; // 结束节点
        /**
         * 记录 from 下一个对应的边
         */
        private final int next; //开始节点 下一个对应的边
        private int cap; // 剩下的流量
        private final int origin; // 原始流量
        private final int cost; // 费用
    }

    /**
     * 打印车对应的网格
     * @param carArr carArr
     */
    public void printCarAndGirdRelation(int[] carArr) {
        for (int value : carArr) {
            for (int j = this.head[value]; j != -1; j = this.edges[j].next) {
                if (this.edges[j].origin - this.edges[j].cap == 1) {
                    System.out.println("车:" + value + ",网格:" + this.edges[j].to);
                }
            }
        }
    }

    public void addEdge(int from, int to, int flow, int cost) {
        Edge edge1 = Edge.builder()
                .from(from)
                .to(to)
                .origin(flow)
                .cap(flow)
                .cost(cost)
                .next(head[from])
                .build();
        this.numEdge++;
        this.edges[numEdge] = edge1;
        head[from] = numEdge;

        Edge edge2 = Edge.builder()
                .from(to)
                .to(from)
                .origin(0)
                .cap(0)
                .cost(-cost)
                .next(head[to])
                .build();
        this.numEdge++;
        this.edges[numEdge] = edge2;
        head[to] = numEdge;
    }

    public void MCMF() {
        while (SPFA(this.start, this.end)) {
            int now = this.end;
            this.maxFlow += this.flow[end];
            this.minCost += this.flow[end] * this.cost[end];
            while (now != this.start) {
                this.edges[this.last[now]].cap -= this.flow[this.end];
                this.edges[this.last[now] ^ 1].cap += this.flow[this.end];
                now = this.pre[now];
            }
        }
    }

    /*
     *  时间复杂度: O(m * E),
     *  m为所有顶点进队列的平均次数,一般小于等于2*顶点个数
     *  E为给定图的边集合
     */
    private boolean SPFA(int start, int end) {
        Arrays.fill(cost, MAN);
        Arrays.fill(flow, MAN);
        Arrays.fill(visited, false);
        this.queue.add(start);
        this.visited[start] = true;
        this.cost[start] = 0;
        this.pre[end] = -1;

        while (!this.queue.isEmpty()) {
            int now = this.queue.poll();
            this.visited[now] = false;
            for (int i = this.head[now]; i != -1; i = this.edges[i].next) {
                if (this.edges[i].cap > 0
                        && this.cost[this.edges[i].to] > this.cost[now] + this.edges[i].cost) {
                    this.cost[this.edges[i].to] = this.cost[now] + this.edges[i].cost;
                    this.pre[this.edges[i].to] = now;
                    this.last[this.edges[i].to] = i;
                    this.flow[this.edges[i].to] = Integer.min(this.flow[now], this.edges[i].cap);
                    if (!this.visited[this.edges[i].to]) {
                        this.visited[this.edges[i].to] = true;
                        queue.add(this.edges[i].to);
                    }
                }
            }
        }
        return this.pre[end] != -1;
    }
}

测试代码如下:

 @Test
    public void test002() {
        McmfAlg2 mcmfAlg2 = new McmfAlg2(0,4);
        mcmfAlg2.addEdge(0,1,3,1);
        mcmfAlg2.addEdge(0,3,4,3);
        mcmfAlg2.addEdge(1,2,2,2);
        mcmfAlg2.addEdge(3,2,3,1);
        mcmfAlg2.addEdge(3,4,2,4);
        mcmfAlg2.addEdge(2,4,1,1);
        mcmfAlg2.MCMF();
        System.out.println("最大流为:"+mcmfAlg2.getMaxFlow());
        System.out.println("最小费用为:"+mcmfAlg2.getMinCost());
    }