KMP算法——两天半梦半醒终于理解透彻

359 阅读7分钟

问题描述:

给定两个字符串S=“s1 s2 s3...sn”和T="t1 t2 t3...tm",在主串S中寻找子串T的过程称为模式匹配,T称为模式,如果匹配成功,返回T在S中的位置。在文本处理、杀毒软件、操作系统、编译系统、数据库系统以及搜索引擎中,模式匹配是使用最频繁的操作。

常规思想

利用BF算法即蛮力匹配、暴力求解,即从主串S的第一个字符开始和模式T的第一个字符比较,若相等,则继续比较两者的后续字符;否则,从主串S的下一个字符开始和模式T的第一个字符进行比较。重复上述过程,直到S或T中所有字符比较完毕。若T中的字符全部比较完毕,则匹配成功,返回本趟匹配的开始位置;否则匹配失败。该算法的时间复杂度为O(n*m)。

KMP算法

思想

在描述KMP算法之前,先对D.E.Knuth,J.H.Morris和V.R.Pratt三位大神膜拜一下!该算法对BF算法进行了改进,改进的出发点是对主串不进行回溯。

如下图所示,此时S[03]=T[03],但S[4]≠T[4],根据BF算法的思想,下一步将会从主串S的下一个字符开始比较,即S[1]与T[0]进行比较,但我们发现,S[1]=T[1]且T[1]≠T[0],故S[1]≠T[0],同理,没有必要将T[0]与S[1]、S[2]进行比较,又因为S[3]=T[0],所以我们可以从S[4]与T[1]开始比较,也即主串下标i不进行回溯,模式T下标j回溯到1再进行比较,在数据量大的情况下,这种措施将会节省不少的时间。

模式串公式的推导

当我们进行模式匹配时,发现S[i]≠T[j]时,我们的下一步操作,是想要找到一个k,使得S[i]与T[k]继续进行比较,即实现了主串不回溯,模式串回溯来寻找正确的模式匹配。观察部分匹配正确时的特征,如下图所示。

对于部分匹配正确之前的模式匹配,我们有下图所示。

综上,我们得到一个只关于模式串T的表达式,即

                                              T[0]...T[k-1]=T[j-k]...T[j-1]

这个式子的具体含义是,是否存在k值,使得模式串T当中j下标之前,能够满足前缀与后缀完全一致,并且前缀与后缀的长度皆为k。

我们再诠释一下这条式子的目的,当我们在进行模式匹配时,发现S[i]≠T[j],我们想要找到一个k值,并从S[i]与T[k]开始比较,这样能更好地进行模式匹配(节省时间,拒绝无用功)。而关于k值的寻找,在我们对模式匹配的情况进行分析后,我们发现k值只与模式串T,以及下标j值有关,与主串无关,综上,我们可以利用next数组来储存下标为j时对应的k值,先让我们看看next数组的取值公式。

让我们来分析公式的各类情况,

1、当j=0时,显然next[j]=0,但是为了后面编程时k值得计算方便,我们不妨令next[0]=-1;

2、当j=1时,此时不存在k能够满足1≤k<j,或者根据公式上方解释易知,next[1]=0;

3、当j≠0、1时,我们会寻找满足T[0]...T[k-1]=T[j-k]...T[j-1]的最大k值并储存,若不存在则设  置next[j] = 0。

在求解得到next数组之后,模式匹配就变得简单起来了,我们设置变量i = j = 0,其中i表示主串S的下标即S[i],j表示模式串T的下标即T[j],然后开始我们的算法流程:

       比较S[i]是否等于T[j],若相等,则对变量i,j都进行自增(i++,j++);若不相等,则令j = next[j]。重复上述过程,直到i或j遍历到字符串的尾部,若i遍历到主串S尾部,则不能找到模式T的匹配;若j遍历到模式T的尾部,则模式匹配成功,下标为(i-j),表示为在主串S下标为(i-j)开始的位置起,找到了模式T的匹配串。

// 前提为已经求解获得next数组

int Kmp(string mainString,string pattern){
    int i = 0;
    int j = 0;
    while (i < mainString.length()){
        if (j==-1 || mainString[i]==pattern[j]){
            i++; j++;
            if (j == pattern.length())
                return (i-j);
        } else {
            j = next[j];
        }
    }

    return -1;    // 表示不能找到模式T的匹配
}

next数组的创建过程 (核心)

如果按照公式进行编码,我们将得到如下代码:

int* next_solve(string pattern){
    int length = pattern.length();
    int *next = new int[length];
    next[0] = -1;
    for (int j=1; j<length; j++){
        int max = 0;
        for (int k=1; k<j; j++){
            string str1 = pattern.substr(k,0);
            string str2 = pattern.substr(k,j-k);
            if (str1 == str2)
                max = k;
        }
        next[j] = max;
    }

    return next;
}

但我们发现这个求next数组的方法好像效率不高(时间复杂度为O(m^2)),当模式串T足够长时,模式匹配将要耗费大量的时间来求解next数组,那有没有更好求解next数组的办法呢?先让我们见识一下真正高效率却晦涩难懂的代码:

int* next_solve(string pattern){
    int length = pattern.length();
    int *next = new int[length];
    next[0] = -1;

    int j = 0;
    int k = -1;
    while (j < length-1){
        if (k==-1 || pattern[k]==pattern[j]){
            next[++j] = ++k;
        } else {
            k = next[k];    // ?
        }
    }

    return next;
}

先不管这段的代码的具体含义以及有效性,我们先来看一下时间复杂度,为O(m)!如果这段代码正确有效,那将带来巨大的时间收益。那接下来就让我们来逐层分析下这段代码的具体含义吧。

首先是特殊情况的处理,即next[0]=-1、next[1]=0,这里是为了符合公式要求不再进行赘述。

算法的主要思想:

           利用已知的next[j]、k以及模式串T,来求解next[j+1]。

若已知next[j],则有T[0]...T[k-1]=T[j-k]...T[j-1],此时我们对T[k]与T[j]进行比较,我们将分为下面两种情况。

1、T[k]=T[j]

如图,当T[k]=T[j]时,我们由之前的T[0]...T[k-1]=T[j-k]...T[j-1]可以得到,T[0]...T[k-1]+T[k]=T[j-k]...T[j-1]+T[j],也即T[0]...T[k]=T[j-k]...T[j],故next[j+1]=k+1,我们实现了利用已知的next[j]、k以及模式串T,来达到求解next[j+1]的目的。

2、T[k]≠T[j]

此段解释对应的代码正是" k = next[k]",这句代码短小精悍,也是全部代码当中最反人类的地方!不过不用担心,接下来我们将结合图以及解释来让你彻底理解"k = next[k]"背后的真正含义,我们先按"k = next[k]"流程走一遍。

在模式匹配过程中,我们发现T[k]≠T[j],此时我们令 k = next[k] = 2(读者可简单自行思考),

此时再继续T[k]与T[j]的比较,发现T[k]=T[j],再有之前的T[0]...T[k-1]=T[j-k]...T[j-1]得到T[0]...T[k]=T[j-k]...T[j],求得T[j+1] = k+1 = 3。

可是为什么"k = next[k]"这么简单的一个语句能实现这么强大的功能呢?

我们用一张图来诠释这个过程。

1、绿色代表元素相等,即T[0]...T[k-1]=T[j-k]...T[j-1];

2、粉色代表在绿色元素中,存在next[k]个元素相互相等,设k' = next[k],满足T[0]...T[k'-1]=T[k-k']...T[k-1],T[j-k]....T[j-k+k'-1]=T[j-k']...T[j-1],由这三个式子推出

 T[0]...T[k'-1]=T[j-k']...T[j-1]

3、此时再判断下T[k]是否等于T[j],若相等则T[j+1]=k'+1=next[k]+1;若不相等,则令 k'' = next[k'] 重复过程3。

至此,关于next数组的求解过程已经诠释完毕。

LeetCode例题试练

附上代码

class Solution {
public:
    int strStr(string haystack,string needle){
        if (needle.length() == 0)
            return 0;

        int *next = next_solve(needle);

        int i = 0,j = 0;
        while (i < haystack.length()){
            if (j==-1 || haystack[i]==needle[j]){
                i++; j++;
                if (j == needle.length())
                    return (i-j);
            } else {
                j = next[j];
            }
        }

        return -1;
    }

    int* next_solve(string pattern){
        int length = pattern.length();
        int *next = new int[length];
        next[0] = -1;

        int j = 0;
        int k = -1;
        while (j < length-1){
            if (k==-1 || pattern[k]==pattern[j]){
                next[++j] = ++k;
            } else {
                k = next[k];
            }
        }

        return next;
    }
};

运行结果: