本文介绍如何在训练网络的时候进行分布式的数据加载,先看代码如下:
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers,
sampler=train_sampler
)
上面的这段代码主要用于初始化训练数据集和数据加载器,是深度学习模型训练中的关键步骤。下面逐行解释其作用。
初始化训练数据集
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
PretrainDataset:这是一个自定义的数据集类,专门用于处理预训练任务的数据。它继承自 PyTorch 的Dataset类。args.data_path:表示训练数据存储的路径。数据可以是文本文件、JSON 文件或其他格式,具体取决于数据集的实现。tokenizer:是分词器,用于将文本数据分词并转换为模型可以接受的输入格式。max_length=lm_config.max_seq_len:指定了序列的最大长度。分词后的序列会根据这个长度进行截断或填充。lm_config.max_seq_len是模型配置中的一个参数,用于定义模型能够处理的最大序列长度。
创建数据采样器
train_sampler = DistributedSampler(train_ds) if ddp else None
DistributedSampler:是 PyTorch 中用于分布式训练的采样器。在分布式训练中,每个进程会处理数据集的不同子集。DistributedSampler会确保每个进程采样的数据互不重叠且覆盖整个数据集。train_ds:是初始化好的PretrainDataset,即训练数据集。它包含所有训练数据,DistributedSampler会根据数据集的大小和分布式训练的配置来划分数据。ddp:是一个布尔值,表示是否启用分布式数据并行模式。如果ddp为True,则使用DistributedSampler;否则,不使用特定的采样器。
创建数据加载器
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers,
sampler=train_sampler
)
DataLoader:是 PyTorch 中用于加载数据的工具,它提供了很多功能,如批量加载、多线程加载、打乱数据等。train_ds:是训练数据集,数据加载器从这里获取数据。batch_size=args.batch_size:指定了每个批次的样本数量。args.batch_size是通过命令行参数或配置文件传递的超参数。pin_memory=True:用于加速数据从 CPU 到 GPU 的传输。如果设置为True,数据会被加载到锁定的内存区域(pinned memory),这样可以加快数据传输速度。drop_last=False:决定了如果数据集大小不能被批次大小整除时,最后一个批次是否被丢弃。如果设置为True,最后一个不完整的批次会被丢弃;如果设置为False,则保留最后一个批次。num_workers=args.num_workers:指定了加载数据时使用的子进程数量。args.num_workers是通过命令行参数或配置文件传递的超参数,增加这个值可以加快数据加载速度,但也会增加内存占用。sampler=train_sampler:指定了使用的采样器。在分布式训练中,train_sampler是DistributedSampler,它会确保每个进程采样的数据互不重叠。在非分布式训练中,train_sampler是None,此时DataLoader会使用默认的采样器,通常会随机打乱数据。
总结
这段代码的主要作用是初始化训练数据集和数据加载器,为模型训练提供数据支持。具体来说:
- 初始化了一个自定义的
PretrainDataset对象,用于处理预训练任务的数据。 - 根据是否启用分布式训练,决定是否使用pytorch 内置的
DistributedSampler对数据进行采样。 - 使用pytorch 内置的
DataLoader对象,用于批量加载数据,并配置了相关的参数以优化数据加载过程。