CIFAR-10 图像分类(结合 Wandb 包可视化训练过程)

44 阅读10分钟

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import wandb
import torch.nn.functional as F  # 导入 torch.nn.functional

print("PyTorch Version: ", torch.__version__)
print("CUDA Version: ", torch.version.cuda)
print("Is CUDA available: ", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA Device Name: ", torch.cuda.get_device_name(0))

# 使用API key进行自动登录
wandb.login(key="your_wandb_api_key")

# 初始化Wandb
# 注意这里your_wandb_username是organization名字而不是username
wandb.init(project="cifar10-classification", entity="your_wandb_username")


# 数据预处理和加载
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# 定义模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = SimpleCNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练函数
def train(epoch):
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:  # 每100个批次打印一次日志
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            wandb.log({"Training Loss": running_loss / 100})
            running_loss = 0.0

# 测试函数
def test():
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f} %')
    wandb.log({"Test Accuracy": accuracy})
    

if __name__ == '__main__':
    # 训练模型
    for epoch in range(10):  # 训练10个epoch
        train(epoch)
        test()

    print('Finished Training')

笔记

  1. torchvision.transforms 提供了许多图像预处理属性,常见的包括:
  • Resize:调整图像大小。
  • CenterCrop:从图像中心裁剪。
  • RandomCrop:随机裁剪图像。
  • RandomHorizontalFlip:随机水平翻转图像。
  • ToTensor:将PIL Image或numpy.ndarray转换为Tensor。
  • Normalize:标准化图像像素值。

这些预处理技术常见应用于图像分类任务中。例如,对于CIFAR-10图像分类:

  • 随机水平翻转:增加数据多样性,提高模型的泛化能力。
  • 随机裁剪:引入数据增强,通过随机裁剪增加训练数据的多样性,减少过拟合风险。
  • 张量转换:将图像数据转换为模型能够处理的张量格式。
  • 归一化:将图像像素值标准化,加速模型收敛,提高训练效果和稳定性。

transforms.Compose 是一个函数,用于将多个图像转换操作组合在一起,形成一个变换管道,便于对图像进行连续处理。

# 数据预处理和加载
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

padding=4,填充4像素,如果原始图像的尺寸小于所需的裁剪尺寸(例如这里的32x32),则需要进行填充。填充的目的是确保裁剪的区域足够大,从而保持裁剪后图像的尺寸一致性。

  1. torch.nn.functional 是PyTorch中的一个功能模块,提供了神经网络中的各种功能函数,如激活函数、损失函数、卷积、池化等。它提供了一些基本的函数,可以直接随时调用而不需要将其绑定到某个特定的层对象上。 适合作为一些不需要学习参数的操作的函数形式使用。torch.nn.Module 是所有神经网络模块的基类,所有的神经网络层(如卷积层、全连接层等)都需要继承自这个类。它包含了定义网络层和前向传播的方法,能够自动管理模型的参数和状态
  • torch.nn.Module有状态的,每个层(如卷积层、全连接层)都包含自己的参数(如权重和偏置),这些参数会在训练过程中进行更新。
  • torch.nn.functional无状态的,它仅仅是提供了一些功能函数,函数本身不包含任何参数。通俗来说,无状态就是这些函数不会维护自己的状态信息,调用这些函数时需要手动传入所有的参数和数据。
  1. 在深度学习中,你可能会经常遇到以下英文词汇:
  • Batch:批次,一次训练中使用的数据批次。
  • Optimizer:优化器,用于更新模型参数以减小损失函数值。
  • Regularization:正则化,用于防止过拟合的技术。
  • Epoch:周期,完整训练数据集的一次传递。
  • Pooling:池化操作,如最大池化和平均池化。
  • Padding:填充,通常在卷积操作中使用。
  • Dropout:随机失活,用于减少过拟合。
  • Batch Normalization:批标准化,用于加速收敛和稳定性。
  • shuffle:随机播放,洗牌。
  • numpy: 用于科学计算的库
  • Conv2d: 二维卷积层
  • MaxPool2d: 二维最大池化层
  • Linear: 全连接层
  • CrossEntropyLoss: 交叉熵损失
  • transform: 预处理操作
  • compose: 组合多个操作
  • normalize: 归一化
  • backward: 反向传播
  1. transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))实际上执行的是均值/标准差的操作,其中,(0.4914, 0.4822, 0.4465) 是均值,(0.2023, 0.1994, 0.2010) 是标准差。
  2. 训练集中需要保持数据随机性的原因是为了增加模型训练的多样性和泛化能力。测试集中,为了能够评估模型在真实场景下的性能,需要保持数据的真实顺序。

shuffle=True,训练的随机性;shuffle=False,保持数据顺序(英语重要性shuffle 随机播放,洗牌)。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
  1. 除了交叉熵损失函数(CrossEntropyLoss),常用的损失函数还包括:
  • 均方误差损失(Mean Squared Error, MSE) :适用于回归问题,如预测房价。
  • 交叉熵损失函数(CrossEntropyLoss):用于分类问题,尤其是多类分类
  • 二进制交叉熵损失(Binary Cross Entropy Loss) :适用于二分类问题。
  • 多标签Soft Margin损失(MultiLabelSoftMarginLoss) :适用于多标签分类问题。
  • K-L散度损失(Kullback-Leibler Divergence Loss) :用于衡量两个概率分布之间的差异。
  1. 常用的优化算法包括:
  • SGD (Stochastic Gradient Descent) :随机梯度下降,适用于大多数深度学习任务。
  • Adam:结合了Adaptive Moment Estimation和RMSProp的优点,适用于广泛的深度学习任务。
  • RMSprop:自适应学习率优化方法,适用于非平稳目标和RNN训练。
  • Adagrad:自适应地为各个参数分配不同的学习率,适用于稀疏数据集。

optim.SGD() 是PyTorch中SGD优化器的函数接口,常用参数包括:

  • lr:学习率,控制参数更新的步长。

  • momentum:动量,用于加速SGD在相关方向上的更新。通常设置为较大的值(如0.9),以在训练过程中保持良好的学习方向和速度,从而加速收敛。动量参数的取值范围通常在[0,1)之间。常见的取值是 0.9。

  • weight_decay:权重衰减(L2惩罚),用于防止过拟合。是用于正则化的一种方法。它通过在损失函数中加入权重的 L2 范数来防止模型过拟合。通俗来说,它在更新权重时,会稍微减小权重值,从而避免权重过大导致模型对训练数据的过拟合。

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

net.parameters() 表示获取模型中所有需要学习的参数,这些参数在优化过程中会被更新。 8. net.train()和net.eval()是分别将模型设置为训练模式和评估模式。在特定模式下,模型中的一些特定层(如Dropout和Batch Normalization)将表现出不同的行为,比如,评估模式下,eval() 方法会关闭这些层的训练特性。

def train(epoch):
    net.train()
    ...
    
def test():
    net.eval()
    ...
  1. Python中常见的数据输出格式:
  • f'' 是Python中的格式化字符串,用于在字符串中嵌入变量或表达式的值。

  • :.3f 是一种格式化字符串的方式,表示浮点数的输出精度为小数点后三位。

  1. torch.no_grad() 是一个上下文管理器,它将所有在其内部的操作都禁用梯度计算。这在评估模型时特别有用,因为在评估时不需要计算梯度,而只是需要前向传播以获取模型的预测结果。
  • _ 在这里是一个惯例,用来表示一个不需要的变量。在Python中,通常使用 _占位表示忽略某些不需要的值

  • torch.max(outputs.data, 1) 是一个PyTorch函数,用于沿着指定维度(这里是维度1)找到张量中的最大值及其索引

  • labels.size(0) 是PyTorch张量的方法,用于返回张量在第一个维度上的大小(通常是批次大小)。在深度学习中,经常使用 size(0) 来获取批次数据的大小。

  • (predicted == labels) 是一个逐元素比较操作,返回一个布尔张量,表示预测值与真实标签是否相等。

  • .sum() 是对布尔张量进行求和操作,统计预测正确的样本数。

  • .item() 是将张量转换为Python标量,返回正确预测的样本数。

with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total

这段代码用于统计在一个批次中预测正确的样本数,并将这个数量累加到 correct 变量中。

 _, predicted = torch.max(outputs.data, 1)

上面这行代码中,可以理解为卷积神经网络测试函数中,输出层张量中的最大值即为预测数据

  1. super(SimpleCNN, self).__init__() 的意思是调用 SimpleCNN 的父类(即 nn.Module)的初始化方法。这行代码的主要目的是确保 SimpleCNN 类能够继承 nn.Module 类的所有属性和方法,并且正确初始化。
  • super() 函数是用于调用父类的方法。在这里,它用于调用 nn.Module 的初始化方法。

  • SimpleCNN 是子类,nn.Module 是父类。

  • self实例自身

  • .__init__(): 这是调用父类的初始化方法。

  • def __init__(self):def forward(self, x):类的构造函数,在实例化对象时自动调用

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = SimpleCNN()

x:表示输入数据,是一个形状为 (batch_size, C, H, W) 的张量。

  1. enumerate() 函数用于将一个可迭代对象(如列表)组合为一个索引序列。常用属性有:
  • start:指定索引的起始值。
list_data = ['a', 'b', 'c']
for index, value in enumerate(list_data, start=1):
    print(index, value)
    
# 输出
# 1 a
# 2 b
# 3 c
  1. Python 中字符串格式化
  • % 格式化字符串:
Copy code
name = 'Alice'
age = 30
print('Name: %s, Age: %d' % (name, age))

# 输出
# Name: Alice, Age: 30
  • str.format() 方法:
Copy code
name = 'Alice'
age = 30
print('Name: {}, Age: {}'.format(name, age))

# 输出
# Name: Alice, Age: 30
  • 格式化字面值字符串 (f-string):
Copy code
name = 'Alice'
age = 30
print(f'Name: {name}, Age: {age}')

# 输出
# Name: Alice, Age: 30

一些报错

  • 使用if __name__ == '__main__': 语句来保护代码入口。Windows上使用多进程时将数据加载部分和训练部分包裹在if __name__ == '__main__':语句下
#原代码
for epoch in range(10):  # 训练10个epoch
    train(epoch)
    test()
print('Finished Training')

#错误
RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
  
  
#解决
#这个错误是由于在Windows系统上多进程处理数据加载时没有使用正确的`if __name__ == '__main__':` 
#语句来保护代码入口。这是Windows上常见的问题,因为它不支持Unix风格的fork机制。
#需要确保在Windows上使用多进程时将数据加载部分和训练部分包裹在`if __name__ == '__main__':`语句下。
if __name__ == '__main__':
    # 训练模型
    for epoch in range(10):  # 训练10个epoch
        train(epoch)
        test()
    print('Finished Training')

结果

W&B Chart 11_06_2024, 15_37_14.png

W&B Chart 11_06_2024, 15_37_48.png