极智开发 | 剖析 darknet entry_index 指针偏移逻辑

1,045 阅读3分钟

  一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第21天,点击查看活动详情

欢迎关注我的公众号 [极智视界],获取我的更多笔记分享

  大家好,我是极智视界。本文分析一下 darknet entry_index 输出指针偏移逻辑。

   Yolo 在目标检测任务中使用广泛,之前也写过一篇关于 yolo 中 route 算子的文章《【模型推理】谈谈 darknet yolo 的 route 算子》,有兴趣的同学可以查阅。

   这里分析一下 yolo 层后处理中 entry_index 的指针偏移逻辑,还是挺有精华的。

1、yolo 层输出数据排布

  以 yolov4 为例,输出为三个 yolo 分支(可能截图比较小,你能看到三个头就行):

   在 cfg 里你还能看到 yolo 层的一些信息如下:

   对于推理来说我们只需要关心 mask、anchors、classes 再加上 nms 阈值和置信度阈值就可以了。那么 yolo 层的输出数据是怎么排布的呢?首先三个 yolo 层的数据肯定是切开独立的,拿其中一个来说:

  (1)数据按四维 [N, C, H, W] 来说,N 为 batch,C 为 3 * (5 + classes)、H / W 为 feature_map 高和宽。需要解释一下 C,C = 3 * (1 + 4 + classes),其中 1 表示置信度,4 为检测框位置信息,classes 为类别数,即每个类别给出一个检测得分,乘 3 表示每个格子有 3 个锚框。这样就形成了 yolo 层输出的四维数据排布;

   (2)darknet 里会用一维动态数组来存放 yolo 层的输出数据,这里就涉及到怎么将四维数据转换为一维数据的问题。darknet 里是这么做的,假设四维数据为 [N, C, H, W] ,每个维度对应的索引为 [n, c, h, w],那么展开就是 n*C*H*W + c*H*W + h*W + w,按这样的逻辑存放到 *output 中。

2、entry_index 实现逻辑

   先来看一下 entry_index 函数的实现代码:

static int entry_index(layer l, int batch, int location, int entry)
{
    int n =   location / (l.w*l.h);
    int loc = location % (l.w*l.h);
    return batch*l.outputs + n*l.w*l.h*(4+l.classes+1) + entry*l.w*l.h + loc;
}

   这个指针偏移操作代码实现很简单,也挺讲究,充分反映了上面说的 yolo 层输出的数据排布。你可以把函数的 return 部分对应到上面我们说的 四维映射到一维的过程,即 [n, c, h, w] -> n*C*H*W + c*H*W + h*W + w,开始奇妙之旅,对照一下,n*C*H*W -> batch*l.outputsc*H*W -> n*(4+1+l.calsses)h*W -> l.w*l.hw -> loc,这样是不是很清晰,但你如果仔细点应该可以发现少了个 entry,这个下面再说。

3、yolo 输出处理逻辑

  来看一个函数 yolo_num_detections,这个函数的作用是统计三个 yolo 分支输出的检测框的数量:

int yolo_num_detections(layer l, float thresh)
{
    int i, n;
    int count = 0;
    for(n = 0; n < l.n; ++n){
        for (i = 0; i < l.w*l.h; ++i) {
            int obj_index  = entry_index(l, 0, n*l.w*l.h + i, 4);
            if(l.output[obj_index] > thresh){
                ++count;
            }
        }
    }
    return count;
}

  这里有两个循环,外循环是 0 ~ n,n 是锚框的数量,为 3;内循环为 0 ~ h * w。来看:

int obj_index  = entry_index(l, 0, n*l.w*l.h + i, 4);

   再结合上面的 entry_index 的实现,entry = 4,所以在第一个外循环 n = 0 时,index 是基于 4 * w * h 的基础上做 0 ~ w * h 间递增,乘 4 的语义是跳过框的数据,然后可以尽情取置信度了,如下图:

   取了检测框的置信度后再与我们外部设置的 thresh 比较进行检测框的筛选:

if(l.output[obj_index] > thresh){
    ++count;
}

  如此走一遍内循环就对 h * w 每个格子里走了一遍,并进行了第一个锚框出来的检测框的筛选工作,然后再走外循环,接着进行第 2 个、第 3 个锚框出来的检测框的筛选。


   通过以上的分析,应该能对 yolo 层的输出结构和处理方式更加清晰了。


 【公众号传送】

《【web】antd 全局数据共享示例》


logo_show.gif