左右翻转
import torchvision
import torch
import numpy as np
from PIL import Image
from d2l import torch as d2l
def apply(img_, aug, num_rows=2, num_cols=4, scale=2):
images = [aug(img_) for _ in range(num_rows * num_cols)]
pic_display(images, num_rows, num_cols, scale)
def pic_display(images, num_rows=2, num_cols=4, scale=2):
widths = images[0].size[0]
heights = images[0].size[1]
total_width = widths * num_cols + scale * (num_cols - 1)
total_height = heights * num_rows + scale * (num_rows - 1)
new_im = Image.new('RGB', (total_width, total_height))
x_offset = 0
y_offset = 0
for i, _img in enumerate(images):
new_im.paste(_img, (x_offset, y_offset))
x_offset += _img.size[0] + scale
if i % num_cols == num_cols - 1:
x_offset = 0
y_offset += heights + scale
new_im.show()
def to_jpg(tensor_):
pil_image = tensor_.clone().permute(1, 2, 0).numpy()
pil_image = Image.fromarray((pil_image * 255).astype(np.uint8))
return pil_image
def load_cifar10(is_train, batch_size):
if is_train:
_augs = torchvision.transforms.Compose(
[torchvision.transforms.RandomVerticalFlip(), torchvision.transforms.ToTensor()])
else:
_augs = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train, transform=_augs, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=is_train, num_workers=d2l.get_dataloader_workers())
return dataloader
if __name__ == '__main__':
img = Image.open('./img/cat2.jpg')
shape_aug = torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
batch_size = 128
train_loader = load_cifar10(True, batch_size)
for i, (images, labels) in enumerate(train_loader):
images_ = [to_jpg(images[k]) for k in range(batch_size)]
pic_display(images_, 8, 16,)
break
上下翻转
随机裁剪
随机更改图像的亮度
随机更改图像的色调
结合多种图像增广方法
使用图像增广进行训练