torchvision中数据集的使用
Pytorch官网
官方文档:
- PyTorch: 核心模块
- torchaudio: 语音模块
- torchtext: 文本模块
- torchvision: 视觉模块
- ......
torchvision中的模块
torchvision.datasets
写代码时,指定这些数据集并给出相应的参数,pytorch可自动下载和使用这些标准数据集
Datasets — Torchvision 0.16 documentation (pytorch.org)
COCO数据集,常用于目标检测或语义分割(Image detection or segmentation)中
MNIST(手写文字)数据集,常用于图像分类(Image classification)中,一般教科书中的常用数据集
the MNIST database of handwirtten digits
CIFAR10数据集:常用于物体识别
CIFAR-10 and CIFAR-100 datasets (toronto.edu)
torchvision.io
输入输出(不常用)
torchvision.ops
torchvision提供的一些少见的特殊操作(不常用)
torchvision.transforms
之前的章节里讲过了
torchvision.utils
之前的tensorboard来自于此
torch.models
提供一些常见的神经网络,有的已经预训练好了,比较常用(于毕设和科研)
- Classification
- Semantic Segmentation
torchvision.datasets的使用
以CIFAR10为例:
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.
Parameters:
- root (string) – Root directory of dataset where directory
cifar-10-batches-pyexists or will be saved to if download is set to True.(数据集所在的位置)- train (bool , optional) – If True, creates dataset from training set, otherwise creates from test set.(训练集 or 测试集)
- transform (callable*,* optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g,
transforms.RandomCrop- target_transform (callable*,* optional) – A function/transform that takes in the target and transforms it.
- download (bool , optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.(是否需要从网上自动下载,下载链接可以在torchvision.datasets.CIFAR10文件中查看)
示例
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
# 因为数据集中的图片比较小(32*32),故不进行裁剪操作
])
# CIFAR10 常用于物体识别
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=False)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=False)
# print(test_set[0]) # (<PIL.Image.Image image mode=RGB size=32x32 at 0x1DF45FAC710>, 3)
# print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
#
# img, target = test_set[0]
# print(img,target) # <PIL.Image.Image image mode=RGB size=32x32 at 0x1D64D04C3D0> 3
#
# print(test_set.classes[target]) # cat
# # img.show()
print(test_set[0])
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image(tag="test_set",img_tensor=img,global_step=i)
writer.close()