09 softmax/01 数据读取

98 阅读1分钟

softmax 回归

均方损失

  • 对类别进行一位有效编码
  • 使用均方损失训练
  • 最大值为预测
  • 需要更置信的识别正确类(大余量) oyoiδ(y,i)o_y - o_i \geq \delta (y, i)

校验比例

  • 输出匹配概率(非负且和为一): y^=softmax(o)\hat{\bold{y}} = softmax(\bold{o}) yi^=exp(oi)kexp(ok)\hat{y_i} = \frac{exp(o_i)}{\sum_k exp(o_k)}
  • 概率的yyy^\hat{y}区别作为损失

交叉熵损失

  • 交叉熵不关心错误类,只关心我们对正确类的信心(置信度)有多高
  • 交叉熵作为损失时: l(y,y^)=(iyilogyi^)=logyy^l(y, \hat{y}) = -(\sum_{i} y_i \log{\hat{y_i}}) = -\log{\hat{y_y}}
  • 其梯度是真实概率和预测概率的区别 Oil(y,y^)=softmax(o)iyi\partial_{O_{i}} l (y, \hat y) = softmax(o)_i - y_i

损失函数

  • L2 Loss = l(y,y)=12(yy)2l(y, y') = \frac{1}{2}(y-y')^2
  • L1 Loss = l(y,y)=yyl(y, y') = |y-y'|
  • Huber's Robust Loss(鲁棒损失):
# 图像分类数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, 
    transform=trans,
    download=True
)

mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, 
    transform=trans,
    download=True
)

len(mnist_train), len(mnist_test)

mnist_train[0][0].shape
torch.Size([1, 28, 28])
def get_fashion_mnist_labels(labels):
    text_labels = [
        "t-shirt",
        "trouser",
        "pullover",
        "dress",
        "coat",
        "sandal",
        "shirt",
        "sneaker",
        "bag",
        "ankle boot",
    ]
    
    return [text_labels[int(i)] for i in labels]

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))

    
batch_size = 4


def get_dataloader_workers():
    return 4


train_iter = data.DataLoader(
    mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()
)

timer = d2l.Timer()

for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'
'11.20 sec'
def load_data_fashion_mnist(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root = "../data", 
        train=True,
        transform=trans,
        download=False
    )
    
    mnist_test = torchvision.datasets.FashionMNIST(
        root = "../data", 
        train=False,
        transform=trans,
        download=False
    )
    
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))