基于PyTorch的深度学习模型,主要用于图像分类任务,设计数据预处理、数据集定义、半监督学习、模型构建以及训练和验证等部分。
首先导入所需的库:
其中包括torch, torch.nn, torch.utils.data用于构建和训练神经网络。
numpy, random处理数据。
PIL.Image读取图片。
tqdm可以显示进度条,方便观察数据加载进度。
torchvision.transforms处理图像数据。
matplotlib.pyplot用于绘制训练曲线。
接着设定了随机种子,作用是确保每次运行的结果一致,避免训练时的随机性带来的不确定性。
设定图像的尺寸为常用的224 x 224.
然后定义数据增强和预处理的函数,RandomResizedCrop(224),随机裁剪并调整到224x224,RandomRotation(50)随机旋转图片,角度范围为-50°~50°,ToTensor()转化为PyTorch张量。
训练集进行数据增强,验证集仅转化为Tensor,不进行数据增强。
接下来定义数据集类:继承Dataset,用于自定义数据集加载。mode决定数据集的类别,"train"是训练集,带有标签,"val"是验证集,带有标签。"semi"是半监督学习数据集,无标签。
接下来是数据加载,半监督模式下,只读取图片,无标签。
其中半监督数据集的模型,对无标签数据进行预测,并筛选出置信度大于0.99的数据作为伪标签。
下面定义CNN分类模型,Conv2d()为卷积层,将输入通道从3变成64,MaxPool2d()池化,减少计算量。
最后开始训练与验证的函数,训练模型,包括:
train_loader:训练数据
val_loader:验证数据
no_label_loader:半监督数据
loss:损失函数
optimizer:优化器
epochs:训练轮次
通过这个函数进行前向传播、计算损失、反向传播、更新参数。