携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第26天,点击查看活动详情
RPN 网络训练
RPN 网络训练是 Faster RCNN 训练的第一步。主要思想是先使用模型将 RPN 网络进行初始化,然后对其进行训练。在训练中用到的的主要函数是 Train_rpn 函数。
首先是参数设置如下:
这里比较重要的是将 cfg.TRAIN.PROPOSAL_METHOD 参数设置为 gt,在下文会讲到这么设置的原因。在对基本参数进行设置以后,接下来就是本次训练的基础–获取 imdb 和 roidb 格式的训练数据。
首先介绍一下 imdb 和 roidb 是什么。imdb 是一个图片数据库类,内含数据库 的名字;Roidb 也叫 roi 数据库,其实就是目标检测包围盒。如下表所示,主要包括以下几个类成员:
获取训练数据主要用到 get_roidb() 函数以用于返回 roidb 数据对象。首先它会 在 cache 路径下找到以扩展名 pkl 结尾的缓存,这个文件是通过 cPickle 工具将 roidb 序列化存储的。如果该文件存在,那么它会先读取这里的内容,以提高效率。否则它将调用 _load_pascal_annotation 这个私有函数加载 roidb 中的数据,并将其 保存在缓存文件中,返回 roidb。一会下文中会讲到这一点。
如下图所示,get_roidb() 函数中需要用到的重要训练数据就是 imdb,而在 get_imdb 函数中调用了 pascal_voc 函数来对 imdb 数据进行创建。所以可见 pascal_voc 的重要性。pascal_voc 主要是用来组织输入的图片数据,主要设置了数据集的路径,图片名称的索引等,但是并不储存实际的图片信息。实际上,pascal_voc 类是 imdb 类的一个子类;当 imdb 数据己经获得后,get_roidb() 函数紧 接着向 set_proposal_method() 函数请求设置产生建议区域的方法,实际也是向 imdb 中添加 roidb 数据,这就用到了 set_proposal_method() 这个函数。
在 set_proposal_method() 函数中,整体流程是用 eval() 方法进行解析数据,使其有效,然后再将其传入 roidb_handler 中。如下图所示,首先用到了之前的 train_rpn() 函数,因为它里面设置了 cfg.TRAIN.PROPOSAL_METHOD= 'gt'(默认值是 selective search,先前用于 Fast RCNN 的),也就是请求解析数据的方法,这也是为什么前文参数中这样设置的原因。
接下来开始通过 train_rpn 中设置的 cfg.TRAIN.PROPOSAL_METHOD 参数对 gt_roidb 函数进行请求:在这个函数里使用了_load_pascal_annotation() 方法,作用是通过解析 XML 文件获得 ground truth 的 roi,在 _load_pascal_annotation 函数中,会根据每个图像的索引,到 Annotations 这个文件夹下去找相应的 XML 标注数据,然后加载所有的 bounding box 对象,并去除所有的“复杂”对象。这个时候就从 imdb 获得了最初的 roidb 格式的数据,但这还不是训练时的 roidb 数据。
在获得了 roidb 格式的数据之后,接下来的工作就是获取最终的训练数据:
如下图所示,在得到最初的 roidb 格式的数据后,继续回到 get_roidb() 函数 中,通过 get_training_roidb() 函数来得到最终用于训练的 roidb 数据,如下图所示,在 get_training_roidb() 函数中先根据 cfg.TRAIN.USE_FLIPPED 参数来判断是否需要对 roi 进行水平镜像翻转,然后通过调用 append_flipped_images() 方法来添加镜像 roi,这样做的原因在于其可以提髙最终网络的训练结果,将镜像 roi 添加完毕后继续回到 get_training_roidb() 函数中,最后经过 prepare_roidb() 函数向 roidb 中再添加一些额外的信息诸如图片路径,宽,高等信息就可以用来进行训练了。