本文以 PASCAL VOC2012 数据集为例子进行说明。(下载地址:PASCAL VOC2012)
Pytorch 自定义数据集见文档:TorchVision Object Detection Finetuning Tutorial
本文将以PASCAL VOC为基础自定义一个数据集VOCDataset
,并随机选取五张图片给将其对应的标注转化为矩形框画在图片上。
本文详细代码见:pytorch-tutorial/01-common/custom_dataset at main · simo-an/pytorch-tutorial (github.com)
定义一些工具类
定义类别数据,共有20中目标类别
class_dict = {
"aeroplane": 1,
"bicycle": 2,
"bird": 3,
"boat": 4,
"bottle": 5,
"bus": 6,
"car": 7,
"cat": 8,
"chair": 9,
"cow": 10,
"diningtable": 11,
"dog": 12,
"horse": 13,
"motorbike": 14,
"person": 15,
"pottedplant": 16,
"sheep": 17,
"sofa": 18,
"train": 19,
"tvmonitor": 20
}
将xml 转化为 json
def parse_xml_to_dict(xml):
if len(xml) == 0:
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child)
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result:
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
在图片上画出矩形框(参考代码: vision/utils.py at main · pytorch/vision (github.com))
def draw_bounding_boxes(
image,
boxes: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
img_to_draw = Image.fromarray(image)
img_boxes = boxes.to(torch.int64).tolist()
draw = ImageDraw.Draw(img_to_draw)
class_map = [k for k, v in class_dict.items()]
for i, bbox in enumerate(img_boxes):
draw.rectangle(bbox, width=2, outline='red')
margin = 2
draw.text((bbox[0] + margin, bbox[1] + margin),
class_map[labels[i] - 1], fill='red')
return np.array(img_to_draw)
生成自定义数据集
一些需要导入的基本库
import torch
import utils
from torch.utils.data import Dataset
from PIL import Image
from os import path
from lxml import etree
按照文档要求,在VOCDataset
中实现三个方法__len__
、__getitem__
、以及get_height_and_width
。
初始化 VOCDataset 类
构造函数定义如下
'''
voc_root: voc 数据集的根目录
year: 哪一个年份的数据集
transforms: 数据预处理
text_name: train.txt or val.txt 该txt文件在数据集的 VOCdevkit\VOC2012\ImageSets\Main 文件夹下
'''
def __init__(self, voc_root, year='2012', transforms=None, text_name='train.txt'):
在构造函数中,我们主要完成以下三个功能
- 设置图片路径
image_root
和标注路径anno_root
- 设置此次要训练的样本所有标注文件路径列表
xml_list
- 设置要检测的目标类别信息
class_dict
设置图片路径image_root
和标注路径anno_root
# 设置数据集、图片、标注的根目录
self.root = path.join(voc_root, 'VOCdevkit', f'VOC{year}')
self.image_root = path.join(self.root, 'JPEGImages')
self.anno_root = path.join(self.root, 'Annotations')
设置此次要训练的样本所有标注文件路径列表xml_list
# 根据 text_name 拿到对应的标注xml文件路径
text_path = path.join(self.root, 'ImageSets','Main', text_name)
# 读取txt文件的每一行并生成xml标注文件路径存放在xml_list中
with open(text_path) as file_reader:
self.xml_list = [
path.join(self.anno_root, f'{line.strip()}.xml')
for line in file_reader.readlines() if len(line.strip()) > 0
]
设置要检测的目标类别信息class_dict
self.class_dict = utils.class_dict
一般使用 0 来表示当前类别是背景
获取所有样例条数
def __len__(self):
return len(self.xml_list)
样本的条数即标注文件列表长度
根据索引获取指定样本
函数定义如下
def __getitem__(self, idx):
传入的即为样本的索引值,其取值范围为 0 ~ len(xml_list)
获取指定样本需要分为如下两大步
- 获取图片
- 获取图片信息(标注信息、索引、区域面积等)
获取图片
首先我们需要根据索引拿到对应标注信息,并将其转化为json
格式
定义一个获取json
格式的annotation
的方法
def get_annotation(self, idx):
xml_path = self.xml_list[idx]
assert path.exists(xml_path), f'file {xml_path} not found'
xml_reader = open(xml_path)
xml_text = xml_reader.read()
xml = etree.fromstring(xml_text)
annotation = utils.parse_xml_to_dict(xml)['annotation']
获取annotation
annotation = self.get_annotation(idx)
然后我们就可以从annotation
中拿到文件名称并获取到文件
image_path = path.join(self.image_root, annotation['filename'])
image = Image.open(image_path)
获取图片信息
声明需要获取的所有信息
# 生成 target
target = {
'boxes': [], # 标注的左上、右下坐标(xmin, ymin, xmax, ymax)
'labels': [],# 标注类别
'image_id': [], # 图片索引
'area': [], # 含有目标区域的面积 (xmax-xmin) * (ymax-ymin)
'iscrowd': [], # 是不是一堆密集的东西在一起
}
便利所有的object
for obj in annotation['object']:
bndbox = obj['bndbox']
xmin = float(bndbox['xmin'])
ymin = float(bndbox['ymin'])
xmax = float(bndbox['xmax'])
ymax = float(bndbox['ymax'])
target['boxes'].append([xmin, ymin, xmax, ymax]) # 设置有目标的坐标信息
target['labels'].append(self.class_dict[obj['name']]) # 获取对应的标签
target['area'].append((xmax - xmin) * (ymax - ymin)) # 计算面积
# 使用 difficult(当前目标是否难以识别) 字段来设置 iscrowd
if 'difficult' in obj:
target['iscrowd'].append(int(obj['difficult']))
else:
target['iscrowd'].append(0)
将所有信息转化为Tensor
# Convert to tensor
target['boxes'] = torch.as_tensor(target['boxes'])
target['labels'] = torch.as_tensor(target['labels'])
target['iscrowd'] = torch.as_tensor(target['iscrowd'])
target['area'] = torch.as_tensor(target['area'])
target['image_id'] = torch.tensor([idx])
如果有设置数据预处理器,则在返回数据前调用
if self.transforms is not None:
image = self.transforms(image)
返回图片以及对应的信息
return image, target
根据索引获取当前图片的宽高
在标注信息里面含有图片宽高信息,所以可以很容易获取到
def get_height_and_width(self, idx):
annotation = annotation = self.get_annotation(idx)
# 从 annotation 中取出宽高并返回
width = int(annotation['size']['width'])
height = int(annotation['size']['height'])
return height, width
以上我们就完成了数据集的定义,下面我们将使用实例代码来使用这个数据集
使用自定义数据集并画上标注框
导入一些基本库
import os
import random
import utils
import main
import matplotlib.pyplot as plt
import numpy as np
定义transformer
,将数据转化为Tensor
data_transform = ts.Compose([ts.ToTensor()])
由于ToTensor
会将数据标准化,为了代码简洁,这里不使用
拿到数据集并将目标框以及类别画出来
train_data_set = VOCDataset(os.getcwd(), '2012', None, 'train.txt')
for index in random.sample(range(0, len(train_data_set)), k=5):
image, target = train_data_set[index]
image = utils.draw_bounding_boxes(
np.array(image),
target['boxes'],
target['labels'],
)
plt.imshow(image)
plt.show()
这样就完成了整个流程了!
运行与测试
可见运行结果正确!