代码层面上学习TabNet

447 阅读5分钟

总览

针对 Python 库 pytorch_tabnetTabNetClassifier 模块,使用泰坦尼克号数据进行 TabNet 的训练过程与网络结构探究。

训练循环

运行 .fit() ,经过一些准备工作后,会进入 epoch 训练循环。

for epoch_idx in range(self.max_epochs):
    self._callback_container.on_epoch_begin(epoch_idx)
    self._train_epoch(train_dataloader)

    ···

    if self._stop_training:
        break

self._callback_container.on_epoch_begin(epoch_idx) 会运行三个 callback:

  • pytorch_tabnet.callbacks.History
  • pytorch_tabnet.callbacks.EarlyStopping
  • pytorch_tabnet.callbacks.LRSchedulerCallback

然后进入 self._train_epoch(train_dataloader)。会调用 TabNet._train_epoch() 并进一步调用 TabModel._train_batch(),完成一个 epoch 内的训练。

网络结构

网络整体都装到 TabNet 中。会经过两个结构:

  • EmbeddingGenerator
  • TabNetNoEmbeddings

EmbeddingGenerator 应该是针对 label 数据处理。由于我已经把 label 数据进行了 one-hot 处理,这个层被跳过了。

TabNetNoEmbeddings 正向传播代码如下。可见只需重点介绍 self.encoder

def forward(self, x):
    res = 0
    steps_output, M_loss = self.encoder(x)
    res = torch.sum(torch.stack(steps_output, dim=0), dim=0)

    out = self.final_mapping(res)  # 就一线性层,用来输出最终分类结果的
    return out, M_loss

self.encoderTabNetNoEmbeddings

进入了 self.encoderTabNetEncoder)。

  • self.initial_bn,是 BatchNorm1d。momentum 从默认值 0.1 设为了 0.01
  • 创建一个值全为 1 的 tensor prior,维度为 [batch, feature_dim]
  • self.initial_splitterFeatTransformer)。GLU(无残差无 scale) -> GLU -> GLU -> GLU
    • self.sharedGLU_Block
      • 0.5\sqrt{0.5} 获得一个 scale
      • self.glu_layersGLU_Layer)。有多个层。除了 BN 有个 trick,其实就是标准 GLU
        • self.fc,线性层,升维到 32
        • self.bnBGN),Ghost Batch Normalization。比起直接的 BN,会将一个 batch 数据切分为更小的 virtual_batch_size 再 BN。后面会详细说明
          • 使用的 momentum 为 0.02
        • 最后使用 xasigmoid(xb)x_a \otimes \text{sigmoid}(x_b) 降维到原来一半
      • 第一个 GLU 层不用 scale 且不进行残差
      • 第二个 GLU 层先残差激活,再乘上 scale
    • self.specificsGLU_Block)。可见,与刚刚的 self.shared 是同一种网络层
      • 不同的是,self.glu_layers 中,两个 GLU 层都是先残差激活,再乘上 scale
  • self.initial_splitter 的后 self.n_d 个(实际是 8 个)特征,赋值给 att
  • 进行 self.n_steps 次(3 次)循环:
    • self.att_transformers[step]AttentiveTransformer),获得 M,代表各特征的贡献率(sum(M)==1),让模型有选择性地使用特征
      • self.fc,维度从 8 映射到 21(这是数据特征的数量)
      • self.bn(BGN),Ghost Batch Normalization
      • 还记得 prior 吗,此时 xprior 完全同维度,两者进行叉乘
        • prior 会补偿特征被关注的重要程度。上一 step 被重点关注的特征 这步会权重降低
      • self.selector(sparsemax.Sparsemax)。稀疏版的 Softmax,后面会详细说明
    • 通过 Mlog(M)\sum M\cdot \log(M) 计算负熵,作为损失的正则项 M_loss
      • 后面会有 loss = loss - self.lambda_sparse * M_loss。鼓励概率分布更加稀疏极端
    • prior 更新为 mul(self.gamma - M, prior) 以便下次使用
      • 减少已经被高度关注特征的权重。间接让每一步有机会关注不同的特征组合
    • 若设置了 self.group_attention_matrix,会有 torch.matmul(M, self.group_attention_matrix)M 进行重映射。默认是 M 没有变化
    • torch.mul(M, x),获得本次被选择的特征 masked_x
    • self.feat_transformers[step]FeatTransformer)。GLU(无残差无 scale) -> GLU -> GLU -> GLU
    • steps_output 进行 append:取 self.feat_transformers[step] 的前 self.n_d 个特征且进行 ReLU 激活
    • att 进行更新:取 self.feat_transformers[step] 的后 self.n_d 个特征
  • M_loss 取 step 的平均
  • 最终获得各步骤的输出 steps_output,和负熵正则项 M_loss

重要的层的意义汇总:

层步骤意义
self.initial_bn对输入特征首先进行 BN
self.initial_splitter连续 GLU 激活。获得用于计算特征权重的 att
self.att_transformers[step]输入 att prior,使用 Sparsemax 技巧计算稀疏的特征权重 M
self.feat_transformers[step]连续 GLU 激活。前一半通道划为 step 输出,后一半划为新的 att

一些变量的意义汇总:

变量意义
att被用于特征权重 M 的计算
prior用在 Sparsemax 之前,一定程度降低 上一步被重点关注的特征 的权重

额外说明

Ghost Batch Normalization

Ghost Batch Normalization 是一种正则化方法。其做法是,比起直接的 BN,会将一个 batch 数据切分为更小的 virtual_batch_size 再 BN。

Ghost Batch Normalization 使得 batch_size 较大时仍能让 BN 结果产生较大随机性,可增加模型泛化能力。

!!! note "" 提出 Ghost Batch Normalization 的论文还说,学习率应该和batch size的平方根成正比。

参考来源:

Sparsemax

Softmax 常用于多分类问题,将一组数转换为总和为 1 的各分类的概率。但实践中发现 Softmax 会导致一个问题:分类较多时,模型预测的目标分类概率必须远远超出其他概率才能获得较低 loss,从而容易过拟合。实际生活中我们只需要模型预测的目标分类概率稍高于其他概率即可。避免过拟合、避免模型过度自信,这也是 Label Smoothing 这样的 trick 出现的动机。

Sparsemax 是稀疏版的 Softmax,实际使用上会将低得分的分类直接归为 0 概率。

求解过程如下。(对于一个 batch 的单组数据 x[0]):

  • x=xmax(x)x=x-\max(x)
  • xx 的元素 x(1),x(2),,x(n)x_{(1)},x_{(2)},\cdots,x_{(n)} 进行降序排序
  • 计算累计和数列 Sk=j1kxj1S_k=\sum^{k}_{j-1}x_{j}-1x.cumsum(dim) - 1
  • 找到满足 x(k)>Skkx_{(k)}>\frac{S_k}{k} 的最小 kk 值(可以通过 support = x * list(range(1, n+1)) > S 间接获得)
  • 从而找到阈值 τ=Skk\tau=\frac{S_k}{k}(可以通过 tau = x[support.sum() - 1] 间接获得)
  • 通过阈值 τ\tau 输出结果 zi=max(xiτ,0)z_i=\max(x_i-\tau,0)

可见,计算重点是阈值 τ\tau 的获取。这样获得的 τ\tau 能使得最终获得的 ziz_i 满足 zi=1\sum z_i=1

此时可以发现,Sparsemax 是完全的线性映射。不像 Softmax 涉及非线性的指数运算。

为了让 Sparsemax 在反向传播后不破坏稀疏性,需要做出两点自定义:

  • 概率为 0 的特征梯度变为 0
  • 其他特征的梯度减去均值(支持集内的梯度之和应为 0)

说个题外话。TabNet 中实现的 Sparsemax 重写的 backward 函数中有这段代码:
grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
完全可以省略为:
grad_input = torch.where(output != 0, grad_input - v_hat, 0)
这个代码只是为了将概率为 0 的特征梯度变为 0。原本的实现有点拐弯抹角了。

顺带一提,entmax 将 Sparsemax 和 Softmax 归为了自己的特殊情况。可以去了解一下 entmax。

参考来源: