PyTorch入门备忘-3- 图片数据集处理 - torchvision&transform

1,035 阅读7分钟

Torchvision

机器学习与深度学习,数据集组织、加载处理、按需按批装载、送入模型训练,不论是图片、文字还是音视频,流程基本上一致。

具体图片处理的大部分实现transform包上,实际使用时需要加入业务场景才能丰满起来。

当我们静下心来,花时间去接触AI相关的知识与工具,我们会深刻的感觉到技术真的只是一个工具,是场景将它丰富了起来。—— 笔者个人观点

简单介绍

Torchvision 是 PyTorch 的一个独立子库,它服务于PyTorch深度学习框架的,主要用于计算机视觉任务,包括图像处理、数据加载、数据增强、预训练模型等。

核心包如下:

  • torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  • torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  • torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  • torchvision.utils: 其他的一些有用的方法。

官网文档入口: pytorch.org/docs/stable…

读取数据集

可以从网上得到数据集,再用torchvision加载数据并处理,也可以从自建的数据集上加载并。

自建过程建 PyTorch入门备忘-1-Dataset自建及Jupyter与Pycharm简易入门

本文用torchvision数据集来演示读取的过程,内部会使用transform对数据进行变型。

数据集准备 - CIFAR10

此次代码中要用到的数据集,见附件有介绍与中文的参数。

用代码下载数据集-CIFAR10

通过py代码


# 使用CIFAR10数据集
# 训练集
# 如果下载比较慢,可以将控制台打印的下载链接放到专门的下载工具中下载
# 首先下载的是一个压缩包,会自动解压

train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=True, download=True)


# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, download=True)


运行代码,控制台显示如下信息

50000 -- 说明有5w张训练数据
10000 --说明有1w张测试数据

会自动下载数据集到torchvision_dataset文件夹
已下载就不会继续下载,控制台会出输Files already downloaded and verified字样

image.png

操作数据

torchvision.datasets和transform的联合使用

  1. 下载数据集
  2. 装载图片
  3. 图片处理
  4. 图片展示

import torchvision.datasets
from torch.utils.tensorboard import SummaryWriter

# 将图片数据都转为tensor类型
# 可以对数据集做任何transforms范围内的操作,该例子只针对数据做toTensor
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor
])

# 使用CIFAR10数据集
train_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset",  train=True, transform=dataset_transform, download=True)

# 测试集
test_set = torchvision.datasets.CIFAR10(root="./torchvision_dataset", train=False, transform=dataset_transform, download=True)


# 用tensorboard显示前10张图片
# 运行tensorboard  --logdir=p10
writer = SummaryWriter('p10')
for i in range(0):
    img, target = test_set[i]
    writer.add_image("test_set", img,i)

writer.close()

Torchvision.transform

Transforms是torchvision模块下面的一个子模块,在Dataset中很常用可以方便地对图像进行各种变换操作。该模块中包含大量用户数据类型转化的类型和方法,比如统一size,每一个图像数据进行类的转化等。

""" transforms.ToTensor 转化:PIL Image或numpy.ndarray(H * W * C) 转到 tensor 的数据类型

主要方法 call(self, pic) 参数:pic - Image或numpy的图像对象 返回值 : 返回tensor类型的图片 """

Tensor图像结构:

Tensor是PyTorch中最基本的数据结构,你可以将其视为多维数组或者矩阵。PyTorch tensor和NumPy array非常相似,但是tensor可以在GPU上运算,而NumPy array则只能在CPU上运算。

可对图像直接操作,代码如下:

导入

import torch

创建一个未初始化的5x3矩阵

# 创建一个未初始化的5x3矩阵
x = torch.empty(5, 3)
print(x)

用.backward() 计算梯度

# 因为out包含一个标量,out.backward()等价于out.backward(torch.tensor(1.))
out.backward()

# 打印梯度 d(out)/dx
print(x.grad)

常用方法

ToTensor

  • 作用: PIL Image或numpy.ndarray(H * W * C) 转到 tensor 的数据类型
  • 输入: PIL Image.open()
  • 输出: 类型 ToTensor

Normalize

  • 作用: 根据均值与标准差归一化tensor类图片
  • 输入: tensor类型图片的均值与标准差
  • 输出: 归一化后的图片数据
  • 计算公式: (Input[channel] - mean[channel]) / std[channel] 举例
  Input[channel] - mean[channel]) / std[channel= (input - 0.5)/0.5= 2 * input  - 1结论 input像素值[0-1]  -->  result[-1,1]

Resize

  • 作用: 将(PIL Image or Tensor)调整为给定的大小。
  • 输入
    • size (sequence or int):如果size是(h, w)这样的序列,则输出size将与此匹配。如果size为int,图像的较小边缘将匹配此数字。即,如果高度>宽度,那么图像将被重新缩放为(size*高度/宽度,size)
    • ...
  • 输出: 变型后的PIL Image or Tensor

Compose

  • 作用:把几个tranforms组合在一起使用,相当于一个组合器,可以对输入图片一次进行多个transforms的操作。比如 compose负责把ToTensor和resize组合起来,一步到位实现PIL图形到resize后的tensor图形的转换

RandomCrop

  • 作用:随机裁剪
  • 输入
    • size: size可以是tuple也可以是Integer
    • ...
  • 输出:裁后图片

代码


# 导入
"""
Torchvision 是 PyTorch 的一个独立子库,主要用于计算机视觉任务,包括图像处理、数据加载、数据增强、预训练模型等。
Torchvision 提供了各种经典的计算机视觉数据集的加载器,如CIFAR-10、ImageNet,以及用于数据预处理和数据增强的工具,可以帮助用户更轻松地进行图像分类、目标检测、图像分割等任务。
"""
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image

"""
用ToTensor将PIL图片转为Tensor图片
"""
# 绝对路径 D:\workspace\python\learn_torch\data\train\ants\0013035.jpg
img_path = "data/train/bees/16838648_415acd9e3f.jpg"
img = Image.open(img_path)
trans_toTensor = transforms.ToTensor()
img_tensor = trans_toTensor(img)

writer = SummaryWriter("logs")
writer.add_image("tensor_img", img_tensor)

"""
2. 用Normalize实现Tensor图片归一化
"""
trans_norm_0 = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5, ])
img_norm_0 = trans_norm_0(img_tensor)
writer.add_image("Normalize", img_norm_0, 1)

trans_norm = transforms.Normalize([6, 3, 2], [9, 3, 5])
img_norm = trans_norm(img_tensor)
writer.add_image("Normalize", img_norm, 2)
# 运行,测蔗
# tensorboard  --logdir=logs


"""
2. Resize-等比例缩放
"""
trans_resize = transforms.Resize((512, 512))
# img_PIl --> img_resize PIL
img_resize = trans_resize(img)
# image PIl ---> toTensor --> 转为 tensor
img_resize = trans_toTensor(img_resize)
# print(img_resize)
writer.add_image("Resize", img_resize, 1)

"""
transforms.Compose
# trans_toTensor: 输入
# trans_resize_2: 输出
"""
trans_resize_2 = transforms.Resize(100)
trans_compose = transforms.Compose([trans_resize_2, trans_toTensor])
img_resize_2 = trans_compose(img)
writer.add_image("Resize", img_resize_2, 0)

"""
transforms.RandomCrop:随机裁剪
"""
trans_random = transforms.RandomCrop((150, 500))
trans_compose_2 = transforms.Compose([trans_random, trans_toTensor])
for i in range(10):
    img_crop = trans_compose_2(img)
    writer.add_image("RandomCrop", img_crop, i)
    
writer.close()

启动tensorboard查看结果

image.png

相关知识

Torchvision

官网: pytorch.org/docs/stable…

Torchvision.dataset

文档入口

image.png

数据集

CIFAR10

CIFAR10由10个不同标签的图像组成。其中包括卡车、青蛙、船、汽车、鹿等常见图像。还有一个CIFAR100版本,有 100 个不同的类别

CIFAR10/CIFAR100一般用于物价识别,其广泛用于机器学习领域的计算机视觉算法基准测试。详情 官网地址 包名 torchvision.datasets.FashionMNIST()

image.png

包名 torchvision.datasets.CIFAR10()

参数说明:

  • root: 数据集根路径,可以是相对路径
  • train: = ture 训练集,否则为测试集
  • transform: 对数据集进行的transform操作
  • target_transform: 训练后的目标数据集执行指定的transform操作
  • download:=true 自动下载数据集,false不会下载

COOC

目前有超过 100,000 个日常物品,如人、瓶子、文具、书籍等。 广泛用于目标检测,语义分割和图像描述

MNIST

MNIST 常用的入门级数据集,手写文字数据集

包名 torchvision.datasets.MNIST()

文档:pytorch.org/vision/stab…

Fashion MNIST

该数据集与 MNIST 类似,但该数据集不是手写数字,而是 T 恤、裤子、包等服装项目。

包名 torchvision.datasets.FashionMNIST()

torchvision.models

提供神经网络常见的神经网络,有一些神经网络已经预训练好了。

image.png

torchvision.transform

图像处理与变形等

transforms.CenterCrop 对图片中心进行裁剪 
transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
transforms.FiveCrop  对图像四个角和中心进行裁剪得到五分图像
transforms.Grayscale  对图像进行灰度变换
transforms.Pad  使用固定值进行像素填充
transforms.RandomAffine  随机仿射变换 
transforms.RandomCrop  随机区域裁剪
transforms.RandomHorizontalFlip  随机水平翻转
transforms.RandomRotation  随机旋转
transforms.RandomVerticalFlip  随机垂直翻转

文档入口

image.png

torchvision.utils

提供一些常用的工具,比如tensorboard

from torch.utils.tensorboard import SummaryWriter

文档入口

image.png



记录于:2012/11/10 _山海



[参考]

www.bilibili.com/video/BV1hE…

pytorch.org/vision/stab…