如何查看pytorch模型结构

2,808 阅读2分钟

事情的起因是这样的:这段时间需要查看一下网络内部的具体工作方式,所以需要重新使用cifar10训练一个网络。正好前两天读了一下resnet在pytorch中的官方实现,发现其中用了一个AdaptiveAvgPool2d导致网络输入可以被视为任意的,因此我直接用torchvision.models定义的resnet18进行训练。开始一切都很顺利,但是过拟合严重,正确率到55%左右,loss再降但是acc就不升高了。我以为是我的参数没对,找到【Pytorch】ResNet-18实现Cifar-10图像分类进行测试。真香,人家代码就是好。但是我进一步发现不对啊,我的代码换成他的参数,还是不对(实测效果好不少,但是训练不稳定,而且收敛速度远远慢于它)。难道是模型有问题? 换了模型,真香too,人家的模型真好用。不过官方给我的resnet18难道有错???然后在滚滚暖风(容易困)下我开始了模型内部的掉坑,爬出坑,把坑填上土的工作(其实我现在还没填上,只知道我掉到哪个坑了)。 最开始我通过:

for name, value in model.named_parameters():
    print('name: {0},\t grad: {1}'.format(name, value.requires_grad))

查看每一层的情况。为什么用这段代码?直接打印出网络,输出太多,不好看(其实是暖风吹得想睡觉,懒得看)。

粗略一看,两个网络一样啊,层数相同,结构一摸一样,就是每一层的名字不一样。其中几步不一样的地方回去看源码,结构定义也一摸一样。所以深深陷入了夜宵,哦不,沉思。仔细看第二和第三层不一样,但是我觉得这点问题会影响这么大?不应该吧。我在这方面的直觉一向很准。(其实这半年以来所有的实验结果我就没有一个事先预测正确的)。今天开完会回来,又想了想,肯定不是我的直觉错误(蜜汁自信),是结构什么地方我看落了。故查资料,看到了一个工具。

from torchsummary import summary
summary(model, (3, 32, 32))

没装这个工具的话,pip很快就能装好。打印一看,我就说机智如我(瞎猫碰上了死耗子)果然模型结构不对。

明显发现前面几层的通道数不对,而且官方的resnet在开始使用了较少的参数,保证对224**224图片测试的时候显存不会占用过多。但是这些措施在cifar10上就只能起反作用,可以看到后面的卷积基本是11的feath_map。后面的层就基本没用了。 看一下怎么修改了,如果全改的话我开始瞎造的很多轮子就都用不了了。