05-torchvision中数据集的使用

227 阅读2分钟

torchvision中数据集的使用

Pytorch官网

官方文档:

05-1.png

  • PyTorch: 核心模块
  • torchaudio: 语音模块
  • torchtext: 文本模块
  • torchvision: 视觉模块
  • ......

torchvision中的模块

05-2.png

torchvision.datasets

写代码时,指定这些数据集并给出相应的参数,pytorch可自动下载和使用这些标准数据集

Datasets — Torchvision 0.16 documentation (pytorch.org)

COCO数据集,常用于目标检测或语义分割(Image detection or segmentation)中

05-3.png

MNIST(手写文字)数据集,常用于图像分类(Image classification)中,一般教科书中的常用数据集

the MNIST database of handwirtten digits

MNIST

05-4.png

05-5.png

CIFAR10数据集:常用于物体识别

CIFAR-10 and CIFAR-100 datasets (toronto.edu)

05-6.png

torchvision.io

输入输出(不常用)

torchvision.ops

torchvision提供的一些少见的特殊操作(不常用)

torchvision.transforms

之前的章节里讲过了

torchvision.utils

之前的tensorboard来自于此

torch.models

提供一些常见的神经网络,有的已经预训练好了,比较常用(于毕设和科研)

  • Classification
  • Semantic Segmentation

05-7.png

torchvision.datasets的使用

以CIFAR10为例:

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

Parameters:

  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.(数据集所在的位置)
  • train (bool , optional) – If True, creates dataset from training set, otherwise creates from test set.(训练集 or 测试集)
  • transform (callable*,* optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable*,* optional) – A function/transform that takes in the target and transforms it.
  • download (bool , optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.(是否需要从网上自动下载,下载链接可以在torchvision.datasets.CIFAR10文件中查看)

示例

import torchvision
from torch.utils.tensorboard import SummaryWriter
​
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
    # 因为数据集中的图片比较小(32*32),故不进行裁剪操作
])
​
# CIFAR10 常用于物体识别
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=False)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=False)
​
# print(test_set[0])  # (<PIL.Image.Image image mode=RGB size=32x32 at 0x1DF45FAC710>, 3)
# print(test_set.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#
# img, target = test_set[0]
# print(img,target)  # <PIL.Image.Image image mode=RGB size=32x32 at 0x1D64D04C3D0> 3
#
# print(test_set.classes[target])  # cat
# # img.show()
​
print(test_set[0])
​
writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image(tag="test_set",img_tensor=img,global_step=i)
​
writer.close()

05-8.png