本文已参与「新人创作礼」活动,一起开启掘金创作之路。
1.手写数字显示
- 导入数据包
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
- 加载数据集
# load minst
# E:\PyCharm\workspace\GAN\data
root = 'E:/PyCharm/workspace/GAN/data/'
transform = transforms.Compose([
transforms.ToTensor()
])
dataset = datasets.MNIST(root=root, train=False, download=False, transform=transform)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
- 从其中选出batch_size张手写数字图片
for i,(real_image, real_image_label) in enumerate(loader):
break
- 显示图片
real_image = torchvision.utils.make_grid(real_image)
real_image = real_image.numpy()
plt.imshow(np.transpose(real_image,(1,2,0)))
<matplotlib.image.AxesImage at 0x1e2762ddcc8>
2.手写数字分类
1)分类模型训练
import torch
import torchvision
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import optim
# 设置超参数
batch_size = 128
# learning_rate = 0.01
learning_rate = 1e-3
epochsize = 30
# 定义LeNet5 结构
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
# 卷积层
self.conv_layer = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0), # torch.Size([128, 6, 24, 24])
nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # torch.Size([128, 6, 12, 12])
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0), # torch.Size([128, 16, 8, 8])
nn.MaxPool2d(kernel_size=2, stride=2, padding=0) # torch.Size([128, 16, 4, 4])
)
# 全连接层
self.fullconn_layer = nn.Sequential(
nn.Linear(16 * 4 * 4, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
def forward(self, x):
output = self.conv_layer(x) # output:torch.Size([batch_size, 16, 5, 5])
output = output.view(x.size(0), -1) # output:torch.Size([batch_size, 16 * 4 * 4])
output = self.fullconn_layer(output) # torch.Size([batch_size, 10])
return output
# 训练集下载
mnist_traindata = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307],
std=[0.1307])
]), download=True)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True)
# 测试集下载
mnist_testdata = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307],
std=[0.1307])
]), download=True)
mnist_test = DataLoader(mnist_testdata, batch_size=batch_size, shuffle=True)
# 查看相关参数
real_image, label = iter(mnist_train).next()
print('real_image:', real_image.shape, 'label:', label.shape)
# 利用GPU加速
device = torch.device('cuda')
# 定义模型
model = LeNet5().to(device)
# 一般来说,分类任务使用CrossEntropyLoss;回归任务使用MSELoss
criteon = nn.CrossEntropyLoss().to(device)
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 进行迭代训练
for epoch in range(epochsize):
# 训练过程
model.train()
# total_num = 0
for batchidx, (image, imagelabel) in enumerate(mnist_train):
# image.shape:[batch_size, 3, 32, 32]
image, imagelabel = image.to(device), imagelabel.to(device)
# category.shape:{batchbatch_size, 10}
category = model(image)
# category: [batch_size, 10]
# imagelabel:[batch_size]
# 计算损失
loss = criteon(category, imagelabel)
# 反向更新训练
optimizer.zero_grad()
loss.backward()
optimizer.step()
# total_num += image.size(0)
# print( 'total_num:', total_num)
print(epoch, 'loss:', loss.item())
# 测试过程
model.eval()
# 不进行计算图构建
with torch.no_grad():
total_connect = 0 # 总的正确个数
total_num = 0 # 总的当前测试个数
for (image, imagelabel) in mnist_test:
# image.shape:[batch_size, 3, 32, 32]
image, imagelabel = image.to(device), imagelabel.to(device)
# category.shape:{batchbatch_size, 10}
category = model(image)
# 得到最大值的索引
pred = category.argmax(dim=1)
# _, pred = category.max(dim=1)
# 计算每一次正确的个数
total_connect += torch.eq(pred, imagelabel).detach().float().sum().item()
total_num += image.size(0)
# 计算一次训练之后计算率
acc = total_connect / total_num
print('epoch:', epoch, 'test_acc:', acc)
# 保存网络结构
torch.save(model.state_dict(), 'LeNet5_mnist.mdl')
最后的正确率可以达到接近99%的效果,下面进行验证 需要注意,参数与相关设置需要与训练的时候一致
2)分类模型测试
# 训练集下载
mnist_traindata = datasets.MNIST('E:/学习/机器学习/数据集/MNIST', train=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307],
std=[0.1307])
]), download=True)
mnist_train = DataLoader(mnist_traindata, batch_size=batch_size, shuffle=True)
# 获取图像与标签
image, label = iter(mnist_train).next()
# 导入参数
device = torch.device('cuda')
net = LeNet5().to(device)
net.load_state_dict(torch.load('E:/PyCharm/workspace/Cifar10/LeNet5_mnist.mdl'))
其中输出label的值为:
tensor([0, 6, 8, 6, 5, 1, 7, 4, 3, 6, 9, 5, 7, 1, 2, 2, 0, 4, 1, 9, 1, 6, 8, 9,
4, 9, 1, 4, 0, 2, 0, 2, 4, 6, 8, 3, 1, 6, 7, 0, 4, 8, 5, 1, 9, 7, 8, 6,
8, 1, 5, 2, 6, 8, 1, 9, 1, 1, 3, 7, 3, 1, 8, 9, 4, 5, 7, 0, 7, 3, 8, 5,
4, 8, 3, 1, 3, 6, 5, 2, 6, 9, 4, 0, 1, 3, 2, 7, 0, 6, 8, 7, 6, 4, 0, 9,
1, 5, 6, 9, 3, 3, 4, 6, 8, 6, 1, 8, 0, 5, 4, 7, 8, 6, 8, 3, 4, 9, 3, 1,
6, 9, 2, 0, 3, 3, 7, 3])
# 显示一个batch_size的照片
real_image = torchvision.utils.make_grid(image)
real_image = real_image.numpy()
plt.imshow(np.transpose(real_image,(1,2,0)))
# 显示第一张图片
photo = torchvision.utils.make_grid(image[0])
photo = photo.numpy()
plt.imshow(np.transpose(photo,(1,2,0)))
# 选择第1张图片进行验证
test = image[0] # torch.Size([3, 32, 32])
test = test.unsqueeze(0) # torch.Size([1, 3, 32, 32])
test = test.to(device)
pred = net(test) # torch.Size([1, 10])
result = F.softmax(pred) # 求概率
result.max(dim=1)
torch.return_types.max(
values=tensor([1.0000], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([0], device='cuda:0'))
由图片可以明显看出是数字“0”,网络输出正确
# 显示第三张图片
photo = torchvision.utils.make_grid(image[2])
photo = photo.numpy()
plt.imshow(np.transpose(photo,(1,2,0)))
# 选择第3张图片进行验证
test = image[2] # torch.Size([3, 32, 32])
test = test.unsqueeze(0) # torch.Size([1, 3, 32, 32])
test = test.to(device)
pred = net(test) # torch.Size([1, 10])
result = F.softmax(pred) # 求概率
result.max(dim=1)
torch.return_types.max(
values=tensor([1.], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([8], device='cuda:0'))
由图片可以明显看出是数字“8”,网络输出正确