探索TextAttack框架:组件、功能与实践应用
引言
过去几年,测试自然语言处理模型对抗鲁棒性的兴趣日益增长。该领域的研究涵盖了生成对抗样本和防御对抗样本的新技术。由于这些攻击是在不同的数据和受害者模型上进行评估的,因此直接比较它们具有挑战性。
由于源代码缺失,复制早期工作作为基线需要时间且增加了出错风险。由于出版物中忽略的微小细节,完美复现结果也很困难。这些问题给该领域的基准比较带来了挑战。
为了解决这些挑战,开发了像TextAttack这样的框架。它是一个用于对抗性攻击、数据增强和对抗训练的NLP Python框架。该框架不仅解决了当前的挑战,还推动了对抗鲁棒性的进步。本文通过探索TextAttack,深入剖析其组件的细节,并通过深入的代码示例来研究其实际实现。
前提条件
- 对NLP的基本理解:熟悉自然语言处理概念,如分词、嵌入和序列模型。
- Python编程技能:具有使用Python的经验,包括安装库、编写脚本和处理数据。
- PyTorch知识:对PyTorch有基础了解,因为TextAttack通常与基于PyTorch的模型集成。
- TextAttack安装:使用
pip install textattack在环境中安装TextAttack框架。确保已安装NumPy、pandas和PyTorch等依赖项。 - 数据集熟悉度:了解如何处理数据集,例如Hugging Face Datasets库中可用的数据集。
TextAttack的关键组件
TextAttack通过将NLP攻击分解为四个组件来统一对抗性攻击方法:
- 目标函数:攻击的目标由目标函数定义,可以包括改变模型的预测或欺骗模型产生特定错误。
- 约束集:这是攻击必须遵守的规则,例如限制修改的单词数量或确保所有修改在语法上正确。
- 转换:转换决定了如何修改输入文本来生成对抗样本。例如,用同义词替换单词或插入/删除单词。
- 搜索方法:该组件决定了攻击如何探索可能的修改空间以找到最有效的对抗样本,例如使用基于梯度的方法或随机搜索算法。
TextAttack的主要功能
如上所述,TextAttack模块化设计的核心是四个主要组件。
在图中,“攻击模块”部分展示了TextAttack重新实现文献中攻击的能力。此部分提到有16篇论文和预构建的攻击方法可供使用。
该图展示了两种使用TextAttack构建攻击的方法:创建新攻击和对现有攻击进行基准测试。用户可以通过组合新的和现有的组件来创建新的对抗策略。
TextAttack拥有超过82个预训练模型,因此研究人员可以针对标准模型测试他们的新攻击。这使他们能够将自己的结果与之前的工作进行比较。
该图展示了两种攻击方法:数据增强和对抗训练。通过使用增强器模块,用户模型可以通过生成新样本来扩展现有的训练数据集,从而提高性能。借助TextAttack的训练流程,可以创建对抗样本并将其馈送到训练过程中,以改进模型。
TextAttack的预训练模型和数据集
TextAttack的预训练模型包括词级LSTM和CNN模块以及基于Transformer的BERT变体。这些模型已经在由HuggingFace提供的多样化数据集上进行了预训练。
TextAttack与NLP库的集成支持为测试和验证数据集自动加载相应的预训练模型。虽然许多先前文献都集中在分类和蕴含任务上,但TextAttack的预训练模型系列为研究提供了新途径。这使得研究人员能够深入研究所有GLUE任务中模型鲁棒性的相关问题。
对抗训练
TextAttack支持创建对抗样本的新训练集。该过程涉及几个步骤:
- 初始训练:在干净的训练集上对模型进行一定轮次的训练。
- 对抗样本生成:攻击为每个输入创建对抗版本。
- 数据集替换:此过程涉及用扰动变体替换原始数据集。
- 定期重新生成:根据模型当前的弱点,定期重新生成对抗数据集。
下表说明了标准LSTM分类器在使用和未使用对抗训练的情况下,针对TextAttack中使用的各种攻击方法的准确率。
它显示了LSTM模型在对抗deepwordbug、textfooler、pruthi、hotflip和bae攻击时的性能,与干净训练集上的基线Carlini分数进行了比较。该表比较了有攻击和无攻击时的准确率,评估了模型分别在20轮和75轮时对deepwordbug的鲁棒性,并评估了模型在20轮时对textfooler的脆弱性。
使用攻击方法对现有攻击进行基准测试
TextAttack允许采用模块化结构,将多个过去的研究攻击组合到一个框架中。这是通过添加一两个新组件来实现的,我们在创建新的攻击计划时获得了更大的灵活性和生产力。
攻击方法是一组预定义的步骤和配置,用于为NLP模型生成对抗样本。每个攻击方法有四个主要组件(目标函数、约束、转换、搜索方法)。
攻击方法是从最新研究中提炼出来的最佳实践和技术。它们加快了对抗样本的生产。这些方法使我们能够实施复杂的攻击策略,而无需了解所有底层细节,从而帮助研究人员和从业人员开展工作。我们在下图中展示了一些常见的攻击类型:
实际用例:情感分析
假设我们有一个情感分析模型,可以将电影评论分类为正面或负面类别。我们想使用textfooler攻击方法来测试该模型的强度和鲁棒性。输入的句子是“The movie was great”,模型预测为正面情感。让我们看看下图中的过程。
要开始攻击过程,我们必须找到语句中与正面情感相关的重要单词。例如,可以使用WordNet数据库将“fantastic”替换为其同义词。
我们希望尽可能保持语法正确性和原始含义。我们在贪心算法中使用迭代方法,在每次替换时检查模型的预测。
该过程持续进行,直到无法再进行实现所需结果的替换,或者发生预测改变为止。通过将单词“fantastic”替换为“great”之类的同义词,导致情感从正面翻转为负面,可能会对模型成功实施攻击。
在这种情况下,初始语句“The movie was fantastic”已被更改为“The movie was great”,我们可以看到模型的预测已从正面变为负面。这证明了模型存在漏洞。
textfooler是一种旨在暴露情感分析模型缺陷的巧妙方案。通过用同义词替换关键词,同时保持上下文连贯性,这种技术可以显著影响模型预测。这迫使我们优先考虑那些能够最大程度减少对此类对抗攻击脆弱性的训练和评估框架。
使用TextAttack的AttackedText对象增强对抗性NLP攻击
在使用传统的NLP攻击实现时,对分词后的文本进行修改通常会导致大写和分词问题。分词涉及将文本分解为单独的单词或标记,这可能会破坏其初始大写字母和单词之间的边界。
对“The movie was fantastic!”进行分词可能会导致“the”、“movie”、“was”或“fantastic!”,这省略了开头的大写字母。这使得转换更加困难。此类问题会削弱对抗样本的一致性和有效性,同时阻碍对NLP模型弹性的准确评估。
TextAttack的AttackedText对象通过允许直接在原始文本而不是分词版本上执行转换来解决这些问题。该对象保留初始输入并保留其所有属性,例如大写和单词边界。
该对象附带的辅助方法有助于进行受控修改。它们可以处理诸如大写转换或特定单词边界调整等挑战。让我们以将“Hello”替换为“Hi”的转换为例。它不会产生像“hi World!”这样的错误输出,而是正确地将“Hello World!”转换为“Hi World!”。
使用AttackedText对象具有各种优势。它通过保持原始文本结构和大写在转换过程中的真实性来生成对抗样本。分词和大写的精确性增强了攻击模型的可靠性。这为模型鲁棒性评估过程带来了可信度。
此外,开发人员可以专注于产生有效的更改,而无需处理阻碍新攻击采用的分词障碍。本质上,AttackedText组件极大地放大和增强了TextAttack的对抗转换能力。
在使用TextAttack的搜索方法进行攻击时,经常会多次遇到相同的输入。在这种情况下,存储先前计算的结果可以大大提高效率。
通过缓存提高TextAttack的效率
在使用TextAttack的搜索方法进行攻击时,经常会多次遇到相同的输入。存储先前计算的结果可以大大提高此类情况下的效率。
TextAttack可以加快预计算模型输出的检索速度,并在不重复任务的情况下验证是否满足所有约束。这个过程被称为记忆化。通过这种优化技术,搜索方法执行得更快,并提高了攻击期间的整体效率。
让我们考虑一个TextAttack的应用场景,我们通过使用对抗样本来测试情感分析模型的鲁棒性。让我们通过下图来可视化和解释这个过程。
- 初始输入:开头语句是“The movie was fantastic.”。情感分析器准确识别出将其表征为正面情感的情感基调。
- 对抗攻击:TextAttack使用对抗攻击技术,对原始句子应用巧妙的修改以操纵模型的预测。此过程可能包含带有修改措辞的示例——例如“The movie was fantastic!”或“The movie was fantastic”。
- 缓存:利用缓存的力量,实时识别和分类伪装成重复项的变体以提高效率,防止不必要的重新处理。例如,如果“The movie was fantastic.”被重复生成——无论是逐字还是稍作修改——系统将使用首次遇到时预先计算的缓存响应。在TextAttack复杂的算法模型对特定语言片段采取任何进一步操作之前,始终会交叉检查约束以确保连贯性。
- 效率增益:TextAttack的效率提升令人印象深刻。想象一下,在我们继续搜索旅程时,遇到了同样令人愉快的短语“The movie was fantastic”。我们的工具不再需要计算其输出或监控任何约束,因为它检索的是缓存结果。这节省了我们宝贵的时间,节省了计算能力,并带来了更顺畅的成功结果之旅。
自定义转换
我们将尝试一个简单的转换来启动在TextAttack中创建转换:将任何单词更改为单词“banana”。在TextAttack中,一个名为WordSwap的抽象类负责将句子分解为单词,同时避免交换停用词。通过扩展WordSwap并仅执行一个函数_get_replacement_words,所有术语都可以替换为“banana”。
在执行后续代码之前,请在环境中运行以下命令:
pip3 install textattack[tensorflow]
以下代码定义了一个自定义转换类BananaWordSwap,它继承自WordSwap。它将输入中的任何给定单词替换为单词“banana”。
from textattack.transformations import WordSwap
class BananaWordSwap(WordSwap):
"""Transforms an input by replacing any word with 'banana'."""
# We don't need a constructor, since our class doesn't require any parameters.
def _get_replacement_words(self, word):
"""Returns 'banana', no matter what 'word' was originally.
Returns a list with one item, since `_get_replacement_words` is intended to
return a list of candidate replacement words.
"""
return ["banana"]
使用转换
已经选择了转换。但是,要完成攻击,仍然缺少一些项目。还必须选择搜索方法和约束来实现攻击。此外,在使用此策略之前,我们需要一个目标函数、模型和数据集。(目标函数指示我们的模型执行的任务——在本例中是分类——以及攻击的类型——在本例中,我们将执行无目标攻击。)
创建目标函数、模型和数据集
我们的任务是发起对分类模型的攻击。因此,我们将使用UntargetedClassification类。对于我们的任务,让我们选择专门为使用AG新闻数据集进行新闻分类而优化的BERT。无需担心,因为许多模型已经准备就绪并方便地存储在HuggingFace的Model Hub中。TextAttack可以协同融合这些高质量模型及其数据集。
# 导入模型
import transformers
from textattack.models.wrappers import HuggingFaceModelWrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"textattack/bert-base-uncased-ag-news"
)
model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
# 使用模型创建目标函数
from textattack.goal_functions import UntargetedClassification
goal_function = UntargetedClassification(model_wrapper)
# 导入数据集
from textattack.datasets import HuggingFaceDataset
dataset = HuggingFaceDataset("ag_news", None, "test")
创建攻击
让我们使用贪心搜索方法。我们暂时不使用任何约束。
from textattack.search_methods import GreedySearch
from textattack.constraints.pre_transformation import (
RepeatModification,
StopwordModification,
)
from textattack import Attack
# 我们将使用我们的Banana word swap类作为攻击转换。
transformation = BananaWordSwap()
# 我们将约束已修改索引和停用词的修改
constraints = [RepeatModification(), StopwordModification()]
# 我们将使用贪心搜索方法
search_method = GreedySearch()
# 现在,让我们从4个组件构建攻击:
attack = Attack(goal_function, constraints, transformation, search_method)
我们可以打印我们的攻击以查看所有参数:
print(attack)
输出:
Attack(
(search_method): GreedySearch
(goal_function): UntargetedClassification
(transformation): BananaWordSwap
(constraints):
(0): RepeatModification
(1): StopwordModification
(is_black_box): True
)
使用攻击
让我们使用我们的攻击来成功攻击10个样本。
from tqdm import tqdm # tqdm为我们提供了一个不错的进度条。
from textattack.loggers import CSVLogger # 为我们跟踪一个数据框。
from textattack.attack_results import SuccessfulAttackResult
from textattack import Attacker
from textattack import AttackArgs
from textattack.datasets import Dataset
attack_args = AttackArgs(num_examples=10)
attacker = Attacker(attack, dataset, attack_args)
attack_results = attacker.attack_dataset()
# 以下传统教程代码详细展示了Attack API的工作方式。
# logger = CSVLogger(color_method='html')
# num_successes = 0
# i = 0
# while num_successes < 10:
# result = next(results_iterable)
# example, ground_truth_output = dataset[i]
# i += 1
# result = attack.attack(example, ground_truth_output)
# if isinstance(result, SuccessfulAttackResult):
# logger.log_attack_result(result)
# num_successes += 1
# print(f'{num_successes} of 10 successes complete.')
输出:
+-------------------------------+--------+
| Attack Results | |
+-------------------------------+--------+
| Number of successful attacks: | 8 |
| Number of failed attacks: | 2 |
| Number of skipped attacks: | 0 |
| Original accuracy: | 100.0% |
| Accuracy under attack: | 20.0% |
| Attack success rate: | 80.0% |
| Average perturbed word %: | 18.71% |
| Average num. words per input: | 63.0 |
| Avg num queries: | 934.0 |
+-------------------------------+--------+
在上面的代码中,利用TextAttack库对数据集执行对抗攻击。导入了必要的模块,包括用于进度条的tqdm和用于记录攻击结果的CSVLogger。为了对10个示例运行攻击,我们使用了AttackArgs类。
我们使用指定的攻击类型、相应的数据集和任何必要的参数创建了Attacker对象。接下来,我们执行了attack_dataset方法来执行攻击并获取结果。或者,我们可以通过遍历数据集中的每个元素、执行攻击并在攻击成功时记录结果来手动记录成功的攻击——正如这段脚本代码注释部分所包含的代码所证明的那样。
在该文本分类模型的对抗攻击中实现了80%的成功率。在十次尝试中,有八次成功地导致输入被错误分类,导致准确率从之前的满分下降到仅20%。
攻击策略证明了其鲁棒性和有效性,因为没有记录到被跳过的攻击。平均而言,每个包含大约63个单词的输入中,有18.71%的单词被修改。此外,在攻击过程中,每个输入的模型大约被查询了934次。
可视化攻击结果
在下面的代码中,我们使用了CSVLogger方法来记录AttackResult对象。此记录器有效地将所有结果攻击存储到数据框中。它便于信息的轻松访问和显示。通过将color_method设置为'html',攻击结果的差异通过HTML着色来表示,以增强视觉清晰度。我们在此过程中使用了IPython实用程序和pandas。
import pandas as pd
pd.options.display.max_colwidth = (
480 # 增加列宽以便我们能够实际读取示例
)
logger = CSVLogger(color_method="html")
for result in attack_results:
if isinstance(result, SuccessfulAttackResult):
logger.log_attack_result(result)
from IPython.core.display import display, HTML
results = pd.DataFrame.from_records(logger.row_list)
display(HTML(results[["original_text", "perturbed_text"]].to_html(escape=False)))
注意:读者可以运行代码并可视化结果。
结论
本文对NLP模型中的对抗鲁棒性进行了广泛分析,特别强调了TextAttack Python框架。这是一个基于Python的框架,旨在简化和增强对抗攻击的生成过程和防御机制。
该框架呈现了一个模块化架构,包含目标函数、约束、转换和搜索方法等基本组件。这种结构使用户能够有效地创建自定义攻击。它还通过基准测试便于轻松适应和标准化性能评估。
该框架扩展了对具有数据增强技术的对抗训练实践的支持。这增强了现实世界应用场景中的模型鲁棒性。该框架的效率通过经验示例得到了证明。这些示例突出了其在各个领域的实际效用。
参考资料
- TextAttack论文
- TextAttack文档
- 代码参考
感谢您通过某中心的社区学习。请查看我们在计算、存储、网络和托管数据库方面的产品。 了解更多关于我们的产品