小样本学习在语言理解任务中的突破

25 阅读3分钟

小样本学习在语言理解任务中的新突破

研究背景

当语音助手接收到新指令时,首先需要对其意图进行分类 - 比如播放音乐、查询天气或控制智能家居设备。随着新功能的开发,需要不断添加新的意图分类。由于这些新意图通常对应新设想的用例,训练数据往往很稀疏。在这种情况下,能够利用现有的意图分类能力,仅从少量示例(如5或10个)中学习新意图变得尤为重要。

研究方法

原型网络(ProtoNets)

原型网络用于元学习,即学习如何学习。通过原型网络,机器学习模型被训练来嵌入输入,将其表示为高维空间中的点。训练的目标是学习一种嵌入方式,最大化不同类别实例点之间的距离,同时最小化同类实例点之间的距离。

原型网络通过批次进行训练,每个批次包含多个不同类别的实例。在每个批次之后,使用随机梯度下降调整模型参数以优化嵌入之间的距离。这种方法不需要每个批次都包含模型将看到的所有类别的实例,使得原型网络在类别数量和每个类别的实例数量方面都非常灵活。

数据增强(ProtoDA)

我们在这一通用流程中加入了数据增强,以实现原型之间更好的分离。在小样本学习过程中,每个新类别的嵌入样本会传递到基于神经网络的生成器,该生成器会产生额外的嵌入样本,标记为与输入样本属于相同类别。

我们使用与训练原型网络相同的损失函数来训练样本生成器。也就是说,生成器学习生成新的样本,这些样本与真实样本结合时,能够最大化不同类别实例之间的分离,同时最小化同类实例之间的分离。

实验设计

在实验中,我们将样本生成器放置在网络中的两个不同位置:

  1. 在语义编码器和原型网络之间
  2. 在原型网络和分类层之间

文本输入首先通过执行初始嵌入的编码器,这个嵌入是可变长度句子的固定长度表示,利用双向长短期记忆网络来捕获输入的上下文信息。

实验结果

  • 在没有数据增强的情况下,原型网络在5样本情况下比基线方法F1分数提高约1%,在10样本情况下提高约5%
  • 添加神经数据增强后,在5样本情况下减少8.4%的F1错误,在10样本情况下减少12.4%的F1错误
  • 当生成器的输入是原型网络产生的嵌入时性能最佳

技术优势

我们相信,原型网络空间的较低维度(128个特征而不是768个特征)以及与训练目标函数(原型网络损失)的接近性,是性能差异的重要原因。这种方法为小样本学习在自然语言理解任务中的应用提供了新的思路和解决方案。