携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第9天,点击查看活动详情
一:Pytorch实现经典模型AlexNet模型
要求:
使用pytorch实现经典的分类模型AlexNet,这里主要因为没有GPU环境,而其完整参数达到了6000万个,所以如ppt要求,在该模型的基础架构上,修改卷积核的大小以及卷积操作的步长等来模拟实现。
实验设计:
实验过程: 注:这里主要介绍一下AlexNet模型的定义,其中因为参数量过大,以及图片的输入大小变为了64*64,所以对于每层的卷积核大小以及步长等做了相关变化。 1.1AlexNet模型定义
1. # 定义神经网络
2. class ALexNet(nn.Module): # 训练 ALexNet
3. '''''
4. 五层卷积,三层全连接 (输入图片大小是 C x H x W ---> 3 * 64 * 64)
5. 这里因为图片大小是64*64,所以这里重新改变了各层的步长、卷积核大小等
6. '''
7. def __init__(self):
8. super(ALexNet, self).__init__()
9. # 五个卷积层
10. self.conv1 = nn.Sequential( # 输入 3 * 64 * 64 输出 6*16*16
11. nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1), # (64-3+2)/1+1 = 64
12. nn.ReLU(),
13. nn.MaxPool2d(kernel_size=4, stride=4, padding=0) # (64-4)/4+1 = 16
14. )
15. self.conv2 = nn.Sequential( # 输入 6 * 16 * 16 输出 16*8*8
16. nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3, stride=1, padding=1), # (16-3+2)/1+1 = 16
17. nn.ReLU(),
18. nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # (16-2)/2+1 = 8
19. )
20. self.conv3 = nn.Sequential( # 输入 16 * 8 * 8 输出 32 * 8 * 8
21. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), # (8-3+2)/1+1 = 8
22. nn.ReLU()
23. )
24. self.conv4 = nn.Sequential( # 输入 32 * 8 * 8 输出 64 * 8 * 8
25. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), # (8-3+2)/1+1 = 8
26. nn.ReLU()
27. )
28. self.conv5 = nn.Sequential( # 输入 64 * 8 * 8 输出 128 * 1 * 1
29. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),# (8-3+2)/1+1 = 8
30. nn.ReLU(),
31. nn.MaxPool2d(kernel_size=4, stride=4, padding=0) # (8-4)/4 + 1 = 2
32. )
33. # 最后一层卷积层,输出 128 * 2 * 2
34. # 全连接层
35. self.dense = nn.Sequential(
36. nn.Linear(512, 120),
37. nn.ReLU(),
38. nn.Linear(120, 84),
39. nn.ReLU(),
40. nn.Linear(84, 3)
41. )
42.
43. def forward(self, x):
44. x = self.conv1(x)
45. x = self.conv2(x)
46. x = self.conv3(x)
47. x = self.conv4(x)
48. x = self.conv5(x)
49. x = x.view(-1, 512)
50. x = self.dense(x)
51. return x
注:主要包括5层卷积层和3层全连接层,其卷积层的卷积核的大小、步长等。
完整代码及数据集下载,见:download.csdn.net/download/qq…