在智能客服领域,对意图分类是核心所在。而往往会出现某些意图,训练sample比较少,也就是所谓的长尾问题,不能用传统的supervised classification模型来解决。对待这种场景的分类,机器学习领域有一个专门的研究领域 -- Few-shot Learning。下面就对此做一些总结。
1. Match Network(NIPS2016)
这是比较早的一篇引入小样本学习(one-shot learning)的文章,来自谷歌的DeepMind。
模型简介
本文提出了小样本学习模式,不同于普通的机器学习模式,即在所有训练数据上训练得到一个模型,然后在测试集上预测。而小样本学习,是在训练阶段引入episode概念,一个episode指随机选取N个类别,每个类别随机选取k个样本(N-way-K-shot),每次计算loss就在这一个episode上执行。就是通过这种训练模型,来更新模型。
模型介绍
这个模型要解决的问题是:**在给定S(Support set)的情况下,对于一个测试样本,如何得到属于的概率。**用数学表达就是:。
作者给的计算方式如下:
其中其实就是计算attention,其实可以看作是support set中所有样本的线性组合,只是给他们赋予不同的权重。这就是模型的loss。
而的计算也比较简单,用Query set中的与Support set中的的embedding计算cosine距离的softmax来计算权重:
其中,和是各自的特征提取器,也就是encoder。只是作者对于Support-set和Query-set两端的encoder有各自的特殊设计,下了一番工作(详细可以看paper),基本上就是基于Bi-LSTM实现的序列编码器。
训练过程
迭代一次的流程如下:
- 选择少数几个类别(例如5类),在每个类别中选择少量样本(例如每类5个);
- 将选出的集合划分:支撑集(Support-set),测试集(Query-set);
- 利用本次迭代的参考集,计算测试集的误差;
- 计算梯度,更新参数; 这样的一个流程文中称为episode。
2. Relation Network
模型简介
这篇论文
模型介绍
此模型主要有两个部分: 1.embedding模块 也就是encoder编码器,即将Support-set和Query-set中的输入sample进行编码。本文具体采用的是4个卷积块的CNN网络。
2.relation模块
这里也是与Match Network的主要区别,它不在是直接使用编码器得到的embedding来计算属于标签的概率分布。而Relation Network是将Support-set中sample得到的embedding和Query-set中sample得到的embedding先进行concat,然后输入给关系网络(实际上就是神经网络),来预测属于标签的概率分布。
可以用下面的公式表示:
Loss函数是一个典型的回归问题(正例向1回归,负例向0回归):
简单来说,关系网络的创新点就是提出用神经网络,而不是欧氏距离去计算两个特征变量之间的匹配程度。
**
3. Induction Network
阿里小蜜团队提出的模型,主要分成三个模块。1.Encoder 2.Induction 3.Relation
首先给出Few-shot训练的基本流程,以及Support Set和 Query Dataset构建的方式(注意support-set和query-set都是从C类中构建出来的):
下图是本文的模型结构:
1.Encoder Module
他们使用一个双向LSTM+attention来对句子进行encode。最终得到的句子representation为。当然,你也可以使用transformer或者BERT来对句子进行编码。
2.Induction Module
这个模块的作用,是想设计一个非线性的映射,把Support Set中类别的全部K个句向量映射到一个分类向量上。
他们采用的是Hinton老爷子的胶囊网络的思想,主要算法是:
基本上就是套用了Capsule Network中的动态路由策略。这样得到的就可以看作是对于类别的一种类向量表示。
3.Relation Module
对于Query Set中的每个query,使用相同的Encoder Module得到对应的句向量,然后与Induction Module中得到的类向量表示进行比较。这是一个典型的比较两个向量关系的问题,采用一个神经网络层+sigmoid来处理:
然后在使用一个sigmod计算第个query与类别之间的relation score。
4.Objective Function
为了衡量relation score 与ground truth (类别相同是1,类别不相同是0),这是一个回归问题,使用MSE来计算loss。这其实可以看作一种pair-wise的训练方式,就是比较类别向量与Query set中对应的sample是否相似。
在一个episode中给定Support Set ,有个类别,和Query Set 每类有n个sample,损失函数如下:
实现细节
对于一个5-way 5-shot的设定,是指每个episode,训练集中包含5(C)个类别,Support Set中每个类别有5(K)个sample。而对于Query Set,每个类别包含20个sample。就是每个episode共需要:205 + 55 = 125个句子。
ctrip的实现细节:
训练: Support Set中随机选c(c=15)个类,每个类选k(k=20)个样本。 Query set是一个batch(batchsize=128),保证batch里被选中的c个类都至少存在1个正例,然后其余batchsize-c个sample随机从所有的未使用过的数据里均匀采样。Query Set的训练<query,label>格式,其中,label是one-hot标签,query的目标标签在c个类中,对应的one-hot位置为1,不在全为0。
encoder:两层的Transformer,8头,word_embedding与char_embedding 进行concat拼接起来。
类别数量:230
predit:
- 预保存模型的label embedding
- 输入query直接绝所有类的lable embedding比较,最后选择argmax,然后使用阈值来控制。