文本分类是自然语言处理 (NLP) 中的经典监督任务。文本分类的热门用例包括情感分析、文档分类、意图检测、假新闻检测、垃圾邮件检测和主题标签等等。深度学习的最新进展,尤其是预训练语言模型 (PLM) 以及来自 HuggingFace (HF) 的转换器等库的流行极大地促进了这项任务。人们可以轻松地拿起 PTM 并在新数据集上对其进行微调,微调可以在几分钟内完成(取决于所选数据集和硬件的容量)。
然而,数据并不总是可用或完整的。实际上,对于行业中的许多文本分类任务,通常不存在带标签的数据集。在其他一些情况下,用于微调的数据集的体积非常小且不完整。要获得一个体面大小的数据集(例如,数据集imdb中的 50k 标记示例),手动标记数据集需要花费大量时间和精力,并且在某些情况下对于资源有限的公司来说是不可能的。
为了应对这一挑战,研究人员已经找到了一种可以利用带有小型标记数据集的 PLM 的方法,这称为少样本学习。在文献中,有许多提出的方法,例如ADAPET和T-FEW。本文讨论了SetFit,这是一种经过验证的文本分类性能的小样本学习方法。
快速复审
SetFit 利用 Sentence Transformer(Sentence BERT 或 SBERT)。在其原始论文中,该方法有两个步骤:
- 微调 SBERT 模型
- 训练分类器头
SBERT 是基于孪生网络或三重网络的模型。本质上,它力求获得文本嵌入,使得两个相似文本序列(例如,具有相同标签)的嵌入具有较小的余弦距离,而两个不同文本序列的嵌入具有较大的余弦距离。
在 SetFit 中,为了在具有C类的有限大小数据集上微调 SBERT 模型,它按如下方式对三元组进行采样:
- 样本锚文本ᵃ
- 在s ᵃ的同一类中采样正文本s ᵖ****
- 用不同的类别对负文本s ⁿ 进行采样
- 重复这些步骤 R 次(例如,R = 20)
这样,即使原始数据集很小,结果数据集也会大得多。三元组网络以三元组 ( sᵃ , sᵖ , sⁿ ) 作为输入并产生三个嵌入。然后通过优化称为“triplet loss”的特殊损失来学习网络参数。
图 1. SetFit 模型(来源:原论文)
微调网络后,生成文本嵌入。然后将这些嵌入与原始标签一起使用来训练分类模型。可以使用经典的分类头,如来自 sklearn 的SVM 。
我们可能会注意到几个有趣的观点:
- 该方法不是端到端的,因为使用了两个单独的学习步骤。然而在其发布的实现中,这些步骤被很好地集成在一起,非常易于使用。
- 与孪生网络类似,三元组网络中的组件共享权重,因此在生成嵌入时,可以使用它们中的任何一个。
执行
数据集:在 HF hub 上,有很多用于文本分类的数据集。我们将为二进制文本分类选择一个数据集,即cola(胶水)。该数据集有 8.55k 行用于训练,1.04k 行用于验证集。我们将使用一小部分训练集来微调一个名为paraphrase-mpnet-base-v2SetFit 的句子转换器模型。下面是代码,主要是根据原文。
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset`
`# Load a dataset from the Hugging Face Hub
dataset = load_dataset("glue", "cola")`
`N = 8 # Sampling N examples per class
train_dataset = sample_dataset(dataset[“train”], label_column=”label”, num_samples=N)
eval_dataset = dataset[“validation”]`
`# Load a SBERT model from HF hub
model = SetFitModel.from_pretrained(“sentence-transformers/paraphrase-mpnet-base-v2”)`
`# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20,
num_epochs=1,
column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer)`
`# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print("evaluation result: ", metrics)`
`# Inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst"])
以下是不同 N 值的结果。
表 1. SetFit在具有不同数量训练示例的 Cola 上的准确性方面的表现。
在这个简单的实验中,还使用了 Roberta-base 和 SBERTparaphrase-mpnet-base-v2对整个训练集进行微调。可以看出,在整个数据集上的性能与非常小的子集相差不远。显然,当有大量完整的数据集可用时,没有理由使用小样本学习,因为通过标准训练/微调可以实现非常好的性能。但如前所述,在许多现实生活中情况并非如此。此外,对小子集的微调比整个数据集快得多。因此,在资源有限的情况下,SetFit 是一种非常有前途的方法。