cityscapes.py阅读笔记

1,638 阅读3分钟
def twoTrainSeg(args, root=Path.db_root_dir('cityscapes')):
    # images_base是原始训练图像所在的目录。
    images_base = os.path.join(root, 'leftImg8bit', 'train')
    # train_files存放所有图像的‘路径+名称’。
    train_files = [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(images_base)
                   for filename in filenames if filename.endswith('.png')]
    # number_images为所有图像总数。
    number_images = len(train_files)
    # permuted_indices_ls是0到number_images-1的随机排列。
    permuted_indices_ls = np.random.permutation(number_images)
    # indices_1和indices_2分别是permuted_indices_ls的前半部分和后半部分。
    indices_1 = permuted_indices_ls[: int(0.5 * number_images) + 1]
    indices_2 = permuted_indices_ls[int(0.5 * number_images):]
    # batch normalization要求分割后的两个训练集都包含偶数个数据。
    if len(indices_1) % 2 != 0 or len(indices_2) % 2 != 0:
        raise Exception('indices lists need to be even numbers for batch norm')
    # 返回分别包装成两个CityscapesSegmentation类的的训练集。
    return CityscapesSegmentation(args, split='train', indices_for_split=indices_1), CityscapesSegmentation(args,
                                                                                                            split='train',
                                                                                                            indices_for_split=indices_2)

ATTENTION: 注意这里之所以命名这个类为CityscapesSegmentation,而不是Cityscapes,是因为——

train数据集对应一个CityscapesSegmentation,val数据集对应一个CityscapesSegmentation,test数据集也对应一个CityscapesSegmentation。

也即,cityscapes数据集其实被分成了三个独立的数据集。它们之间主要用self.split来区分。

# 继承自pytorch的dataset,需要实现_len_()和_getitem_()函数,如图1。
class CityscapesSegmentation(data.Dataset):
    # 数据集里包含的类别数,cityscapes里有19类
    NUM_CLASSES = 19
    # 分别列出cityscapes里都有哪19个类
    CLASSES = [
        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
        'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
        'truck', 'bus', 'train', 'motorcycle', 'bicycle'
    ]
    
    # 初始化cityscapesSegmentation类
    def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train", indices_for_split=None):
        self.root = root        # 数据所在目录
        self.split = split      # 分离方法
        self.args = args        # 其他参数
        self.files = {}
        self.mean = (0.485, 0.456, 0.406)   # 数据的平均值
        self.std = (0.229, 0.224, 0.225)    # 数据的方差值
        self.crop = self.args.crop_size     # 剪裁大小
        
        # 哈哈哈,其实这里就是在取训练数据啦。
        # images_base:原图所在目录。
        # annotations_base: label图所在目录。
        if split.startswith('re'):
            self.images_base = os.path.join(self.root, 'leftImg8bit', self.split[2:])
            self.annotations_base = os.path.join(self.root, 'gtFine', self.split[2:])
        else:
            self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
            self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
        
        # 获取image_base目录(包括所有子目录)中,所有以.png结尾的文件;
        # 返回的结果是所有文件的‘路径+文件名’组成的list列表。
        self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png')

        # indices_for_split 索引的排序值,相当于以某种方式打乱原来的训练数据集(shuffle?)。
        if indices_for_split is not None:
            self.files[split] = np.array(self.files[split])[indices_for_split].tolist()

        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16个。
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19个。
        
        # 在cityscapes数据集中包含的所有类CLASSES中,加一个'unlabelled'类。共20个类。
        self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence',
                            'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain',
                            'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train',
                            'motorcycle', 'bicycle']
        
        # 设置ignore_index为255,这个变量是用来干嘛的,目前还不知道。
        self.ignore_index = 255
        # 将valid_classes中19个数字为key,[0,1,2,...,18]中19个数字为value。一一对应组成字典。
        self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES)))
        
        # 如果self.files[split]为空,即split文件下没有任何数据,则报错。
        if not self.files[split]:
            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))

        print("Found %d %s images" % (len(self.files[split]), split))
        
        # 根据split的不同,转换数据
        self.transform = self.get_transform()

图1

函数__len__(), getitem(), encode_segmap(), recursive_glob()

    # 返回该数据集里包含的所有元素个数。
    def __len__(self):
        return len(self.files[self.split])

    def __getitem__(self, index):
        # 获取第index个图像的路径。rstrip()函数用来删除末尾空格。
        img_path = self.files[self.split][index].rstrip()
        # 获取第index个图像对应的label图像的路径。
        lbl_path = os.path.join(self.annotations_base,
                                img_path.split(os.sep)[-2],
                                os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
        
        # 打开原始图像,转换为RGB格式。
        _img = Image.open(img_path).convert('RGB')
        # 获取对应的label图像,用np.unint8格式打开,保存为格式np.array。
        # 之所以要有这一步,是为了下面一行的encode_segmap做准备。
        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
        # 转换label图像中,每一个像素的类别,便于训练。
        _tmp = self.encode_segmap(_tmp)
        # 将array转换成image图像。
        _target = Image.fromarray(_tmp)

        sample = {'image': _img, 'label': _target}
        return self.transform(sample)

    def encode_segmap(self, mask):
        # Put all void classes to zero
        # 将属于void_classes的像素类别转成ignore_index。
        for _voidc in self.void_classes:
            mask[mask == _voidc] = self.ignore_index
        # 将属于valid_classes的像素类别转成class_map对应的元素值。
        for _validc in self.valid_classes:
            mask[mask == _validc] = self.class_map[_validc]
        return mask
    
    # 提取rootdir目录及其递归子目录下,所有以suffix结尾的文件。
    def recursive_glob(self, rootdir='.', suffix=''):
        """Performs recursive glob with given suffix and rootdir
            :param rootdir is the root directory
            :param suffix is the suffix to be searched
        """
        return [os.path.join(looproot, filename)
                for looproot, _, filenames in os.walk(rootdir)
                for filename in filenames if filename.endswith(suffix)]
    
    # 数据预处理。
    def get_transform(self):
        if self.split == 'train':
            return tr.transform_tr(self.args, self.mean, self.std)
        elif self.split == 'val':
            return tr.transform_val(self.args, self.mean, self.std)
        elif self.split == 'test':
            return tr.transform_ts(self.args, self.mean, self.std)
        elif self.split == 'retrain':
            return tr.transform_retr(self.args, self.mean, self.std)
        elif self.split == 'reval':
            return tr.transform_reval(self.args, self.mean, self.std)