跟李沐学AI随记-14-目标检测和数据集

2 阅读3分钟

目标检测定义:图像里有多个我们感兴趣的目标,我们不仅想知道它们的类别,还想得到它们在图像中的具体位置

边界框:(bounding box)来描述对象的空间位置。 边界框是矩形的,由矩形左上角的以及右下角的x和y坐标决定。 另一种常用的边界框表示方法是边界框中心的(x,y)轴坐标以及框的宽度和高度。

两种常用方法的代码实现:

def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    w = x2 - x1
    h = y2 - y1
    boxes = torch.stack((cx, cy, w, h), axis=-1)
    return boxes


def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    boxes = torch.stack((x1, y1, x2, y2), axis=-1)
    return boxes

绘制方框:用plt库

# bbox是边界框的英文缩写
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]

def bbox_to_rect(bbox, color):
    # 将边界框(左上x,左上y,右下x,右下y)格式转换成matplotlib格式:
    # ((左上x,左上y),宽,高)
    return d2l.plt.Rectangle(
        xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
        fill=False, edgecolor=color, linewidth=2)
fig = d2l.plt.imshow(img)
# 改变颜色
fig.axes.add_patch(bbox_to_rect(dog_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox, 'red'))

目标检测数据集:

  • 每行表示一个物体
    • 图片文件名+物体类别+边缘框
  • 用COCO数据集

下面的代码演示用了banana数据集,防止出现版权问题(老师说原本是皮卡丘的,怕有版权问题才换了,很可惜)

import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

# @save
d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')


def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    # 下载并解压文件
    data_dir = d2l.download_extract('banana-detection')
    # 转换成csv文件
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
    else 'bananas_val', 'label.csv')
    csv_data = pd.read_csv(csv_fname)
    # 将 'img_name' 列设置为索引,得到一个 DataFrame 对象。
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        # 把文件读到内存去
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'bananas_train' if is_train else
            'bananas_val', 'images', f'{img_name}')))
        # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
        # 其中所有图像都具有相同的香蕉类(索引为0)
        targets.append(list(target))
        # unsqueeze(1) 是将 targets 的维度增加一维,具体地,在索引 1 处增加一个维度。这意味着,如果 targets 是一个形状为 (N,) 的张量,
        # 经过 unsqueeze(1) 操作后,它将变成一个形状为 (N, 1) 的张量,其中 N 是样本数量。
        # 图片大小都是256的,这里用于归一化
    return images, torch.tensor(targets).unsqueeze(1) / 256


# 创建dataset实例
class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""

    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
                                                   is_train else f' validation examples'))

    def __getitem__(self, idx):
        return self.features[idx].float(), self.labels[idx]

    def __len__(self):
        return len(self.features)


# 数据加载器实例。正常情况下返回的是五维,第一维为类别,后面四维是boundingbox
def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                                             batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                                           batch_size)
    return train_iter, val_iter


batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)

# 迭代一次进行展示
batch = next(iter(train_iter))
print(batch[0].shape, batch[1].shape)

# 演示
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
# 展示前十张图像
for ax, label in zip(axes, batch[1][0:10]):
    # 拿后四维--boundingbox。边框的存储要小心,是用0~1归一化后的存,还是按实际像素位置存。
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

图像中没有绘制出来boundingbox--小问题细节没去纠了,主要是感受流程

image.png