def twoTrainSeg(args, root=Path.db_root_dir('cityscapes')):
# images_base是原始训练图像所在的目录。
images_base = os.path.join(root, 'leftImg8bit', 'train')
# train_files存放所有图像的‘路径+名称’。
train_files = [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(images_base)
for filename in filenames if filename.endswith('.png')]
# number_images为所有图像总数。
number_images = len(train_files)
# permuted_indices_ls是0到number_images-1的随机排列。
permuted_indices_ls = np.random.permutation(number_images)
# indices_1和indices_2分别是permuted_indices_ls的前半部分和后半部分。
indices_1 = permuted_indices_ls[: int(0.5 * number_images) + 1]
indices_2 = permuted_indices_ls[int(0.5 * number_images):]
# batch normalization要求分割后的两个训练集都包含偶数个数据。
if len(indices_1) % 2 != 0 or len(indices_2) % 2 != 0:
raise Exception('indices lists need to be even numbers for batch norm')
# 返回分别包装成两个CityscapesSegmentation类的的训练集。
return CityscapesSegmentation(args, split='train', indices_for_split=indices_1), CityscapesSegmentation(args,
split='train',
indices_for_split=indices_2)
ATTENTION: 注意这里之所以命名这个类为CityscapesSegmentation,而不是Cityscapes,是因为——
train数据集对应一个CityscapesSegmentation,val数据集对应一个CityscapesSegmentation,test数据集也对应一个CityscapesSegmentation。
也即,cityscapes数据集其实被分成了三个独立的数据集。它们之间主要用self.split来区分。
# 继承自pytorch的dataset,需要实现_len_()和_getitem_()函数,如图1。
class CityscapesSegmentation(data.Dataset):
# 数据集里包含的类别数,cityscapes里有19类
NUM_CLASSES = 19
# 分别列出cityscapes里都有哪19个类
CLASSES = [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]
# 初始化cityscapesSegmentation类
def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train", indices_for_split=None):
self.root = root # 数据所在目录
self.split = split # 分离方法
self.args = args # 其他参数
self.files = {}
self.mean = (0.485, 0.456, 0.406) # 数据的平均值
self.std = (0.229, 0.224, 0.225) # 数据的方差值
self.crop = self.args.crop_size # 剪裁大小
# 哈哈哈,其实这里就是在取训练数据啦。
# images_base:原图所在目录。
# annotations_base: label图所在目录。
if split.startswith('re'):
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split[2:])
self.annotations_base = os.path.join(self.root, 'gtFine', self.split[2:])
else:
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
# 获取image_base目录(包括所有子目录)中,所有以.png结尾的文件;
# 返回的结果是所有文件的‘路径+文件名’组成的list列表。
self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png')
# indices_for_split 索引的排序值,相当于以某种方式打乱原来的训练数据集(shuffle?)。
if indices_for_split is not None:
self.files[split] = np.array(self.files[split])[indices_for_split].tolist()
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16个。
self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19个。
# 在cityscapes数据集中包含的所有类CLASSES中,加一个'unlabelled'类。共20个类。
self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence',
'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
'motorcycle', 'bicycle']
# 设置ignore_index为255,这个变量是用来干嘛的,目前还不知道。
self.ignore_index = 255
# 将valid_classes中19个数字为key,[0,1,2,...,18]中19个数字为value。一一对应组成字典。
self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))
# 如果self.files[split]为空,即split文件下没有任何数据,则报错。
if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
# 根据split的不同,转换数据
self.transform = self.get_transform()

函数__len__(), getitem(), encode_segmap(), recursive_glob()
# 返回该数据集里包含的所有元素个数。
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
# 获取第index个图像的路径。rstrip()函数用来删除末尾空格。
img_path = self.files[self.split][index].rstrip()
# 获取第index个图像对应的label图像的路径。
lbl_path = os.path.join(self.annotations_base,
img_path.split(os.sep)[-2],
os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
# 打开原始图像,转换为RGB格式。
_img = Image.open(img_path).convert('RGB')
# 获取对应的label图像,用np.unint8格式打开,保存为格式np.array。
# 之所以要有这一步,是为了下面一行的encode_segmap做准备。
_tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
# 转换label图像中,每一个像素的类别,便于训练。
_tmp = self.encode_segmap(_tmp)
# 将array转换成image图像。
_target = Image.fromarray(_tmp)
sample = {'image': _img, 'label': _target}
return self.transform(sample)
def encode_segmap(self, mask):
# Put all void classes to zero
# 将属于void_classes的像素类别转成ignore_index。
for _voidc in self.void_classes:
mask[mask == _voidc] = self.ignore_index
# 将属于valid_classes的像素类别转成class_map对应的元素值。
for _validc in self.valid_classes:
mask[mask == _validc] = self.class_map[_validc]
return mask
# 提取rootdir目录及其递归子目录下,所有以suffix结尾的文件。
def recursive_glob(self, rootdir='.', suffix=''):
"""Performs recursive glob with given suffix and rootdir
:param rootdir is the root directory
:param suffix is the suffix to be searched
"""
return [os.path.join(looproot, filename)
for looproot, _, filenames in os.walk(rootdir)
for filename in filenames if filename.endswith(suffix)]
# 数据预处理。
def get_transform(self):
if self.split == 'train':
return tr.transform_tr(self.args, self.mean, self.std)
elif self.split == 'val':
return tr.transform_val(self.args, self.mean, self.std)
elif self.split == 'test':
return tr.transform_ts(self.args, self.mean, self.std)
elif self.split == 'retrain':
return tr.transform_retr(self.args, self.mean, self.std)
elif self.split == 'reval':
return tr.transform_reval(self.args, self.mean, self.std)