Pytorch中的Grad-CAM:使用前向和后向钩子

383 阅读10分钟

我注意到了这个叫做Grad-CAM的技术,它能够检查卷积神经网络如何预测其输出。例如,在一个分类器中,你可以深入了解你的神经网络是如何使用输入来进行预测的。这一切都始于描述它的原始论文。在这篇文章中,我们将使用Pytorch库来实现它,你可以应用于任何卷积神经网络,不需要改变你已有的神经网络模块中的任何东西。

image.png 我在Medium上读到一篇名为《在PyTorch中实现Grad-CAM》的论文,作者是Stepan Ulyanin,他启发我写了类似的东西,但方式略有不同。Stepan提出了一种方法,即你需要重写你的模型的前向函数,以便计算Grad-CAM。感谢Pytorch,我们可以在不改变前向函数的情况下通过注册前向和后向钩子来做同样的事情。我希望这篇文章能对Stepan写的那篇精彩的文章有一点贡献。

让我们深入了解一下!

1.加载并检查预训练的模型

为了演示Grad-CAM的实现,我将使用来自Kaggle的胸部X射线数据集和我做的一个预训练的分类器,能够将X射线分类为有或没有肺炎。

model_path = "your/model/path/"# instantiate your modelmodel = XRayClassifier() # load your model. Here we're loading on CPU since we're not going to do # large amounts of inferencemodel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # put it in evaluation mode for inferencemodel.eval()

接下来,让我们检查一下模型的架构。由于我们对了解输入图像的哪些方面有助于预测感兴趣,我们需要确定最后一个卷积层,特别是其激活函数。这一层代表了模型为了对其输入进行分类而学习的最复杂的特征。因此,它最能帮助我们理解模型的行为。

import torchimport torch.nn as nnimport torch.nn.functional as F# hyperparametersnc = 3 # number of channelsnf = 64 # number of features to begin withdropout = 0.2device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# setup a resnet block and its forward functionclass ResNetBlock(nn.Module):    def __init__(self, in_channels, out_channels, stride=1):        super(ResNetBlock, self).__init__()        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)        self.bn1 = nn.BatchNorm2d(out_channels)        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)        self.bn2 = nn.BatchNorm2d(out_channels)                self.shortcut = nn.Sequential()        if stride != 1 or in_channels != out_channels:            self.shortcut = nn.Sequential(                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),                nn.BatchNorm2d(out_channels)            )            def forward(self, x):        out = F.relu(self.bn1(self.conv1(x)))        out = self.bn2(self.conv2(out))        out += self.shortcut(x)        out = F.relu(out)        return out# setup the final model structureclass XRayClassifier(nn.Module):    def __init__(self, nc=nc, nf=nf, dropout=dropout):        super(XRayClassifier, self).__init__()        self.resnet_blocks = nn.Sequential(            ResNetBlock(nc,   nf,    stride=2), # (B, C, H, W) -> (B, NF, H/2, W/2), i.e., (64,64,128,128)            ResNetBlock(nf,   nf*2,  stride=2), # (64,128,64,64)            ResNetBlock(nf*2, nf*4,  stride=2), # (64,256,32,32)            ResNetBlock(nf*4, nf*8,  stride=2), # (64,512,16,16)            ResNetBlock(nf*8, nf*16, stride=2), # (64,1024,8,8)        )        self.classifier = nn.Sequential(            nn.Conv2d(nf*16, 1, 8, 1, 0, bias=False),            nn.Dropout(p=dropout),            nn.Sigmoid(),        )    def forward(self, input):        output = self.resnet_blocks(input.to(device))        output = self.classifier(output)        return output

这个模型被准备用来接收256x256的3个通道。因此,它的输入预计会有[批量大小,3,256,256]的形状。每个ResNet块以ReLU激活函数结束。对于我们的目标,我们需要选择最后的ResNet块。

XRayClassifier(  (resnet_blocks): Sequential(    (0): ResNetBlock(      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (1): ResNetBlock(      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (2): ResNetBlock(      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (3): ResNetBlock(      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )    (4): ResNetBlock(      (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)      (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)      (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (shortcut): Sequential(        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      )    )  )  (classifier): Sequential(    (0): Conv2d(1024, 1, kernel_size=(8, 8), stride=(1, 1), bias=False)    (1): Dropout(p=0.2, inplace=False)    (2): Sigmoid()  ))

在Pytorch中,我们可以使用模型的属性来做这个选择,非常容易。

model.resnet_blocks[-1]#ResNetBlock(#  (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)#  (bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#  (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)#  (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#  (shortcut): Sequential(#    (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)#    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)#  )#)

2.Pytorch注册钩子的方法

Pytorch有很多处理钩子的函数,基本上是允许你在前向或后向传递过程中处理流经模型的信息。你可以用它来检查中间的梯度值,对特定层的输出进行修改,等等。

在这里,我们将重点讨论nn.Module类的两个方法。让我们仔细看一下它们。

2.1. register_full_backward_hook(hook,prepend=False)**

这个方法在模块上注册了一个后向钩子,这意味着钩子函数将在*backward()*方法被调用时运行。

后向钩子函数接收模块本身的输入,相对于层的输入的梯度,以及相对于层的输出的梯度。

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

它返回一个torch.utils.hooks.RemovableHandle, 这允许你以后删除钩子。因此,把它分配给一个变量是很有用的。我们稍后会回到这个问题上。

2.2. register_forward_hook(hook,*,prepend=False,with_kwargs=False)**

这和前面的很相似,只是钩子函数在前向传递中运行,也就是说,当感兴趣的层处理其输入并返回其输出时。

钩子函数有一个稍微不同的签名。它让你能够访问该层的输出:

hook(module, args, output) -> None or modified output

它还返回一个Torch.utils.hooks.RemovableHandle。

3.在你的模型中添加后向和前向钩子

首先,我们需要定义我们的后向和前向钩子函数。为了计算Grad-CAM,我们需要相对于最后一个卷积层输出的梯度,并且我们需要它的激活,即该层激活函数的输出。因此,我们的钩子函数将只在推理和后向传递中为我们提取这些值。

# defines two global scope variables to store our gradients and activationsgradients = Noneactivations = Nonedef backward_hook(module, grad_input, grad_output):  global gradients # refers to the variable in the global scope  print('Backward hook running...')  gradients = grad_output  # In this case, we expect it to be torch.Size([batch size, 1024, 8, 8])  print(f'Gradients size: {gradients[0].size()}')   # We need the 0 index because the tensor containing the gradients comes  # inside a one element tuple.def forward_hook(module, args, output):  global activations # refers to the variable in the global scope  print('Forward hook running...')  activations = output  # In this case, we expect it to be torch.Size([batch size, 1024, 8, 8])  print(f'Activations size: {activations.size()}')

在定义了我们的钩子函数和将存储激活和梯度的变量后,我们需要在感兴趣的层中实际注册钩子:

backward_hook = model.resnet_blocks[-1].register_full_backward_hook(backward_hook, prepend=False)forward_hook = model.resnet_blocks[-1].register_forward_hook(forward_hook, prepend=False)

4.检索我们需要的梯度和激活值

现在我们已经为我们的模型设置了钩子,让我们加载一张图片,我们将对其进行Grad-CAM计算。

from PIL import Imageimg_path = "/your/image/path/"image = Image.open(img_path).convert('RGB')

这就是我们要使用的图像

我们需要对其进行预处理,以准备将其送入模型进行推理。

from torchvision import transformsfrom torchvision.transforms import ToTensorimage_size = 256transform = transforms.Compose([                               transforms.Resize(image_size, antialias=True),                               transforms.CenterCrop(image_size),                               transforms.ToTensor(),                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),                           ])img_tensor = transform(image) # stores the tensor that represents the image

现在我们需要使用这个图像张量作为输入进行前向处理。而且我们还需要做后向传递,以使我们的后向钩子发挥作用。

# since we're feeding only one image, it is a 3d tensor (3, 256, 256). # we need to unsqueeze to it has 4 dimensions (1, 3, 256, 256) as # the model expects it to.model(img_tensor.unsqueeze(0)).backward()# here we did the forward and the backward pass in one line.

我们的钩子函数返回如下:

Forward hook running...Activations size: torch.Size([1, 1024, 8, 8])Backward hook running...Gradients size: torch.Size([1, 1024, 8, 8])

现在我们可以使用梯度激活 变量来计算我们的热图了

5.计算Grad-CAM

为了计算Grad-CAM,我们将使用原始的论文方程和Stepan Ulyanin对它们的实现。

# pool the gradients across the channelspooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])
import torch.nn.functional as Fimport matplotlib.pyplot as plt# weight the channels by corresponding gradientsfor i in range(activations.size()[1]):    activations[:, i, :, :] *= pooled_gradients[i]# average the channels of the activationsheatmap = torch.mean(activations, dim=1).squeeze()# relu on top of the heatmapheatmap = F.relu(heatmap)# normalize the heatmapheatmap /= torch.max(heatmap)# draw the heatmapplt.matshow(heatmap.detach())

下面是我们的热图

值得注意的是,我们通过前向钩子获得的激活包含1024个特征图,这些特征图捕捉了输入图像的不同方面,每个特征图的空间分辨率为8x8。

另一方面,我们通过后向钩获得的梯度代表了每个特征图对最终预测的重要性。通过计算梯度和激活的元素之积,我们得到了一个加权的特征图之和,突出了图像中最相关的部分。

最后,通过计算加权特征图的全局平均值,我们得到一个单一的热图,表明图像中对模型预测最重要的区域。这种被称为Grad-CAM的技术为模型的决策过程提供了一个可视化的解释,可以帮助我们解释和调试模型的行为。

6.结合原始图像和热图

下面的代码将图像绘制在另一个图像上。

from torchvision.transforms.functional import to_pil_imagefrom matplotlib import colormapsimport numpy as npimport PIL# Create a figure and plot the first imagefig, ax = plt.subplots()ax.axis('off') # removes the axis markers# First plot the original imageax.imshow(to_pil_image(img_tensor, mode='RGB'))# Resize the heatmap to the same size as the input image and defines# a resample algorithm for increasing image resolution# we need heatmap.detach() because it can't be converted to numpy array while# requiring gradientsoverlay = to_pil_image(heatmap.detach(), mode='F')                      .resize((256,256), resample=PIL.Image.BICUBIC)# Apply any colormap you wantcmap = colormaps['jet']overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)# Plot the heatmap on the same axes, # but with alpha < 1 (this defines the transparency of the heatmap)ax.imshow(overlay, alpha=0.4, interpolation='nearest', extent=extent)# Show the plotplt.show()

这就是结果。由于它是一个正常的X射线,该模型看的最多的是正常X射线中预期的正常结构。

在另一个例子中,我们有一个肺炎的X射线。而Grad-CAM正确地显示了医生为确定是否有肺炎而必须查看的胸部X光片的区域。

最后,要从你的模型中删除钩子,你只需要调用每个句柄中的*remove()*方法。

backward_hook.remove()forward_hook.remove()

总结

我希望这篇文章有助于澄清Grad-CAM是如何工作的,如何使用Pytorch实现它,如何在不改变原始模型前向功能的情况下通过使用前向和后向钩子来实现它。

我想感谢Stepan Ulyanin的文章,感谢他帮助我更好地理解Grad-CAM。我希望我也能对你有所贡献。