label embedding做文本分类

1,823 阅读3分钟

1. Joint Embedding of Words and Labels for Text Classification(ACL 2018)

图1

1.传统模式:

这片文章首先给出了传统做Text Classification的模式(图1a): 对于一个序列X=x1,...,xLX={x_1,...,x_L}, 整个过程需要计算f=f0f1f2f = f_0*f_1*f_2。即:

  1. f0f_0,将token变为word embedding(维度为P),构造句子表示V(PL)V(P*L),。
  2. f1f_1,aggregation层,即输入sample(token拼接的word-embedding矩阵)变成向量表示zz。可以使用各种模型,比如:TextCNN,LSTM+Attention,BERT等。
  3. f2f_2,分类层,比如用交叉熵,softmax等。

这种模型的问题是,它只在f2f_2阶段才会利用label信息。 所以本文提出的方案是出发点有两个:

  1. 想在前面两个阶段也引入label信息来指导生成zz
  2. 文本分类往往不是利用整句话的向量来预测,往往使用某个关键词来进行。所以,在进行aggregation时,需要考虑每个token与label之间的关系,这也是attention的思想,只是他与label-embedding之间进行

2.使用label-embedding的模式:

  1. f0f_0阶段,生成label-embedding。如果分类是token,可以直接使用token-embedding。如果是标签,可以用高斯分布初始化,然后跟随模型一起训练。生成的label-embedding为C={c1,..,ck}C=\{c_1, .., c_k\},一共有k个类别。
  2. 在计算聚合向量时,引入label信息。
  • 2.1 首先计算输入句子矩阵与label-embedding矩阵之间的cosine相似度,用词来表示token与label之间的相似度: 截屏2020-08-07 下午7.32.31.png-9.7kB

  • 2.2 进一步获取连续词之间的相对时空信息,对于以l为中心长度为2r-1的文本序列做如下操作: 截屏2020-08-07 下午7.42.02.png-12.5kB 这一步,计算得到ulRKu_l\in{R^{K}}

  • 2.3 利用max-polling得到最大相关的系数,第l个token的最大兼容性得分(compatibility value)来表示对应的label: 公式4 mlml是一个长度的LL的向量,整个文本序列的兼容性/注意力得分(label-based attention score)为: 截屏2020-08-07 下午7.49.17.png-10.7kB

  • 2.4 然后,这个序列的representation可以用注意力得分与词向量加权计算得到: 截屏2020-08-07 下午7.52.06.png-9kB

2. Explicit Interaction Model towards Text Classification (AAAI 2019)

1.简介

本文跟上文思路一样,引入交互机制,计token与label的相关信息,将其带到文本分类中。模型结构如下: image.png-83.9kB 主要不同在Interaction的计算方式,上文中,是计算token与label-embedding的cosine相似度,而本文直接在label矩阵与token矩阵之间计算点机运算

2.模型简述

  • 2.1 encoder层 作者使用两种:GRU和Region Embedding
  • 2.2 交互层(interaction)
  1. 定义可训练的标签矩阵TRckT\in{R^{c*k}} (c表示类别数,k表示维度),文本的字级表示为HRnkH\in{R^{n*k}}(n表示文本长度,k表示维度)。
  2. 目标词t和类别s之间的匹配分数,计算如下: Its=Ts,:Ht,:TI_{ts}=T_{s,:}H_{t,:}^T
  3. 文本序列整体计算结果是,得到的结果IRcnI\in{R^{c*n}}I=THTI = TH^T 作者使用这种简单直接的方式,相比较计算元素对位乘法(elementt-wise multiply)和余弦相似度(cosine similarity),主要是考虑到计算效率。
  • 2.3 聚合层(aggregation) 该层的作用是将每个类的交互矩阵聚合到一个logits上,表示类于输入文本之间的匹配分数。作者简单使用两个FC层的MLP,中间加一个Relu激活函数实现,最后使用softmax或者sigmoid得到logits。

  • 2.4 分类层 使用交叉熵计算loss。