注意:如果网络中有拉直层,请先暂时注释掉,否则输出是一个一维向量,可视化没有意义。
代码:
import torch
from torch import nn
from torchkeras import summary
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
def create():
net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
net.add_module("relu",nn.ReLU())
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("relu",nn.ReLU())
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv3",nn.Conv2d(in_channels=64,out_channels=32,kernel_size = 5))
net.add_module("relu",nn.ReLU())
net.add_module("pool3",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv4",nn.Conv2d(in_channels=32,out_channels=16,kernel_size = 5))
net.add_module("relu",nn.ReLU())
return net
def read_img():
image_name = 'test/111.jpg'
image = Image.open(image_name)
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor()
])
image = transform(image)
image = image.unsqueeze(0)
return image
def get_row_col(num_pic):
squr = num_pic ** 0.5
row = round(squr)
col = row + 1 if squr - row > 0 else row
return row, col
def visualize_feature_map(out):
feature_map = np.squeeze(out, axis=0)
num_pic = feature_map.shape[2]
row, col = get_row_col(num_pic)
plt.figure()
feature_map_combination = []
for i in range(0, num_pic):
feature_map_split = feature_map[:, :, i]
feature_map_combination.append(feature_map_split)
plt.subplot(row, col, i + 1)
plt.imshow(feature_map_split.data.cpu().numpy())
plt.axis('off')
plt.title('feature_map_{}'.format(i+1),fontsize=7)
plt.tight_layout() # 会自动调整子图参数,使之填充整个图像区域
plt.savefig('feature_map.png')
plt.show()
# 各个特征图按1:1 叠加
feature_map_sum = sum(ele for ele in feature_map_combination)
plt.imshow(feature_map_sum.data.cpu().numpy())
plt.savefig("feature_map_sum.png")
if __name__ == "__main__":
image = read_img()
net = create()
out = net(image)
out = out.squeeze()
out = out.permute(1,2,0).contiguous() # 按各维度的索引进行排列
visualize_feature_map(out)