基于LeNet模型实现MNIST手写体数据集的训练

279 阅读9分钟

LeNet卷积神经网络是LeCun于1998年提出,是卷积神经网络的开篇之作。通过共享卷积核减少了网络的参数,另外还提出了局部感受野和池化(下采样)。

 

 

一、知识点解读

1.1 共享卷积核参数

关于共享卷积核参数的优点,可以用一个对比图来展示:

  • 进行parameter sharing

    图1:参数共享情况下的卷积计算过程

    概括为:一个边长为m的卷积核在图像上扫描,区域内对位相乘再累加对应输出后的一个像素点,以此进行特征提取。

     

  • 不进行parameter sharing

    图2:参数不共享情况下的卷积计算过程

    表示卷积核的参数数量与图像像素矩阵的大小保持一致,当channels很大时,参数的数量会特别大。

    由此可以看出:利用卷积核参数的共享,可以避免参数数量在加大深度时出现激增现象。

 

 

1.2 感受野

感受野(Receptive Field)是指输出特征图中的1个像素点映射到原始输入图片的区域大小。为了更形象的理解感受野的含义,我用word画了下面这幅图,代表对于同一张输入图像,使用连续两层3x3卷积核与单层5x5卷积核进行特征提取的效果:

图3:感受野

比如对于一张5x5的原始图像

  • 如果使用3x3的卷积核进行计算会输出一个3x3的输出特征图,这个输出特征图上的每个像素点映射到原始图片是3x3的区域,所以它的感受野是3;如果对这个3x3的特征图再用一次3x3的卷积核进行扫描,会再输出一个1x1的输出特征图,这个特征图的上像素点映射到原始图片(不是上一级图片)是5x5的区域,所以它的感受野是5.
  • 如果直接使用5x5的卷积核进行计算会输出一个11的输出特征图,这个输出特征图上的每个像素点映射到原始图片是5x5的区域,所以它的感受野是5.

 

由此可知:对于这幅原始图片,连续使用两次3x3的卷积核和单独使用一次5x5的卷积核,其特征提取能力是一样的。那么,该如何选择呢?这个时候我们就要考虑他们的待训练参数数量和卷积计算量了。

设输入特征图宽度、高度均为n,卷积步长为1.

  • 对于待训练参数:

    • 前者:9+9=18
    • 后者:25
  • 对于计算次数:

    • 前者:
    9(n3+1)2+9((n3+1)3+1)2=18n2108n2+1809(n-3+1)^2 + 9((n-3+1)-3+1)^2 = 18n^2 - 108n^2 +180
    • 后者
      25(n5+1)2=25n2200n2+40025(n-5+1)^2 = 25n^2 - 200n^2 +400
  • 对于卷积后的Feature Map:

    卷积前后它们满足下面的关系:

    W2=(W1F+2P)/S+1H2=(H1F+2P)/S+1W_2 = (W_1 - F + 2P)/S + 1\\ H_2 = (H_1 - F + 2P)/S + 1

    其中 W2W_2 , 是卷积后 Feature Map 的宽度;W1W_1 是卷积前图像的宽度;FF 是 filter 的宽度;PP 是 Zero Padding 数量,Zero Padding 是指在原始图像周围补几圈 00,如果 PP 的值是 11,那么就补 1 圈 0;SS 是步幅;H2H_2 卷积后 Feature Map 的高度;H1H_1 是卷积前图像的宽度。

 

计算次数两个公式联立计算可知:当n>10时,两层3x3卷积核比一层5x5卷积核性能要好,这也就是为什么现在的神经网络在卷积计算中常使用多个小卷积核堆叠来替换单个大卷积核的原因。

 

 

1.3 池化

池化操作用于减少卷积神经网络中特征数据量。池化的主要方法有最大池化均值池化,最大池化可以提取图片纹理,均值池化可以保留背景特征。

设池化核为2x2,步长为2,两种池化方法的结果如图4所示(图片来自课堂ppt):

图4:两种池化方法

可以容易的看出:最大池化表示取区域中的像素最大值作为池化结果;平均池化表示取区域中的像素平均值作为池化结果。

 

Tensorflow给出了对应的池化函数MaxPool2D和AveragePooling2D. 具体如下:

#最大池化
tf.keras.layers.MaxPool2D(
    pool_size=池化核尺寸          #正方形可直接写核长整数,或元组形式(宽,高)
    strides=池化步长              #用整数,或元组形式(纵向h,横向w),默认为pool_size
    padding='valid' or 'same'     #全零填充是"same" , 不使用是"valid"(默认)
)

#平均池化
tf.keras.layers.AveragePooling2D(
    pool_size=池化核尺寸          #正方形可直接写核长整数,或元组形式(宽,高)
    strides=池化步长              #用整数,或元组形式(纵向h,横向w),默认为pool_size
    padding='valid' or 'same'     #全零填充是"same" , 不使用是"valid"(默认)
)

#实例化
model = tf.keras.models.Sequential([
    Conv2D(filter=6,kernel_size=(5,5),padding='same') ,     # C
    BatchNormalization(),                                   # B
    Activation('relu'),                                     # A
    MaxPool2D(pool_size=(2,2),strides=2,padding='same'),    # P
    Dropout(0,2)                                            # D
]) 

 

 

 

二、网络结构

LeNet网络结构如图5所示(图片来自课堂ppt):

图5:LeNet模型

从图中可以发现:LeNet一共有五层网络(只计算卷积层和全连接层,其余操作认为是卷积计算层的附属)。包含两层卷积和连续的三层全连接。

 

卷积,就是特征提取器CBAPD。(其中D是舍弃Dropout)具体表示为:

关于卷积的具体细节,将会在稍后的文章中给出,这里直接用我已经掌握好的知识。

 

 

C1表示第一层卷积,对齐分析:

  • C:(16个5*5的卷积核,步长为1,不使用全零填充)
  • B:(None)LeNet时代,还没有BN操作,所以此项为空
  • A:(sigmoid)LeNet时代,sigmoid是主流的激活函数
  • P:(用2*2的池化核,步长为2,进行最大池化,不使用全零填充)
  • D:(None)LeNet时代,还没有Dropout,所以此项为空

 

C3表示第二层卷积,对齐分析:

  • C:(6个5*5的卷积核,步长为1,不使用全零填充)
  • B:(None)
  • A:(sigmoid)
  • P:(用2*2的池化核,步长为2,进行最大池化,不使用全零填充)
  • D:(None)

 

随后,Flatten拉直,后接三层全连接网络。

神经元分别是120、84、10(因为是10分类问题),前两层全连接使用sigmoid激活函数,最后一层使用softmax使输出符合概率分布。

 

各层详细解读,可参考文章:【经典CNN模型LeNet解读 - 于晨晨的文章 - 知乎】

 

对着CBAPD,网络结构用代码应描述为:

class Mnist_LeNet(Model):
    def __init__(self):
        super(Mnist_LeNet, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5),activation='sigmoid')    #第一层是6个5*5的卷积核,使用sigmoid激活函数
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2)                        #选择最大池化方法,池化核是2*2的尺寸,池化步长是2

        self.c2 = Conv2D(filters=16, kernel_size=(5, 5),activation='sigmoid')   #第一层是16个5*5的卷积核,使用sigmoid激活函数
        self.p2 = MaxPool2D(pool_size=(2, 2), strides=2)                        #选择最大池化方法,池化核是2*2的尺寸,池化步长是2

        self.flatten = Flatten()                            #拉直
        self.f1 = Dense(120, activation='sigmoid')          #连续三层全连接网络,最后一层设置10个神经元(10分类)
        self.f2 = Dense(84, activation='sigmoid')
        self.f3 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.p1(x)

        x = self.c2(x)
        x = self.p2(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y

model = Mnist_LeNet()    #实例化

 

 

 

三、代码实现

基于简单的全连接网络实现 MNIST 手写体数据集的训练与识别的基础上重写Mnist_Base类,其余代码相同。

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model
import os
import numpy as np
np.set_printoptions(threshold=np.inf)  #设置print输出格式,通过np.inf使完全输出,不允许用省略号代替
from matplotlib import pyplot as plt




#  一:导入数据集,设定训练集和测试集的特征和标签
mnist = tf.keras.datasets.mnist                         #下载手写数字数据集
(x_train,y_train),(x_test,y_test) = mnist.load_data()   #指定训练集和测试集的输入特征和标签
x_train , x_test = x_train/255.0 , x_test/255.0         #对输入网络的特征进行归一化。全部转化为0到1之间的数,数值变小有利于神经网络的吸收


x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 注意:卷积计算要求输入的图片必须是4个维度的,第0个维度表示一次喂入几个batch,第1、2、3个维度分别表示输入图片的分辨率和通道数。
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)     #       而mnist数据集是单通道灰度图片,加上batch,才3个维度,即(batch数,row,col)  , 而Conv2D要求有四个参数,即(batch数,row,col,通道数),所以需要再加一个。  如果数据集是三通道彩色图(如cifar10数据集),则不需要这个步骤





#  二:搭建网络结构
class Mnist_LeNet(Model):
    def __init__(self):
        super(Mnist_LeNet, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5),activation='sigmoid')    #第一层是6个5*5的卷积核,使用sigmoid激活函数
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2)                        #选择最大池化方法,池化核是2*2的尺寸,池化步长是2

        self.c2 = Conv2D(filters=16, kernel_size=(5, 5),activation='sigmoid')   #第一层是16个5*5的卷积核,使用sigmoid激活函数
        self.p2 = MaxPool2D(pool_size=(2, 2), strides=2)                        #选择最大池化方法,池化核是2*2的尺寸,池化步长是2

        self.flatten = Flatten()                            #拉直
        self.f1 = Dense(120, activation='sigmoid')          #连续三层全连接网络,最后一层设置10个神经元(10分类)
        self.f2 = Dense(84, activation='sigmoid')
        self.f3 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.p1(x)

        x = self.c2(x)
        x = self.p2(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.f2(x)
        y = self.f3(x)
        return y

model = Mnist_LeNet()    #实例化








#  三:配置训练方法
model.compile(
    optimizer='adam',    #优化器选择adam
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),  #损失函数选择SparseCategoricalCrossentropy,因为前面已经保证输出满足概率分布,所以这里from_logits=False
    metrics=['sparse_categorical_accuracy']   #数据集中标签是数值,输出结果y是概率分布,所以衡量方法选择sparse_categorical_accuracy
)


#  (断点续训,存取模型参数。  在下次训练时,从之前获取的最优的参数开始,提高了准确率)
    #读取模型
checkpoint_save_path='./checkpoint/mnist.ckpt'
if os.path.exists(checkpoint_save_path+'.index'):   #生成ckpt文件时会自动生成索引文件,所以拿它的索引文件来判断
    print('------加载已有模型------')
    model.load_weights(checkpoint_save_path)       #如果存在模型,则直接读取


    #保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(   #使用tf给出的回调函数来保存模型参数
    filepath=checkpoint_save_path,
    save_weights_only=True,    #是否只保留模型参数
    save_best_only=True        #是否只保留最优结果
)






#  四:执行训练过程
history = model.fit(
    x_train,y_train,
    batch_size=32,                          #每次喂入网络32组数据
    epochs=3,                               #数据集迭代10次
    validation_data=(x_test,y_test),
    validation_freq=1,                      #每迭代一次训练集执行一次测试集的评测
    callbacks=[cp_callback])                #加入回调选项,返回给history。(如果不用断点续训,则不用写 “history=” 和 “callbacks=[cp_callback]” )


#通过写入到txt文本的方式查看断点续训时保存的参数
file = open('./weights.txt','w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()








#  五:打印网络结构和参数信息
model.summary()











#  六:展示acc和loss曲线 (断点续训中history里已经保存好了)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(acc,label='Training Accuracy')
plt.plot(val_acc,label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1,2,2)
plt.plot(loss,label='Training Loss')
plt.plot(val_loss,label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.show()

 

 

 

四、运行结果

图6:训练结果

 

 

 

www.bilibili.com/video/BV1B7… zhuanlan.zhihu.com/p/41736894 ieeexplore.ieee.org/document/72…