虽然有些过时,但如果自己动手实现一遍 YOLOv1 势必会有所收获(3)—加载数据集

401 阅读6分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第5天,点击查看活动详情

前言

今天是周末,难得有时间出去转一转,每天忙碌并不是自己没有可控的时间,而是最近想在短时间让自己有所提升,以顺应这个时代。所以不舍得时间去感受一下生活。现在 8 点多了,忽然今天还有一篇日更任务还没有完成,打起精神打开我的老 mac ,带上耳机,稍微调整,指尖就开始在键盘上飞舞。今天是跟大家分享的是如何读取数据集。数据集这里选择的 COCO 128 ,所以选择这个数据集主要是为了演示如何加载数据集到模型。

category_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
        'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
        'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
        'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
        'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
        'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
        'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
        'hair drier', 'toothbrush'] 
# coco128 数据集

# 选择一个样作为研究对象
sample = 1

# 图像格式
image_formate = "jpg"
# 标注格式
label_formate = "txt"
# 项目目录
project_path = os.getcwd()
# 数据库目录
dataset_path = "dataset\\coco128\\"
# 图像和标注存放位置
img_dir = "images\\train2017"
label_dir = "labels\\train2017"
# 图像文件列表
img_paths = glob.glob(os.path.join(project_path,dataset_path,img_dir,"*.jpg"))
# 标注文件列表
label_paths = glob.glob(os.path.join(project_path,dataset_path,label_dir,"*.txt"))

简单了解一下 coco128 数据集

img_path = img_paths[sample]
im = Image.open(img_path)
plt.imshow(im)
plt.show()

301.png

定义数据集类,这个类继承 torch.utils.data.Dataset,需要实现两个方法分别为 __len__ 返回数据集大小,也就是样本的数量,__getitem__ 根据索引返回输入索引的样本和标注。

class COCO128Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        image_paths,#存放图像的目录
        label_paths,#存放标签的目录
        S = 7, #网格数量
        B = 2, #每一个网格(grid cell)产生边界框(bbox) 数量
        C = 80,#类别数目
        transform=None #图像数据变换
        ):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.S = S
        self.B = B
        self.C = C
        self.transform = transform
        #对图像列表和标注列表进行排序

模型输出预测数据格式为 batch_size x grid_size x grid_size x (B x (1 + 4) + class_number), 所以需要指定 S x S 网格,每个网格输出 2 两个边界框(bbox)的置信度和中心坐标和宽高,以及类别数

len 方法的实现

    def __len__(self):
        return len(self.image_paths)

getitem 方法的实现

    def __getitem__(self,index):
        #校验图像文件和标注文件是否一致
        
        assert os.path.basename(self.image_paths[index]).split('.')[0] == os.path.basename(self.label_paths[index]).split('.')[0]
        
        boxes = []
        with open(self.label_paths[index]) as f:
            for label in f.readlines():
                class_label, x, y, width, height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n","").split()
                ]
                boxes.append([class_label, x, y, width, height])

在数据集中每一个图像文件,都会有一个对应同名的后缀 txt 的文件,在 txt 文件中,每一行都是一个目标,每一行包括该目标类别信息,类别索引,以及中心点和宽高信息,这些数值都是以宽和高的比率,这些数值之间用空格间隔,接下来我们要做的是逐行读取目标信息。

        #获取样本图像路径
        img_path = self.image_paths[index]
        
        #使用 PIL 来读取图像,PIL 读取图像不会改变图像RGB通道顺序
        image = Image.open(img_path)
        
        #转换为 tensor 好处是随后在变换(transform)时候        
        boxes = torch.tensor(boxes)

读取图像,并且将边界框和图像都转为 tensor 以备用,因为在 torch 提供 transform 需要输入 tensor 格式。

        if self.transform:
            # 将图像和标注一并输入到 transform 当对于图像进行调整(旋转,缩放等几何变换)时
            # 标注值也会随之变换             
            image, boxes = self.transform(image,boxes)
            
        label_matrix = torch.zeros((self.S,self.S,self.C + 5 ))

解析边界框为标注

        for box in boxes:
            class_label, x, y, width, height = box.tolist()
            #将类别为整数类型             
            class_label = int(class_label)
            # i 和 j 表示哪一个 grid cell 中有元素 
            # 将图像划分为 S*S 个网格,这里 x 和 y 相对于图像 w 和 h 的比率,如果将 w 和 h 转换为 S 和 S
            # 那么 i 和 j 和 S 做乘法就会得到相对于 S*S 大小网格的相对位置,对其进行取整也就是得到了
            # x,y 中心点落在了那个位置
            i, j = int(self.S * y),int(self.S * x)
            # 然后 x 和 y 在网格中相对于网格左上角的位置            
            x_cell,y_cell = self.S * x - j, self.S * y - i
            # 目标物体             
            width_cell, height_cell = (width*self.S,height*self.S)
            # 判断如果 i,j 位置是否已经存在目标,如果已经存在则跳过这一步,因为在 YOLOv1 中
            # 每一个网格只能负责预测一个目标,不然就是
            if label_matrix[i,j,self.C] == 0:
                label_matrix[i,j,self.C] = 1
                box_coordinates = torch.tensor([x_cell,y_cell,width_cell,height_cell])
                label_matrix[i,j,(self.C+1):(self.C+5)] = box_coordinates
                label_matrix[i,j,class_label] = 1
        return image,label_matrix

在 YOLO 中如果目标的中心哪一个网格,那么就由这个网格来负责预测这个目标,所以我们应该考虑如何定位到目标对应网格位置,目标中心点位置相对于左上角点的偏移位置,我们知道从 box 拿到 x 和 y 相对于图像宽和高的比率,现在图像映射到 S×SS \times S 的网格,每个网格大小为 1×11 \times 1 i, j = int(self.S * y),int(self.S * x) 就可以得到目标所在网格, 然后计算 x_cell、y_cell 以及 width_cell 和 height_cell

S = 7
def preview_image_with_grid(im,size=(224,224),S=7):
    im = im.copy()
    resize_im = im.resize((224,224))
    step = 224//S
    x = range(0,224,step)
    for i in x:
        draw = ImageDraw.Draw(resize_im)
        draw.line((0,i,224,i),fill=(0,0,255))
        draw.line((i,0,i,224),fill=(0,0,255))
        
    return resize_im

resize_im = preview_image_with_grid(im)
plt.imshow(resize_im)
plt.show()

302.png

ground_truth_boxes = []
with open(label_paths[sample]) as f:
    for label in f.readlines():
        class_label, x, y, w, h = [
            float(x) if float(x) != int(float(x)) else int(x)
            for x in label.replace("\n","").split()
        ]
        ground_truth_boxes.append([class_label,x,y,w,h])
ground_truth_boxes
[[23, 0.770336, 0.489695, 0.335891, 0.697559],
 [23, 0.185977, 0.901608, 0.206297, 0.129554]]
x,y = ground_truth_boxes[0][1:3]
x,y #(0.770336, 0.489695)
i,j = int(7*y),int(7*x)
i,j #(3, 5)
cell_grid_rect = ((32*5, 32*3),(32*6, 32*4))
resize_im_copy = resize_im.copy()
resize_im_copy_draw = ImageDraw.Draw(resize_im_copy)
resize_im_copy_draw.rectangle(cell_grid_rect,outline="yellow",width=5)

303.png

完整代码

class COCO128Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        image_paths,#存放图像的目录
        label_paths,#存放标签的目录
        S = 7, # 网格数量
        B = 2, #每一个网格(grid cell)产生边界框(bbox) 数量
        C = 80,#类别数目
        transform=None #图像数据变换
        ):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.S = S
        self.B = B
        self.C = C
        self.transform = transform
        #对图像列表和标注列表进行排序
        
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self,index):
        #校验图像文件和标注文件是否一致
        assert os.path.basename(self.image_paths[index]).split('.')[0] == os.path.basename(self.label_paths[index]).split('.')[0]
        boxes = []
        #逐行读取标注文件,每行表示图像中一个目标
        with open(self.label_paths[index]) as f:
            for label in f.readlines():
                class_label, x, y, width, height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n","").split()
                ]
                boxes.append([class_label, x, y, width, height])
        #获取样本图像路径
        img_path = self.image_paths[index]
        
        #使用 PIL 来读取图像,PIL 读取图像不会改变图像RGB通道顺序
        image = Image.open(img_path)
        
        #转换为 tensor 好处是随后在变换(transform)时候        
        boxes = torch.tensor(boxes)
        
        if self.transform:
            # 将图像和标注一并输入到 transform 当对于图像进行调整(旋转,缩放等几何变换)时
            # 标注值也会随之变换             
            image, boxes = self.transform(image,boxes)
            
        label_matrix = torch.zeros((self.S,self.S,self.C + 5 ))
        
        for box in boxes:
            class_label, x, y, width, height = box.tolist()
            #将类别为整数类型             
            class_label = int(class_label)
            # i 和 j 表示哪一个 grid cell 中有元素 
            # 将图像划分为 S*S 个网格,这里 x 和 y 相对于图像 w 和 h 的比率,如果将 w 和 h 转换为 S 和 S
            # 那么 i 和 j 和 S 做乘法就会得到相对于 S*S 大小网格的相对位置,对其进行取整也就是得到了
            # x,y 中心点落在了那个位置
            i, j = int(self.S * y),int(self.S * x)
            # 然后 x 和 y 在网格中相对于网格左上角的位置            
            x_cell,y_cell = self.S * x - j, self.S * y - i
            # 目标物体             
            width_cell, height_cell = (width*self.S,height*self.S)
            # 判断如果 i,j 位置是否已经存在目标,如果已经存在则跳过这一步,因为在 YOLOv1 中
            # 每一个网格只能负责预测一个目标,不然就是
            if label_matrix[i,j,self.C] == 0:
                label_matrix[i,j,self.C] = 1
                box_coordinates = torch.tensor([x_cell,y_cell,width_cell,height_cell])
                label_matrix[i,j,(self.C+1):(self.C+5)] = box_coordinates
                label_matrix[i,j,class_label] = 1
        return image,label_matrix

数据集