动手学深度学习3.7 softmax回归简洁实现

421 阅读2分钟

参与11月更文挑战的第10天,活动详情查看:2021最后一次更文挑战

import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

参考手动实现softmax回归那一篇文章:动手学深度学习3.6-手动实现softmax回归 - 掘金 (juejin.cn)

这里会出现一个用户警告,可以直接忽略,如果你实在想知道这是什么可以看torchvision.transforms.ToTensor详解 | 使用transforms.ToTensor()出现用户警告 | 图像的H W C 代表什么

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);
  • nn.Flatten(): PyTorch不会隐式地调整输入的形状,因此在线性层前定义了展平层(flatten)调整输入的形状。nn.Linear(784, 10)指定输入维度和输出维度,每次处理一张图,已知图片是28*28,展开成向量就是784。
  • net.apply(init_weights)对net的每一层都应用这个函数 init_weights:这个函数
    • 判断拿到的层是不是nn.Linear,当然type(m) == nn.Linear也可以用前边提到的isinstance(m,nn.Linear)
    • 如果是的话初始化该层权重设定均值为0,方差为0.01
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

loss直接使用nn自带的交叉熵loss,trainer也是直接使用nn自带的SGD函数。 对于交叉熵。

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

这个训练函数就是前一小节写的训练函数,就不搬过来了。见动手学深度学习3.6-手动实现softmax回归 - 掘金 (juejin.cn)

Softmax

上一节中实现的softmax使用:

softmax(X)ij=exp(Xij)kexp(Xik).\mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(\mathbf{X}_{ij})}{\sum_k \exp(\mathbf{X}_{ik})}.

softmax函数y^j=exp(oj)kexp(ok)\hat y_j = \frac{\exp(o_j)}{\sum_k \exp(o_k)},其中y^j\hat y_j是预测的概率分布。ojo_j是未归一化的预测o\mathbf{o}的第jj个元素。如果oko_k中的一些数值非常大,那么exp(ok)\exp(o_k)可能大于数据类型容许的最大数字(即上溢(overflow))。这将使分母或分子变为inf(无穷大),我们最后遇到的是0、infnan(不是数字)的y^j\hat y_j。在这些情况下,我们不能得到一个明确定义的交叉熵的返回值。

解决这个问题的一个技巧是,在继续softmax计算之前,先从所有oko_k中减去max(ok)\max(o_k)。你可以证明每个oko_k按常数进行的移动不会改变softmax的返回值。在减法和归一化步骤之后,可能有些ojo_j具有较大的负值。由于精度受限,exp(oj)\exp(o_j)将有接近零的值,即下溢(underflow)。这些值可能会四舍五入为零,使y^j\hat y_j为零,并且使得log(y^j)\log(\hat y_j)的值为-inf。反向传播几步后,我们可能会发现自己面对一屏幕可怕的nan结果。

尽管我们要计算指数函数,但我们最终在计算交叉熵损失时会取它们的对数。 通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的等式所示,我们避免计算exp(oj)\exp(o_j),而可以直接使用ojo_j。因为log(exp())\log(\exp(\cdot))被抵消了。

log(y^j)=log(exp(oj)kexp(ok))=log(exp(oj))log(kexp(ok))=ojlog(kexp(ok)).\begin{aligned} \log{(\hat y_j)} & = \log\left( \frac{\exp(o_j)}{\sum_k \exp(o_k)}\right) \\ & = \log{(\exp(o_j))}-\log{\left( \sum_k \exp(o_k) \right)} \\ & = o_j -\log{\left( \sum_k \exp(o_k) \right)}. \end{aligned}
l(y,y^)=j=1qyjlogy^j.l(\mathbf{y}, \hat{\mathbf{y}}) = - \sum_{j=1}^q y_j \log \hat{y}_j.