卷积神经网络中间特征图可视化

228 阅读1分钟

注意:如果网络中有拉直层,请先暂时注释掉,否则输出是一个一维向量,可视化没有意义。


代码:
import torch 
from torch import nn
from torchkeras import summary
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def create():
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
    net.add_module("relu",nn.ReLU())
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))

    net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
    net.add_module("relu",nn.ReLU())
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))

    net.add_module("conv3",nn.Conv2d(in_channels=64,out_channels=32,kernel_size = 5))
    net.add_module("relu",nn.ReLU())
    net.add_module("pool3",nn.MaxPool2d(kernel_size = 2,stride = 2))

    net.add_module("conv4",nn.Conv2d(in_channels=32,out_channels=16,kernel_size = 5))
    net.add_module("relu",nn.ReLU())


    return net






def read_img():
    

    image_name = 'test/111.jpg'
    image = Image.open(image_name)


    import torchvision.transforms as transforms
    transform = transforms.Compose([
        transforms.ToTensor()
    ])


    image = transform(image)
    image = image.unsqueeze(0)

    return image


def get_row_col(num_pic):
    squr = num_pic ** 0.5
    row = round(squr)
    col = row + 1 if squr - row > 0 else row
    return row, col


def visualize_feature_map(out):
    feature_map = np.squeeze(out, axis=0)

    num_pic = feature_map.shape[2]
    row, col = get_row_col(num_pic)
 
    plt.figure()
    feature_map_combination = []
    for i in range(0, num_pic):
        feature_map_split = feature_map[:, :, i]
        feature_map_combination.append(feature_map_split)
        plt.subplot(row, col, i + 1)
        plt.imshow(feature_map_split.data.cpu().numpy())
        plt.axis('off')
        plt.title('feature_map_{}'.format(i+1),fontsize=7)
        plt.tight_layout()  # 会自动调整子图参数,使之填充整个图像区域
 
    plt.savefig('feature_map.png')
    plt.show()
 
    # 各个特征图按1:1 叠加
    feature_map_sum = sum(ele for ele in feature_map_combination)
    plt.imshow(feature_map_sum.data.cpu().numpy())
    plt.savefig("feature_map_sum.png")



if __name__ == "__main__":
    image = read_img()
    net = create()
    out = net(image)
    out = out.squeeze()
    out = out.permute(1,2,0).contiguous() # 按各维度的索引进行排列
    visualize_feature_map(out)



输入图片:


运行结果:


16张叠成一张: