文本生成任务:beamsearch解码搜索

47 阅读2分钟

上一篇文章的末尾简单介绍了解码器在测试阶段的工作原理,每次预测下一时刻的输出单词时,其实是预测整个词表的概率分布,一般采用贪婪搜索算法 (greedy search),每次选择概率最大的单词作为输出。 贪婪搜索策略往往无法得到全局最优解,于是诞生了beam search,接下来我们将详细介绍这两个解码方式。

Greedy search

解码阶段相当于是一个语言模型,给定一个初始隐层状态 h,生成序列为 Y,则贪婪算法计算公式如下:

P(Y)=i=1Tlogp(y~i)P(Y) = \sum_{i=1}^{T}log p(\widetilde{y}_i)

其中,

p(y~i)=argmaxp(v1,v2,v3vnhi1yi1)p(\widetilde{y}_i) = argmax p(v_1, v_2, v_3……v_n|h_{i-1}, y_{i-1})

其中 P(Y) 表示最终生成序列的概率,Y 表示序列 y~1\widetilde{y}_1,y~2\widetilde{y}_2,y~3\widetilde{y}_3,……y~n\widetilde{y}_n ,每次生成的序列 y~i\widetilde{y}_i 是在给定前一时刻隐层和前一时刻的输出下,整个词表 V 上概率最大的位置对应的单词。

显然这种方法无法取得全局最优解,虽然当前时刻,某个单词的概率比较低,但语言模型是在序列上计算序列的概率之积,是一个累积的过程。获取全局状态最优解,与隐马尔科夫模型所使用的维特比算法相似,但维特比算法需要消耗大量的时间和空间复杂度,这对于生成任务来说,相当于每个时刻有三万多种不同的状态,句子的最大序列长度至少二十,如果直接采用维特比算法,显然是无法接受的。

因此,一种折中的算法:柱搜索策略 (beam search)应运而生。

柱搜索算法

柱搜索算法,也就是我们常说的 beam search 算法,整体示意图如下:

截屏2023-09-21 下午12.43.31.png

为了方便描述,我们简化词表大小为 3,实际上加上句子起始标记和结束标记,词表大小为 5,设置柱大小设置为 2,初始解码用“[S]”做选择,深色部分表示被挑选的概率值最大的前两个单词,以被选择的两个词对应的隐层状态和预测单词,分别衍生出下一时刻 6 种不同的状态,从这 6 中状态中,选择 2 条概率最大的路径,依次类推, 直到解码出 “[EOS]” 标志,解码过程结束,最终得到两个最优解。

相比于贪婪搜索策略,柱搜索扩大了搜索的路径,能够避免陷入局部最优。如果设置柱大小 (Beam Size) 为 K,柱搜索策略具体做法可以描述为, 在 t − 1 时刻,已经通过柱搜索策略,得到 K 个最大概率的句子,Yt1=y~1<t,y~2<t,...,y~<KtY_{t−1} = {\widetilde{y}_1<t, \widetilde{y}_2<t,...,\widetilde{y}<K_t} ,其中 y~k<t \widetilde{y}_k<t 表示在 t − 1 时刻生成的前 t − 1 个单词序列,上标 k ∈ {1, 2, ...K},表示概率最大的前 K 中序列中的第 k 个序列。那么下一时刻预 测的序列 Yt 可以表示为:

Yt=y~<t+11,y~<t+12,...,y~<t+1kYt ={\widetilde{y}^1_{<t+1},\widetilde{y}^2_{<t+1},...,\widetilde{y}^k_{<t+1}}

=argsortklogp(y~<tky~tk),i{1,2,...K},k{1,2,...K}=argsort_k logp(\widetilde{y}^k_{<t}\widetilde{y}^k_{t} ), i ∈ \{1, 2, ...K\}, k ∈ \{1, 2, ...K\}

y~tk1,...y~tkK=argsortKp(v1,v2,...vVht1,y~t1k){\widetilde{y}^{k1}_{t},...\widetilde{y}^{kK}_{t}} = argsort_Kp(v_1, v_2,...v_V|h_{t-1}, \widetilde{y}^{k}_{t-1})

其中,
̃y~t1iK\widetilde{y}^{iK}_{t-1} 表示,给定 t − 1 时刻概率最大的前 K 个序列中的第 i 个序列,生成 t时刻对应单词,在词表 V 上的概率分布,与传统的直接利用 V 个概率词的作为候选词的做法不同,本文选取概率最大的前 K 个单词构成 t 时刻的候选单词列表 {y~tiK,...y~tkK \widetilde{y}^{iK}_t, ...\widetilde{y}^{kK}_t},最后的 Y 序列由 Y 序列和当前时刻的候选词累乘,选择前 K 大概率值对应的候选序列作为 Yt 时刻的序列。以此类推,直到解码结束,将得到概率值最大的前 K 个句子。

为了防止模型倾向输出短句子,需要加入长度惩罚项,对目标函数归一化。除以生成句子的长度,相当于取每个单词的概率对数值的平均,这样很明显地减少对输出长的结果的惩罚。

总结

采用 greedy 算法,可能会导致陷入局部最优,然后生成重复序列的片段,但是解码速度比较快; 采用 beam search方法,可以避免陷入局部最优,而且会生成多个不同的句子,在多样性上比 greedy 好,但是速度比较慢。现在随着芯片技术发展,一般采用 beam search 搜索方式,近期在此基础上还出现了 topk、topP等方法。