数据集类
pytorch自带数据集
- torchvision 提供图像数据
- torchtext 提供文本数据
import torch
import torchvision
# 下载数据集
# torchvision.datasets 提供数据集下载方式,把MNIST换成其他名字,MNIST是灰度图/IMDB电影评论文本数据
# torchvision.datasets.DatasetFolder 加载其他数据集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True)
加载本地数据集
data_path = r"../../../" # r表该句为字符串
# 完成数据类
class MyDataset(Dataset):
def __init__(slef):
self.lines = open(data_path).readlines()
def __getitem__(self, index):
return self.lines[index]
def __len__(self):
return len(self.lines)
my_dataset = MyDataset()
数据加载器
在对数据量非常大的数据进行训练时,往往会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时对数据进行预处理
# 数据加载器--处理数据
# 批处理数据batch 打乱数据shuffing
from torch.utils.data import DataLoader
batch_size = 128 # 并行处理样本数量,会将128份样本压缩成一份来进行处理
'''
shuffle 是否提前打乱数据
num_workers 加载数据的线程数
'''
train_set = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2)
print('afterLoader_train_set len=', len(train_set))
print('afterLoader_train_set type=', type(train_set))
准备数据
图片数据集的格式是PIL.Image.Image,要通过模型进行训练,要对数据进行处理,使用torchvision.transforms将数据对象转化成符合要求的Tensor,然后对数据进行标准化/正则化