torchvision中数据集的使用

66 阅读1分钟

一、使用torchvision自带的数据集CIFAR-10 dataset

[链接](CIFAR-10 and CIFAR-100 datasets)

import torchvision

train_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=False,download=True) 

print(test_set[0])
print(test_set.classes)
img,target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()

输出截图(由于服务器没有图像查看软件,所以img.show会报错):

image.png

二、与transformers配合使用

import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
]
)
train_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=False,transform=dataset_transform,download=True) 
print(test_set[0])
writer = SummaryWriter("cifar-10")
for i in range(10):
    img,target = test_set[i]
    writer.add_image("test_set",img,i)
writer.close()