持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第23天,点击查看活动详情
方法
我们使用最大均值差异(MMD)和 Kolmogorov-Smirnov (K-S) 检测器检测文本数据的漂移。
在这个示例中,我们将专注于检测协变量漂移, 因为检测预测的标签分布漂移与其他方式没有区别(在 CIFAR-10 上检查 K-S 和 MMD 漂移)。
然而,当我们想要获取输入数据漂移时,它变得更加复杂。
当我们处理表格或图像数据时,我们可以直接在输入上应用两个样本假设检验,或者在预处理步骤后进行测试。
例如:使用Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift(他们称之为Untrained AutoEncoder 或UAE)中建议的随机初始化编码器。
但是在处理文本时,无论是字符串还是tokenized格式,都不是那么简单,因为它们不直接表示输入的语义。
因此,我们提取文本的embeddings并检测它们的漂移。此过程对我们检测到的漂移类型有重大影响。严格来说,我们不再检测,因为(预)训练embeddings的整个训练过程(目标函数、训练数据等)对我们提取的embeddings有影响。
该库包含利用 HuggingFace transformer 包中预训练embeddings的功能,但也允许您轻松使用自己选择的embeddings。本文中的示例说明了这两个选项。
注意:
正如本文中所做的那样,建议将文本数据作为字符串列表 (
List[str]) 传递给检测器。 这允许与 HuggingFace 的 transformers 库无缝集成。上述情况的一个例外是使用自定义 embeddings 时。 在这里,确保数据以兼容的格式传递给自定义 embeddings 模型非常重要。 在最后一个示例中,定义了
preprocess_batch_fn以将list转换为自定义TensorFlow embedding所期望的np.ndarray。
后端
该方法适用于 PyTorch 和 TensorFlow 框架,用于统计测试和预处理步骤。 但是 Alibi Detect 不会为您安装 PyTorch。 如何执行此操作请查看 PyTorch 文档。
数据集
我们使用包含 25000 个用于训练的和 25000 个用于测试的电影评论情感分类数据集(二分类)。 安装 nlp 库以获取数据集:
!pip install nlp
import nlp
import numpy as np
import os
import tensorflow as tf
from transformers import AutoTokenizer
from alibi_detect.cd import KSDrift, MMDDrift
from alibi_detect.saving import save_detector, load_detector
加载 Tokenizer
model_name = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
加载数据
def load_dataset(dataset: str, split: str = 'test'):
# 包含训练、测试、无监督数据集
data = nlp.load_dataset(dataset)
X, y = [], []
for x in data[split]:
X.append(x['text'])
y.append(x['label'])
X = np.array(X)
y = np.array(y)
return X, y
# 训练集
X, y = load_dataset('imdb', split='train')
print(X.shape, y.shape)
运行结果:
(25000,) (25000,)
让我们分别看一下负面和正面的评论:
# 0 表示负面 , 1表示正面
labels = ['Negative', 'Positive']
print(labels[y[-1]])
print(X[-1])
运行结果:
Negative
This is one of the dumbest films, I've ever seen. It rips off nearly ever type of thriller and manages to make a mess of them all.<br /><br />There's not a single good line or character in the whole mess. If there was a plot, it was an afterthought and as far as acting goes, there's nothing good to say so Ill say nothing. I honestly cant understand how this type of nonsense gets produced and actually released, does somebody somewhere not at some stage think, 'Oh my god this really is a load of shite' and call it a day. Its crap like this that has people downloading illegally, the trailer looks like a completely different film, at least if you have download it, you haven't wasted your time or money Don't waste your time, this is painful.
print(labels[y[2]])
print(X[2])
运行结果:
Positive
Brilliant over-acting by Lesley Ann Warren. Best dramatic hobo lady I have ever seen, and love scenes in clothes warehouse are second to none. The corn on face is a classic, as good as anything in Blazing Saddles. The take on lawyers is also superb. After being accused of being a turncoat, selling out his boss, and being dishonest the lawyer of Pepto Bolt shrugs indifferently "I'm a lawyer" he says. Three funny words. Jeffrey Tambor, a favorite from the later Larry Sanders show, is fantastic here too as a mad millionaire who wants to crush the ghetto. His character is more malevolent than usual. The hospital scene, and the scene where the homeless invade a demolition site, are all-time classics. Look for the legs scene and the two big diggers fighting (one bleeds). This movie gets better each time I see it (which is quite often).
我们将原始测试集拆分为一个参考数据集和一个在统计测试的 H0 下不应拒绝的数据集。 我们还创建了不平衡的数据集,并在参考集中注入了选定的单词。
def random_sample(X: np.ndarray, y: np.ndarray, proba_zero: float, n: int):
if len(y.shape) == 1:
# 获取下标
idx_0 = np.where(y == 0)[0]
idx_1 = np.where(y == 1)[0]
else:
idx_0 = np.where(y[:, 0] == 1)[0]
idx_1 = np.where(y[:, 1] == 1)[0]
# 计算分别从标签为 0 或 1 的数据中取出多少数据
n_0, n_1 = int(n * proba_zero), int(n * (1 - proba_zero))
# 随机选择N个数据
idx_0_out = np.random.choice(idx_0, n_0, replace=False)
idx_1_out = np.random.choice(idx_1, n_1, replace=False)
# 拼接筛选出的标签为0或标签为1的数据
X_out = np.concatenate([X[idx_0_out], X[idx_1_out]])
y_out = np.concatenate([y[idx_0_out], y[idx_1_out]])
return X_out.tolist(), y_out.tolist()
def padding_last(x: np.ndarray, seq_len: int) -> np.ndarray:
try: # try not to replace padding token
last_token = np.where(x == 0)[0][0]
except: # no padding
last_token = seq_len - 1
return 1, last_token
def padding_first(x: np.ndarray, seq_len: int) -> np.ndarray:
try: # try not to replace padding token
first_token = np.where(x == 0)[0][-1] + 2
except: # no padding
first_token = 0
return first_token, seq_len - 1
def inject_word(token: int, X: np.ndarray, perc_chg: float, padding: str = 'last'):
seq_len = X.shape[1]
n_chg = int(perc_chg * .01 * seq_len)
X_cp = X.copy()
for _ in range(X.shape[0]):
if padding == 'last':
first_token, last_token = padding_last(X_cp[_, :], seq_len)
else:
first_token, last_token = padding_first(X_cp[_, :], seq_len)
if last_token <= n_chg:
choice_len = seq_len
else:
choice_len = last_token
idx = np.random.choice(np.arange(first_token, choice_len), n_chg, replace=False)
X_cp[_, idx] = token
return X_cp.tolist()
参考、H0 和不平衡数据集:
# proba_zero = fraction with label 0 (=negative sentiment)
n_sample = 1000
# 参考数据集
X_ref = random_sample(X, y, proba_zero=.5, n=n_sample)[0]
# H0数据集
X_h0 = random_sample(X, y, proba_zero=.5, n=n_sample)[0]
# 不平衡数据集
n_imb = [.1, .9]
X_imb = {_: random_sample(X, y, proba_zero=_, n=n_sample)[0] for _ in n_imb}
在参考数据集中注入单词:
# 好极了、好的、不好的,极差的
words = ['fantastic', 'good', 'bad', 'horrible']
# 受干扰的百分比,1% 或 5%
perc_chg = [1., 5.] # % of tokens to change in an instance
# input_ids: 你的 tokens 的数字表示
words_tf = tokenizer(words)['input_ids']
words_tf = [token[1:-1][0] for token in words_tf]
max_len = 100
tokens = tokenizer(X_ref, pad_to_max_length=True,
max_length=max_len, return_tensors='tf')
X_word = {}
for i, w in enumerate(words_tf):
X_word[words[i]] = {}
for p in perc_chg:
x = inject_word(w, tokens['input_ids'].numpy(), p)
dec = tokenizer.batch_decode(x, **dict(skip_special_tokens=True))
X_word[words[i]][p] = dec
tokens['input_ids']
<tf.Tensor: shape=(1000, 100), dtype=int32, numpy=
array([[ 101, 1188, 1794, ..., 0, 0, 0],
[ 101, 1556, 5122, ..., 1307, 1800, 102],
[ 101, 3406, 4720, ..., 5674, 2723, 102],
...,
[ 101, 2082, 1122, ..., 1641, 107, 102],
[ 101, 1124, 118, ..., 1155, 1104, 102],
[ 101, 1249, 24017, ..., 0, 0, 0]], dtype=int32)>
预处理
首先,我们需要指定要从 BERT 模型中提取的embedding类型。我们可以从…中提取embedding
- pooler_output:序列的第一个标记(分类标记;CLS)的最后一层隐藏状态,由线性层和 Tanh 激活函数进一步处理。线性层权重在预训练期间从下一个句子预测(分类)目标进行训练。注意:这个输出通常不能很好地总结输入的语义内容,你通常最好对整个输入序列的隐藏状态序列进行平均或池化。
- last_hidden_state:模型最后一层输出的隐藏状态序列,对tokens进行平均。
- hidden_state:模型在每层输出处的隐藏状态,对tokens进行平均。
- hidden_state_cls:查看 hidden_state 但使用 CLS tokens 输出。
如果 hidden_state 或 hidden_state_cls 用作 embedding 类型,您还需要传递用于从中提取 embedding 的层号。作为一个例子,我们从最后 8 个隐藏状态中提取 embeddings。
from alibi_detect.models.tensorflow import TransformerEmbedding
emb_type = 'hidden_state'
n_layers = 8
layers = [-_ for _ in range(1, n_layers + 1)]
embedding = TransformerEmbedding(model_name, emb_type, layers)
让我们检查一下 embedding 的样子:
tokens = tokenizer(list(X[:5]),
pad_to_max_length=True,
max_length=max_len,
return_tensors='tf')
# embedding模型
x_emb = embedding(tokens)
print(x_emb.shape)
运行结果:
(5, 768)
因此,漂移检测器使用的 BERT 模型的 embedding 空间由每个实例的 768 维向量组成。 因此,在进行统计假设检验之前,我们将首先使用未经训练的自动编码器 (UAE) 应用降维步骤。 我们使用embedding 模型作为 UAE 的输入,然后将 embedding 投影到低维空间。
tf.random.set_seed(0)
from alibi_detect.cd.tensorflow import UAE
# 低维
enc_dim = 32
shape = (x_emb.shape[1],)
uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)
让我们再次测试一下:
emb_uae = uae(tokens)
print(emb_uae.shape)
运行结果:
(5, 32)
K-S 检测器
初始化
我们继续初始化漂移检测器。从这里开始,检测器的工作原理与其他方式(如图像)相同。 请查看图像示例或 K-S 检测器文档以获取有关每个可能参数的更多信息。
from functools import partial
from alibi_detect.cd.tensorflow import preprocess_drift
# 定义预处理函数
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=uae, tokenizer=tokenizer,
max_len=max_len, batch_size=32)
# 初始化检测器,指定参考数据集X_ref
# initialize detector
cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn, input_shape=(max_len,))
# 保存/加载一个初始化检测器
# we can also save/load an initialised detector
filepath = 'my_path' # change to directory where detector is saved
save_detector(cd, filepath)
cd = load_detector(filepath)
检测漂移
让我们首先检查在训练集中与参考数据集相似的样本上是否发生漂移。
preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
运行结果:
Drift? No!
p-value: [0.31356168 0.18111965 0.60991895 0.43243074 0.6852314 0.722555 0.28769323 0.18111965 0.50035924 0.9134755 0.40047103 0.79439443 0.79439443 0.722555 0.5726548 0.1640792 0.9540582 0.60991895 0.5726548 0.5726548 0.31356168 0.40047103 0.6852314 0.34099194 0.5726548 0.07762147 0.79439443 0.09710453 0.5726548 0.79439443 0.7590978 0.26338065]
检测不平衡和扰动数据集上的漂移:
# 不平衡数据集
for k, v in X_imb.items():
preds = cd.predict(v)
print('% negative sentiment {}'.format(k * 100))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
% negative sentiment 10.0
Drift? Yes!
p-value: [4.32430744e-01 4.00471032e-01 5.46463318e-02 7.76214674e-02 1.08282514e-01 1.12110768e-02 6.91903234e-02 2.82894098e-03 8.59294355e-01 6.47557259e-01 1.33834302e-01 7.94394433e-01 4.28151786e-02 2.87693232e-01 6.09918952e-01 1.33834302e-01 2.40603596e-01 9.71045271e-02 7.76214674e-02 9.35580969e-01 2.87693232e-01 2.92505771e-02 4.00471032e-01 6.09918952e-01 2.87693232e-01 5.06567594e-04 1.64079204e-01 6.09918952e-01 1.33834302e-01 2.19330013e-01 7.94394433e-01 2.56591532e-02]
% negative sentiment 90.0
Drift? Yes!
p-value: [7.36993998e-02 1.37563676e-01 5.86588383e-02 5.07961273e-01 8.37696046e-02 8.80799629e-03 1.23670578e-01 1.76981179e-04 3.21924835e-01 1.20594716e-02 8.43600273e-01 4.08206195e-01 1.69703156e-01 5.79056978e-01 6.32701874e-01 4.48510349e-02 5.07465303e-01 6.64306164e-04 5.23085408e-02 3.78374875e-01 6.65342569e-01 4.06090707e-01 6.21288121e-01 5.85612692e-02 5.87646782e-01 7.55570829e-03 8.99188042e-01 1.18489005e-02 6.68586135e-01 1.01421457e-02 7.97733963e-02 1.73885196e-01]
# 扰动数据集
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.8879386 0.01711409 0.2406036 0.9134755 0.21933001 0.04281518
0.03778438 0.28769323 0.3699725 0.996931 0.8879386 0.43243074
0.01121108 0.6852314 0.99870795 0.996931 0.93558097 0.99365413
0.02246371 0.60991895 0.8879386 0.34099194 0.09710453 0.8879386
0.1338343 0.06155144 0.85929435 0.99365413 0.07762147 0.07762147
0.9882611 0.85929435]
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: [1.29345525e-02 1.69780876e-14 1.52437299e-11 5.72654784e-01
1.85489473e-08 1.88342838e-17 6.14975981e-09 4.28151786e-02
5.62237052e-13 2.13202584e-05 4.28151786e-02 1.97469308e-09
0.00000000e+00 1.48931602e-02 9.68870163e-01 1.29345525e-02
2.63380647e-01 1.08282514e-01 1.04535818e-26 4.28151786e-02
2.13202584e-05 3.47411038e-14 1.09291570e-20 1.08282514e-01
5.68982140e-18 1.69780876e-14 1.64079204e-01 4.00471032e-01
3.12689441e-34 3.89208371e-27 2.86525619e-06 1.71956726e-05]
Word: good -- % perturbed: 1.0
Drift? Yes!
p-value: [3.40991944e-01 9.80161786e-01 1.08282514e-01 9.98707950e-01
1.48338065e-01 9.35580969e-01 7.59097815e-01 9.88261104e-01
8.87938619e-01 6.47557259e-01 9.68870163e-01 7.94394433e-01
8.69054198e-02 9.99999642e-01 9.96931016e-01 5.72654784e-01
9.99870896e-01 4.32430744e-01 9.99870896e-01 2.92505771e-02
9.13475513e-01 9.13475513e-01 4.65766221e-01 9.35580969e-01
8.87938619e-01 9.98707950e-01 9.80161786e-01 9.99972701e-01
7.59097815e-01 1.34916729e-04 9.96931016e-01 9.68870163e-01]
Word: good -- % perturbed: 5.0
Drift? Yes!
p-value: [6.1319246e-16 8.5929435e-01 8.4248814e-24 5.3605431e-01 6.1410643e-10
1.9951835e-01 2.9080641e-04 3.6997250e-01 2.4072561e-04 3.3837957e-10
9.5405817e-01 8.6666952e-04 5.2673625e-28 1.4893160e-02 9.7104527e-02
5.3955968e-11 1.6407920e-01 6.1410643e-10 7.2255498e-01 2.5362303e-18
7.9439443e-01 1.7943768e-06 1.5330249e-07 2.0378644e-03 1.4563050e-03
2.1933001e-01 1.9626908e-02 6.4755726e-01 1.4790693e-09 0.0000000e+00
1.9626908e-02 3.1356168e-01]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.8879386 0.21933001 0.12050407 0.9540582 0.9134755 0.9540582
0.99870795 0.9540582 0.7590978 0.40047103 0.9801618 0.7590978
0.02925058 0.996931 0.9995433 0.79439443 0.26338065 0.04281518
0.93558097 0.14833806 0.50035924 0.82795686 0.18111965 0.43243074
0.99365413 0.9882611 0.9801618 0.99870795 0.96887016 0.10828251
0.07762147 0.9882611 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [7.04859247e-08 5.78442112e-12 7.08821891e-21 1.33834302e-01
7.13247118e-06 3.69972497e-01 9.68870163e-01 1.81119651e-01
2.13202584e-05 3.47411038e-14 5.00359237e-01 1.97830971e-07
9.82534992e-39 1.03241683e-03 1.96269080e-02 2.92505771e-02
8.76041099e-07 8.49670826e-18 1.08282514e-01 3.38379574e-10
8.07501343e-25 5.37760343e-07 2.79573150e-17 2.40344345e-03
1.99518353e-01 7.59097815e-01 8.69054198e-02 3.32311448e-03
2.15581372e-12 3.95873130e-15 1.95523170e-16 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? Yes!
p-value: [2.63380647e-01 9.98707950e-01 9.98707950e-01 9.88261104e-01
6.47557259e-01 8.59294355e-01 9.96931016e-01 9.13475513e-01
3.50604125e-04 9.99870896e-01 9.99870896e-01 6.09918952e-01
1.33834302e-01 9.80161786e-01 9.35580969e-01 9.88261104e-01
9.71045271e-02 4.00471032e-01 6.85231388e-01 1.81119651e-01
4.65766221e-01 9.80161786e-01 8.69054198e-02 9.96931016e-01
9.99870896e-01 6.91903234e-02 9.80161786e-01 9.99972701e-01
9.93654132e-01 5.32228360e-03 1.20504074e-01 7.22554982e-01]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [1.6978088e-14 8.8793862e-01 2.8769323e-01 5.7265478e-01 1.3491673e-04
1.7114086e-02 4.3243074e-01 1.1211077e-02 8.5801831e-33 3.5060412e-04
8.6905420e-02 6.1497598e-09 1.4797455e-32 1.3383430e-01 1.7244401e-03
2.6338065e-01 1.4117470e-08 3.5060412e-04 5.7140245e-15 4.9547091e-14
5.9822431e-37 8.9143086e-06 8.4967083e-18 3.1356168e-01 8.7604110e-07
3.9584363e-20 1.4833806e-01 1.7244401e-03 1.1053569e-12 0.0000000e+00
1.3007273e-15 2.9250577e-02]
MMD TensorFlow 检测器
初始化
再次查看 图像示例 或 MMD 检测器文档以获取有关每个可能参数的更多信息。
cd = MMDDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn,
n_permutations=100, input_shape=(max_len,))
检测漂移
H0数据集:
preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
运行结果:
Drift? No!
p-value: 0.6
不平衡数据集:
for k, v in X_imb.items():
preds = cd.predict(v)
print('% negative sentiment {}'.format(k * 100))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
% negative sentiment 10.0
Drift? Yes!
p-value: 0.01
% negative sentiment 90.0
Drift? Yes!
p-value: 0.0
扰动数据集:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: 0.09
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: good -- % perturbed: 1.0
Drift? No!
p-value: 0.71
Word: good -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: 0.38
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: 0.18
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
MMD PyTorch 检测器
初始化
对于预处理步骤和 MMD 实现,我们可以使用 PyTorch 后端运行相同的检测器:
import torch
import torch.nn as nn
# set random seed and device
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
运行结果:
cuda
from alibi_detect.cd.pytorch import preprocess_drift
from alibi_detect.models.pytorch import TransformerEmbedding
embedding_pt = TransformerEmbedding(model_name, emb_type, layers)
model = nn.Sequential(
embedding_pt,
nn.Linear(768, 256),
nn.ReLU(),
nn.Linear(256, enc_dim)
).to(device).eval()
# define preprocessing function
preprocess_fn = partial(preprocess_drift, model=model, tokenizer=tokenizer,
max_len=max_len, batch_size=32, device=device)
# initialise drift detector
cd = MMDDrift(X_ref, backend='pytorch', p_val=.05, preprocess_fn=preprocess_fn,
n_permutations=100, input_shape=(max_len,))
检测漂移
H0数据集:
preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
运行结果:
Drift? No!
p-value: 0.49000000953674316
不平衡数据:
for k, v in X_imb.items():
preds = cd.predict(v)
print('% negative sentiment {}'.format(k * 100))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
% negative sentiment 10.0
Drift? Yes!
p-value: 0.0
% negative sentiment 90.0
Drift? Yes!
p-value: 0.0
扰动数据:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
Word: fantastic -- % perturbed: 1.0
Drift? Yes!
p-value: 0.0
Word: fantastic -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: good -- % perturbed: 1.0
Drift? No!
p-value: 0.10000000149011612
Word: good -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: bad -- % perturbed: 1.0
Drift? Yes!
p-value: 0.0
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: 0.05999999865889549
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: 0.0
从头开始训练 embeddings
到目前为止,我们使用了来自 BERT 模型的预训练 embeddings。 然而,我们也可以使用从头开始训练的模型中的 embeddings。
首先,我们在 TensorFlow 中定义和训练一个由 embedding 和 LSTM 层组成的简单分类模型。
加载数据并训练模型
from tensorflow.keras.datasets import imdb, reuters
from tensorflow.keras.layers import Dense, Embedding, Input, LSTM
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.utils import to_categorical
INDEX_FROM = 3
NUM_WORDS = 10000
def print_sentence(tokenized_sentence: str, id2w: dict):
print(' '.join(id2w[_] for _ in tokenized_sentence))
print('')
print(tokenized_sentence)
def mapping_word_id(data):
w2id = data.get_word_index()
w2id = {k: (v + INDEX_FROM) for k, v in w2id.items()}
w2id["<PAD>"] = 0
w2id["<START>"] = 1
w2id["<UNK>"] = 2
w2id["<UNUSED>"] = 3
id2w = {v: k for k, v in w2id.items()}
return w2id, id2w
def get_dataset(dataset: str = 'imdb', max_len: int = 100):
if dataset == 'imdb':
data = imdb
elif dataset == 'reuters':
data = reuters
else:
raise NotImplementedError
w2id, id2w = mapping_word_id(data)
(X_train, y_train), (X_test, y_test) = data.load_data(
num_words=NUM_WORDS, index_from=INDEX_FROM)
X_train = sequence.pad_sequences(X_train, maxlen=max_len)
X_test = sequence.pad_sequences(X_test, maxlen=max_len)
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
return (X_train, y_train), (X_test, y_test), (w2id, id2w)
def imdb_model(X: np.ndarray, num_words: int = 100, emb_dim: int = 128,
lstm_dim: int = 128, output_dim: int = 2) -> tf.keras.Model:
X = np.array(X)
inputs = Input(shape=(X.shape[1:]), dtype=tf.float32)
x = Embedding(num_words, emb_dim)(inputs)
x = LSTM(lstm_dim, dropout=.5)(x)
outputs = Dense(output_dim, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(
loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
return model
加载和tokenize数据:
(X_train, y_train), (X_test, y_test), (word2token, token2word) = \
get_dataset(dataset='imdb', max_len=max_len)
我们来看一个实例:
print_sentence(X_train[0], token2word)
运行结果:
cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all
[1415 33 6 22 12 215 28 77 52 5 14 407 16 82
2 8 4 107 117 5952 15 256 4 2 7 3766 5 723
36 71 43 530 476 26 400 317 46 7 4 2 1029 13
104 88 4 381 15 297 98 32 2071 56 26 141 6 194
7486 18 4 226 22 21 134 476 26 480 5 144 30 5535
18 51 36 28 224 92 25 104 4 226 65 16 38 1334
88 12 16 283 5 16 4472 113 103 32 15 16 5345 19
178 32]
定义和训练一个简单的模型:
model = imdb_model(X=X_train, num_words=NUM_WORDS, emb_dim=256, lstm_dim=128, output_dim=2)
model.fit(X_train, y_train, batch_size=32, epochs=2,
shuffle=True, validation_data=(X_test, y_test))
运行结果:
Epoch 1/2
782/782 [==============================] - 17s 17ms/step - loss: 0.4314 - accuracy: 0.7988 - val_loss: 0.3481 - val_accuracy: 0.8474
Epoch 2/2
782/782 [==============================] - 14s 18ms/step - loss: 0.2707 - accuracy: 0.8908 - val_loss: 0.3858 - val_accuracy: 0.8451
从训练好的模型中提取嵌入层并结合UAE预处理步骤:
embedding = tf.keras.Model(inputs=model.inputs, outputs=model.layers[1].output)
x_emb = embedding(X_train[:5])
print(x_emb.shape)
运行结果:
(5, 100, 256)
tf.random.set_seed(0)
shape = tuple(x_emb.shape[1:])
uae = UAE(input_layer=embedding, shape=shape, enc_dim=enc_dim)
同样,创建参考、H0 和扰动数据集。 还针对Reuters新闻主题分类数据集进行测试。
X_ref, y_ref = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
X_h0, y_h0 = random_sample(X_test, y_test, proba_zero=.5, n=n_sample)
tokens = [word2token[w] for w in words]
X_word = {}
for i, t in enumerate(tokens):
X_word[words[i]] = {}
for p in perc_chg:
X_word[words[i]][p] = inject_word(t, np.array(X_ref), p, padding='first')
# load and tokenize Reuters dataset
(X_reut, y_reut), (w2t_reut, t2w_reut) = \
get_dataset(dataset='reuters', max_len=max_len)[1:]
# sample random instances
idx = np.random.choice(X_reut.shape[0], n_sample, replace=False)
X_ood = X_reut[idx]
初始化检测器并检测漂移
from alibi_detect.cd.tensorflow import preprocess_drift
# define preprocess_batch_fn to convert list of str's to np.ndarray to be processed by `model`
def convert_list(X: list):
return np.array(X)
# define preprocessing function
preprocess_fn = partial(preprocess_drift,
model=uae,
batch_size=128,
preprocess_batch_fn=convert_list)
# initialize detector
cd = KSDrift(X_ref, p_val=.05, preprocess_fn=preprocess_fn)
H0数据集:
preds_h0 = cd.predict(X_h0)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_h0['data']['is_drift']]))
print('p-value: {}'.format(preds_h0['data']['p_val']))
运行结果:
Drift? No!
p-value: [0.18111965 0.50035924 0.5360543 0.722555 0.2406036 0.02925058 0.43243074 0.12050407 0.722555 0.60991895 0.19951835 0.60991895 0.50035924 0.79439443 0.722555 0.64755726 0.40047103 0.34099194 0.1338343 0.10828251 0.64755726 0.9995433 0.9540582 0.9134755 0.40047103 0.1640792 0.40047103 0.64755726 0.9134755 0.7590978 0.5726548 0.722555 ]
扰动数据集:
for w, probas in X_word.items():
for p, v in probas.items():
preds = cd.predict(v)
print('Word: {} -- % perturbed: {}'.format(w, p))
print('Drift? {}'.format(labels[preds['data']['is_drift']]))
print('p-value: {}'.format(preds['data']['p_val']))
print('')
运行结果:
Word: fantastic -- % perturbed: 1.0
Drift? No!
p-value: [0.9998709 0.7590978 0.99870795 0.9995433 0.9801618 0.9134755
0.82795686 0.99870795 0.9882611 0.8879386 0.9801618 0.79439443
0.85929435 0.96887016 0.9134755 0.996931 0.5726548 0.93558097
0.9882611 0.99870795 0.93558097 0.96887016 0.85929435 0.9882611
0.93558097 0.996931 0.996931 0.96887016 0.9882611 0.96887016
0.8879386 0.996931 ]
Word: fantastic -- % perturbed: 5.0
Drift? No!
p-value: [0.85929435 0.06155144 0.9540582 0.79439443 0.43243074 0.6852314
0.722555 0.9134755 0.28769323 0.996931 0.60991895 0.19951835
0.43243074 0.64755726 0.722555 0.8879386 0.18111965 0.18111965
0.43243074 0.14833806 0.50035924 0.43243074 0.01489316 0.01121108
0.722555 0.46576622 0.07762147 0.8879386 0.05464633 0.10828251
0.03327804 0.9801618 ]
Word: good -- % perturbed: 1.0
Drift? No!
p-value: [0.99365413 0.8879386 0.99870795 0.9801618 0.99870795 0.99870795
0.9134755 0.93558097 0.8879386 0.9995433 0.93558097 0.996931
0.99999607 0.9995433 0.99870795 0.9801618 0.99870795 0.9801618
0.8879386 0.996931 0.9134755 0.996931 0.7590978 0.99365413
0.9540582 0.99870795 0.99870795 0.9998709 0.9801618 0.64755726
0.9999727 0.8879386 ]
Word: good -- % perturbed: 5.0
Drift? No!
p-value: [0.9882611 0.6852314 0.79439443 0.60991895 0.28769323 0.3699725
0.28769323 0.6852314 0.79439443 0.31356168 0.99870795 0.85929435
0.34099194 0.34099194 0.8879386 0.996931 0.96887016 0.96887016
0.9540582 0.722555 0.19951835 0.9995433 0.3699725 0.722555
0.1338343 0.9134755 0.5360543 0.26338065 0.85929435 0.2406036
0.31356168 0.6852314 ]
Word: bad -- % perturbed: 1.0
Drift? No!
p-value: [0.93558097 0.996931 0.85929435 0.9540582 0.50035924 0.64755726
0.82795686 0.85929435 0.82795686 0.9882611 0.82795686 0.9540582
0.21933001 0.96887016 0.93558097 0.99870795 0.79439443 0.722555
0.93558097 0.93558097 0.64755726 0.99365413 0.5726548 0.9998709
0.93558097 0.96887016 0.9995433 0.99365413 0.7590978 0.93558097
0.9882611 0.9134755 ]
Word: bad -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 8.27956855e-01 2.87693232e-01 6.47557259e-01
3.89581337e-03 1.03241683e-03 3.40991944e-01 7.59097815e-01
2.82894098e-03 5.46463318e-02 1.20504074e-01 2.63380647e-01
1.11190266e-05 5.46463318e-02 4.65766221e-01 7.94394433e-01
9.69783217e-03 3.69972497e-01 9.35580969e-01 1.71140861e-02
6.91903234e-02 7.94394433e-01 9.07998619e-05 4.00471032e-01
8.27956855e-01 7.59097815e-01 1.64079204e-01 4.84188050e-02
1.71140861e-02 6.85231388e-01 5.46463318e-02 5.72654784e-01]
Word: horrible -- % perturbed: 1.0
Drift? No!
p-value: [0.996931 0.9801618 0.96887016 0.79439443 0.79439443 0.5726548
0.82795686 0.996931 0.43243074 0.93558097 0.79439443 0.82795686
0.06919032 0.3699725 0.96887016 0.9540582 0.5360543 0.6852314
0.60991895 0.79439443 0.9540582 0.9801618 0.40047103 0.5726548
0.82795686 0.8879386 0.9540582 0.9134755 0.99365413 0.60991895
0.82795686 0.79439443]
Word: horrible -- % perturbed: 5.0
Drift? Yes!
p-value: [4.00471032e-01 1.48931602e-02 4.84188050e-02 1.96269080e-02
1.12110768e-02 1.48931602e-02 4.00471032e-01 5.72654784e-01
1.45630504e-03 1.96269080e-02 7.59097815e-01 1.72444014e-03
1.30072730e-15 1.79437677e-06 2.63380647e-01 6.47557259e-01
1.11478073e-06 1.99518353e-01 1.20504074e-01 4.55808453e-03
7.21312594e-03 2.40603596e-01 2.24637091e-02 4.28151786e-02
4.28151786e-02 7.22554982e-01 1.08282514e-01 9.07998619e-05
5.36054313e-01 9.71045271e-02 1.64079204e-01 3.40991944e-01]
该检测器不如基于 Transformer 的 K-S 漂移检测器灵敏。从头开始训练的 embeddings 只在一个小数据集和一个具有交叉熵损失函数的简单模型上训练了 2 个 epoch。 另一方面,预训练的 BERT 模型可以更好地捕捉数据的语义。
来自 Reuters 数据集的样本:
preds_ood = cd.predict(X_ood)
labels = ['No!', 'Yes!']
print('Drift? {}'.format(labels[preds_ood['data']['is_drift']]))
print('p-value: {}'.format(preds_ood['data']['p_val']))
运行结果:
Drift? Yes!
p-value: [7.22554982e-01 1.07232365e-08 3.69972497e-01 9.54058170e-01 7.22554982e-01 4.84188050e-02 9.69783217e-03 1.71956726e-05 8.87938619e-01 4.01514189e-05 2.54783203e-07 1.22740539e-03 4.21853358e-04 3.49877549e-09 5.46463318e-02 1.79437677e-06 6.91903234e-02 4.20066499e-07 3.50604125e-04 2.87693232e-01 1.69780876e-14 1.69780876e-14 3.40991944e-01 2.53623026e-18 2.26972293e-06 3.18301190e-08 2.40344345e-03 5.32228360e-03 2.40725611e-04 2.56591532e-02 3.27475419e-07 5.69539361e-06]