pytorch 模型model 的一些常用属性和函数说明

7,182 阅读1分钟

 

首先创建一个简单的网络,用来举例说明后来的例子。

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6,8,kernel_size=3,padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.conv1(1)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

net = Net()

net.parameters(),可以得到net这个具体的模型中的参数:

for para in net.parameters():
    print(para)
    print(para.shape)
    print()


'''
输出为:

Parameter containing:
tensor([[[[ 0.0794,  0.1070,  0.0415],
          [ 0.0037, -0.0850,  0.0919],
          [ 0.0039,  0.0899, -0.1446]],

         [[-0.0642,  0.0251, -0.1055],
          [ 0.1085,  0.0627,  0.0388],
          [-0.0878, -0.1305,  0.1335]],

         [[-0.0907,  0.0113,  0.1400],
          [ 0.0051, -0.0605, -0.1085],
          [ 0.0544, -0.0649,  0.0847]]],


        [[[-0.1319,  0.0152, -0.0736],
          [ 0.1796,  0.0857,  0.1668],
          [ 0.0586, -0.1508, -0.1571]],

         [[-0.1053,  0.0372,  0.1596],
          [ 0.1509,  0.1125, -0.1773],
          [ 0.0960, -0.0507,  0.0569]],

         [[-0.0640,  0.0070, -0.1253],
          [-0.1739,  0.0552,  0.1892],
          [ 0.1232, -0.0811,  0.1263]]],


        [[[ 0.0483, -0.1212, -0.0870],
          [-0.0915,  0.0072,  0.1581],
          [ 0.1184, -0.0907,  0.1109]],

         [[-0.0024,  0.0980, -0.1080],
          [-0.0311, -0.1013, -0.0581],
          [ 0.1855, -0.0202,  0.0950]],

         [[-0.1640, -0.0848,  0.0254],
          [ 0.0318,  0.0538,  0.0277],
          [ 0.0641,  0.0298,  0.0352]]],


        [[[-0.0955,  0.0569, -0.0565],
          [-0.1186, -0.0177,  0.0604],
          [ 0.0305, -0.0398, -0.1165]],

         [[-0.1532,  0.0179,  0.0317],
          [ 0.0910,  0.1470, -0.1013],
          [-0.0165,  0.0095, -0.0887]],

         [[-0.0314,  0.1790, -0.1142],
          [ 0.1710, -0.1628,  0.1342],
          [-0.0781,  0.0194, -0.0568]]],


        [[[-0.1903, -0.1659, -0.1797],
          [ 0.1109,  0.0686,  0.1767],
          [-0.0777, -0.0341, -0.1549]],

         [[ 0.0615, -0.1309, -0.1492],
          [ 0.1291, -0.1705,  0.1749],
          [ 0.0173, -0.1587,  0.0072]],

         [[-0.1669, -0.0803,  0.0378],
          [ 0.1880, -0.0338,  0.1056],
          [-0.0171,  0.0892, -0.0090]]],


        [[[-0.1615,  0.1901, -0.1313],
          [-0.0775, -0.0043, -0.0902],
          [-0.0786,  0.0501,  0.0921]],

         [[ 0.1332,  0.1698,  0.1657],
          [ 0.0244,  0.0792, -0.1830],
          [ 0.0519, -0.1610,  0.0821]],

         [[-0.1437,  0.0229, -0.0810],
          [-0.1200,  0.1311,  0.0776],
          [ 0.0772, -0.0238, -0.0981]]]], requires_grad=True)
torch.Size([6, 3, 3, 3])

Parameter containing:
tensor([ 0.0599, -0.1511, -0.0591,  0.1000,  0.1050,  0.0743],
       requires_grad=True)
torch.Size([6])

Parameter containing:
tensor([1., 1., 1., 1., 1., 1.], requires_grad=True)
torch.Size([6])

Parameter containing:
tensor([0., 0., 0., 0., 0., 0.], requires_grad=True)
torch.Size([6])

Parameter containing:
tensor([[[[-1.0620e-01,  5.6997e-02, -7.9542e-03],
          [-6.6638e-02, -1.0529e-02,  1.3376e-01],
          [ 7.1680e-02,  1.3388e-01,  1.2293e-01]],

         [[ 9.2092e-02,  2.4215e-02, -1.2708e-01],
          [ 1.9943e-03, -8.7654e-02,  1.0564e-01],
          [-1.2967e-01, -1.2077e-01, -4.4365e-02]],

         [[ 9.9798e-04, -7.9709e-02,  2.7571e-02],
          [-1.4309e-02,  1.1243e-01, -1.1661e-01],
          [ 7.5213e-02,  7.6132e-02,  1.4844e-02]],

         [[ 1.2713e-01, -7.3697e-02,  9.4301e-02],
          [ 7.7325e-02,  9.6845e-02, -1.0990e-01],
          [ 6.2486e-02,  1.0107e-01,  3.0378e-02]],

         [[-1.0599e-01,  2.7444e-02, -8.8193e-02],
          [-1.0384e-01,  1.2580e-01,  4.1619e-02],
          [ 1.3596e-01, -1.2098e-01,  8.2317e-02]],

         [[-1.0979e-01,  9.2484e-02, -5.2828e-03],
          [ 7.7915e-02,  6.0981e-02,  9.0634e-02],
          [ 8.3001e-02,  7.1535e-02, -1.6206e-02]]],


        [[[ 1.1561e-01, -2.1935e-02, -8.5694e-03],
          [-4.9740e-03, -2.1594e-02,  9.7255e-02],
          [ 1.2904e-01,  7.2028e-02,  9.6564e-02]],

         [[-7.6498e-02, -1.2666e-01, -3.2563e-02],
          [ 9.0076e-02, -8.3288e-02,  1.1785e-01],
          [-4.3596e-02,  3.6950e-03, -5.0087e-02]],

         [[-2.9787e-02, -5.2824e-02, -9.9231e-02],
          [ 9.1963e-02,  7.7965e-02,  1.1397e-01],
          [ 1.3667e-02,  1.1007e-01, -4.1288e-02]],

         [[ 9.4790e-02, -6.8296e-02, -4.3310e-02],
          [-6.3128e-02,  2.3350e-02, -6.3908e-02],
          [-1.2005e-01, -6.2899e-02, -7.2392e-02]],

         [[-1.1934e-01, -4.5716e-02, -5.7582e-02],
          [ 8.1211e-06,  9.6752e-02, -4.1839e-02],
          [ 9.9383e-02, -4.9952e-02, -4.1875e-02]],

         [[ 1.0271e-01, -9.7970e-02, -2.5481e-02],
          [ 1.2039e-01,  1.7195e-02, -2.2504e-02],
          [ 6.3394e-02, -1.0446e-02,  9.7013e-02]]],


        [[[-6.2230e-02, -8.0188e-02, -4.3593e-02],
          [ 9.6622e-02,  7.5777e-02,  1.9751e-02],
          [ 4.6756e-02,  8.1505e-02,  2.1734e-02]],

         [[-4.0420e-02, -4.7027e-02,  2.7860e-02],
          [-4.5530e-04,  1.0848e-01, -9.7263e-02],
          [ 4.0441e-02, -2.3740e-03, -1.1751e-01]],

         [[-1.0342e-01,  1.4509e-02,  3.5800e-02],
          [-7.3109e-02, -4.4676e-02,  1.1477e-01],
          [ 1.0436e-01, -1.1468e-01,  1.1279e-01]],

         [[ 1.2757e-01, -5.4175e-02,  3.9229e-02],
          [ 1.2238e-01, -4.1751e-02,  1.0329e-02],
          [ 1.1175e-01, -1.3469e-01,  9.0738e-02]],

         [[-1.2890e-01,  1.0985e-01, -3.5065e-02],
          [-1.0353e-02, -1.1117e-01, -1.0932e-01],
          [ 2.3825e-02, -5.1328e-02,  1.0952e-01]],

         [[-1.2119e-01, -1.1721e-01,  3.9911e-02],
          [-9.3294e-02,  3.6181e-02, -9.2453e-02],
          [-1.0519e-01,  5.3727e-02,  4.4648e-03]]],


        ...,


        [[[ 6.6163e-02, -1.0531e-01, -1.0589e-01],
          [ 7.9671e-02, -3.3005e-02, -1.0760e-01],
          [ 1.4868e-02,  1.4420e-02, -9.6573e-02]],

         [[ 2.2414e-02, -1.5715e-02,  2.4232e-02],
          [ 2.3479e-02, -8.7212e-02, -1.8911e-02],
          [ 9.3712e-02,  1.0342e-01,  5.4269e-02]],

         [[-9.8044e-02,  7.1834e-02, -1.0760e-01],
          [-9.7597e-02,  9.9367e-02, -9.9010e-02],
          [ 2.6155e-02, -1.3208e-01,  1.0316e-02]],

         [[ 7.7097e-02,  1.0838e-01,  2.7527e-02],
          [-4.3391e-02,  1.3416e-01, -1.1440e-01],
          [-3.8224e-02, -2.7650e-03, -5.9436e-03]],

         [[ 6.5886e-02,  1.1016e-02, -1.0989e-01],
          [ 4.2206e-02, -9.2878e-02,  7.4586e-02],
          [ 1.1299e-01, -1.1260e-01, -7.2581e-02]],

         [[ 8.6093e-03,  3.0288e-02,  7.8243e-02],
          [-6.7512e-03, -8.5671e-02,  8.3012e-02],
          [-2.4528e-02,  1.7389e-02,  2.0112e-02]]],


        [[[ 3.9985e-02,  6.4231e-03,  1.3579e-01],
          [ 8.8007e-02, -1.8449e-02,  2.9483e-02],
          [-5.8890e-02,  3.1275e-02,  1.1129e-01]],

         [[ 9.9826e-02, -1.0343e-01,  1.7781e-02],
          [-1.5528e-02, -1.2074e-01, -5.4819e-02],
          [-8.1487e-02,  3.7535e-02, -6.7128e-02]],

         [[-2.2612e-02, -4.7612e-02, -1.3335e-01],
          [ 3.7972e-02, -1.2762e-01,  5.4009e-02],
          [ 9.0579e-02,  5.4727e-02, -9.1461e-02]],

         [[ 8.0858e-02,  1.4411e-03, -1.2739e-01],
          [ 1.0097e-01,  8.3857e-02, -8.0914e-02],
          [-1.9743e-02,  1.1509e-01,  8.2933e-02]],

         [[-3.0184e-02,  1.0409e-01,  2.2486e-02],
          [-7.8506e-02, -7.7744e-02, -2.8042e-02],
          [-3.3265e-02,  9.1861e-02,  4.7874e-02]],

         [[ 3.1688e-02,  1.2607e-01,  8.8575e-02],
          [ 1.0217e-01,  2.8618e-02,  8.4546e-02],
          [ 2.8103e-02,  1.2679e-01,  2.4444e-02]]],


        [[[ 7.9484e-02, -1.1017e-02, -2.9063e-02],
          [ 5.4235e-02,  1.1226e-01, -1.0663e-01],
          [ 9.8365e-02, -2.1643e-02,  6.3686e-02]],

         [[ 3.0368e-03,  1.2335e-03,  1.3460e-02],
          [-5.6941e-02, -9.9266e-02,  3.3269e-02],
          [ 8.6997e-02,  1.1879e-01, -1.2027e-02]],

         [[ 3.4441e-02,  1.3346e-01,  1.4495e-03],
          [ 6.1219e-02,  8.4678e-02, -4.3233e-02],
          [ 1.3061e-01, -1.1880e-01, -1.2782e-01]],

         [[ 3.4226e-02,  7.5535e-02, -7.4717e-02],
          [ 8.2468e-03, -9.3862e-02, -5.3166e-02],
          [ 1.3202e-01,  7.6724e-02,  6.3903e-02]],

         [[-5.8022e-02, -7.8344e-02, -4.7197e-02],
          [ 3.7977e-02,  8.6118e-02,  1.1670e-02],
          [-1.3180e-01, -3.9207e-02,  1.3028e-01]],

         [[-5.4157e-03, -7.3742e-02,  4.5027e-02],
          [-2.8969e-02, -2.3086e-02, -3.3792e-02],
          [ 7.5957e-02,  3.4847e-02,  1.3248e-01]]]], requires_grad=True)
torch.Size([32, 6, 3, 3])

Parameter containing:
tensor([-0.0801,  0.0075,  0.0469, -0.0886,  0.0583, -0.0399, -0.0551,  0.0094,
        -0.0457,  0.1121,  0.0496, -0.0684,  0.1093,  0.0834, -0.0910, -0.1112,
        -0.0711, -0.0641, -0.0981,  0.0356,  0.1234, -0.0284,  0.0813,  0.0188,
        -0.0063, -0.0851,  0.1308, -0.0041, -0.0926, -0.0906,  0.1180,  0.0142],
       requires_grad=True)
torch.Size([32])



'''

net.named_parameters()会返回两部分内容,分别是模型中的属性名称和对应的参数值:

for name, para in net.named_parameters():
    print(name)
    print(para)
    print(para.shape)
    print()

'''
输出例子:
conv1.weight
torch.Size([6, 3, 3, 3])
Parameter containing:
tensor([[[[ 0.0215,  0.1517,  0.1218],
          [ 0.1887, -0.0702,  0.1366],
          [-0.0947,  0.0794, -0.1096]],

         [[-0.0045,  0.0683, -0.0814],
          [ 0.0367, -0.0305, -0.1630],
          [ 0.0413,  0.0197,  0.1726]],

         [[ 0.0212,  0.1100,  0.0536],
          [ 0.1513,  0.0163,  0.1070],
          [-0.1378, -0.1698,  0.1431]]],

'''

输出的conv1.weight,正好对应着最上面定义的self.conv1 = nn.Conv2d(3,6,kernal_size=1,padding=1)中的conv1中的weight参数。

指定参数的更新方式:

net = Net()
ignored_params = list( map(id, net.conv1.parameters()) )
base_params = filter ( lambda p: id(p) not in ignored_params, net.parameters() )
optimizer = torch.optim.SGD( [
    {'params':base_params},
    {'params':net.conv1.parameters(), 'lr':1e-3}
                             ], lr = 1e-2, momentum=0.9 )

在{}中可以对某些参数指定更新方式,没有设置的更新细节,可以在()后面统一规定