本文已参与「新人创作礼」活动,一起开启掘金创作之路。
首先先放代码:
analyze_feature_map.py
import torch
from backbone import MobileNetV3_large
import matplotlib.pyplot as plt
import numpy as np
# create model
model = MobileNetV3_large()
# load model weights
model_weight_path = 'MobileNetV3_Large.mdl'
model.load_state_dict(torch.load(model_weight_path))
# print(model)
# 获取模型的所以状态
modelstate = model.state_dict()
# 获取模型的全部关键字
weights_keys = model.state_dict().keys()
print(weights_keys)
for key in weights_keys:
# remove num_batches_tracked para(in bn)
if "num_batches_tracked" in key:
continue
# 获取对应字典中的数值参数
# [kernel_number, kernel_channel, kernel_height, kernel_width]
weight_t = model.state_dict()[key].numpy()
# print('weight_t.size', weight_t.size)
# read a kernel information
# 根据切片操作还可以对特定的卷积核进行状态查看
# k = weight_t[0, :, :, :]
# calculate mean, std, min, max
weight_mean = weight_t.mean()
weight_std = weight_t.std(ddof=1)
weight_min = weight_t.min()
weight_max = weight_t.max()
print("mean is {}, std is {}, min is {}, max is {}"
.format(weight_mean, weight_std, weight_max, weight_min))
# plot hist image
plt.close()
# 将卷积核的权重展开成一个一维的向量
weight_vec = np.reshape(weight_t, [-1])
# hist绘制直方图,bins将min-max区间平分成50等份,再统计每一个小份之间的数量
plt.hist(weight_vec, bins=50)
plt.title(key)
# 保存图像的方法
# plt.savefig('outputs.jpg')
plt.show()
在调试过程中,可以查看全部的参数的
具体思路就是通过model.state_dict()来获得整个模型中的全部参数,然后得出的结果是一个字典,然后可以在这个字典中的关键字来获取每一个卷积层的具体的权重参数或者是偏置参数,注意BN层是没有bias参数的,只有weight参数。
然后得到具体卷积层的参数之后,就可以查看卷积核的数量,通道,长宽等具体参数了,排列顺序为:
[kernel_number, kernel_channel, kernel_height, kernel_width]
得到整层参数之后,可以借助matplotlib来绘制等等