文本分类还停留在BERT?对偶对比学习框架也太强了

1,702 阅读11分钟

论文简介:对偶对比学习:如何将对比学习用于有监督文本分类

论文标题:Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation

论文链接:arxiv.org/abs/2201.08…

代码链接:github.com/hiyouga/dua…

论文作者:{Qianben Chen}

知乎笔记:zhuanlan.zhihu.com/p/466685216

论文摘要

对比学习在无监督环境下通过自我监督进行表征学习方面取得了显著的成功。然而,有效地使对比学习适应于监督学习任务在实践中仍然是一个挑战。在这项工作中,作者提出了一个对偶对比学习(DualCL)框架,在同一空间内同时学习输入样本的特征和分类器的参数。具体来说,DualCL将分类器的参数视为关联到不同标签的增强样本,然后利用其进行输入样本和增强样本之间的对比学习。对5个基准文本分类数据集及对应低资源版本数据集的实验研究表明,DualCL分类精度明显得到提高,并证实了DualCL能够实现样本判别表示的效果

DualCL简介

表示法学习是当前深度学习的核心。在无监督学习的背景下,对比学习最近已被证明是获得下游任务的通用表征的有效方法。简单地说,无监督对比学习采用了一个损失函数,它迫使同一样本的不同“视角”的表示向量相似,而不同样本的表示向量则不同。最近对比学习的有效性方面被证明是由于同时实现了“对齐性”【alignment】和“一致性”【uniformity】。

对比学习方法也适用于监督表示学习 ,以往研究中也使用了类似的对比损失,基本原理是坚持同一类中样本的表示是相似的,不同类的样本表示是相似的。 清楚的。 然而,尽管取得了成功,但与无监督对比学习相比,这种方法的原则性似乎要差得多。 例如,表示的统一性不再有效; 也不是必需的,通俗来讲就是特征的在空间分布不再均匀,所以我们认为标准的监督对比学习方法对于监督表示学习来说并不自然。 另外存在一个事实就是,即这种对比学习方法的结果并没有直接给我们一个分类器,需要开发另一种分类算法来解决分类任务。

接下来我们讲讲DualCL提出的动机,作者为了旨在开发一种更自然的方法来实现在监督任务下的对比学习。作者关键动机是监督表示学习应该包括学习两种参数:一个是输入xx在适当的空间特征zz,用来满足分类任务需求,另一个是分类器的参数,或者作用于分类器空间的参数θ\theta;我们将这个分类器称为xx的“one example”分类器。在这种观点下,自然将样本xx联想到两种参数:一个为维度为ddzRdz\in \mathbb{R}^d,用来表示特征;一个是分类器参数θRd×K\theta \in \mathbb{R}^{d \times K},其中KK代表样本中分类总数。那么有监督的表示学习可以认为是为输入样本xx生成(z,θ)(z,\theta)

为了保证分类器θ\theta对于特征zz有效,只需要保证θTz\theta^Tz与样本xx的标签保持对齐,可以通过softmax归一化的概率与交叉熵进行约束。除此之外,在对比学习方法可以用来强制对这些(zθ)(z,θ)表示进行约束,具体来讲,我们将θ\theta^*记为样本xx的真实标签对应分类器θ\theta的理想参数,这里我们可以设计两种对比损失。第一个loss用来对比(z,θ)(z,\theta^*)与多个(z,θ)(z',\theta^*),其中zz'代表与样本xx不同类别的样本特征;第二个loss用来对比(z,θ)(z,\theta^*)与多个(z,θ)(z,\theta'^*),其中θ\theta'代表样本的不同类别对应分类器参数,作者将这种学习框架叫做dual contrastive learning(DualCL),对偶对比学习

在对比学习基础之上,正如论文标题,DualCL可以认为是一种独特的数据增强方法。具体来说,对于每个样本xx,其θθ的每一列都可以被视为“标签启发的输入表示”,或者是在特征空间中注入了标签信息的xx的增强视图。表1中说明了这种方法的强大之处,从左边的两个图片可以看出,标准的对比学习不能利用标签信息。相反,从右边的两个图来看,DualCL有效地利用标签信息对其类中的输入样本进行分类。

在论文实验中,作者在5个基准文本分类数据集上验证了DualCL的有效性。通过使用对偶比损失对预先训练好的语言模型(BERT和RoBERTa)进行微调,DualCL与现有的对比学习监督基线相比获得了最好的性能。作者还发现,DualCL提高了分类精度,特别是在低资源的场景下。此外通过可视化所学习的表征和注意力图,对DualCL给出了一些可解释性分析。

论文贡献可以总结如下:

  • 1)提出了双对比学习(DualCL),自然地将对比损失与有监督任务相结合;
  • 2)引入标签感知数据增强来获得输入样本的多个视图,用于DualCL的训练;
  • 3)在5个基准文本分类数据集上实证验证了DualCL框架的有效性;

DualCL原理

“对偶”表示有监督的对比学习方法目的就是:第一个是在适当空间中对分类任务的输入进行判别表示,第二个是构建监督任务的分类器,学习到分类器空间中分类器的参数。接下来我们看看DualCL的核心部分。

标签启发式的数据增强

为了获得训练样本的不同视图(views)表示,作者利用数据增强的思想来获得特征ziz_{i}和分类器θi\theta_{i}的表示。具体来讲就是将分类器每个类别的对应参数θi\theta_{i}作为ziz_{i}的独特表示,记为θik\theta_{i}^{k},称为标签感知输入表示,将标签kk信息注入到的xix_{i},作为额外增强视图。

在实践中,将标签集合{1,...,K}\{1,...,K\}插入到输入序列xix_{i},可以得到一个新的输入序列riRL+Kr_{i}\in \mathbb{R}^{L+K},然后通过PLMS(Bert或者Roberta)模型作为编码器ff,来获取输入序列的每个token特征,其中[CLS]特征作为样本xix_{i}的特征ziz_{i},插入的标签对应为标签启发的输入表示θik\theta_{i}^{k}。标签的名称作为标记,形成序列rir_{i},如“positive”、“negative”等。对于包含多个单词的标签,我们采用token特征的平均池化来获得具有标签感知的输入表示。这个操作和之前一篇论文很相似,大家可以有兴趣可以阅读:Bert还可以这么用:融合标签向量到BERT

对偶对比损失

利用输入样本xix_{i}的特征表示ziz_{i}和分类器θiθ_{i},DualCL作用就是将θiTziθ_{i}^{T}z_{i}的softmax归一化概率与xix_{i}的标签对齐。将θiθ^∗_{i}表示θiθ_{i}的某一列,对应于xix_{i}的真实标签索引,DualCL期望θiTziθ_{i}^{*T}z_{i}点积是最大化的。为了学习更好的ziz_{i}θiθ_{i},DualCL利用不同训练样本之间的关系定义了对偶对比损失,如果xjx_{j}xix_{i}有相同的标签,那么试图最大化θiTzjθ_{i}^{*T}z_{j},而如果xjx_{j}xix_{i}有不同的标签,则最小化θiTzjθ_{i}^{*T}z_{j}

给定一个来自输入样本xix_{i}的锚点ziz_{i}{θj}jPi\{ \theta^*_j \}_j\in P_{i}是正样本集合,{θj}jAi Pi\{ \theta^*_j \}_j\in A_i \ | P_i是负样本集合,关于z的对比损失可以定义如下:

同理,给定一个来自输入样本xix_{i}的锚点θi\theta_{i}^*{zj}jPi\{ z_j \}_j\in P_{i}是正样本集合,{zj}jAi Pi\{ z_j \}_j\in A_i \ | P_i是负样本集合,关于z的对比损失可以定义如下:

对偶比损失是上述两个对比损失项的组合: LDual=Lz+Lθ \mathcal{L_{Dual}}= \mathcal{L_{z}}+ \mathcal{L_{ \theta }}

对比训练和有监督预测

为了充分利用监督信号,DualCL还期望θiθ_{i}是一个很好的ziz_{i}分类器。因此作者使用一个改进版本的交叉熵损失来最大化每个输入样本xix_iθiTziθ_{i}^{*T}z_{i}

最后,最小化这两个训练目标来训练编码器ff。这两个目标同时提高了特征的表示质量和分类器的表示质量。总体损失应为: Loverall=LCE+λLDual \mathcal{L_{overall}}= \mathcal{L_{CE}}+\lambda \mathcal{L_{ Dual }} 其中,λλ是一个控制双对比损失项权重的超参数。

在分类过程中,我们使用训练好的编码器ff来生成输入句子xix_i的特征表示ziz_i和分类器θiθ_i。这里的θiθ_i可以看作是一个“one-example”的分类器,例如xix_i,我们将θiTziθ_{i}^{T}z_{i}的argmax结果作为模型预测: y^i=argmax(θikzi)\widehat{y}_i=argmax(\theta_i^k ·z_{i})

图1说明了对偶对比学习的框架,其中eCLSe_{CLS}是特征表示,ePOSe_{POS}eNEGe_{NEG}是分类器表示。在这个具体的例子中,我们假设具有“positive”类的目标样本作为锚点,并且有一个正样本具有相同的类标签,而有一个具有不同的类标签的负样本。对偶对比损失旨在同时将特征表示吸引到正样本之间的分类器表示上,并将特征表示排斥到负样本之间的分类器上。

表示之间的对偶性

对比损失采用点积函数作为表示之间相似性的度量,这就使得DualCL中的特征表示zz和分类器表示θθ之间存在双重关系。在线性分类器中,输入特征与参数之间的关系也出现了类似的现象。然后我们可以将θθ看作是一个线性分类器的参数,这样预先训练好的编码器ff就可以为每个输入样本生成一个线性分类器。因此,DualCL很自然地学习如何为每个输入样本生成一个线性分类器来执行分类任务。

实验设置

数据集

论文采用了SST-2、SUBJ、TREC、PC和CR四种数据集,数据集相关统计如下:

实验结果

从结果中可以看出,除了使用RoBERTa的TREC数据集外,同时使用BERT和RoBERTa编码器在几乎所有设置中都取得了最好的分类性能。与具有完整训练数据的CE+CL相比,DualCL对BERT和RoBERTa的平均改善率分别为0.46%和0.39%。此外,我们观察到,在10%的训练数据下,DualCL的性能明显大于CE+CL方法,在BERT和RoBERTa上分别高出0.74%和0.51%。同时,CE 和 CE+SCL 的性能无法超越 DualCL。 这是因为CE方法忽略了样本之间的关系,CE+SCL方法不能直接学习分类任务的分类器。

此外论文发现双重对比损失项有助于模型在所有五个数据集上实现更好的性能。 它表明利用样本之间的关系有助于模型在对比学习中学习更好的表示。

案例分析

为了验证DualCL是否能够捕获信息特征,作者还计算了[CLS]标记的特征与句子中每个单词之间的注意得分。首先在整个训练集上微调RoBERTa编码器。然后我们计算特征之间的l2l_2距离,并可视化图4中的注意图。结果表明,在对情绪进行分类时,所捕获的特征是不同的。上面的例子来自SST-2数据集,我们可以看到我们的模型更关注表达“积极”情绪的句子“predictably heart warming”。下面的例子来自CR数据集,我们可以看到我们的模型对表达“消极”情绪的句子更关注“small”。相反,CE方法没有集中于这些鉴别特征。结果表明DualCL能够成功地处理句子中的信息性关键词。

论文总结

  • 在本研究中,从文本分类任务的角度,提出了一种对偶对比学习方法DualCL,来解决监督学习任务。
  • 在DualCL中,作者使用PLMs同时学习两种表示形式。一个是输入示例的鉴别特征,另一个是该示例的分类器。我们引入了具有标签感知的数据增强功能来生成输入样本的不同视图,其中包含特征和分类器。然后设计了一个对偶对比损失,使分类器对输入特征有效。
  • 对偶对比损失利用训练样本之间的监督信号来学习更好的表示,通过大量的实验验证了对偶对比学习的有效性。

核心代码

关于Dual-Contrastive-Learning实现,大家可以查看开源代码:

github.com/hiyouga/Dua…

 def _contrast_loss(self, cls_feature, label_feature, labels):
        normed_cls_feature = F.normalize(cls_feature, dim=-1)
        normed_label_feature = F.normalize(label_feature, dim=-1)
        list_con_loss = []
        BS, LABEL_CLASS, HS = normed_label_feature.shape
        normed_positive_label_feature = torch.gather(normed_label_feature, dim=1,
                                                     index=labels.reshape(-1, 1, 1).expand(-1, 1, HS)).squeeze(1)  # (bs, 768)
        if "1" in self.opt.contrast_mode:
            loss1 = self._calculate_contrast_loss(normed_positive_label_feature, normed_cls_feature, labels)
            list_con_loss.append(loss1)
        if "2" in self.opt.contrast_mode:
            loss2 = self._calculate_contrast_loss(normed_cls_feature, normed_positive_label_feature, labels)
            list_con_loss.append(loss2)
        if "3" in self.opt.contrast_mode:
            loss3 = self._calculate_contrast_loss(normed_positive_label_feature, normed_positive_label_feature, labels)
            list_con_loss.append(loss3)
        if "4" in self.opt.contrast_mode:
            loss4 = self._calculate_contrast_loss(normed_cls_feature, normed_cls_feature, labels)
            list_con_loss.append(loss4)
        return list_con_loss

    def _calculate_contrast_loss(self, anchor, target, labels, mu=1.0):
        BS = len(labels)
        with torch.no_grad():
            labels = labels.reshape(-1, 1)
            mask = torch.eq(labels, labels.T)  # (bs, bs)
            # compute temperature using mask
            temperature_matrix = torch.where(mask == True, mu * torch.ones_like(mask),
                                             1 / self.opt.temperature * torch.ones_like(mask)).to(self.opt.device)
            # # mask-out self-contrast cases, 即自身对自身不考虑在内
            # logits_mask = torch.scatter(
            #     torch.ones_like(mask),
            #     1,
            #     torch.arange(BS).view(-1, 1).to(self.opt.device),
            #     0
            # )
            # mask = mask * logits_mask
        # compute logits
        anchor_dot_target = torch.multiply(torch.matmul(anchor, target.T), temperature_matrix)  # (bs, bs)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_target, dim=1, keepdim=True)
        logits = anchor_dot_target - logits_max.detach()  # (bs, bs)
        # compute log_prob
        exp_logits = torch.exp(logits)  # (bs, bs)
        exp_logits = exp_logits - torch.diag_embed(torch.diag(exp_logits))  # 减去对角线元素,对自身不可以
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)  # (bs, bs)
        # in case that mask.sum(1) has no zero
        mask_sum = mask.sum(dim=1)
        mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum)
        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask_sum.detach()
        loss = - mean_log_prob_pos.mean()
        return loss

参考资料

ICML 2020: 从Alignment 和 Uniformity的角度理解对比表征学习 blog.csdn.net/c2a2o2/arti…