深度学习的分类任务细节

47 阅读2分钟

基于PyTorch的深度学习模型,主要用于图像分类任务,设计数据预处理、数据集定义、半监督学习、模型构建以及训练和验证等部分。

首先导入所需的库:

其中包括torch, torch.nn, torch.utils.data用于构建和训练神经网络。

numpy, random处理数据。

PIL.Image读取图片。

tqdm可以显示进度条,方便观察数据加载进度。

torchvision.transforms处理图像数据。

matplotlib.pyplot用于绘制训练曲线。

接着设定了随机种子,作用是确保每次运行的结果一致,避免训练时的随机性带来的不确定性。

设定图像的尺寸为常用的224 x 224.

然后定义数据增强和预处理的函数,RandomResizedCrop(224),随机裁剪并调整到224x224,RandomRotation(50)随机旋转图片,角度范围为-50°~50°,ToTensor()转化为PyTorch张量。

训练集进行数据增强,验证集仅转化为Tensor,不进行数据增强。

接下来定义数据集类:继承Dataset,用于自定义数据集加载。mode决定数据集的类别,"train"是训练集,带有标签,"val"是验证集,带有标签。"semi"是半监督学习数据集,无标签。

接下来是数据加载,半监督模式下,只读取图片,无标签。

其中半监督数据集的模型,对无标签数据进行预测,并筛选出置信度大于0.99的数据作为伪标签。

下面定义CNN分类模型,Conv2d()为卷积层,将输入通道从3变成64,MaxPool2d()池化,减少计算量。

最后开始训练与验证的函数,训练模型,包括:

train_loader:训练数据

val_loader:验证数据

no_label_loader:半监督数据

loss:损失函数

optimizer:优化器

epochs:训练轮次

通过这个函数进行前向传播、计算损失、反向传播、更新参数。