李沐的深度学习课 课程笔记5 图像增广

109 阅读2分钟

左右翻转

import torchvision
import torch
import numpy as np
from PIL import Image
from d2l import torch as d2l


def apply(img_, aug, num_rows=2, num_cols=4, scale=2):
    images = [aug(img_) for _ in range(num_rows * num_cols)]
    # d2l.show_images(images, num_rows, num_cols, scale=scale)
    pic_display(images, num_rows, num_cols, scale)


def pic_display(images, num_rows=2, num_cols=4, scale=2):
    # 每张图像的宽度和高度
    widths = images[0].size[0]
    heights = images[0].size[1]
    # 整个拼接图像的宽度和高度
    total_width = widths * num_cols + scale * (num_cols - 1)
    total_height = heights * num_rows + scale * (num_rows - 1)
    new_im = Image.new('RGB', (total_width, total_height))
    x_offset = 0
    y_offset = 0
    for i, _img in enumerate(images):
        new_im.paste(_img, (x_offset, y_offset))
        x_offset += _img.size[0] + scale
        if i % num_cols == num_cols - 1:  # 每行4个图像后,下一行起始位置重置
            x_offset = 0
            y_offset += heights + scale
    new_im.show()  # 显示图像
    # new_im.save('combined_image.jpg')  # 保存图像


def to_jpg(tensor_):
    # torch.Tensor,shape 为 (C, H, W),并且其值在 [0, 1] 范围内
    # 将 tensor 转换为 PIL 图像
    pil_image = tensor_.clone().permute(1, 2, 0).numpy()  # 将 tensor 的维度转换为 H x W x C
    pil_image = Image.fromarray((pil_image * 255).astype(np.uint8))  # 将值缩放到 [0, 255] 并转换为 PIL 图像
    # pil_image.save('output.jpg')
    return pil_image


def load_cifar10(is_train, batch_size):
    if is_train:
        _augs = torchvision.transforms.Compose(
            [torchvision.transforms.RandomVerticalFlip(), torchvision.transforms.ToTensor()])
    else:
        _augs = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train, transform=_augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader


if __name__ == '__main__':
    img = Image.open('./img/cat2.jpg')
    # 左右翻转
    # apply(img, torchvision.transforms.RandomHorizontalFlip())
    # 上下翻转
    # apply(img, torchvision.transforms.RandomVerticalFlip())
    # 随机裁剪
    shape_aug = torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
    # apply(img, shape_aug)
    # 随机更改图像的亮度
    # apply(img, torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0))
    # 随机更改图像的色调
    # apply(img, torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5))
    # 同时[随机更改图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)]
    color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
    # apply(img, color_aug)
    # 结合多种图像增广方法
    augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
    # apply(img, augs)

    # 使用图像增广进行训练
    batch_size = 128
    train_loader = load_cifar10(True, batch_size)
    # d2l.show_images([train_loader[i][0] for i in range(32)], 4, 8, scale=0.8)
    for i, (images, labels) in enumerate(train_loader):
        images_ = [to_jpg(images[k]) for k in range(batch_size)]
        pic_display(images_, 8, 16,)
        break

上下翻转

随机裁剪

随机更改图像的亮度

随机更改图像的色调

结合多种图像增广方法

使用图像增广进行训练