携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第18天,点击查看活动详情
非极大值抑制(Non Max Suppression)
在预测过程中,会产生许多预测边界框,首先会按置信度概率进行降序排序,然后将所有的候选框和第一个进行比较
主要用于清除多余边界框,因为非极大值抑制的存在限制了速度,在当下流行模型中已经逐渐取消对其使用,不过在早期目标检测模型中还是存在的。 计算边界框之间 IoU 然后找到那些彼此 IoU 比较的边界框,在这些边界框中仅是保留置信度比较高的边界框,非废弃其他边界框。关于对比边界框 IoU 是按类别进行,也就是说我们不能跨类别取计算边界框之间 IoU,也就是我们只会计算同属于一个类别或者说一个目标的多个边界框之间。
def non_max_suppression(bboxes,iou_threshold,prob_threshold,box_format="corners"):
# preds = [[类别],[置信度],[x1,y1,x2,y2]]
assert type(bboxes) == list
# 先用边界框置信度进行筛选,置信度不达标就会在一轮筛选中淘汰
bboxes = [box for box in bboxes if box[1] > prob_threshold]
# 对边界框按置信度进行降序排列
bboxes = sorted(bboxes, key=lambda x:x[1], reverse=True)
bboxes_after_nms = []
# 遍历 bboxes
while bboxes:
#从bboxes 列表中随机抽取一个边界框
chosen_box = bboxes.pop(0)
#选择与 chosen_box 类别相同的边界框,这个语句是理解,
bboxes = [box for box in bboxes if box[0] != chosen_box[0] or intersection_over_union(
torch.tensor(chosen_box[2:]),torch.tensor(box[2:]),box_format=box_format) < iou_threshold]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
chosen_box = bboxes.pop(0) 先从列表开始位置取出第一元素,这是因为 bboxes 已经进行降序排序了,也就是第一个置信度是最大的边界框。
注意是不放回的取出,然后开始遍历列表,保留那些和选中元素类别不相同的元素,保留和选中框 IoU 小于阈值的边界框。
在开始定义一些工具类前,为了演示效果,我们先做一些准备,也就是准备数据、准备模型,工具类主要用于在
- 准备数据
- 准备模型
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self,img,bboxes):
for t in self.transforms:
img, bboxes = t(img),bboxes
return img, bboxes
import torchvision.transforms as transforms
# 对于图像进行变换(448,448)并且将其转换为 tensor
transform = Compose([transforms.Resize((448,448)),transforms.ToTensor(),])
# 定义数据集
train_dataset = COCO128Dataset(img_paths,label_paths,transform=transform)
from torch.utils.data import DataLoader
# 定义加载数据集
batch_size = 2
num_workers = 0
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
drop_last=True
)
# 引入模型
split_size=7
num_boxes=2
num_classes=80
model = Yolov1(split_size=split_size,
num_boxes=num_boxes,
num_classes=num_classes)
for batch_idx,(x,labels) in enumerate(train_loader):
preds = model(x)
bboxes = cellboxes_to_boxes(preds)
预测结果转换为 cellboxes(convert_cellboxes)
"""
preds: 是模型预测的结果
batch_size: 批量大小
S: 网格尺寸
C: 类别数
将预测结果转换为 cellboxes
"""
def convert_cellboxes(preds,batch_size=2,S=7,C=80):
# [2, 4410] batch size x S*S*(2*5 + C=80)
preds = preds.to("cpu")
# 将预测结果输出 reshape (m,S,S,(C + B*(1+4)=10))
preds = preds.reshape(batch_size,S,S,(10 + C))
# 2, 7, 7, 90
#print(preds.shape)
# 获取 4 个坐标点,[m,S,S,4]
bboxes1 = preds[...,C+1:C+5]
bboxes2 = preds[...,C+6:C+10]
# 获取置信度,也就是每个 bbox 包含物体 [m,S,S,2]
# 这里多出一个维度就是为了 cat
scores = torch.cat(
(preds[...,C].unsqueeze(0),preds[...,C+5].unsqueeze(0)),dim=0
)
#计算两个 bboxes 中置信度较大索引保留,1 表示保留第二个预测边界框 0 表示保留第一个预测边界框
best_box = scores.argmax(0).unsqueeze(-1)
#保留置信度较大边界框对应数据
best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2
# 就是为每一个位置给出 grid 0,1,2,3,4,5,6 标号
"""
batch_size x 7 x7 x 1
"""
cell_indices = torch.arange(S).repeat(batch_size,S,1).unsqueeze(-1)
# 3.5 表示位于 4 网格中心 除以 S 这样就得到 x 相对于宽度(w)的比率
x = 1 / S*(best_boxes[...,:1] + cell_indices)
# 交换 x 和 y 轴,这样一来就将索引转换到列上,
# 想要更好理解 tensor 变换对于深度学习是多么重要
y = 1 / S*(best_boxes[...,1:2] + cell_indices.permute(0,2,1,3))
# 对于宽高直接除以 S 即可
w_y = 1 / S * best_boxes[...,2:4]
# 在最后一个维度将 x,y w_y 进行拼接
converted_bboxes = torch.cat((x,y,w_y),dim=-1)
# 选取 80 类别预测分数中最高值位置对应索引,然后为了进行拼接增加一个维度
pred_class = preds[...,:C].argmax(-1).unsqueeze(-1)
# 提取类别
best_confidence = torch.max(preds[...,C],preds[...,C+5]).unsqueeze(-1)
converted_preds = torch.cat((pred_class,best_confidence,converted_bboxes),dim=-1)
return converted_preds
预测结果转换为 boxes(cellboxes_to_boxes)
def cellboxes_to_boxes(out,S=7):
#[2,7,7,6] 分别为 batch size,7x7 网格以及 6 个指标网格给出预测类别数、置信度得分和位置信息
#print(convert_cellboxes(out).shape)
# 2 49 6 将网格展平
converted_pred= convert_cellboxes(out).reshape(out.shape[0],S*S,-1)
# 将类别分类转换为 long
converted_pred[...,0] = converted_pred[...,0].long()
#
all_bboxes = []
# 为 batch size 获取到每一个样本的所有网格都是对应一个边界框
for ex_idx in range(out.shape[0]):
bboxes = []
for bbox_idx in range(S*S):
bboxes.append(x.item() for x in converted_pred[ex_idx,bbox_idx,:])
all_bboxes.append(bboxes)
return all_bboxes