前言
大家好,我是阿光。
本专栏整理了《PyTorch深度学习项目实战100例》,内包含了各种不同的深度学习项目,包含项目原理以及源码,每一个项目实例都附带有完整的代码+数据集。
正在更新中~ ✨
🚨 我的项目环境:
- 平台:Windows10
- 语言环境:python3.7
- 编译器:PyCharm
- PyTorch版本:1.8.1
💥 项目专栏:【PyTorch深度学习项目实战100例】
一、基于vgg16进行迁移学习服装分类
本项目在研究服装图像分类过程中,借助人工智能算法的优势,开展基于深度神经网络的图像分类实验。借助vgg-16模型进行迁移学习,经迭代100次后分类准确率达到95% 。
在这里插入图片描述
二、数据集介绍
「clothing-dataset」 — 服装数据集,超过5000张图片,20个不同类别。
在这里插入图片描述
「此数据集可自由用于任何目的,包括商业:例如:」
- 创建教程或课程(免费或付费)
- 写一本书
- Kaggle竞赛(作为外部数据集)
- 在任何公司培训内部模型
「数据文件images.csv包括:」
- image - 图像的 ID(使用它从images/.jpg加载图像)
- sender_id - 贡献图像的人的 ID
- label - 图像的类
- kids - 标记为“true”,说明它是孩子们的衣服
「数据集下载地址:」
三、vgg网络介绍
3.1 特点
- 结构简洁。VGG由5层卷积层、3层全连接层、softmax输出层构成,层与层之间使用max-pooling分开,所有隐层的激活单元都采用ReLU函数。
- 小卷积核和多卷积子层。VGG使用多个较小卷积核(3x3)的卷积层代替一个卷积核较大的卷积层,一方面可以减少参数,另一方面相当于进行了更多的非线性映射,可以增加网络的拟合/表达能力。VGG通过降低卷积核的大小(3x3),增加卷积子层数来达到同样的性能。
- 小池化核。相比AlexNet的3x3的池化核,VGG全部采用2x2的池化核。
- 通道数多。VGG网络第一层的通道数为64,后面每层都进行了翻倍,最多到512个通道,通道数的增加,使得更多的信息可以被提取出来。
- 层数更深、特征图更宽。使用连续的小卷积核代替大的卷积核,网络的深度更深,并且对边缘进行填充,卷积的过程并不会降低图像尺寸。
- 全连接转卷积(测试阶段)。在网络测试阶段将训练阶段的三个全连接替换为三个卷积,使得测试得到的全卷积网络因为没有全连接的限制,因而可以接收任意宽或高为的输入。
3.2 网络结构
在这里插入图片描述
- 输入224x224x3的图片,经64个3x3的卷积核作两次卷积+ReLU,卷积后的尺寸变为224x224x64
- 作max pooling(最大化池化),池化单元尺寸为2x2(效果为图像尺寸减半),池化后的尺寸变为112x112x64
- 经128个3x3的卷积核作两次卷积+ReLU,尺寸变为112x112x128
- 作2x2的max pooling池化,尺寸变为56x56x128
- 经256个3x3的卷积核作三次卷积+ReLU,尺寸变为56x56x256
- 作2x2的max pooling池化,尺寸变为28x28x256
- 经512个3x3的卷积核作三次卷积+ReLU,尺寸变为28x28x512
- 作2x2的max pooling池化,尺寸变为14x14x512
- 经512个3x3的卷积核作三次卷积+ReLU,尺寸变为14x14x512
- 作2x2的max pooling池化,尺寸变为7x7x512
- 与两层1x1x4096,一层1x1x1000进行全连接+ReLU(共三层)
- 通过softmax输出1000个预测结果
四、vgg迁移学习
首先需要调用torchvision中的models,通过models调用vgg16模型,如果将参数pretrained参数置为true,则会自动下载vgg16训练好的模型,但是网络较差很难下载完成,所以可以预先去其它网站下载好网络权重,然后将其加载到模型中。
本项目采用的是后者做法。
model_path = './checkpoints/vgg16-397923af.pth'
model = models.vgg16(pretrained=False)
model.load_state_dict(torch.load(model_path, 'cpu'))
pretrained=True即会返回一个预训练好的模型,vgg16在ImageNet上是224*224,之后只要将自己数据集中调整到符合vgg16要求即可了,但是要注意vgg16的输出,因为它是基于ImageNet训练的,所以输出会是一个1000维的向量,所以我们需要修改最后的全连接层来符合我们自己的任务需求,之后就可以使用该模型做预测了,或者进行微调和二次训练。
for parma in model.parameters(): # 设置自动梯度为false
parma.requires_grad = False
由于我们已经加载好了已经训练好的vgg模型,所有模型中的权重不需要重新训练,只需将参数的自动梯度设为false即可,这样现有模型中的权重就会被锁定。
首先我们先输出一下vgg16的网络结构
model._modules
在这里插入图片描述
可以看到最终的分类器在classifier中,关键层是最终的全连接层为model.classifier[6],所以我们需要修改它。
model.classifier[6] = nn.Linear(in_features=4096, out_features=2)
修改后的模型即可看到最终的全连接层输出维度为19,即对应我们服装数据的类别数量。
五、模型训练
❝
train epoch[1/10] loss:4.383: 100%|████████████████████████████████████████████████████| 16/16 [03:06<00:00, 11.66s/it] 【EPOCH: 】1 训练损失为183.66039514541626 训练精度为34.66% train epoch[2/10] loss:6.681: 100%|████████████████████████████████████████████████████| 16/16 [03:05<00:00, 11.59s/it] 【EPOCH: 】2 训练损失为128.7272868156433 训练精度为46.49% train epoch[3/10] loss:12.032: 100%|███████████████████████████████████████████████████| 16/16 [02:54<00:00, 10.93s/it] 【EPOCH: 】3 训练损失为152.00664615631104 训练精度为46.69% train epoch[4/10] loss:7.229: 100%|████████████████████████████████████████████████████| 16/16 [03:07<00:00, 11.71s/it] 【EPOCH: 】4 训练损失为145.03388357162476 训练精度为44.28% train epoch[5/10] loss:12.364: 100%|███████████████████████████████████████████████████| 16/16 [02:59<00:00, 11.25s/it] 【EPOCH: 】5 训练损失为125.43753290176392 训练精度为51.10% train epoch[6/10] loss:7.263: 100%|████████████████████████████████████████████████████| 16/16 [02:48<00:00, 10.51s/it] 【EPOCH: 】6 训练损失为131.48169898986816 训练精度为49.09% train epoch[7/10] loss:13.049: 100%|███████████████████████████████████████████████████| 16/16 [02:54<00:00, 10.91s/it] 【EPOCH: 】7 训练损失为141.09557104110718 训练精度为50.10% train epoch[8/10] loss:4.843: 100%|████████████████████████████████████████████████████| 16/16 [02:55<00:00, 10.96s/it] 【EPOCH: 】8 训练损失为133.9519591331482 训练精度为54.50% train epoch[9/10] loss:13.381: 100%|███████████████████████████████████████████████████| 16/16 [02:42<00:00, 10.18s/it] 【EPOCH: 】9 训练损失为134.1738579273224 训练精度为56.11% train epoch[10/10] loss:7.710: 100%|███████████████████████████████████████████████████| 16/16 [02:57<00:00, 11.09s/it] 【EPOCH: 】10 训练损失为116.83155393600464 训练精度为60.12%
❞
六、完整源码
【PyTorch深度学习项目实战100例】—— 基于vgg16进行迁移学习服装分类 | 第12例_咕 嘟的博客-CSDN博客_pytorchvgg16迁移学习