简单说说NSM在KBQA的应用

1,190 阅读5分钟

简单说说NSM在KBQA的应用

论文:Improving Multi-hop Knowledge Base Question Answering by Learning Intermediate Supervision Signals

地址:dl.acm.org/doi/abs/10.…

源码:github.com/RichardHGL/…

提出背景

传统KBQA方案缺点:一个问题在KB中需要多跳才能得到答案,但是只有得到最终的答案了才能得到反馈,中间的结果浪费了。也就是在训练的过程中,模型只看答案而不给他具体的推理过程,使得模型学习不稳定和效率低。

所以该论文提出的方案:分为学生网络(student)和老师网络(teacher),借鉴了知识蒸馏的思想(一般业界考虑知识蒸馏是压缩模型的方法,后面也有证明soft label能够帮助学生网络提高表现)。

  • 学生网络的目标是找到问题的正确答案

  • 老师网络在学习中正向推理和反向推理结合,生成中间的推理过程(也就是实体-关系路径),后续可由学生网络从老师网络的输出中继续学习

    • 正向推理:从topic entity -> answer entity

    • 反向推理:从answer entity -> topic entity

fig1.png

  • 绿色的是topic entity,红色的是answer entity,黄色的推理中间entity,灰色则是不相干的实体

  • 蓝色的线就是虚假推理,而红色的线则是真正的推理

  • 可以看到正向推理过程中得到coffin rock和devil's doorway两个中间结果,而反向推理是到不了这两个中间结果的,所以就可以把通过robert taylor这条路径过滤掉

具体方法

训练学生网络的目的就是使其专注于KBQA的任务,学生网络只进行正向推理。训练老师网络的目的是为了给学生网络提供中间的推理过程,也就是说还是知识蒸馏的步骤,首先训练teacher然后使用teacher和ground truth两者结合loss来训练student。

  • student和teacher都是使用的NSM模型(Neural State Machine),在多跳推理的过程中逐步学习实体的分布(并未使用到KGE)

  • NSM又包含两个模块:instruction component和reasoning component

fig2.png

在介绍两个模块前,首先使用GloVe来对问题的word序列做embedding,然后使用LSTM encoder得到问题的hidden states,{hj}j=1n\{h_j\}_{j=1}^{n}ll是问题的长度,q=hlq = h_lqq就是LSTM的最终输出。

instruction component

instruction component具体则是读取问题的不同侧重点,得到一个list的instruction向量。

i(k)Rdi^{(k) \in R^d}:第k步推导的instruction向量,如图所示后续输入到reasoning component中

instruction component的公式:

fig3.png

fig4.png:这些参数都是训练过程中需要学习的

{i(k)}k=1n\{i^{(k)}\}_{k=1}^n:n步推导后会有n个instruction向量

reasoning component

reasoning component首先将instruction向量和当前取得的topic entity得到的邻居子图包含的所有关系做匹配,然后乘上上一步得到的实体概率分布求和,得到当前步实体的embedding。当前步实体embedding和上一步实体embedding结合输入到FF网络中,整合二者信息,更新得到当前所有实体的embedding,然后使用该embedding矩阵更新实体概率分布。

reasoning component至始至终都是在和topic entity相关的子图下进行计算,还是子图推理的坏处,如果子图不包含答案,那么再怎么推理结果也是错误的。

该模块的输入包含当前步的instruction向量和上一步推理的实体分布和实体embedding,输出则是实体分布p(k)p^{(k)}和实体embedding{e(k)}\{e^{(k)}\}

初始的实体embedding在特定的e下则是:

fig5.png

NeN_e是实体ee的邻居集,WTRd×dW_T \in R^{d \times d}

给定的三元组<e,r,e><e',r,e>,match向量m<e,r,e>(k)m_{<e',r,e>}^{(k)}则是由当前步的instruction向量和关系向量得到的:

fig6.png

其中WRRd×dW_R \in R^{d \times d}

match关系和instruction后,整合match向量和从上一步得到ee'的概率

fig7.png

pe(k1)p_{e'}^{(k-1)}:则是实体ee'的概率

然后更新实体embedding:fig8.png

最后计算中间实体的分布概率:fig9.png

E(k)E^{(k)}是一个第k步推理的矩阵,每一列则是实体的embedding向量

两种推理架构

fig10.png

  • 并行(parallel reasoning):两个NSM模型,一个做前向推理,一个做反向推理,两者并行,不共享参数

  • 混合(hybrid reasoning):共享instruction模块,前向反向循环,首先做前向推理,然后前向推理最终的输出为后向推理的起始输入

fig11.png

下标b则是反向推理的结果,下标f是前向推理的结果

loss的考虑

teacher

主要考虑两个loss:

  • reason loss:即最终结果是否正确,带*号的则是Ground Truth,即概率转化为频率的分布

fig12.png

  • correspondence loss:即中间结果是否正确,前向推理的第k步和反向推理的n-k步是否匹配

fig13.png

  • 最终考虑上面的结合:λb(0,1)and λc(0,1)\lambda_b \in (0 , 1) \, and \, \lambda_c \in (0 , 1),都是超参,控制相应loss的权重

fig14.png

student

teacher网络训练完成收敛后,就可以利用teacher网络得到的中间entity概率分布

fig15.png

然后student只是执行前向推理,但是利用到了teacher得到的中间结果,L1就是最终结果的匹配,而L2则是student在推理过程中得出中间结果和teacher得出中间结果的匹配𝜆也是超参来调整训练的权重。

fig16.png

最后

大概有一年没有更新博客了吧,上次更新还是在研究生开学前的一段时间,当时还在恶补机器学习的基础。研一一年的学习,完全转成了面向Pytorch编程。。每天看看论文,跑跑模型啥的。现在暑假出来实习,把每天学的东西总结一下,发点存货,哈哈