[nlp系列] - 编辑距离

331 阅读2分钟

编辑距离指的就是,将一个字符串转化成另一个字符串,需要的最少编辑操作次数(比如增加一个字符、删除一个字符、替换一个字符)。编辑距离越大,说明两个字符串的相似程度越小;相反,编辑距离就越小,说明两个字符串的相似程度越大。对于两个完全相同的字符串来说,编辑距离就是 0。

这个问题是求把一个字符串变成另一个字符串,需要的最少编辑次数。整个求解过程,涉及多个决策阶段,我们需要依次考察一个字符串中的每个字符,跟另一个字符串中的字符是否匹配,匹配的话如何处理,不匹配的话又如何处理。所以,这个问题符合多阶段决策最优解模型

1. 回溯算法

回溯是一个递归处理的过程。如果 a[i]与 b[j]匹配,我们递归考察 a[i+1]和 b[j+1]。如果 a[i]与 b[j]不匹配,那我们有多种处理方式可选:

可以删除 a[i],然后递归考察 a[i+1]和 b[j];

可以删除 b[j],然后递归考察 a[i]和 b[j+1];

可以在 a[i]前面添加一个跟 b[j]相同的字符,然后递归考察 a[i]和 b[j+1];

可以在 b[j]前面添加一个跟 a[i]相同的字符,然后递归考察 a[i+1]和 b[j];

可以将 a[i]替换成 b[j],或者将 b[j]替换成 a[i],然后递归考察 a[i+1]和 b[j+1]。

Python代码实现:

def edit_dist_recur(s, t):
    m = len(s)
    n = len(t)
    min_dist = float('inf')
    
    def _edit_dist(i, j, edist):
        nonlocal min_dist
        
        if i == m or j == n:
            if i < m:
                edist += (m-i)
            if j < n:
                edist += (n-j)
            if edist < min_dist:
                min_dist = edist
            return
        
        if s[i] == t[j]:
            _edit_dist(i+1, j+1, edist)
        else:
            _edit_dist(i+1, j, edist+1)
            _edit_dist(i, j+1, edist+1)
            _edit_dist(i+1, j+1, edist+1)
            
    _edit_dist(0, 0, 0)
    return min_dist
    
s = "mitcmud"
t = "mtacnufgy"
print(edit_dist_recur(s, t))

2. 动态规划

如果:a[i]!=b[j],那么:min_edist(i, j)就等于: min(min_edist(i-1,j)+1, min_edist(i,j-1)+1, min_edist(i-1,j-1)+1)

如果:a[i]==b[j],那么:min_edist(i, j)就等于: min(min_edist(i-1,j)+1, min_edist(i,j-1)+1,min_edist(i-1,j-1))

其中,min表示求三数中的最小值。

Python代码实现:

def edit_dist_dynamic(s: str, t: str) -> int:
    m, n = len(s), len(t)
    table = [[0] * (n) for _ in range(m)]

    for i in range(n):  #填充第一行
        if s[0] == t[i]:
            table[0][i] = i
        elif i != 0:
            table[0][i] = table[0][i - 1] + 1
        else:
            table[0][i] = 1

    for i in range(m):  #填充第一列
        if s[i] == t[0]:
            table[i][0] = i
        elif i != 0:
            table[i][0] = table[i - 1][0] + 1
        else:
            table[i][0] = 1

    for i in range(1, m):
        for j in range(1, n):
            table[i][j] = min(1 + table[i - 1][j], 1 + table[i][j - 1], int(s[i] != t[j]) + table[i - 1][j - 1])
    
    return table[-1][-1]
    
s = "mitcmud"
t = "mtacnufgy"
print(edit_dist_dynamic(s, t))