第P1周:Pytorch实现mnist手写数字识别

1 阅读2分钟

前言

  • 任务:P1周(实现mnist手写数字识别)
  • 要求:
    1.了解pytorch,并使用pytorch构建一个深度学习程序;2.了解什么是深度学习
  • 拔高(可选):学习提到的函数方法

语言环境
Python3.8
PyTorch 版本: 1.12.1+cpu
Torchvision 版本: 0.13.1+cpu

代码运行过程
1.设置硬件设备
image.png 2.下载数据

image.png 3.数据加载

个人总结:训练集中shuffle为True为了打乱顺序,防止将特征与顺序进行联系

image.png

加载模型并打印

个人总结

  • 模型主要分为两个部分,一个是参数的设置,第二个是定义计算流程的顺序
  • 卷积核=输出通道,每一个卷积核负责捕捉一种局部特征。32和64是经验参数,不能太大也不能太小,经典本案例使用的32和64是经典逐层翻倍设计
  • 全连接层将前面得到的额特征一维化,可用于后续特征加权计算

image.png image.png

编写训练函数

个人总结

  • 函数先设置初始随机,在一次次迭代修正和优化特征权重
  • 训练函数包含三个核心:1.梯度清零2.反向传播计算梯度3.更新权重。目的让模型越来越准

image.png

编写测试函数

  • 测试函数不需要权重更新,只需要前向预测、统计loss和准确率

image.png

正式训练

image.png image.png

可视化

image.png


CNN模型定义主要分为两块:1.参数定义;2.计算流程定义
在应用中需要分别定义训练函数和测试函数
训练函数目的:根据CNN模型得到的特征进行最优权重计算
测试函数目的:誉为进行预测即损失韩式和准确性的计算。