一、使用torchvision自带的数据集CIFAR-10 dataset
[链接](CIFAR-10 and CIFAR-100 datasets)
import torchvision
train_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=False,download=True)
print(test_set[0])
print(test_set.classes)
img,target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
输出截图(由于服务器没有图像查看软件,所以img.show会报错):
二、与transformers配合使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
]
)
train_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="pytorchstu/dataset",train=False,transform=dataset_transform,download=True)
print(test_set[0])
writer = SummaryWriter("cifar-10")
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()