总览
针对 Python 库 pytorch_tabnet 的 TabNetClassifier 模块,使用泰坦尼克号数据进行 TabNet 的训练过程与网络结构探究。
训练循环
运行 .fit() ,经过一些准备工作后,会进入 epoch 训练循环。
for epoch_idx in range(self.max_epochs):
self._callback_container.on_epoch_begin(epoch_idx)
self._train_epoch(train_dataloader)
···
if self._stop_training:
break
self._callback_container.on_epoch_begin(epoch_idx) 会运行三个 callback:
pytorch_tabnet.callbacks.Historypytorch_tabnet.callbacks.EarlyStoppingpytorch_tabnet.callbacks.LRSchedulerCallback
然后进入 self._train_epoch(train_dataloader)。会调用 TabNet._train_epoch() 并进一步调用 TabModel._train_batch(),完成一个 epoch 内的训练。
网络结构
网络整体都装到 TabNet 中。会经过两个结构:
EmbeddingGeneratorTabNetNoEmbeddings
EmbeddingGenerator 应该是针对 label 数据处理。由于我已经把 label 数据进行了 one-hot 处理,这个层被跳过了。
TabNetNoEmbeddings 正向传播代码如下。可见只需重点介绍 self.encoder。
def forward(self, x):
res = 0
steps_output, M_loss = self.encoder(x)
res = torch.sum(torch.stack(steps_output, dim=0), dim=0)
out = self.final_mapping(res) # 就一线性层,用来输出最终分类结果的
return out, M_loss
self.encoder(TabNetNoEmbeddings)
进入了 self.encoder(TabNetEncoder)。
self.initial_bn,是BatchNorm1d。momentum 从默认值 0.1 设为了 0.01- 创建一个值全为 1 的 tensor
prior,维度为 [batch, feature_dim] self.initial_splitter(FeatTransformer)。GLU(无残差无 scale) -> GLU -> GLU -> GLUself.shared(GLU_Block)- 用 获得一个
scale self.glu_layers(GLU_Layer)。有多个层。除了 BN 有个 trick,其实就是标准 GLUself.fc,线性层,升维到 32self.bn(BGN),Ghost Batch Normalization。比起直接的 BN,会将一个 batch 数据切分为更小的 virtual_batch_size 再 BN。后面会详细说明- 使用的 momentum 为 0.02
- 最后使用 降维到原来一半
- 第一个 GLU 层不用
scale且不进行残差 - 第二个 GLU 层先残差激活,再乘上
scale
- 用 获得一个
self.specifics(GLU_Block)。可见,与刚刚的self.shared是同一种网络层- 不同的是,
self.glu_layers中,两个 GLU 层都是先残差激活,再乘上scale
- 不同的是,
- 取
self.initial_splitter的后self.n_d个(实际是 8 个)特征,赋值给att - 进行
self.n_steps次(3 次)循环:self.att_transformers[step](AttentiveTransformer),获得M,代表各特征的贡献率(sum(M)==1),让模型有选择性地使用特征self.fc,维度从 8 映射到 21(这是数据特征的数量)self.bn(BGN),Ghost Batch Normalization- 还记得
prior吗,此时x与prior完全同维度,两者进行叉乘prior会补偿特征被关注的重要程度。上一 step 被重点关注的特征 这步会权重降低
self.selector(sparsemax.Sparsemax)。稀疏版的 Softmax,后面会详细说明
- 通过 计算负熵,作为损失的正则项
M_loss- 后面会有
loss = loss - self.lambda_sparse * M_loss。鼓励概率分布更加稀疏极端
- 后面会有
prior更新为mul(self.gamma - M, prior)以便下次使用- 减少已经被高度关注特征的权重。间接让每一步有机会关注不同的特征组合
- 若设置了
self.group_attention_matrix,会有torch.matmul(M, self.group_attention_matrix)将M进行重映射。默认是M没有变化 torch.mul(M, x),获得本次被选择的特征masked_xself.feat_transformers[step](FeatTransformer)。GLU(无残差无 scale) -> GLU -> GLU -> GLUsteps_output进行 append:取self.feat_transformers[step]的前self.n_d个特征且进行 ReLU 激活att进行更新:取self.feat_transformers[step]的后self.n_d个特征
M_loss取 step 的平均- 最终获得各步骤的输出
steps_output,和负熵正则项M_loss
重要的层的意义汇总:
| 层步骤 | 意义 |
|---|---|
self.initial_bn | 对输入特征首先进行 BN |
self.initial_splitter | 连续 GLU 激活。获得用于计算特征权重的 att |
self.att_transformers[step] | 输入 att prior,使用 Sparsemax 技巧计算稀疏的特征权重 M |
self.feat_transformers[step] | 连续 GLU 激活。前一半通道划为 step 输出,后一半划为新的 att |
一些变量的意义汇总:
| 变量 | 意义 |
|---|---|
att | 被用于特征权重 M 的计算 |
prior | 用在 Sparsemax 之前,一定程度降低 上一步被重点关注的特征 的权重 |
额外说明
Ghost Batch Normalization
Ghost Batch Normalization 是一种正则化方法。其做法是,比起直接的 BN,会将一个 batch 数据切分为更小的 virtual_batch_size 再 BN。
Ghost Batch Normalization 使得 batch_size 较大时仍能让 BN 结果产生较大随机性,可增加模型泛化能力。
!!! note "" 提出 Ghost Batch Normalization 的论文还说,学习率应该和batch size的平方根成正比。
参考来源:
Sparsemax
Softmax 常用于多分类问题,将一组数转换为总和为 1 的各分类的概率。但实践中发现 Softmax 会导致一个问题:分类较多时,模型预测的目标分类概率必须远远超出其他概率才能获得较低 loss,从而容易过拟合。实际生活中我们只需要模型预测的目标分类概率稍高于其他概率即可。避免过拟合、避免模型过度自信,这也是 Label Smoothing 这样的 trick 出现的动机。
Sparsemax 是稀疏版的 Softmax,实际使用上会将低得分的分类直接归为 0 概率。
求解过程如下。(对于一个 batch 的单组数据 x[0]):
- 对 的元素 进行降序排序
- 计算累计和数列 (
x.cumsum(dim) - 1) - 找到满足 的最小 值(可以通过
support = x * list(range(1, n+1)) > S间接获得) - 从而找到阈值 (可以通过
tau = x[support.sum() - 1]间接获得) - 通过阈值 输出结果
可见,计算重点是阈值 的获取。这样获得的 能使得最终获得的 满足 。
此时可以发现,Sparsemax 是完全的线性映射。不像 Softmax 涉及非线性的指数运算。
为了让 Sparsemax 在反向传播后不破坏稀疏性,需要做出两点自定义:
- 概率为 0 的特征梯度变为 0
- 其他特征的梯度减去均值(支持集内的梯度之和应为 0)
说个题外话。TabNet 中实现的 Sparsemax 重写的 backward 函数中有这段代码:
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
完全可以省略为:
grad_input = torch.where(output != 0, grad_input - v_hat, 0)
这个代码只是为了将概率为 0 的特征梯度变为 0。原本的实现有点拐弯抹角了。
顺带一提,entmax 将 Sparsemax 和 Softmax 归为了自己的特殊情况。可以去了解一下 entmax。
参考来源: