【Pytorch】浅析Dataset、DataLoader、Sampler

1,708 阅读4分钟

两个引子: 最近学习RNN(循环神经网络),在进行Pytorch的代码实现时,发现文本数据的读取与图像数据的读取有较大的区别,通过上网查阅资料与文档并简要阅读原码,对Pytorch中Dataset、DataLoader、Sampler三个类进行了浅要分析。 另外,如何组织数据是一个很重要的问题,在SGD(随机梯度下降)的过程中,batch的大小对训练的速率有较大的影响,因此如何对数据集进行采样,也是一个重要的问题。

三者间的关系

废话不多说,先用一张图来解释他们间的关系

Pytorch官网文档对于这三个类相关介绍的第一句话就是:

At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class.

那么我们先来讨论一下DataLoader的作用,DataLoader用于我们构建的Dataset与Network之间通信,DataLoader要从Dataset中获取数据,就有索引,这个索引就由Sampler提供

Dataset

Dataset类是这三个类中最基本的一个类,根据Pytorch对Dataset类的描述,Dataset代表了一个从索引到数据的一个映射关系,即给Dataset提供一个索引,它能给我们返回对应的数据,因此这个类中最重要的一个方法就是__getitem__,我们需要重写这个方法来实现上面的功能;此外,我们还要重写__len__方法,用于返回数据集的大小,后续的Sampler与DataLoader需要获取数据集长度时会从这里取,这种风格的Dataset被称为Map-style datasets(映射风格)

class myDataset(torch.utils.data.Dataset):
	def __init__(self):
		pass
		
	def __getitem__(self, idx):
		pass
		
	def __len__(self):
		pass
		

除此之外,还有另一种风格的Datasets,称为Iterable-style datasets(迭代器风格),这种Datasets与Map-style不同的是,它实现的是__iter__方法,而非__getitem__,这种风格的Dataset获取数据不再需要通过索引,而是通过迭代的方式,它适用于一些这种类型的数据集特别适合于随机读取昂贵甚至不可能的情况,以及批量大小取决于所取数据的情况

class MyIterableDataset(torch.utils.data.IterableDataset):  
    def __init__(self, start, end):  
        super(MyIterableDataset).__init__()  
         assert end > start
         self.start = start  
         self.end = end  
  
     def __iter__(self):  
         return iter(range(self.start, self.end))
		
	def __len__(self):
		pass
		

需要注意的是:在多进程下使用IterableDatasets时,即设置num_worker=N{\rm num\_worker}=N时,IterableDatasets会被复制NN份,在每个进程中都进行一次遍历 多个进程在被调用时,可以通过torch.utils.data.get_worker_info()来访问每个进程的信息 这种情况可以通过在__iter__方法中加入以下语句来避免

class MyIterableDataset_Multworker(torch.utils.data.IterableDataset):  
     def __init__(self, start, end):  
         super(MyIterableDataset_Multworker).__init__()  
         assert end > start
         self.start = start
         self.end = end
           
     def __iter__(self):
	     # 获取多进程信息
         worker_info = torch.utils.data.get_worker_info()
         # 判断是否多线程
         if worker_info is None:
             iter_start = self.start
             iter_end = self.end
         else:
	         # 若为多线程,则先给每个进程分配一段负责的空间,per_worker就是这个worker负责的空间的首元素索引
             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
             worker_id = worker_info.id
             iter_start = self.start + worker_id * per_worker
             iter_end = min(iter_start + per_worker, self.end)
         return iter(range(iter_start, iter_end))

	def __len__(self):
		pass
		

Sampler

Sampler的功能就是为DataLoader访问Dataset提供索引,每个继承Sampler的子类都要重写__iter__方法,也可以选择重写__len__方法,若是DataLoader需要计算自身长度时必须提供

class mySampler(torch.utils.data.Sampler):
	def __init__(self, data_source):
		super(mySampler).__init__()
		self.data_source = data_source
		
	def __iter__(self):
		pass

	def __len__(self):
		pass
		

Pytorch还提供了内置实现的部分Sampler

  • torch.utils.data.SequentialSampler : 顺序采样样本,始终按照同一个顺序
  • torch.utils.data.RandomSampler: 可指定有无放回地,进行随机采样样本元素
  • torch.utils.data.SubsetRandomSampler: 无放回地按照给定的索引列表采样样本元素
  • ......

DataLoader

终于到了大头的DataLoader啦,DataLoader可以读取分批的数据也可以读取非分批的数据

  • batch_size:每个批次的大小
  • drop_last:若最后剩下的数据不够一个批次大小,是否丢弃这些数据
  • batch_sampler:对于Map-style datasets我们可以自定义sampler来决定怎么取数据
  • collate_fn:根据Sampler提供的索引在获取数据后,经过collate_fn()函数处理成一个batch 默认情况下,DataLoader会帮我们自动分批次,若要手动设置分批次的方式,此时我们将batch_sizebatch_sampler参数设为None,就可以禁用自动批处理,此时从Datasets中获取的每个样本传入collate_fn处理成batch
  • num_worker:该参数将开启多线程数据加载,用于加速数据加载,可以用torch.utils.data.get_worker_info()来访问每个进程的信息
  • pin_memory:该参数适用于GPU加速,在有GPU的情况下建议开启

参考