浅谈KM算法

1,541 阅读4分钟

背景介绍

最近测试团队在对项目中的算法进行性能测试,构造不同的压测场景发现KM算法的耗时差距显著。构造的场景如下:

  1. 订单数1000,车辆数1000,在全城范围匹配,每个订单匹配的车辆数限制为100,经过KM计算耗时在190ms左右。
  2. 订单数1000,车辆数1000,在一个六边形网格匹配,每个订单匹配的车辆限制为100,经过KM计算耗时在90ms左右。 场景1和场景2的订单数,车辆数,以及组成匹配对的数量都是一样,为什么场景1比场景2 KM算法的耗时要大一倍呢?

带着这样的疑惑,开始对KM算法进行调研和验证。

匹配相关的算法

KM算法的作者是Kuhn-Munkras,发明时间在1960年左右,你随手百度一把就能找到不少相关资料,这里我就不重复罗列了,稍微归总一下基于二分图匹配相关的几个算法。

  • 匈牙利算法:最大匹配算法,基于增广路径的思路来解决匹配问题。这个算法应该是二分图匹配算法的开端,后面的算法都是以它作为基础,衍生而来。
  • 完美匹配算法:完美匹配必然就是最大匹配,属于最大匹配的子集,保证了二分图中每个顶点都有一个唯一的匹配项。
  • KM算法:也叫最优匹配算法。和最大匹配不同的点是:KM匹配时之间带有权重,除了保证了数量最大外,还需要保证匹配整体的权重最优。为了保证数量最大化,KM需要对不存在的匹配关系,初始化一个不可能的值。

从产品发展的角度来看,最大匹配算法,完美匹配算法,最优匹配算法,应该是一个向前迭代的过程。完美匹配算法是最大匹配算法的子集,最优匹配算法又是完美匹配算法的子集。

KM算法性能分析

KM算法整体的时间复杂度是:O(V^3),V代表顶点的个数。这个只是KM算法整体视角的时间复杂度,并不能解决我上面遇到的问题,需要进一步对KM算法的原理有深入理解。

整体的解题思路:

  1. 初始化可行顶标的值。
  2. 匈牙利算法寻找完备匹配。
  3. 若未找到完备匹配则修改可行顶标的值。
  4. 重复(2)(3)直到找到相等子图的完备匹配为止。

代码实现如下:

package km;

import java.util.LinkedList;
import java.util.Queue;

public class KM {
    private double[][] weight;
    private int maxVoNum;
    private double[] lx, ly, slack;
    private boolean[] xUsed, yUsed;
    private int[] linkX, linkY, before;

    public int[] entrance(double[] inputWeight, int inputMaxVoNum) {
        this.weight = new double[inputMaxVoNum][inputMaxVoNum];
        for (int i=0; i<inputMaxVoNum*inputMaxVoNum; i++) {
            int x = i/inputMaxVoNum;
            int y = i%inputMaxVoNum;
            weight[x][y] = inputWeight[i];
        }
        init(weight, inputMaxVoNum);
        int isExecutor = executeBFS();
        if (isExecutor == -1) {
            int[] res = new int[1];
            res[0] = -1;
            return res;
        }
        int[] res = new int[inputMaxVoNum];
        for (int i=0; i<inputMaxVoNum; i++){
            res[i] = linkX[i];
        }
        return res;
    }

    private int executeBFS() {
        for(int i=0; i < maxVoNum; i++){
            lx[i] = 0;
            ly[i] = 0;
            for (int j=0; j<maxVoNum; j++){
                if (lx[i] < weight[i][j]) {
                    lx[i] = weight[i][j];
                }
            }
        }
        for (int i=0; i<maxVoNum; i++) {
            linkX[i] = -1;
            linkY[i] = -1;
            before[i] = -1;
        }
        for (int i=0; i<maxVoNum; i++){
            clearSlack();
            while (true) {
                clearXUsed();
                clearYUsed();
                if (bfs(i)) {
                    break;
                }else {
                    adjustLxy();
                }
            }

        }
        return 0;
    }

    private boolean doubleIsZero (double num) {
        return num >= -1*Math.pow(10, -6) && num <= 1*Math.pow(10, -6);
    }

    private int init(double[][] inputWeight, int inputMaxVoNum) {
        if (null == inputWeight || inputMaxVoNum <= 0 ){
            return -1;
        }
        this.maxVoNum = inputMaxVoNum;
        this.lx = new double[maxVoNum];
        this.ly = new double[maxVoNum];
        this.slack = new double[maxVoNum];
        this.xUsed = new boolean[maxVoNum];
        this.yUsed = new boolean[maxVoNum];
        this.linkX = new int[maxVoNum];
        this.linkY = new int[maxVoNum];
        this.before = new int[maxVoNum];
        return 0;
    }

    private boolean bfs(int x) {
        Queue<Integer> queue = new LinkedList<>();
        queue.add(x);
        while (!queue.isEmpty()) {
            int u = queue.poll();
            this.xUsed[u] = true;
            for(int v = 0; v < maxVoNum ; ++v) {
                if (yUsed[v]) continue;
                double t = lx[u] + ly[v] - weight[u][v];
                if (doubleIsZero(t)){
                    yUsed[v] = true;
                    if (-1 != linkY[v]) {
                        queue.add(linkY[v]);
                        before[linkY[v]] = u;
                    }else {
                        int left= u, right = v;
                        while (-1 != left) {
                            int nextRight = linkX[left];
                            linkX[left] = right;
                            linkY[right] = left;
                            left = before[left];
                            right = nextRight;
                        }
                        return true;
                    }

                }else {
                    if (slack[v] > t) {
                        slack[v] = t;
                    }
                }
            }

        }
        return false;
    }

    private void clearXUsed() {
        for (int i=0; i<maxVoNum; i++) {
            xUsed[i] = false;
        }
    }

    private void clearYUsed() {
        for (int i=0; i<maxVoNum; i++) {
            yUsed[i] = false;
        }
    }

    private void clearSlack() {
        for (int i=0; i<maxVoNum; i++) {
            slack[i] = Double.MAX_VALUE;
        }
    }

    private void adjustLxy() {
        double d = Double.MAX_VALUE;
        for (int i=0; i<maxVoNum; i++){
            if (yUsed[i]==false&&slack[i]<d){
                d = slack[i];
            }
        }
        for (int i=0; i< maxVoNum; i++){
            if (xUsed[i]) {
                lx[i] -= d;
            }
            if (yUsed[i]) {
                ly[i] += d;
            }else {
                slack[i] -= d;
            }
        }
    }
}

再回到上面的问题,场景1和场景2的耗时差别这么大。KM除了整体的时间复杂度是:O(V^3),还跟订单和车辆匹配权重不同值的个数相关,正好对应上面一个是全城范围匹配,一个是六边形网格范围匹配。

匹配权重不同值的个数越多,代表要循环越多才能找到最大值,消耗的时间就越多。

针对实验场景,KM的性能耗时跟如下因素相关:订单数量,车辆数量,匹配值不同的数量。

总结

KM算法运用的业务场景相对较好理解,但是要理解KM算法的原理和实现细节还是要花一定的精力。