faster rcnn 源码(1)——GeneralizedRCNN

721 阅读3分钟

faster rcnn 继承于GeneralizedRCNN

对于 GeneralizedRCNN 类,其中有4个重要的接口:

  1. transform : 主要是标准化和把图片缩放到固定大小,后续说明
  2. backbone :一般是VGG、ResNet、MobileNet 等网络
  3. rpn:通过rpn生成proposals 和 proposal_losses
  4. roi_heads:roi pooling + 分类
class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN.

    Args:
        backbone (nn.Module):
        rpn (nn.Module):
        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """

    def __init__(self, backbone, rpn, roi_heads, transform):
        super(GeneralizedRCNN, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_head

前向传播

def forward(self, images, targets=None):
    # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
    """
    Args:
        images (list[Tensor]): images to be processed
        targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

    Returns:
        result (list[BoxList] or dict[Tensor]): the output from the model.
            During training, it returns a dict[Tensor] which contains the losses.
            During testing, it returns list[BoxList] contains additional fields
            like `scores`, `labels` and `mask` (for Mask R-CNN models).

    """

images 参数:(list[Tensor(C,H,W)*Batch_size])

targets 可选参数:传递gt,是一个列表,列表中的每个元素都是一个字典,字典中包含了与图像中的真实目标相关的信息。"boxes":一个张量(Tensor),包含了真实目标框的坐标信息。通常是一个形状为 [N, 4] 的张量,其中 N 是目标框的数量,每行表示一个目标框的坐标信息,通常是左上角和右下角的坐标。 其他键值对:可能还包含其他与目标相关的信息,比如类别标签、分割掩码等。

输出结果 :(list[BoxList] or dict[Tensor])

在训练过程中,模型返回一个字典,其中包含了损失信息, Dict[str, Tensor],键是各种损失名称 loss_name,值是损失张量。

在测试过程中,模型返回一个列表,其中包含了检测结果的信息。每个元素是一个字典,表示一张图像的检测结果。这个字典包的key为:检测框的置信度 scores、类别标签 labels、分割掩码 mask 等。因此,每个元素的类型是 List[Dict[str, Tensor]]

original_image_sizes: List[Tuple[int, int]] = []
for img in images:
    val = img.shape[-2:]
    assert len(val) == 2
    original_image_sizes.append((val[0], val[1])) # 记录变换前original_images_sizes
images, targets = self.transform(images, targets)
# transfrom的定义为class GeneralizedRCNNTransform(nn.Module),对images和target都进行resize等操作 
# 这里transform返回的images是ImageList类型(Tensors:tensor,image_sizes:List)

这里 transform 主要包括标准化和将图像缩放到固定大小

需要说明的是,把缩放后的图像输入网络,那么网络输出的检测框也是在缩放后的图像上的。但是实际中我们需要的是在原始图像的检测框,为了对应起来,所以需要记录变换前original_images_sizes。

进入主要网络流程

image.png

features = self.backbone(images.tensors)

将transform 后的图像进入backbone(一般包括VGG,ResNet,MobileNet等网络) 提取特征

if isinstance(features, torch.Tensor):
    features = OrderedDict([('0', features)])

类型检查和转化,如果 featurestorch.Tensor 类型的实例,将 features 变量转换为一个字典类型的对象,其中包含一个键值对,键是字符串 '0',值是原始的 features 变量。

proposals, proposal_losses = self.rpn(images, features, targets)

通过 rpn 模块生成proposals和 proposal_losses

proposals 是 rpn 生成的 bbox ,type:List[tensor(n,4)*Batch_size],并且按置信度降序排序 proposal_losses 是 rpn 阶段的loss,包括置信度 loss 和 bbox回归 loss

roi_heads包括roi pooling + 分类

detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)

接着进入 roi_heads 模块,对候选区域(proposals)进行进一步处理,以产生最终的检测结果。

detctions : 每个元素代表一张输入图像进 roi_heads 处理后的检测结果。每个检测结果通常包含以下信息boxes,labels,scores.返回List[dict[Tensor]*batch_size],key为boxes,labels,scores

detector_losses :roi_head阶段的loss包括:class分类 loss 和 bbox回归 loss

detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

后续处理,经 postprocess 模块(进行 NMS,同时将 box 通过 original_images_size映射回原图,即transform阶段Resize的逆操作)

losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)

if torch.jit.is_scripting():
    if not self._has_warned:
        warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
        self._has_warned = True
    return losses, detections
else:
    return self.eager_outputs(losses, detections)

根据根据train和test阶段不同返回不同值,训练阶段返回 losses,测试阶段返回 detections