PyTorch 系列 | 数据加载和预处理教程

2,339 阅读8分钟

原题 | DATA LOADING AND PROCESSING TUTORIAL

作者 | Sasank Chilamkurthy

译者 | kbsc13("算法猿的成长"公众号作者)

原文 | pytorch.org/tutorials/b…

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途

简介

本文教程主要是介绍如何加载、预处理并对数据进行增强的方法。

首先需要确保安装以下几个 python 库:

  • scikit-image :处理图片数据
  • pandas :处理 csv 文件

导入模块代码如下:

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

本次教程采用的是一个人脸姿势数据集,其图片如下所示:

每张人脸都是有 68 个人脸关键点,它是由 dlib 生成的,具体实现可以查看其官网介绍:

blog.dlib.net/2014/08/rea…

数据集下载地址:

download.pytorch.org/tutorial/fa…

数据集中的 csv 文件的格式如下所示,图片名字和每个关键点的坐标 x, y

image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
0805personali01.jpg,27,83,27,98, ... 84,134
1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

数据集下载解压缩后放到文件夹 data/faces 中,然后我们先快速打开 face_landmarks.csv 文件,查看文件内容,即标注信息,代码如下所示:

landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出如下所示:

接着写一个辅助函数来显示人脸图片及其关键点,代码如下所示:

def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()

输出如下所示:

Dataset 类

torch.utils.data.Dataset 是表示一个数据集的抽象类,在自定义自己的数据集的时候需要继承 Dataset 类别,并重写下方这些方法:

  • __len__ :调用 len(dataset) 时可以返回数据集的数量;
  • __getitem__:获取数据,可以实现索引访问,即 dataset[i] 可以访问第 i 个样本数据

接下来将给我们的人脸关键点数据集自定义一个类别,在 __init__ 方法中将读取数据集的信息,并在 __getitem__ 方法调用获取的数据集,这主要是基于内存的考虑,这种做法不需要将所有数据一次读取存储在内存中,可以在需要读取数据的时候才读取加载到内存里。

数据集的样本将用一个字典表示:{'image': image, 'landmarks': landmarks},另外还有一个可选参数 transform 用于预处理读取的样本数据,下一节将介绍这个 transform 的用处。

自定义函数的代码如下所示:

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): 带有标注信息的 csv 文件路径
            root_dir (string): 图片所在文件夹
            transform (callable, optional): 可选的用于预处理图片的方法
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        # 读取图片
        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        # 读取关键点并转换为 numpy 数组
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

接下来是一个简单的例子来使用上述我们自定义的数据集类,例子中将读取前 4 个样本并展示:

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()
# 读取前 4 张图片并展示
for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

输出结果如下所示:

Transforms

从上述例子输出的结构可以看到一个问题,图片的大小并不一致,但大多数神经网络都需要输入图片的大小固定。因此,接下来是给出一些预处理的代码,主要是下面三种预处理方法:

  • Rescale :调整图片大小
  • RandomCrop:随机裁剪图片,这是一种数据增强的方法
  • ToTensor:将 numpy 格式的图片转换为 pytorch 的数据格式 tensors ,这里需要交换坐标。

这几种方法都将写成可调用的类,而不是简单的函数,这样就不需要每次都传递参数。因此,我们需要实现 __call__ 方法,以及有必要的话,__init__ 方法也是要实现的,然后就可以如下所示一样调用这些方法:

tsfm = Transform(params)
transformed_sample = tsfm(sample)

Rescale 方法的实现代码如下:

class Rescale(object):
    """将图片调整为给定的大小.

    Args:
        output_size (tuple or int): 期望输出的图片大小. 如果是 tuple 类型,输出图片大小就是给定的 output_size;
                                    如果是 int 类型,则图片最短边将匹配给的大小,然后调整最大边以保持相同的比例。
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        # 判断给定大小的形式,tuple 还是 int 类型
        if isinstance(self.output_size, int):
            # int 类型,给定大小作为最短边,最大边长根据原来尺寸比例进行调整
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # 根据调整前后的尺寸比例,调整关键点的坐标位置,并且 x 对应 w,y 对应 h
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

RandomCrop 的代码实现:

class RandomCrop(object):
    """给定图片,随机裁剪其任意一个和给定大小一样大的区域.

    Args:
        output_size (tuple or int): 期望裁剪的图片大小。如果是 int,将得到一个正方形大小的图片.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        # 随机选择裁剪区域的左上角,即起点,(left, top),范围是由原始大小-输出大小
        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]
        # 调整关键点坐标,平移选择的裁剪起点
        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}

ToTensor 的方法实现:

class ToTensor(object):
    """将 ndarrays 转换为 tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # 调整坐标尺寸,numpy 的维度是 H x W x C,而 torch 的图片维度是 C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

组合使用预处理方法

接下来就是介绍使用上述自定义的预处理方法的例子。

假设我们希望将图片的最短边长调整为 256,然后随机裁剪一个 224*224 大小的图片区域,也就是我们需要组合调用 RescaleRandomCrop 预处理方法。

torchvision.transforms.Compose 是一个可以实现组合调用欲处理方法的类,实现代码如下所示:

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# 对图片数据调用上述 3 种形式预处理方法,即单独使用 Rescale,RandomCrop,组合使用 Rescale和 RandomCrop
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

输出结构:

迭代整个数据集

现在我们已经定义好一个处理数据集的类,3种预处理数据的类,那么可以将它们整合在一起,实现加载并预处理数据的流程,流程如下所示:

  • 首先根据图片路径读取图片
  • 对图片都调用预处理的方法
  • 预处理方法也可以实现数据增强

实现的代码如下所示:

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

输出结果:

上述只是一个简单的处理过程,实际上处理和加载数据的时候,我们一般还对数据做以下的处理:

  • 将数据按给定大小分成一批一批数据
  • 打乱数据排列顺序
  • 采用 multiprocessing 来并行加载数据

torch.utils.data.DataLoader 是一个可以实现上述操作的迭代器。其需要的参数如下代码所示,其中一个参数 collate_fn 是用于指定如何对数据进行分批的操作,但也可以采用默认函数。

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)


# 辅助函数,用于展示一个 batch 的数据
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

输出结果:

torchvision

最后介绍 torchvision 这个库,它提供了一些常见的数据集和预处理方法,采用这个库就可以不需要自定义类,它比较常用的方法是 ImageFolder ,它假定图片的保存路径如下所示:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

这里的 antsbees 等等都是类别标签,此外对 PIL.Image 的预处理方法,如 RandomHorizontalFlipScale 都包含在 torchvision 中,一个使用例子如下所示:

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

小结

本教程主要介绍如何对自己的数据集自定义一个类来加载,以及预处理的方法,同时最后也介绍了 PyTorch 中的 torchvisiontorch.utils.data.DataLoader 方法。

本文的代码上传至 Github:

github.com/ccc013/Deep…

另外,还有用 dlib 生成人脸关键点的代码:

github.com/ccc013/Deep…

此外,也可以公众号后台回复“PyTorch”获取本次教程的数据集和代码。


欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!