Pytorch加载数据

186 阅读2分钟

本文已参与[新人创作礼]活动,一起开启掘金创作之路。

Pytorch导入数据主要依靠torch.utils.data.DataLoader和torch.utils.data.Dataset这两个类来完成。
Dataset:获取数据及其标签以及数据数量。
Dataloader:为后面的网络提供不同的数据形式。

加载自己数据的过程

1.第一步,先重写Dataset。重写Dataset要重写三个函数。分别是__init__,_ getitem__和 __ len_。

先定义一个自己需要的类,再重写函数。示例代码如下:

#这里的代码格式有一点问题,他们的空格没有用好,需要注意一下。
from torch.utils.data import Dataset
import os
class MyData(Dataset):#然后重写三个函数
 def __init__(self,root):#root图片路径,根据路径得到数据列表
 	 imgs=os.listdir(root)
 	 self.imgs=[os.path.join(root,k) for k in imgs]
def __len__(self):
    return len(self.imgs) #返回列表长度
def __getitem__(self, index):#根据index找到对应图片并打开
     img_path = self.imgs[index]
     pil_img = Image.open(img_path)
     pil_img = np.asarray(pil_img)
     data = torch.from_numpy(pil_img)
     return data
 if __name__ == '__main__':
        trian_dataset=MyData('test')
        print(dataSet[0])            

2.使用DataLoader迭代数据主要实现以下功能:

  1. 批处理数据(Batching the data)
  2. 打乱数据(Shuffling the data)
  3. 使用多线程 multiprocessing 并行加载数据,默认单线程

示例代码如下:

train_loader = dataloader.DataLoader(
  dataset=train_dataset,
  batch_size=128, # batch_size可以理解每次加载一个包,每个包中含有128张图片
  shuffle=False

综上所述使用pytorch 加载数据一共用到了两个类:torch.utils.data.DataLoader和torch.utils.data.Dataset。这两个类分别实现了不同的功能,使用Dataset类构建自己的数据加载类来加载数据,在上述代码中MyData类中__init__实现初始化功能,_ _len__函数实现确定需要加载多少数据, _ getitem _函数实现加载一个数据。完成Dataset类重写之后,配合使用DataLoader类进行迭代加载数据,一次加载需要的batchsize的数据,打乱顺序,还有设置多线程等等以实现快速加载等功能。