Faster R-CNN 之初入篇

1,110 阅读3分钟

该系列文章通过 图片小M在 Faster R-CNN 中的奇妙之旅来解构 Faster R-CNN 网络结构

初入篇代码见: pytorch-tutorial/framework.py

有一个原始部落被称为图之国,那里生活着各色各样的图片,每一张图片在出生之时,身上不同的口袋里装有不同的物品,而每一张图片一生的宿命就是弄清楚哪只口袋里装着东西,装了什么。

在这个部落里居住着许多智慧的建筑师,他们建造着一个又一个的迷宫。而这些迷宫有的结构复杂,有的构建巧妙,参差不齐,短则几天,长则一生可能都走不完。而每一张走完迷宫的图片,便可以完成这一生的使命。

其中有一位建筑师RHRS历时数年创建了名为Faster R-CNN的迷宫,该迷宫构造玄妙,成为当时最为著名的迷宫之一。

这一天,图片小M 来到了这个迷宫之前,敲开了大门,只见在迷宫入口的玄关墙上,刻着如下的文字

欲过此迷宫者,必先历下之五关,中难易皆存,愿君保持思考、保持耐心方可。

关卡一: Transform 对图片进行预处理,以使其满足进入Faster R-CNN最基本的要求
关卡二: Backbone 获取图片的特征图
关卡三: RPN 给出图片含有检测目标的可能区域
关卡四: ROI Pooling & Flatten 将所有的可能区域进行处理,以使之满足后序处理的要求
关卡五: 分类与回归 对获取到的可能区域进行分类得知其是何物,进行回归得知其在何处

最终经过 Post Process方可 将得到的类别与位置在原图上标志出来

文字的下方该给出了整个迷宫的设计图,图下文字注明,各关卡详细设计图皆由途中悉数给出。

image.png

正在小M思忖的同时,一个卷轴掉了下来,打开看来,迷宫构建的基础结构被完整的写在了下面,如:

class FasterRCNNBase(nn.Module):
    '''
    transform: 即关卡一,用于对图片做预处理
    backbone: 即关卡二,用于提取图片的特征(Feature Map)
    rpn: 即关卡三,生成图片中可能含有目标的区域(Region Proposal)
    roi_heads: 即关卡四、五的集成,用于从获取最后的含有目标的区域
    '''
    def __init__(self, transform, backbone, rpn, roi_heads):
        super(FasterRCNNBase, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
        
    def forward(self, images: List[Tensor], targets=None):
    
        raw_image_shape: List[Tuple[int, int]] = []
        for image in images:
            raw_image_shape.append((image.shape[1], image.shape[2]))

        '''第一关卡'''
        images, targets = self.transform(images, targets)  # 对图像进行预处理
        
        '''第二关卡'''
        features = self.backbone(images.tensors) # 将预处理后的图像输入到 backbone 得到特征图

        '''第三关卡'''
        # 将特征图输入到RPN中,得到 region proposals
        proposals, proposal_losses = self.rpn(images, features, targets) 
        
        '''第四、五关卡'''
        # 将 特征图、可能含有目标的区域、预处理后的图像、targets 传入 roi_heads 得到 检测目标
        detections, detector_losses = self.roi_heads(
            features, proposals, images, targets
        )
        
        ''' 最终,将预测的bboxes还原到原始图像尺度上 '''
        detections = self.transform.post_process(
            detections, images, raw_image_shape
        )

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

        return losses, detections

看完了文字说明、设计图纸以及构建框架之后,小M非常有信心能走完整个迷宫。

带着卷轴,小M 推开了 第一关: Transform 的大门,那么 Transform 中究竟有什么呢?且听下回分解……