faster rcnn 继承于GeneralizedRCNN
对于 GeneralizedRCNN 类,其中有4个重要的接口:
- transform : 主要是标准化和把图片缩放到固定大小,后续说明
- backbone :一般是VGG、ResNet、MobileNet 等网络
- rpn:通过rpn生成proposals 和 proposal_losses
- roi_heads:roi pooling + 分类
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN.
Args:
backbone (nn.Module):
rpn (nn.Module):
roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
detections / masks from it.
transform (nn.Module): performs the data transformation from the inputs to feed into
the model
"""
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_head
前向传播
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
images 参数:(list[Tensor(C,H,W)*Batch_size])
targets 可选参数:传递gt,是一个列表,列表中的每个元素都是一个字典,字典中包含了与图像中的真实目标相关的信息。"boxes":一个张量(Tensor),包含了真实目标框的坐标信息。通常是一个形状为 [N, 4] 的张量,其中 N 是目标框的数量,每行表示一个目标框的坐标信息,通常是左上角和右下角的坐标。 其他键值对:可能还包含其他与目标相关的信息,比如类别标签、分割掩码等。
输出结果 :(list[BoxList] or dict[Tensor])
在训练过程中,模型返回一个字典,其中包含了损失信息, Dict[str, Tensor],键是各种损失名称
loss_name,值是损失张量。
在测试过程中,模型返回一个列表,其中包含了检测结果的信息。每个元素是一个字典,表示一张图像的检测结果。这个字典包的key为:检测框的置信度 scores、类别标签 labels、分割掩码 mask 等。因此,每个元素的类型是 List[Dict[str, Tensor]]。
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1])) # 记录变换前original_images_sizes
images, targets = self.transform(images, targets)
# transfrom的定义为class GeneralizedRCNNTransform(nn.Module),对images和target都进行resize等操作
# 这里transform返回的images是ImageList类型(Tensors:tensor,image_sizes:List)
这里 transform 主要包括标准化和将图像缩放到固定大小
需要说明的是,把缩放后的图像输入网络,那么网络输出的检测框也是在缩放后的图像上的。但是实际中我们需要的是在原始图像的检测框,为了对应起来,所以需要记录变换前original_images_sizes。
进入主要网络流程
features = self.backbone(images.tensors)
将transform 后的图像进入backbone(一般包括VGG,ResNet,MobileNet等网络) 提取特征
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
类型检查和转化,如果 features 是 torch.Tensor 类型的实例,将 features 变量转换为一个字典类型的对象,其中包含一个键值对,键是字符串 '0',值是原始的 features 变量。
proposals, proposal_losses = self.rpn(images, features, targets)
通过 rpn 模块生成proposals和 proposal_losses
proposals 是 rpn 生成的 bbox ,type:List[tensor(n,4)*Batch_size],并且按置信度降序排序
proposal_losses 是 rpn 阶段的loss,包括置信度 loss 和 bbox回归 loss
roi_heads包括roi pooling + 分类
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
接着进入 roi_heads 模块,对候选区域(proposals)进行进一步处理,以产生最终的检测结果。
detctions : 每个元素代表一张输入图像进 roi_heads 处理后的检测结果。每个检测结果通常包含以下信息boxes,labels,scores.返回List[dict[Tensor]*batch_size],key为boxes,labels,scores
detector_losses :roi_head阶段的loss包括:class分类 loss 和 bbox回归 loss
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
后续处理,经 postprocess 模块(进行 NMS,同时将 box 通过 original_images_size映射回原图,即transform阶段Resize的逆操作)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return losses, detections
else:
return self.eager_outputs(losses, detections)
根据根据train和test阶段不同返回不同值,训练阶段返回 losses,测试阶段返回 detections