paddle卷积神经网络实现Cifar10数据集解析
卷积神经网络解析
本项目把几大重要的卷积神经网络进行了解析使用了Cifar10
项目是陆平老师的,解析采取了由上至下的方式,上面的解析详细,下面的可能没有标注
如果有疑问可以留言或私聊我都可以。
案例一:AlexNet网络
AlexNet模型由Alex Krizhevsky、Ilya Sutskever和Geoffrey E. Hinton开发,是2012年ImageNet挑战赛冠军模型。相比于LeNet模型,AlexNet的神经网络层数更多,其中包含ReLU激活层,并且在全连接层引入Dropout机制防止过拟合。下面详细解析AlexNet模型。
原项目传送门
import paddle
import paddle.nn.functional as F
import numpy as np
from paddle.vision.transforms import Compose, Resize, Transpose, Normalize
#准备数据
t = Compose([Resize(size=227),Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], data_format='HWC'),Transpose()]) # 数据归一化处理
cifar10_train = paddle.vision.datasets.cifar.Cifar10(mode='train', transform=t, backend='cv2')
cifar10_test = paddle.vision.datasets.cifar.Cifar10(mode="test", transform=t, backend='cv2')
Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-10-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-10-python.tar.gz
Begin to download
Download finished
for i in cifar10_train:
print(type(i))
print(i[0].shape)
print(i)
break
<class 'tuple'>
(3, 227, 227)
(array([[[ 0.39607844, 0.39607844, 0.39607844, ..., 0.29411766,
0.29411766, 0.29411766],
[ 0.39607844, 0.39607844, 0.39607844, ..., 0.29411766,
0.29411766, 0.29411766],
[ 0.39607844, 0.39607844, 0.39607844, ..., 0.29411766,
0.29411766, 0.29411766],
...,
[-0.19215687, -0.19215687, -0.19215687, ..., -0.28627452,
-0.28627452, -0.28627452],
[-0.19215687, -0.19215687, -0.19215687, ..., -0.28627452,
-0.28627452, -0.28627452],
[-0.19215687, -0.19215687, -0.19215687, ..., -0.28627452,
-0.28627452, -0.28627452]],
[[ 0.38039216, 0.38039216, 0.38039216, ..., 0.2784314 ,
0.2784314 , 0.2784314 ],
[ 0.38039216, 0.38039216, 0.38039216, ..., 0.2784314 ,
0.2784314 , 0.2784314 ],
[ 0.38039216, 0.38039216, 0.38039216, ..., 0.2784314 ,
0.2784314 , 0.2784314 ],
...,
[-0.24705882, -0.24705882, -0.24705882, ..., -0.34117648,
-0.34117648, -0.34117648],
[-0.24705882, -0.24705882, -0.24705882, ..., -0.34117648,
-0.34117648, -0.34117648],
[-0.24705882, -0.24705882, -0.24705882, ..., -0.34117648,
-0.34117648, -0.34117648]],
[[ 0.48235294, 0.48235294, 0.48235294, ..., 0.3647059 ,
0.3647059 , 0.3647059 ],
[ 0.48235294, 0.48235294, 0.48235294, ..., 0.3647059 ,
0.3647059 , 0.3647059 ],
[ 0.48235294, 0.48235294, 0.48235294, ..., 0.3647059 ,
0.3647059 , 0.3647059 ],
...,
[-0.2784314 , -0.2784314 , -0.2784314 , ..., -0.39607844,
-0.39607844, -0.39607844],
[-0.2784314 , -0.2784314 , -0.2784314 , ..., -0.39607844,
-0.39607844, -0.39607844],
[-0.2784314 , -0.2784314 , -0.2784314 , ..., -0.39607844,
-0.39607844, -0.39607844]]], dtype=float32), array(0))
AlexNet卷积网络解析
卷积操作Conv2D
paddle.nn.Conv2D
涉及到的参数
in_channels (int) - 输入图像的通道数。
out_channels (int) - 由卷积操作产生的输出的通道数。
kernel_size (int|list|tuple) - 卷积核大小。可以为单个整数或包含两个整数的元组或列表,分别表示卷积核的高和宽。如果为单个整数,表示卷积核的高和宽都等于该整数。
stride (int|list|tuple,可选) - 步长大小。可以为单个整数或包含两个整数的元组或列表,分别表示卷积沿着高和宽的步长。如果为单个整数,表示沿着高和宽的步长都等于该整数。默认值:1。
padding (int|list|tuple|str,可选) - 填充大小。
以paddle.nn.Conv2D(3,96,11,4,0)为例进行解析
3:输入为三通道
96:输出的通道数
11:卷积核大小(fw = fh = 11 )
4:步长大小 (s=4)
0:填充大小 (p=0)
输入大小:3 * 227 * 227(xw = xh = 227)
按照计算公式
new_w = (227+0-11)/4 +1 = 55
new_h同理等于55
输出大小等于 输出的通道数 * new_w * new_y = 96 * 55 * 55
池化操作
paddle.nn.MaxPool2D(最大池化)
主要数据:
kernel_size (int|list|tuple): 池化核大小。如果它是一个元组或列表,它必须包含两个整数值, (pool_size_Height, pool_size_Width)。若为一个整数,则它的平方值将作为池化核大小,比如若pool_size=2, 则池化核大小为2x2。
stride (int|list|tuple):池化层的步长。如果它是一个元组或列表,它将包含两个整数,(pool_stride_Height, pool_stride_Width)。若为一个整数,则表示H和W维度上stride均为该值。默认值为kernel_size.
padding (string|int|list|tuple) 池化填充。
输出大小:w = h = (55+0-3)/2 +1 = 27
#构建模型
class AlexNetModel(paddle.nn.Layer):
def __init__(self):
super(AlexNetModel, self).__init__()
self.conv_pool1 = paddle.nn.Sequential( # 输入大小m*3*227*227
paddle.nn.Conv2D(3,96,11,4,0), # L1, 输出大小m*96*55*55
paddle.nn.ReLU(), # L2, 输出大小m*96*55*55
paddle.nn.MaxPool2D(kernel_size=3, stride=2)) # L3, 输出大小m*96*27*27
self.conv_pool2 = paddle.nn.Sequential(
paddle.nn.Conv2D(96, 256, 5, 1, 2), # L4, 输出大小m*256*27*27
paddle.nn.ReLU(), # L5, 输出大小m*256*27*27
paddle.nn.MaxPool2D(3, 2)) # L6, 输出大小m*256*13*13
self.conv_pool3 = paddle.nn.Sequential(
paddle.nn.Conv2D(256, 384, 3, 1, 1), # L7, 输出大小m*384*13*13
paddle.nn.ReLU()) # L8, 输出m*384*13*13
self.conv_pool4 = paddle.nn.Sequential(
paddle.nn.Conv2D(384, 384, 3, 1, 1),# L9, 输出大小m*384*13*13
paddle.nn.ReLU()) # L10, 输出大小m*384*13*13
self.conv_pool5 = paddle.nn.Sequential(
paddle.nn.Conv2D(384, 256, 3, 1, 1),# L11, 输出大小m*256*13*13
paddle.nn.ReLU(), # L12, 输出大小m*256*13*13
paddle.nn.MaxPool2D(3, 2)) # L13, 输出大小m*256*6*6
self.full_conn = paddle.nn.Sequential(
paddle.nn.Linear(256*6*6, 4096), # L14, 输出大小m*4096
paddle.nn.ReLU(), # L15, 输出大小m*4096
paddle.nn.Dropout(0.5), # L16, 输出大小m*4096
paddle.nn.Linear(4096, 4096), # L17, 输出大小m*4096
paddle.nn.ReLU(), # L18, 输出大小m*4096
paddle.nn.Dropout(0.5), # L19, 输出大小m*4096
paddle.nn.Linear(4096, 10)) # L20, 输出大小m*10
self.flatten=paddle.nn.Flatten()
def forward(self, x): # 前向传播
x = self.conv_pool1(x)
x = self.conv_pool2(x)
x = self.conv_pool3(x)
x = self.conv_pool4(x)
x = self.conv_pool5(x)
x = self.flatten(x)
x = self.full_conn(x)
return x
epoch_num = 2
batch_size = 256
learning_rate = 0.0001
val_acc_history = []
val_loss_history = []
model = AlexNetModel()
paddle.summary(model,(1,3,227,227))
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Conv2D-1 [[1, 3, 227, 227]] [1, 96, 55, 55] 34,944
ReLU-1 [[1, 96, 55, 55]] [1, 96, 55, 55] 0
MaxPool2D-1 [[1, 96, 55, 55]] [1, 96, 27, 27] 0
Conv2D-2 [[1, 96, 27, 27]] [1, 256, 27, 27] 614,656
ReLU-2 [[1, 256, 27, 27]] [1, 256, 27, 27] 0
MaxPool2D-2 [[1, 256, 27, 27]] [1, 256, 13, 13] 0
Conv2D-3 [[1, 256, 13, 13]] [1, 384, 13, 13] 885,120
ReLU-3 [[1, 384, 13, 13]] [1, 384, 13, 13] 0
Conv2D-4 [[1, 384, 13, 13]] [1, 384, 13, 13] 1,327,488
ReLU-4 [[1, 384, 13, 13]] [1, 384, 13, 13] 0
Conv2D-5 [[1, 384, 13, 13]] [1, 256, 13, 13] 884,992
ReLU-5 [[1, 256, 13, 13]] [1, 256, 13, 13] 0
MaxPool2D-3 [[1, 256, 13, 13]] [1, 256, 6, 6] 0
Flatten-1 [[1, 256, 6, 6]] [1, 9216] 0
Linear-1 [[1, 9216]] [1, 4096] 37,752,832
ReLU-6 [[1, 4096]] [1, 4096] 0
Dropout-1 [[1, 4096]] [1, 4096] 0
Linear-2 [[1, 4096]] [1, 4096] 16,781,312
ReLU-7 [[1, 4096]] [1, 4096] 0
Dropout-2 [[1, 4096]] [1, 4096] 0
Linear-3 [[1, 4096]] [1, 10] 40,970
===========================================================================
Total params: 58,322,314
Trainable params: 58,322,314
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.59
Forward/backward pass size (MB): 11.11
Params size (MB): 222.48
Estimated Total Size (MB): 234.18
---------------------------------------------------------------------------
{'total_params': 58322314, 'trainable_params': 58322314}
def train(model):
#启动训练模式
model.train()
opt = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters()) # 优化器Adam
train_loader = paddle.io.DataLoader(cifar10_train, shuffle=True, batch_size=batch_size) # 数据集乱序处理
valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
for epoch in range(epoch_num): # 训练轮数
for batch_id, data in enumerate(train_loader()): # 训练集拆包
x_data = paddle.cast(data[0], 'float32') # 转换数据类型
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1)) # 数据形状重构
y_predict = model(x_data) # 导入模型
loss = F.cross_entropy(y_predict, y_data) # 返回loss
loss.backward() # 对 loss进行反馈
opt.step()
opt.clear_grad()
print("训练轮次: {}; 损失: {}".format(epoch, loss.numpy()))
#每训练完1个epoch, 用测试数据集来验证一下模型
model.eval()
accuracies = []
losses = []
# 与训练集相同进行测试
for batch_id, data in enumerate(valid_loader()):
x_data = paddle.cast(data[0], 'float32')
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1))
y_predict = model(x_data)
# 数据处理
loss = F.cross_entropy(y_predict, y_data)
acc = paddle.metric.accuracy(y_predict, y_data) # 返回acc数据
accuracies.append(np.mean(acc.numpy())) # 对数据进行处理
losses.append(np.mean(loss.numpy()))
avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("评估准确度为:{};损失为:{}".format(avg_acc, avg_loss))
val_acc_history.append(avg_acc)
val_loss_history.append(avg_loss)
model.train()
model = AlexNetModel()
train(model)
训练轮次: 0; 损失: [1.242394]
评估准确度为:0.5498046875;损失为:1.2228338718414307
训练轮次: 1; 损失: [0.9661963]
评估准确度为:0.648632824420929;损失为:0.9971475601196289
案例二:卷积网络GoogLeNet
项目传送门
GoogLeNet模型是由谷歌(Google)团队开发出来的卷积神经网络,它是2014年ImageNet挑战赛的冠军模型。相比于AlexNet模型,GoogLeNet模型的网络结构更深,共包括87层。尽管模型结构变得更复杂,但参数量更少了。GoogLeNet模型的参数量为5941552个,仅为AlexNet模型参数量的1/10。这主要归功于GoogLeNet创新性地采用了Inception模块。下面详细解析GoogLeNet模型原理。
import paddle
import paddle.nn.functional as F
import numpy as np
from paddle.vision.transforms import Compose, Resize, Transpose, Normalize
t = Compose([Resize(size=96),Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], data_format='HWC'),Transpose()]) #数据转换
cifar10_train = paddle.vision.datasets.cifar.Cifar10(mode='train', transform=t, backend='cv2')
cifar10_test = paddle.vision.datasets.cifar.Cifar10(mode="test", transform=t, backend='cv2')
# 构建模型(Inception层)
class Inception(paddle.nn.Layer):
def __init__(self, in_channels, c1, c2, c3, c4):
super(Inception, self).__init__()
# 路线1,卷积核1x1
self.route1x1_1 = paddle.nn.Conv2D(in_channels, c1, kernel_size=1)
# 路线2,卷积层1x1、卷积层3x3
self.route1x1_2 = paddle.nn.Conv2D(in_channels, c2[0], kernel_size=1)
self.route3x3_2 = paddle.nn.Conv2D(c2[0], c2[1], kernel_size=3, padding=1)
# 路线3,卷积层1x1、卷积层5x5
self.route1x1_3 = paddle.nn.Conv2D(in_channels, c3[0], kernel_size=1)
self.route5x5_3 = paddle.nn.Conv2D(c3[0], c3[1], kernel_size=5, padding=2)
# 路线4,池化层3x3、卷积层1x1
self.route3x3_4 = paddle.nn.MaxPool2D(kernel_size=3, stride=1, padding=1)
self.route1x1_4 = paddle.nn.Conv2D(in_channels, c4, kernel_size=1)
def forward(self, x):
route1 = F.relu(self.route1x1_1(x))
route2 = F.relu(self.route3x3_2(F.relu(self.route1x1_2(x))))
route3 = F.relu(self.route5x5_3(F.relu(self.route1x1_3(x))))
route4 = F.relu(self.route1x1_4(self.route3x3_4(x)))
out = [route1, route2, route3, route4]
return paddle.concat(out, axis=1) # 在通道维度(axis=1)上进行连接
# 构建 BasicConv2d 层
def BasicConv2d(in_channels, out_channels, kernel, stride=1, padding=0):
layer = paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels, out_channels, kernel, stride, padding),
paddle.nn.BatchNorm2D(out_channels, epsilon=1e-3),
paddle.nn.ReLU())
return layer
# 搭建网络
class GoogLeNet(paddle.nn.Layer):
def __init__(self, in_channel, num_classes):
super(GoogLeNet, self).__init__()
self.b1 = paddle.nn.Sequential(
BasicConv2d(in_channel, out_channels=64, kernel=7, stride=2, padding=3),
paddle.nn.MaxPool2D(3, 2))
self.b2 = paddle.nn.Sequential(
BasicConv2d(64, 64, kernel=1),
BasicConv2d(64, 192, kernel=3, padding=1),
paddle.nn.MaxPool2D(3, 2))
self.b3 = paddle.nn.Sequential(
Inception(192, 64, (96, 128), (16, 32), 32),
Inception(256, 128, (128, 192), (32, 96), 64),
paddle.nn.MaxPool2D(3, 2))
self.b4 = paddle.nn.Sequential(
Inception(480, 192, (96, 208), (16, 48), 64),
Inception(512, 160, (112, 224), (24, 64), 64),
Inception(512, 128, (128, 256), (24, 64), 64),
Inception(512, 112, (144, 288), (32, 64), 64),
Inception(528, 256, (160, 320), (32, 128), 128),
paddle.nn.MaxPool2D(3, 2))
self.b5 = paddle.nn.Sequential(
Inception(832, 256, (160, 320), (32, 128), 128),
Inception(832, 384, (182, 384), (48, 128), 128),
paddle.nn.AvgPool2D(2))
self.flatten=paddle.nn.Flatten()
self.b6 = paddle.nn.Linear(1024, num_classes)
def forward(self, x):
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
x = self.b4(x)
x = self.b5(x)
x = self.flatten(x)
x = self.b6(x)
return x
model = GoogLeNet(3, 10)
paddle.summary(model,(256, 3, 96, 96))
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Conv2D-187 [[256, 3, 96, 96]] [256, 64, 48, 48] 9,472
BatchNorm2D-10 [[256, 64, 48, 48]] [256, 64, 48, 48] 256
ReLU-31 [[256, 64, 48, 48]] [256, 64, 48, 48] 0
MaxPool2D-49 [[256, 64, 48, 48]] [256, 64, 23, 23] 0
Conv2D-188 [[256, 64, 23, 23]] [256, 64, 23, 23] 4,160
BatchNorm2D-11 [[256, 64, 23, 23]] [256, 64, 23, 23] 256
ReLU-32 [[256, 64, 23, 23]] [256, 64, 23, 23] 0
Conv2D-189 [[256, 64, 23, 23]] [256, 192, 23, 23] 110,784
BatchNorm2D-12 [[256, 192, 23, 23]] [256, 192, 23, 23] 768
ReLU-33 [[256, 192, 23, 23]] [256, 192, 23, 23] 0
MaxPool2D-50 [[256, 192, 23, 23]] [256, 192, 11, 11] 0
Conv2D-190 [[256, 192, 11, 11]] [256, 64, 11, 11] 12,352
Conv2D-191 [[256, 192, 11, 11]] [256, 96, 11, 11] 18,528
Conv2D-192 [[256, 96, 11, 11]] [256, 128, 11, 11] 110,720
Conv2D-193 [[256, 192, 11, 11]] [256, 16, 11, 11] 3,088
Conv2D-194 [[256, 16, 11, 11]] [256, 32, 11, 11] 12,832
MaxPool2D-51 [[256, 192, 11, 11]] [256, 192, 11, 11] 0
Conv2D-195 [[256, 192, 11, 11]] [256, 32, 11, 11] 6,176
Inception-28 [[256, 192, 11, 11]] [256, 256, 11, 11] 0
Conv2D-196 [[256, 256, 11, 11]] [256, 128, 11, 11] 32,896
Conv2D-197 [[256, 256, 11, 11]] [256, 128, 11, 11] 32,896
Conv2D-198 [[256, 128, 11, 11]] [256, 192, 11, 11] 221,376
Conv2D-199 [[256, 256, 11, 11]] [256, 32, 11, 11] 8,224
Conv2D-200 [[256, 32, 11, 11]] [256, 96, 11, 11] 76,896
MaxPool2D-52 [[256, 256, 11, 11]] [256, 256, 11, 11] 0
Conv2D-201 [[256, 256, 11, 11]] [256, 64, 11, 11] 16,448
Inception-29 [[256, 256, 11, 11]] [256, 480, 11, 11] 0
MaxPool2D-53 [[256, 480, 11, 11]] [256, 480, 5, 5] 0
Conv2D-202 [[256, 480, 5, 5]] [256, 192, 5, 5] 92,352
Conv2D-203 [[256, 480, 5, 5]] [256, 96, 5, 5] 46,176
Conv2D-204 [[256, 96, 5, 5]] [256, 208, 5, 5] 179,920
Conv2D-205 [[256, 480, 5, 5]] [256, 16, 5, 5] 7,696
Conv2D-206 [[256, 16, 5, 5]] [256, 48, 5, 5] 19,248
MaxPool2D-54 [[256, 480, 5, 5]] [256, 480, 5, 5] 0
Conv2D-207 [[256, 480, 5, 5]] [256, 64, 5, 5] 30,784
Inception-30 [[256, 480, 5, 5]] [256, 512, 5, 5] 0
Conv2D-208 [[256, 512, 5, 5]] [256, 160, 5, 5] 82,080
Conv2D-209 [[256, 512, 5, 5]] [256, 112, 5, 5] 57,456
Conv2D-210 [[256, 112, 5, 5]] [256, 224, 5, 5] 226,016
Conv2D-211 [[256, 512, 5, 5]] [256, 24, 5, 5] 12,312
Conv2D-212 [[256, 24, 5, 5]] [256, 64, 5, 5] 38,464
MaxPool2D-55 [[256, 512, 5, 5]] [256, 512, 5, 5] 0
Conv2D-213 [[256, 512, 5, 5]] [256, 64, 5, 5] 32,832
Inception-31 [[256, 512, 5, 5]] [256, 512, 5, 5] 0
Conv2D-214 [[256, 512, 5, 5]] [256, 128, 5, 5] 65,664
Conv2D-215 [[256, 512, 5, 5]] [256, 128, 5, 5] 65,664
Conv2D-216 [[256, 128, 5, 5]] [256, 256, 5, 5] 295,168
Conv2D-217 [[256, 512, 5, 5]] [256, 24, 5, 5] 12,312
Conv2D-218 [[256, 24, 5, 5]] [256, 64, 5, 5] 38,464
MaxPool2D-56 [[256, 512, 5, 5]] [256, 512, 5, 5] 0
Conv2D-219 [[256, 512, 5, 5]] [256, 64, 5, 5] 32,832
Inception-32 [[256, 512, 5, 5]] [256, 512, 5, 5] 0
Conv2D-220 [[256, 512, 5, 5]] [256, 112, 5, 5] 57,456
Conv2D-221 [[256, 512, 5, 5]] [256, 144, 5, 5] 73,872
Conv2D-222 [[256, 144, 5, 5]] [256, 288, 5, 5] 373,536
Conv2D-223 [[256, 512, 5, 5]] [256, 32, 5, 5] 16,416
Conv2D-224 [[256, 32, 5, 5]] [256, 64, 5, 5] 51,264
MaxPool2D-57 [[256, 512, 5, 5]] [256, 512, 5, 5] 0
Conv2D-225 [[256, 512, 5, 5]] [256, 64, 5, 5] 32,832
Inception-33 [[256, 512, 5, 5]] [256, 528, 5, 5] 0
Conv2D-226 [[256, 528, 5, 5]] [256, 256, 5, 5] 135,424
Conv2D-227 [[256, 528, 5, 5]] [256, 160, 5, 5] 84,640
Conv2D-228 [[256, 160, 5, 5]] [256, 320, 5, 5] 461,120
Conv2D-229 [[256, 528, 5, 5]] [256, 32, 5, 5] 16,928
Conv2D-230 [[256, 32, 5, 5]] [256, 128, 5, 5] 102,528
MaxPool2D-58 [[256, 528, 5, 5]] [256, 528, 5, 5] 0
Conv2D-231 [[256, 528, 5, 5]] [256, 128, 5, 5] 67,712
Inception-34 [[256, 528, 5, 5]] [256, 832, 5, 5] 0
MaxPool2D-59 [[256, 832, 5, 5]] [256, 832, 2, 2] 0
Conv2D-232 [[256, 832, 2, 2]] [256, 256, 2, 2] 213,248
Conv2D-233 [[256, 832, 2, 2]] [256, 160, 2, 2] 133,280
Conv2D-234 [[256, 160, 2, 2]] [256, 320, 2, 2] 461,120
Conv2D-235 [[256, 832, 2, 2]] [256, 32, 2, 2] 26,656
Conv2D-236 [[256, 32, 2, 2]] [256, 128, 2, 2] 102,528
MaxPool2D-60 [[256, 832, 2, 2]] [256, 832, 2, 2] 0
Conv2D-237 [[256, 832, 2, 2]] [256, 128, 2, 2] 106,624
Inception-35 [[256, 832, 2, 2]] [256, 832, 2, 2] 0
Conv2D-238 [[256, 832, 2, 2]] [256, 384, 2, 2] 319,872
Conv2D-239 [[256, 832, 2, 2]] [256, 182, 2, 2] 151,606
Conv2D-240 [[256, 182, 2, 2]] [256, 384, 2, 2] 629,376
Conv2D-241 [[256, 832, 2, 2]] [256, 48, 2, 2] 39,984
Conv2D-242 [[256, 48, 2, 2]] [256, 128, 2, 2] 153,728
MaxPool2D-61 [[256, 832, 2, 2]] [256, 832, 2, 2] 0
Conv2D-243 [[256, 832, 2, 2]] [256, 128, 2, 2] 106,624
Inception-36 [[256, 832, 2, 2]] [256, 1024, 2, 2] 0
AvgPool2D-4 [[256, 1024, 2, 2]] [256, 1024, 1, 1] 0
Flatten-1955 [[256, 1024, 1, 1]] [256, 1024] 0
Linear-13 [[256, 1024]] [256, 10] 10,250
===========================================================================
Total params: 5,942,192
Trainable params: 5,940,912
Non-trainable params: 1,280
---------------------------------------------------------------------------
Input size (MB): 27.00
Forward/backward pass size (MB): 2810.82
Params size (MB): 22.67
Estimated Total Size (MB): 2860.48
---------------------------------------------------------------------------
{'total_params': 5942192, 'trainable_params': 5940912}
epoch_num = 2 # 训练轮数
batch_size = 256 # 训练使用的批大小
learning_rate = 0.001 # 学习率
val_acc_history = []
val_loss_history = []
train_loader = paddle.io.DataLoader(cifar10_train, shuffle=True, batch_size=batch_size)
for i in train_loader:
print(i)
break
[Tensor(shape=[256, 3, 96, 96], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,
[[[[-0.78039217, -0.78039217, -0.78039217, ..., -0.81176472, -0.81960785, -0.81960785],
[-0.78039217, -0.78039217, -0.78039217, ..., -0.81176472, -0.81960785, -0.81960785],
[-0.77254903, -0.77254903, -0.77254903, ..., -0.80392158, -0.80392158, -0.80392158],
...,
[ 0.11372549, 0.11372549, 0.11372549, ..., -0.03529412, -0.05098039, -0.05098039],
[ 0.12941177, 0.12941177, 0.12941177, ..., -0.06666667, -0.08235294, -0.08235294],
[ 0.12941177, 0.12941177, 0.12941177, ..., -0.06666667, -0.08235294, -0.08235294]],
[[-0.77254903, -0.77254903, -0.77254903, ..., -0.81176472, -0.81176472, -0.81176472],
[-0.77254903, -0.77254903, -0.77254903, ..., -0.80392158, -0.81176472, -0.81176472],
[-0.76470590, -0.76470590, -0.76470590, ..., -0.79607844, -0.80392158, -0.80392158],
...,
[ 0.05882353, 0.05882353, 0.05882353, ..., -0.11372549, -0.12156863, -0.12156863],
[ 0.07450981, 0.07450981, 0.07450981, ..., -0.14509805, -0.15294118, -0.15294118],
[ 0.07450981, 0.07450981, 0.07450981, ..., -0.14509805, -0.15294118, -0.15294118]],
[[-0.75686276, -0.75686276, -0.75686276, ..., -0.77254903, -0.77254903, -0.77254903],
[-0.75686276, -0.75686276, -0.75686276, ..., -0.77254903, -0.77254903, -0.77254903],
[-0.74901962, -0.74901962, -0.74901962, ..., -0.76470590, -0.76470590, -0.76470590],
...,
[-0.07450981, -0.07450981, -0.07450981, ..., -0.20784314, -0.20784314, -0.20784314],
[-0.05882353, -0.05882353, -0.05882353, ..., -0.23921569, -0.23921569, -0.23921569],
[-0.05882353, -0.05882353, -0.05882353, ..., -0.24705882, -0.23921569, -0.23921569]]],
[[[-0.38039216, -0.38039216, -0.38823530, ..., -0.39607844, -0.39607844, -0.39607844],
[-0.38039216, -0.38039216, -0.38823530, ..., -0.39607844, -0.39607844, -0.39607844],
[-0.37254903, -0.37254903, -0.38823530, ..., -0.38823530, -0.38823530, -0.38823530],
...,
[ 0.36470589, 0.36470589, 0.37254903, ..., 0.48235294, 0.49803922, 0.49803922],
[ 0.36470589, 0.36470589, 0.38039216, ..., 0.49019608, 0.49803922, 0.49803922],
[ 0.36470589, 0.36470589, 0.38039216, ..., 0.49019608, 0.49803922, 0.49803922]],
[[-0.35686275, -0.35686275, -0.36470589, ..., -0.39607844, -0.39607844, -0.39607844],
[-0.35686275, -0.35686275, -0.36470589, ..., -0.39607844, -0.39607844, -0.39607844],
[-0.34901962, -0.34901962, -0.36470589, ..., -0.38823530, -0.38039216, -0.38039216],
...,
[ 0.17647059, 0.17647059, 0.19215687, ..., 0.28627452, 0.30196080, 0.30196080],
[ 0.17647059, 0.17647059, 0.19215687, ..., 0.29411766, 0.30980393, 0.30980393],
[ 0.17647059, 0.17647059, 0.19215687, ..., 0.29411766, 0.30980393, 0.30980393]],
[[-0.42745098, -0.42745098, -0.43529412, ..., -0.45882353, -0.45882353, -0.45882353],
[-0.42745098, -0.42745098, -0.43529412, ..., -0.45098040, -0.45882353, -0.45882353],
[-0.42745098, -0.42745098, -0.43529412, ..., -0.44313726, -0.44313726, -0.44313726],
...,
[-0.07450981, -0.07450981, -0.05882353, ..., 0.01176471, 0.02745098, 0.02745098],
[-0.07450981, -0.07450981, -0.05882353, ..., 0.01960784, 0.03529412, 0.03529412],
[-0.07450981, -0.07450981, -0.05882353, ..., 0.01176471, 0.03529412, 0.03529412]]],
[[[ 0.42745098, 0.42745098, 0.41960785, ..., 0.45882353, 0.45098040, 0.45098040],
[ 0.42745098, 0.42745098, 0.41960785, ..., 0.45882353, 0.45098040, 0.45098040],
[ 0.41960785, 0.41960785, 0.41176471, ..., 0.45098040, 0.44313726, 0.44313726],
...,
[ 0.23137255, 0.23137255, 0.21568628, ..., 0.20000000, 0.19215687, 0.19215687],
[ 0.23137255, 0.23137255, 0.21568628, ..., 0.19215687, 0.18431373, 0.18431373],
[ 0.23137255, 0.23137255, 0.21568628, ..., 0.19215687, 0.18431373, 0.18431373]],
[[ 0.02745098, 0.02745098, 0.01960784, ..., 0.07450981, 0.06666667, 0.06666667],
[ 0.02745098, 0.02745098, 0.01960784, ..., 0.07450981, 0.06666667, 0.06666667],
[ 0.02745098, 0.02745098, 0.01960784, ..., 0.06666667, 0.05882353, 0.05882353],
...,
[-0.11372549, -0.11372549, -0.12156863, ..., -0.16862746, -0.16862746, -0.16862746],
[-0.10588235, -0.10588235, -0.12156863, ..., -0.16862746, -0.17647059, -0.17647059],
[-0.10588235, -0.10588235, -0.12156863, ..., -0.17647059, -0.17647059, -0.17647059]],
[[-0.32549021, -0.32549021, -0.33333334, ..., -0.24705882, -0.25490198, -0.25490198],
[-0.32549021, -0.32549021, -0.33333334, ..., -0.23921569, -0.25490198, -0.25490198],
[-0.32549021, -0.32549021, -0.32549021, ..., -0.25490198, -0.26274511, -0.26274511],
...,
[-0.38823530, -0.38823530, -0.40392157, ..., -0.45098040, -0.45098040, -0.45098040],
[-0.38823530, -0.38823530, -0.40392157, ..., -0.45098040, -0.45882353, -0.45882353],
[-0.38823530, -0.38823530, -0.40392157, ..., -0.45882353, -0.45882353, -0.45882353]]],
...,
[[[-0.60000002, -0.60000002, -0.60784316, ..., -0.60784316, -0.66274512, -0.66274512],
[-0.60000002, -0.60000002, -0.60784316, ..., -0.60784316, -0.66274512, -0.66274512],
[-0.58431375, -0.58431375, -0.59215689, ..., -0.60784316, -0.64705884, -0.64705884],
...,
[ 0.51372552, 0.51372552, 0.49803922, ..., 0.30196080, 0.30980393, 0.30980393],
[ 0.52156866, 0.52156866, 0.49803922, ..., 0.32549021, 0.32549021, 0.32549021],
[ 0.52156866, 0.52156866, 0.49803922, ..., 0.31764707, 0.32549021, 0.32549021]],
[[-0.34117648, -0.34117648, -0.34117648, ..., -0.35686275, -0.38039216, -0.38039216],
[-0.34117648, -0.34117648, -0.34117648, ..., -0.35686275, -0.38039216, -0.38039216],
[-0.34117648, -0.34117648, -0.34117648, ..., -0.36470589, -0.37254903, -0.37254903],
...,
[ 0.83529413, 0.83529413, 0.81960785, ..., 0.72549021, 0.71764708, 0.71764708],
[ 0.85098040, 0.85098040, 0.82745099, ..., 0.73333335, 0.72549021, 0.72549021],
[ 0.85098040, 0.85098040, 0.82745099, ..., 0.73333335, 0.72549021, 0.72549021]],
[[-0.34117648, -0.34117648, -0.34901962, ..., -0.40392157, -0.45098040, -0.45098040],
[-0.34117648, -0.34117648, -0.34901962, ..., -0.40392157, -0.45098040, -0.45098040],
[-0.34117648, -0.34117648, -0.34117648, ..., -0.40392157, -0.44313726, -0.44313726],
...,
[ 0.09803922, 0.09803922, 0.08235294, ..., -0.14509805, -0.15294118, -0.15294118],
[ 0.10588235, 0.10588235, 0.08235294, ..., -0.13725491, -0.14509805, -0.14509805],
[ 0.10588235, 0.10588235, 0.08235294, ..., -0.13725491, -0.14509805, -0.14509805]]],
[[[-0.27058825, -0.27058825, -0.25490198, ..., -0.19215687, -0.17647059, -0.17647059],
[-0.27058825, -0.27058825, -0.25490198, ..., -0.19215687, -0.17647059, -0.17647059],
[-0.29411766, -0.29411766, -0.26274511, ..., -0.19215687, -0.19215687, -0.19215687],
...,
[-0.66274512, -0.66274512, -0.59215689, ..., -0.73333335, -0.74117649, -0.74117649],
[-0.67058825, -0.67058825, -0.60784316, ..., -0.74117649, -0.74901962, -0.74901962],
[-0.67058825, -0.67058825, -0.60784316, ..., -0.74117649, -0.74901962, -0.74901962]],
[[-0.15294118, -0.15294118, -0.13725491, ..., -0.09019608, -0.07450981, -0.07450981],
[-0.15294118, -0.15294118, -0.13725491, ..., -0.09019608, -0.07450981, -0.07450981],
[-0.17647059, -0.17647059, -0.14509805, ..., -0.07450981, -0.07450981, -0.07450981],
...,
[-0.61568630, -0.61568630, -0.60784316, ..., -0.96862745, -0.97647059, -0.97647059],
[-0.61568630, -0.61568630, -0.60784316, ..., -0.97647059, -0.97647059, -0.97647059],
[-0.61568630, -0.61568630, -0.60784316, ..., -0.97647059, -0.97647059, -0.97647059]],
[[ 0.12941177, 0.12941177, 0.14509805, ..., 0.08235294, 0.09803922, 0.09803922],
[ 0.12941177, 0.12941177, 0.14509805, ..., 0.08235294, 0.09803922, 0.09803922],
[ 0.09019608, 0.09019608, 0.12156863, ..., 0.12156863, 0.12156863, 0.12156863],
...,
[-0.53725493, -0.53725493, -0.52156866, ..., -0.95294118, -0.96078432, -0.96078432],
[-0.54509807, -0.54509807, -0.52941179, ..., -0.96078432, -0.96078432, -0.96078432],
[-0.54509807, -0.54509807, -0.52941179, ..., -0.96078432, -0.96078432, -0.96078432]]],
[[[-0.97647059, -0.97647059, -0.97647059, ..., -0.76470590, -0.71764708, -0.71764708],
[-0.97647059, -0.97647059, -0.97647059, ..., -0.76470590, -0.71764708, -0.71764708],
[-0.98431373, -0.98431373, -0.98431373, ..., -0.76470590, -0.71764708, -0.71764708],
...,
[-0.52941179, -0.52941179, -0.51372552, ..., -0.34901962, -0.36470589, -0.36470589],
[-0.50588238, -0.50588238, -0.50588238, ..., -0.46666667, -0.49019608, -0.49019608],
[-0.50588238, -0.50588238, -0.50588238, ..., -0.46666667, -0.49019608, -0.49019608]],
[[-0.97647059, -0.97647059, -0.97647059, ..., -0.82745099, -0.78823531, -0.78823531],
[-0.97647059, -0.97647059, -0.97647059, ..., -0.82745099, -0.78823531, -0.78823531],
[-0.98431373, -0.98431373, -0.98431373, ..., -0.82745099, -0.78039217, -0.78039217],
...,
[-0.28627452, -0.28627452, -0.26274511, ..., -0.34117648, -0.37254903, -0.37254903],
[-0.26274511, -0.26274511, -0.25490198, ..., -0.41960785, -0.45098040, -0.45098040],
[-0.26274511, -0.26274511, -0.25490198, ..., -0.41960785, -0.45098040, -0.45098040]],
[[-0.97647059, -0.97647059, -0.97647059, ..., -0.85882354, -0.85098040, -0.85098040],
[-0.97647059, -0.97647059, -0.96862745, ..., -0.85882354, -0.85098040, -0.85098040],
[-0.98431373, -0.98431373, -0.97647059, ..., -0.85882354, -0.84313726, -0.84313726],
...,
[ 0.12156863, 0.12156863, 0.16078432, ..., -0.31764707, -0.34901962, -0.34901962],
[ 0.14509805, 0.14509805, 0.16078432, ..., -0.33333334, -0.37254903, -0.37254903],
[ 0.14509805, 0.14509805, 0.16078432, ..., -0.33333334, -0.37254903, -0.37254903]]]]), Tensor(shape=[256], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,
[2, 4, 3, 3, 4, 7, 2, 8, 2, 5, 8, 1, 4, 7, 3, 0, 6, 1, 2, 2, 8, 9, 8, 1, 6, 5, 5, 3, 5, 5, 1, 2, 5, 6, 2, 1, 9, 6, 2, 7, 6, 0, 2, 7, 6, 0, 2, 2, 3, 2, 8, 0, 7, 3, 1, 2, 1, 8, 5, 6, 1, 3, 9, 7, 4, 7, 9, 9, 6, 1, 2, 1, 0, 0, 4, 9, 6, 1, 4, 7, 4, 0, 5, 2, 1, 9, 2, 0, 8, 4, 5, 7, 4, 3, 2, 0, 7, 9, 7, 8, 4, 5, 9, 9, 5, 0, 2, 6, 1, 7, 1, 8, 9, 1, 1, 7, 2, 1, 5, 8, 0, 8, 4, 1, 4, 9, 3, 9, 1, 3, 6, 0, 2, 2, 5, 0, 2, 8, 7, 2, 4, 5, 2, 5, 4, 5, 6, 6, 7, 7, 0, 5, 8, 9, 4, 5, 6, 2, 3, 0, 8, 3, 5, 9, 7, 5, 3, 1, 1, 3, 5, 4, 8, 3, 2, 2, 0, 9, 6, 8, 7, 1, 3, 6, 9, 9, 4, 6, 3, 8, 7, 3, 2, 1, 4, 0, 1, 4, 0, 2, 0, 4, 2, 1, 8, 1, 6, 4, 3, 7, 1, 8, 1, 9, 9, 2, 0, 4, 4, 6, 8, 2, 8, 4, 1, 4, 4, 0, 6, 5, 2, 9, 6, 9, 9, 6, 9, 5, 1, 8, 5, 1, 3, 3, 7, 8, 9, 5, 2, 1, 1, 5, 3, 7, 0, 3])]
def train(model):
#启动训练模式
model.train()
opt = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters())
train_loader = paddle.io.DataLoader(cifar10_train, shuffle=True, batch_size=batch_size)
valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader()):
x_data = paddle.cast(data[0], 'float32')
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1))
y_predict = model(x_data)
loss = F.cross_entropy(y_predict, y_data)
loss.backward()
opt.step()
opt.clear_grad()
print("训练轮次: {}; 损失: {}".format(epoch, loss.numpy()))
#启动评估模式
model.eval()
accuracies = []
losses = []
for batch_id, data in enumerate(valid_loader()):
x_data = paddle.cast(data[0], 'float32')
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1))
y_predict = model(x_data)
loss = F.cross_entropy(y_predict, y_data)
acc = paddle.metric.accuracy(y_predict, y_data)
accuracies.append(np.mean(acc.numpy()))
losses.append(np.mean(loss.numpy()))
avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("评估准确度为:{};损失为:{}".format(avg_acc, avg_loss))
val_acc_history.append(avg_acc)
val_loss_history.append(avg_loss)
model.train()
model = GoogLeNet(3, 10)
train(model)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:636: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
训练轮次: 0; 损失: [1.4179909]
评估准确度为:0.4229492247104645;损失为:1.5987226963043213
训练轮次: 1; 损失: [1.2787999]
评估准确度为:0.5985351800918579;损失为:1.0844180583953857
案例三:残差神经网络模型
项目传送门
残差网络(ResNet)模型是由何凯明开发,它是2015年ImageNet ILSVRC-2015分类挑战赛的冠军模型。ResNet模型引入残差模块,它能够有效地消除由于模型层数增加而导致的梯度弥散或梯度爆炸问题。
import paddle
import paddle.nn.functional as F
import numpy as np
from paddle.vision.transforms import Compose, Resize, Transpose, Normalize
#准备数据
t = Compose([Resize(size=96),Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], data_format='HWC'),Transpose()]) #数据转换
cifar10_train = paddle.vision.datasets.cifar.Cifar10(mode='train', transform=t, backend='cv2')
cifar10_test = paddle.vision.datasets.cifar.Cifar10(mode="test", transform=t, backend='cv2')
#构建模型
class Residual(paddle.nn.Layer):
def __init__(self, in_channel, out_channel, use_conv1x1=False, stride=1):
super(Residual, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channel, out_channel, kernel_size=3, padding=1, stride=stride)
self.conv2 = paddle.nn.Conv2D(out_channel, out_channel, kernel_size=3, padding=1)
if use_conv1x1: #使用1x1卷积核
self.conv3 = paddle.nn.Conv2D(in_channel, out_channel, kernel_size=1, stride=stride)
else:
self.conv3 = None
self.batchNorm1 = paddle.nn.BatchNorm2D(out_channel)
self.batchNorm2 = paddle.nn.BatchNorm2D(out_channel)
def forward(self, x):
y = F.relu(self.batchNorm1(self.conv1(x)))
y = self.batchNorm2(self.conv2(y))
if self.conv3:
x = self.conv3(x)
out = F.relu(y+x) #核心代码
return out
def ResNetBlock(in_channel, out_channel, num_layers, is_first=False):
if is_first:
assert in_channel == out_channel
block_list = []
for i in range(num_layers):
if i == 0 and not is_first:
block_list.append(Residual(in_channel, out_channel, use_conv1x1=True, stride=2))
else:
block_list.append(Residual(out_channel, out_channel))
resNetBlock = paddle.nn.Sequential(*block_list) #用*号可以把list列表展开为元素
return resNetBlock
class ResNetModel(paddle.nn.Layer):
def __init__(self):
super(ResNetModel, self).__init__()
self.b1 = paddle.nn.Sequential(
paddle.nn.Conv2D(3, 64, kernel_size=7, stride=2, padding=3),
paddle.nn.BatchNorm2D(64),
paddle.nn.ReLU(),
paddle.nn.MaxPool2D(kernel_size=3, stride=2, padding=1))
self.b2 = ResNetBlock(64, 64, 2, is_first=True)
self.b3 = ResNetBlock(64, 128, 2)
self.b4 = ResNetBlock(128, 256, 2)
self.b5 = ResNetBlock(256, 512, 2)
self.AvgPool = paddle.nn.AvgPool2D(2)
self.flatten = paddle.nn.Flatten()
self.Linear = paddle.nn.Linear(512, 10)
def forward(self, x):
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
x = self.b4(x)
x = self.b5(x)
x = self.AvgPool(x)
x = self.flatten(x)
x = self.Linear(x)
return x
model = ResNetModel()
paddle.summary(model,(256, 3, 96, 96))
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Conv2D-301 [[256, 3, 96, 96]] [256, 64, 48, 48] 9,472
BatchNorm2D-16 [[256, 64, 48, 48]] [256, 64, 48, 48] 256
ReLU-37 [[256, 64, 48, 48]] [256, 64, 48, 48] 0
MaxPool2D-75 [[256, 64, 48, 48]] [256, 64, 24, 24] 0
Conv2D-302 [[256, 64, 24, 24]] [256, 64, 24, 24] 36,928
BatchNorm2D-17 [[256, 64, 24, 24]] [256, 64, 24, 24] 256
Conv2D-303 [[256, 64, 24, 24]] [256, 64, 24, 24] 36,928
BatchNorm2D-18 [[256, 64, 24, 24]] [256, 64, 24, 24] 256
Residual-1 [[256, 64, 24, 24]] [256, 64, 24, 24] 0
Conv2D-304 [[256, 64, 24, 24]] [256, 64, 24, 24] 36,928
BatchNorm2D-19 [[256, 64, 24, 24]] [256, 64, 24, 24] 256
Conv2D-305 [[256, 64, 24, 24]] [256, 64, 24, 24] 36,928
BatchNorm2D-20 [[256, 64, 24, 24]] [256, 64, 24, 24] 256
Residual-2 [[256, 64, 24, 24]] [256, 64, 24, 24] 0
Conv2D-306 [[256, 64, 24, 24]] [256, 128, 12, 12] 73,856
BatchNorm2D-21 [[256, 128, 12, 12]] [256, 128, 12, 12] 512
Conv2D-307 [[256, 128, 12, 12]] [256, 128, 12, 12] 147,584
BatchNorm2D-22 [[256, 128, 12, 12]] [256, 128, 12, 12] 512
Conv2D-308 [[256, 64, 24, 24]] [256, 128, 12, 12] 8,320
Residual-3 [[256, 64, 24, 24]] [256, 128, 12, 12] 0
Conv2D-309 [[256, 128, 12, 12]] [256, 128, 12, 12] 147,584
BatchNorm2D-23 [[256, 128, 12, 12]] [256, 128, 12, 12] 512
Conv2D-310 [[256, 128, 12, 12]] [256, 128, 12, 12] 147,584
BatchNorm2D-24 [[256, 128, 12, 12]] [256, 128, 12, 12] 512
Residual-4 [[256, 128, 12, 12]] [256, 128, 12, 12] 0
Conv2D-311 [[256, 128, 12, 12]] [256, 256, 6, 6] 295,168
BatchNorm2D-25 [[256, 256, 6, 6]] [256, 256, 6, 6] 1,024
Conv2D-312 [[256, 256, 6, 6]] [256, 256, 6, 6] 590,080
BatchNorm2D-26 [[256, 256, 6, 6]] [256, 256, 6, 6] 1,024
Conv2D-313 [[256, 128, 12, 12]] [256, 256, 6, 6] 33,024
Residual-5 [[256, 128, 12, 12]] [256, 256, 6, 6] 0
Conv2D-314 [[256, 256, 6, 6]] [256, 256, 6, 6] 590,080
BatchNorm2D-27 [[256, 256, 6, 6]] [256, 256, 6, 6] 1,024
Conv2D-315 [[256, 256, 6, 6]] [256, 256, 6, 6] 590,080
BatchNorm2D-28 [[256, 256, 6, 6]] [256, 256, 6, 6] 1,024
Residual-6 [[256, 256, 6, 6]] [256, 256, 6, 6] 0
Conv2D-316 [[256, 256, 6, 6]] [256, 512, 3, 3] 1,180,160
BatchNorm2D-29 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,048
Conv2D-317 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,359,808
BatchNorm2D-30 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,048
Conv2D-318 [[256, 256, 6, 6]] [256, 512, 3, 3] 131,584
Residual-7 [[256, 256, 6, 6]] [256, 512, 3, 3] 0
Conv2D-319 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,359,808
BatchNorm2D-31 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,048
Conv2D-320 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,359,808
BatchNorm2D-32 [[256, 512, 3, 3]] [256, 512, 3, 3] 2,048
Residual-8 [[256, 512, 3, 3]] [256, 512, 3, 3] 0
AvgPool2D-6 [[256, 512, 3, 3]] [256, 512, 1, 1] 0
Flatten-2430 [[256, 512, 1, 1]] [256, 512] 0
Linear-15 [[256, 512]] [256, 10] 5,130
===========================================================================
Total params: 11,192,458
Trainable params: 11,176,842
Non-trainable params: 15,616
---------------------------------------------------------------------------
Input size (MB): 27.00
Forward/backward pass size (MB): 2351.02
Params size (MB): 42.70
Estimated Total Size (MB): 2420.72
---------------------------------------------------------------------------
{'total_params': 11192458, 'trainable_params': 11176842}
epoch_num = 2 # 学习轮数
batch_size = 512 # 批次学习大小
learning_rate = 0.001 # 学习率
val_acc_history = []
val_loss_history = []
train_loader = paddle.io.DataLoader(cifar10_train, shuffle=True, batch_size=batch_size)
for i in train_loader:
print(i)
break
[Tensor(shape=[256, 3, 96, 96], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,
[[[[ 0.47450981, 0.47450981, 0.26274511, ..., 0.74901962, 0.74117649, 0.74117649],
[ 0.47450981, 0.47450981, 0.26274511, ..., 0.74901962, 0.74117649, 0.74117649],
[ 0.35686275, 0.35686275, 0.12156863, ..., 0.75686276, 0.74901962, 0.74901962],
...,
[ 0.78823531, 0.78823531, 0.79607844, ..., 0.78823531, 0.79607844, 0.79607844],
[ 0.79607844, 0.79607844, 0.79607844, ..., 0.79607844, 0.80392158, 0.80392158],
[ 0.79607844, 0.79607844, 0.79607844, ..., 0.79607844, 0.80392158, 0.80392158]],
[[ 0.46666667, 0.46666667, 0.25490198, ..., 0.77254903, 0.77254903, 0.77254903],
[ 0.46666667, 0.46666667, 0.25490198, ..., 0.77254903, 0.77254903, 0.77254903],
[ 0.34117648, 0.34117648, 0.10588235, ..., 0.78823531, 0.78039217, 0.78039217],
...,
[ 0.81960785, 0.81960785, 0.82745099, ..., 0.83529413, 0.83529413, 0.83529413],
[ 0.82745099, 0.82745099, 0.83529413, ..., 0.84313726, 0.84313726, 0.84313726],
[ 0.82745099, 0.82745099, 0.82745099, ..., 0.84313726, 0.84313726, 0.84313726]],
[[ 0.42745098, 0.42745098, 0.20784314, ..., 0.80392158, 0.80392158, 0.80392158],
[ 0.42745098, 0.42745098, 0.20784314, ..., 0.80392158, 0.80392158, 0.80392158],
[ 0.29411766, 0.29411766, 0.05098039, ..., 0.81176472, 0.81176472, 0.81176472],
...,
[ 0.82745099, 0.82745099, 0.84313726, ..., 0.85882354, 0.85098040, 0.85098040],
[ 0.83529413, 0.83529413, 0.85098040, ..., 0.85882354, 0.85882354, 0.85882354],
[ 0.83529413, 0.83529413, 0.85098040, ..., 0.85882354, 0.85882354, 0.85882354]]],
[[[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
...,
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ]],
[[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
...,
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ]],
[[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
...,
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ],
[ 1. , 1. , 1. , ..., 1. , 1. , 1. ]]],
[[[ 0.10588235, 0.10588235, 0.06666667, ..., -0.06666667, -0.07450981, -0.07450981],
[ 0.10588235, 0.10588235, 0.06666667, ..., -0.06666667, -0.07450981, -0.07450981],
[ 0.12941177, 0.12941177, 0.08235294, ..., -0.08235294, -0.11372549, -0.11372549],
...,
[ 0.26274511, 0.26274511, 0.23921569, ..., -0.39607844, -0.41960785, -0.41960785],
[ 0.23137255, 0.23137255, 0.23921569, ..., -0.33333334, -0.31764707, -0.31764707],
[ 0.23137255, 0.23137255, 0.23921569, ..., -0.33333334, -0.31764707, -0.31764707]],
[[ 0.10588235, 0.10588235, 0.06666667, ..., -0.09019608, -0.10588235, -0.10588235],
[ 0.10588235, 0.10588235, 0.06666667, ..., -0.09019608, -0.10588235, -0.10588235],
[ 0.12156863, 0.12156863, 0.07450981, ..., -0.11372549, -0.15294118, -0.15294118],
...,
[ 0.27843139, 0.27843139, 0.23921569, ..., -0.45098040, -0.46666667, -0.46666667],
[ 0.25490198, 0.25490198, 0.24705882, ..., -0.39607844, -0.37254903, -0.37254903],
[ 0.25490198, 0.25490198, 0.24705882, ..., -0.39607844, -0.37254903, -0.37254903]],
[[ 0.09803922, 0.09803922, 0.05882353, ..., -0.12156863, -0.12941177, -0.12941177],
[ 0.09803922, 0.09803922, 0.06666667, ..., -0.12156863, -0.12941177, -0.12941177],
[ 0.12941177, 0.12941177, 0.07450981, ..., -0.13725491, -0.17647059, -0.17647059],
...,
[ 0.27058825, 0.27058825, 0.24705882, ..., -0.43529412, -0.45098040, -0.45098040],
[ 0.24705882, 0.24705882, 0.25490198, ..., -0.38039216, -0.35686275, -0.35686275],
[ 0.24705882, 0.24705882, 0.25490198, ..., -0.38039216, -0.35686275, -0.35686275]]],
...,
[[[-0.29411766, -0.29411766, -0.26274511, ..., 0.98431373, 0.98431373, 0.98431373],
[-0.29411766, -0.29411766, -0.26274511, ..., 0.98431373, 0.98431373, 0.98431373],
[-0.16078432, -0.16078432, -0.19215687, ..., 0.99215686, 0.99215686, 0.99215686],
...,
[ 0.07450981, 0.07450981, 0.10588235, ..., 0.54509807, 0.54509807, 0.54509807],
[ 0.02745098, 0.02745098, 0.05882353, ..., 0.50588238, 0.49803922, 0.49803922],
[ 0.02745098, 0.02745098, 0.05882353, ..., 0.50588238, 0.49803922, 0.49803922]],
[[-0.23921569, -0.23921569, -0.23137255, ..., 0.97647059, 0.97647059, 0.97647059],
[-0.23921569, -0.23921569, -0.23137255, ..., 0.97647059, 0.97647059, 0.97647059],
[-0.10588235, -0.10588235, -0.16078432, ..., 0.98431373, 0.98431373, 0.98431373],
...,
[ 0.02745098, 0.02745098, 0.05098039, ..., 0.60784316, 0.60784316, 0.60784316],
[-0.01960784, -0.01960784, 0.00392157, ..., 0.56862748, 0.56078434, 0.56078434],
[-0.01960784, -0.01960784, 0.00392157, ..., 0.56862748, 0.56078434, 0.56078434]],
[[-0.23137255, -0.23137255, -0.23921569, ..., 0.98431373, 0.98431373, 0.98431373],
[-0.23137255, -0.23137255, -0.23921569, ..., 0.98431373, 0.98431373, 0.98431373],
[-0.10588235, -0.10588235, -0.18431373, ..., 0.99215686, 0.99215686, 0.99215686],
...,
[-0.10588235, -0.10588235, -0.09019608, ..., 0.63137257, 0.63137257, 0.63137257],
[-0.15294118, -0.15294118, -0.12941177, ..., 0.59215689, 0.58431375, 0.58431375],
[-0.15294118, -0.15294118, -0.12941177, ..., 0.59215689, 0.58431375, 0.58431375]]],
[[[ 0.05098039, 0.05098039, 0.08235294, ..., -0.27058825, -0.31764707, -0.31764707],
[ 0.05098039, 0.05098039, 0.08235294, ..., -0.27058825, -0.31764707, -0.31764707],
[ 0.07450981, 0.07450981, 0.12156863, ..., -0.29411766, -0.35686275, -0.35686275],
...,
[-0.61568630, -0.61568630, -0.60784316, ..., 0.00392157, 0.12156863, 0.12156863],
[-0.63137257, -0.63137257, -0.62352943, ..., 0.04313726, 0.16078432, 0.16078432],
[-0.63137257, -0.63137257, -0.62352943, ..., 0.04313726, 0.16078432, 0.16078432]],
[[-0.07450981, -0.07450981, -0.03529412, ..., -0.32549021, -0.34901962, -0.34901962],
[-0.07450981, -0.07450981, -0.03529412, ..., -0.32549021, -0.34901962, -0.34901962],
[-0.04313726, -0.04313726, 0.00392157, ..., -0.33333334, -0.38039216, -0.38039216],
...,
[-0.65490198, -0.65490198, -0.63921571, ..., -0.01960784, 0.07450981, 0.07450981],
[-0.67058825, -0.67058825, -0.66274512, ..., 0.01176471, 0.10588235, 0.10588235],
[-0.67058825, -0.67058825, -0.67058825, ..., 0.00392157, 0.10588235, 0.10588235]],
[[-0.01960784, -0.01960784, 0.01176471, ..., -0.31764707, -0.31764707, -0.31764707],
[-0.01960784, -0.01960784, 0.01176471, ..., -0.30980393, -0.31764707, -0.31764707],
[ 0.01960784, 0.01960784, 0.05882353, ..., -0.31764707, -0.34901962, -0.34901962],
...,
[-0.68627453, -0.68627453, -0.67843139, ..., -0.02745098, 0.08235294, 0.08235294],
[-0.70196080, -0.70196080, -0.69411767, ..., 0.01176471, 0.12156863, 0.12156863],
[-0.70196080, -0.70196080, -0.69411767, ..., 0.01176471, 0.12156863, 0.12156863]]],
[[[ 0.71764708, 0.71764708, 0.62352943, ..., 0.38823530, 0.52156866, 0.52156866],
[ 0.71764708, 0.71764708, 0.62352943, ..., 0.38823530, 0.52156866, 0.52156866],
[ 0.70980394, 0.70980394, 0.62352943, ..., 0.37254903, 0.48235294, 0.48235294],
...,
[ 0.82745099, 0.82745099, 0.76470590, ..., 0.49803922, 0.57647061, 0.57647061],
[ 0.86666667, 0.86666667, 0.79607844, ..., 0.57647061, 0.65490198, 0.65490198],
[ 0.86666667, 0.86666667, 0.79607844, ..., 0.56862748, 0.65490198, 0.65490198]],
[[ 0.02745098, 0.02745098, 0.05098039, ..., -0.27843139, -0.27058825, -0.27058825],
[ 0.02745098, 0.02745098, 0.05098039, ..., -0.27843139, -0.27058825, -0.27058825],
[ 0.03529412, 0.03529412, 0.06666667, ..., -0.26274511, -0.29411766, -0.29411766],
...,
[-0.08235294, -0.08235294, -0.05098039, ..., -0.30980393, -0.33333334, -0.33333334],
[-0.15294118, -0.15294118, -0.14509805, ..., -0.33333334, -0.32549021, -0.32549021],
[-0.15294118, -0.15294118, -0.15294118, ..., -0.33333334, -0.32549021, -0.32549021]],
[[-0.45098040, -0.45098040, -0.34901962, ..., -0.71764708, -0.79607844, -0.79607844],
[-0.45098040, -0.45098040, -0.34901962, ..., -0.71764708, -0.79607844, -0.79607844],
[-0.42745098, -0.42745098, -0.31764707, ..., -0.68627453, -0.79607844, -0.79607844],
...,
[-0.68627453, -0.68627453, -0.59215689, ..., -0.80392158, -0.89803922, -0.89803922],
[-0.82745099, -0.82745099, -0.76470590, ..., -0.89019608, -0.94509804, -0.94509804],
[-0.82745099, -0.82745099, -0.76470590, ..., -0.89019608, -0.94509804, -0.94509804]]]]), Tensor(shape=[256], dtype=int64, place=CUDAPinnedPlace, stop_gradient=True,
[2, 9, 6, 3, 3, 1, 2, 5, 8, 2, 8, 8, 4, 9, 7, 5, 5, 5, 8, 8, 1, 8, 7, 2, 0, 1, 0, 5, 4, 3, 8, 2, 1, 3, 9, 5, 4, 6, 9, 6, 7, 3, 5, 4, 3, 9, 1, 5, 8, 5, 2, 0, 4, 9, 4, 1, 6, 0, 3, 7, 2, 1, 9, 7, 4, 6, 2, 1, 8, 2, 5, 7, 4, 4, 0, 3, 9, 0, 3, 9, 0, 2, 8, 8, 6, 8, 0, 5, 8, 4, 2, 3, 3, 7, 9, 7, 7, 5, 7, 9, 8, 5, 2, 0, 7, 6, 0, 9, 5, 1, 6, 3, 7, 9, 4, 7, 7, 5, 7, 2, 1, 9, 8, 7, 8, 1, 9, 0, 4, 6, 1, 1, 6, 4, 2, 7, 3, 4, 3, 4, 7, 9, 7, 7, 9, 3, 5, 6, 7, 1, 1, 2, 3, 6, 0, 3, 4, 3, 1, 3, 1, 9, 3, 8, 2, 0, 7, 1, 7, 2, 1, 2, 0, 3, 5, 5, 0, 1, 4, 2, 9, 0, 2, 0, 2, 3, 9, 2, 4, 9, 8, 2, 0, 2, 9, 7, 3, 5, 8, 6, 4, 8, 4, 1, 6, 3, 9, 1, 5, 3, 4, 6, 4, 1, 9, 1, 2, 1, 6, 3, 5, 9, 1, 5, 7, 2, 7, 1, 3, 1, 3, 7, 3, 0, 7, 4, 0, 9, 5, 7, 8, 8, 7, 4, 7, 5, 2, 5, 5, 8, 9, 0, 3, 1, 6, 5])]
def train(model):
#开启训练模式
model.train()
#优化器
opt = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=model.parameters())
#数据小批量加载器
train_loader = paddle.io.DataLoader(cifar10_train, shuffle=True, batch_size=batch_size)
valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader()):
x_data = paddle.cast(data[0], 'float32')
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1))
y_predict = model(x_data)
loss = F.cross_entropy(y_predict, y_data)
loss.backward()
opt.step()
opt.clear_grad()
print("训练轮次: {}; 损失: {}".format(epoch, loss.numpy()))
#启动评估模式
model.eval()
accuracies = []
losses = []
for batch_id, data in enumerate(valid_loader()):
x_data = paddle.cast(data[0], 'float32')
y_data = paddle.cast(data[1], 'int64')
y_data = paddle.reshape(y_data, (-1, 1))
y_predict = model(x_data)
loss = F.cross_entropy(y_predict, y_data)
acc = paddle.metric.accuracy(y_predict, y_data)
accuracies.append(np.mean(acc.numpy()))
losses.append(np.mean(loss.numpy()))
avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("评估准确度为:{};损失为:{}".format(avg_acc, avg_loss))
val_acc_history.append(avg_acc)
val_loss_history.append(avg_loss)
model.train()
model = ResNetModel()
train(model)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:636: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
训练轮次: 0; 损失: [1.2771044]
评估准确度为:0.503693699836731;损失为:1.350608229637146
训练轮次: 1; 损失: [0.9957652]
评估准确度为:0.5834788680076599;损失为:1.1765481233596802
作者简介
作者:三岁
经历:自学python,现在混迹于paddle社区,希望和大家一起从基础走起,一起学习Paddle
csdn地址:blog.csdn.net/weixin_4562…
我在AI Studio上获得塑料等级,点亮-7个徽章,来互关呀~ aistudio.baidu.com/aistudio/pe…
传说中的飞桨社区最差代码人,让我们一起努力!
记住:三岁出品必是精品 (不要脸系列)