简单说说NSM在KBQA的应用
论文:Improving Multi-hop Knowledge Base Question Answering by Learning Intermediate Supervision Signals
提出背景
传统KBQA方案缺点:一个问题在KB中需要多跳才能得到答案,但是只有得到最终的答案了才能得到反馈,中间的结果浪费了。也就是在训练的过程中,模型只看答案而不给他具体的推理过程,使得模型学习不稳定和效率低。
所以该论文提出的方案:分为学生网络(student)和老师网络(teacher),借鉴了知识蒸馏的思想(一般业界考虑知识蒸馏是压缩模型的方法,后面也有证明soft label能够帮助学生网络提高表现)。
-
学生网络的目标是找到问题的正确答案
-
老师网络在学习中正向推理和反向推理结合,生成中间的推理过程(也就是实体-关系路径),后续可由学生网络从老师网络的输出中继续学习
-
正向推理:从topic entity -> answer entity
-
反向推理:从answer entity -> topic entity
-
-
绿色的是topic entity,红色的是answer entity,黄色的推理中间entity,灰色则是不相干的实体
-
蓝色的线就是虚假推理,而红色的线则是真正的推理
-
可以看到正向推理过程中得到coffin rock和devil's doorway两个中间结果,而反向推理是到不了这两个中间结果的,所以就可以把通过robert taylor这条路径过滤掉
具体方法
训练学生网络的目的就是使其专注于KBQA的任务,学生网络只进行正向推理。训练老师网络的目的是为了给学生网络提供中间的推理过程,也就是说还是知识蒸馏的步骤,首先训练teacher然后使用teacher和ground truth两者结合loss来训练student。
-
student和teacher都是使用的NSM模型(Neural State Machine),在多跳推理的过程中逐步学习实体的分布(并未使用到KGE)
-
NSM又包含两个模块:instruction component和reasoning component
在介绍两个模块前,首先使用GloVe来对问题的word序列做embedding,然后使用LSTM encoder得到问题的hidden states,。是问题的长度,,就是LSTM的最终输出。
instruction component
instruction component具体则是读取问题的不同侧重点,得到一个list的instruction向量。
:第k步推导的instruction向量,如图所示后续输入到reasoning component中
instruction component的公式:
:这些参数都是训练过程中需要学习的
:n步推导后会有n个instruction向量
reasoning component
reasoning component首先将instruction向量和当前取得的topic entity得到的邻居子图包含的所有关系做匹配,然后乘上上一步得到的实体概率分布求和,得到当前步实体的embedding。当前步实体embedding和上一步实体embedding结合输入到FF网络中,整合二者信息,更新得到当前所有实体的embedding,然后使用该embedding矩阵更新实体概率分布。
reasoning component至始至终都是在和topic entity相关的子图下进行计算,还是子图推理的坏处,如果子图不包含答案,那么再怎么推理结果也是错误的。
该模块的输入包含当前步的instruction向量和上一步推理的实体分布和实体embedding,输出则是实体分布和实体embedding
初始的实体embedding在特定的e下则是:
是实体的邻居集,
给定的三元组,match向量则是由当前步的instruction向量和关系向量得到的:
其中
match关系和instruction后,整合match向量和从上一步得到的概率
:则是实体的概率
然后更新实体embedding:
最后计算中间实体的分布概率:
是一个第k步推理的矩阵,每一列则是实体的embedding向量
两种推理架构
-
并行(parallel reasoning):两个NSM模型,一个做前向推理,一个做反向推理,两者并行,不共享参数
-
混合(hybrid reasoning):共享instruction模块,前向反向循环,首先做前向推理,然后前向推理最终的输出为后向推理的起始输入
下标b则是反向推理的结果,下标f是前向推理的结果
loss的考虑
teacher
主要考虑两个loss:
- reason loss:即最终结果是否正确,带*号的则是Ground Truth,即概率转化为频率的分布
- correspondence loss:即中间结果是否正确,前向推理的第k步和反向推理的n-k步是否匹配
- 最终考虑上面的结合:,都是超参,控制相应loss的权重
student
teacher网络训练完成收敛后,就可以利用teacher网络得到的中间entity概率分布
然后student只是执行前向推理,但是利用到了teacher得到的中间结果,L1就是最终结果的匹配,而L2则是student在推理过程中得出中间结果和teacher得出中间结果的匹配𝜆也是超参来调整训练的权重。
最后
大概有一年没有更新博客了吧,上次更新还是在研究生开学前的一段时间,当时还在恶补机器学习的基础。研一一年的学习,完全转成了面向Pytorch编程。。每天看看论文,跑跑模型啥的。现在暑假出来实习,把每天学的东西总结一下,发点存货,哈哈