pytorch学习笔记(二)

106 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第4天,点击查看活动详情

上一篇文章,学习笔记(一)中通过几个简单的例子,简略介绍了Pytorch的一些基本用法,如同numpy一样,pytorch的使用也是非常便捷,由于可以使用gpu进行加速计算,因此在很大程度上,其应用的场景不亚于numpy。

本篇文章,先从构建一个简单的网络开始,慢慢的举例说明pytorch在深度学习中的强大作用。

首先,设计一个网络,假设这个网络有一个卷积层1,在卷积层1后面跟了一个relu激活函数,然后接着就是一次最大池化;然后一个卷积层2,同样,在后面跟了一个relu激活函数和一个最大池化。然后就是一个全连接1加一个relu,一个全连接2加一个relu,最后一个全连接进行输出。

整个网络的结构就是如此,下面我们自己新建一个Net类,在这个Net类中实现上述的网络结构,代码如下

class Net(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        '''
        定义卷积、全连接层等几个操作,无顺序可言
        :param in_channels:
        :param middle_channels:
        :param out_channels:
        '''
        super(Net, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 5)
        # self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 5)
        # self.bn2 = nn.BatchNorm2d(out_channels)

        self.fc1 = nn.Linear(out_channels*5*5, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 10)

    def forward(self, x):
        '''
        定义网络的结构顺序,完成前向的传播,这里的顺序就是整个网络结构的顺序
        :param x:
        :return:
        '''
        # 卷积层后跟一个池化层
        x = F.max_pool2d(self.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(self.relu(self.conv2(x)), 2)
        # view中给定第一个参数为-1表示,让系统自动计算前面的维度大小

        x = x.view(-1, self.feature_size(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def feature_size(self, x):
        '''
        计算特征图的数量
        :param x:
        :return:
        '''
        numer = x.size()#[0:]
        num_feature = 1
        for num in numer:
            num_feature *= num
        return num_feature

其中init函数中通过给网络层起别名的形式定义了几个网络结构,forward的作用是构成前向传播,也就是网络从头到尾的连接计算过程,从头部开始,经过一个输入,然后中间的部分都成为隐藏层,直到最后通过一个全连接层进行输出。

首先,对类进行实例化,可以查看整个网络的结构

image.png

然后,对网络中的参数进行一个赋值操作,in_channels=1,middle_channels=16, out_channels=32。此时查看一下网络的结构输入与输出的大小,如下

image.png

最后,我们可以使用mnist手写数据集进行验证,查看我们的网络预测结果是否准确。注意,mnist数据集的大小为28x28,而网络的输入为32x32,因此在进行测试的时候注意数据大小的转换。

下面是我测试的一组结果

image.png

image.png

由于我们现在设计的网络相对较为简单,因此预测的结果不准确是正常的,本文的主要目的就是了解一下如何构建自己的网络结构,在后面的博客中我会继续设计新的复杂的网络结构,然后通过这个简单到复杂的过程,慢慢的掌握其中的要点。