
# 使用 `torch.utils.data.DataLoader`
尽管 `Dataset` 类提供了一种清晰的方式来抽象对单个数据样本的访问,但对大型数据集逐个样本进行迭代对于训练深度学习模型通常效率不高。训练通常受益于分批处理数据。在这种情况下,`torch.utils.data.DataLoader` 在高效数据处理中发挥着核心作用。
`DataLoader` 封装了一个 `Dataset`(无论是内置的还是你自定义的实现)并提供了对其的迭代接口。它的主要职责是:
1. **分批处理**:将单个样本分组为小批量。
2. **打乱数据**:在每个周期随机打乱数据,以防止模型学习到样本的顺序并提高泛化能力。
3. **并行数据加载**:使用多个子进程并发加载数据,这可能会显著加快数据处理流程。
### 基本用法和迭代
创建 `DataLoader` 很简单。你主要需要提供 `Dataset` 实例并指定所需的 `batch_size`。
```scala 3
import torch
import torch.utils.data.Dataset
import torch.utils.data.DataLoader
class DummyDataset extends Dataset:
def __init__(self, num_samples=100):
val num_samples = num_samples
val features = torch.randn(num_samples, 10) // 示例:100 个样本,10 个特征
val labels = torch.randint(0, 2, (num_samples,)) // 示例:100 个二元标签
def __len__(self):
return num_samples
def __getitem__(self, idx):
return self.features[idx], self.labels[idx]
// 实例化数据集
val dataset = DummyDataset(num_samples=105)
// 实例化 DataLoader
// batch_size: 每个批次的样本数量
// shuffle: 设置为 True 以在每个周期打乱数据(对训练很重要)
val train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=true)
// 迭代 DataLoader
println(s"Dataset size: ${len(dataset)}")
println(s"DataLoader batch size: ${train_loader.batch_size}")
for epoch <- Range(1): // 一个周期的示例
println(s"\n--- Epoch ${epoch+1} ---")
for i, batch in enumerate(train_loader):
// DataLoader 产生批次。每个 'batch' 通常是元组或列表
// 包含特征和标签的张量。
val features, labels = batch
println(s"Batch ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")
// 在这里你通常会执行训练步骤:
model.train()
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
```
```java
package featurestore;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.JavaRandomDataLoader;
import org.bytedeco.pytorch.DataLoaderOptions;
import org.bytedeco.pytorch.global.torch;
import java.util.Iterator;
public class DummyDataset extends JavaDataset {
private final long numSamples;
private final Tensor features;
private final Tensor labels;
public DummyDataset(long numSamples) {
super();
this.numSamples = numSamples;
this.features = torch.randn(
new long[]{numSamples, 10},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float))
);
this.labels = torch.randint(
0, 2,
new long[]{numSamples},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long))
);
}
public DummyDataset() {
this(100);
}
@Override
public SizeTOptional size() {
return new SizeTOptional(new SizeTPointer(numSamples));
}
@Override
public Example get(long index) {
if (index < 0 || index >= numSamples) {
throw new IndexOutOfBoundsException("索引超出范围: " + index);
}
Tensor feature = features.index_select(0, torch.tensor(new LongPointer(index)));
Tensor label = labels.index_select(0, torch.tensor(new LongPointer(index)));
return new Example(feature, label);
}
public static void main(String[] args) {
DummyDataset dataset = new DummyDataset(105);
DataLoaderOptions loaderOptions = new DataLoaderOptions();
loaderOptions.batch_size().put(32);
loaderOptions.enforce_ordering().put(false);
RandomSampler sampler = new RandomSampler(dataset.size().get());
JavaRandomDataLoader trainLoader = new JavaRandomDataLoader(dataset,sampler, loaderOptions);
System.out.println("Dataset size: " + dataset.size().get());
System.out.println("DataLoader batch size: " + loaderOptions.batch_size().get());
int numEpochs = 1;
for (int epoch = 0; epoch < numEpochs; epoch++) {
System.out.println("\n--- Epoch " + (epoch + 1) + " ---");
var beginBatch = trainLoader.begin();
var endBatch = trainLoader.end();
while(beginBatch != endBatch){
var batch = beginBatch.access();
beginBatch.increment();
}
}
}
}
private static String tensorShapeToString(SizeTPointer sizes) {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < sizes.limit(); i++) {
sb.append(sizes.get(i));
if (i < sizes.limit() - 1) {
sb.append(", ");
}
}
sb.append("]");
return sb.toString();
}
}
```
```java
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
public class DummyDataset extends JavaDataset {
private final Tensor features;
private final Tensor labels;
private final long numSamples;
public DummyDataset(long numSamples) {
this.numSamples = numSamples;
this.features = torch.randn(new long[]{numSamples, 10});
this.labels = torch.randint(0, 2, new long[]{numSamples});
}
@Override
public Example get(long index) {
Tensor feature = features.index(new TensorIndexVector(new TensorIndex(index)));
Tensor label = labels.index(new TensorIndexVector(new TensorIndex(index)));
return new Example(feature, label);
}
@Override
public SizeTOptional size() {
return new SizeTOptional(numSamples);
}
}
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import static org.bytedeco.pytorch.global.torch.*;
public class DummyDatasetApp {
public static void main(String[] args) {
long numSamples = 105;
long batchSize = 32;
DummyDataset dataset = new DummyDataset(numSamples);
DataLoaderOptions options = new DataLoaderOptions(batchSize);
RandomSampler sampler = new RandomSampler(numSamples);
JavaRandomDataLoader trainLoader = new JavaRandomDataLoader(dataset, sampler, options);
System.out.println("Dataset size: " + dataset.size());
System.out.println("DataLoader batch size: " + batchSize);
for (int epoch = 0; epoch < 1; epoch++) {
System.out.println("\n--- Epoch " + (epoch + 1) + " ---");
int step = 0;
ExampleVectorIterator it = trainLoader.begin();
while (!it.equals(trainLoader.end())) {
ExampleVector batch = it.access();
Tensor features = stackExampleData(batch);
Tensor labels = stackExampleTarget(batch);
System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
++step,
features.sizes().get(0) + "x" + features.sizes().get(1),
labels.sizes().get(0));
it.increment();
}
}
}
private static Tensor stackExampleData(ExampleVector batch) {
TensorVector tensors = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tensors.push_back(batch.get(i).data());
}
return torch.stack(tensors);
}
private static Tensor stackExampleTarget(ExampleVector batch) {
TensorVector tensors = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tensors.push_back(batch.get(i).target());
}
return torch.stack(tensors);
}
}
```
```java
import org.bytedeco.pytorch.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.pytorch.global.torch.*;
public class DummyDataset extends ChunkDataset implements AutoCloseable {
static { Loader.load(); }
private Tensor features;
private Tensor labels;
private long numSamples;
private long currentIndex = 0;
public DummyDataset(Pointer p) { super(p); }
public DummyDataset(long numSamples) {
super((Pointer)null);
this.numSamples = numSamples;
this.features = randn(new long[]{numSamples, 10}, new TensorOptions().dtype(new ScalarTypeOptional(kFloat())));
this.labels = randint(0, 2, new long[]{numSamples}, new TensorOptions().dtype(new ScalarTypeOptional(kLong())));
}
@Override
public @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size) {
if (currentIndex >= numSamples) {
return new ExampleVectorOptional();
}
long actualBatchSize = Math.min(batch_size, numSamples - currentIndex);
Tensor batchFeatures = features.narrow(0, currentIndex, actualBatchSize);
Tensor batchLabels = labels.narrow(0, currentIndex, actualBatchSize);
currentIndex += actualBatchSize;
ExampleVector vec = new ExampleVector();
vec.push_back(new Example(batchFeatures, batchLabels));
return new ExampleVectorOptional(vec);
}
@Override
public void reset() {
this.currentIndex = 0;
}
@Override
public @ByVal SizeTOptional size() {
return new SizeTOptional(numSamples);
}
}
```
```java
public static void main(String[] args) {
long numSamples = 105;
DummyDataset dataset = new DummyDataset(numSamples);
long batchSize = 32;
DataLoaderOptions options = new DataLoaderOptions(batchSize);
ChunkRandomDataLoader trainLoader = new ChunkRandomDataLoader(dataset, options);
System.out.println("Dataset size: " + dataset.size().get());
for (int epoch = 0; epoch < 1; epoch++) {
System.out.println("\n--- Epoch " + (epoch + 1) + " ---");
int batchIdx = 0;
ExampleIterator iter = trainLoader.begin();
ExampleIterator end = trainLoader.end();
while (!iter.equals(end)) {
Example batch = iter.access();
Tensor features = batch.data();
Tensor labels = batch.target();
System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
(++batchIdx),
java.util.Arrays.toString(features.sizes().get()),
java.util.Arrays.toString(labels.sizes().get()));
iter.increment();
}
trainLoader.retainReference()
}
}
```
```java
public class DummyMapDataset extends ChunkMapDataset {
private long numSamples;
Tensor features = randn(new long[]{32, 10}, new TensorOptions().dtype(new ScalarTypeOptional(kFloat())));
Tensor labels = randint(0, 2, new long[]{32}, new TensorOptions().dtype(new ScalarTypeOptional(kLong())));
public DummyMapDataset(long numSamples) {
super((Pointer) null);
this.numSamples = numSamples;
}
@Override
public ExampleOptional get_batch_example(@Cast("size_t") long batch_index) {
Example example = new Example(features, labels);
return new ExampleOptional(example);
}
@Override
public @Cast("size_t") long size() {
return numSamples;
}
}
```
运行此代码将展示 `DataLoader` 如何产生数据批次。注意每个批次打印的形状反映了 `batch_size`(除了可能的最后一个批次)。
### 批次策略和 `drop_last`
默认情况下,如果 `Dataset` 中的总样本数不能被 `batch_size` 完全整除,最后一个批次将包含剩余样本,因此会更小。
在我们有 105 个样本、批次大小为 32 的示例中:
- 批次 1: 32 个样本
- 批次 2: 32 个样本
- 批次 3: 32 个样本
- 批次 4: 9 个样本 (105 - 3*32 = 9)
有时,拥有可变批次大小,尤其是非常小的最后一个批次,可能会影响某些训练动态或特定层的要求(例如训练期间的 BatchNorm 层,尽管 PyTorch 处理得相当好)。如果你希望所有批次都具有精确的 `batch_size`,并丢弃较小的最后一个批次,你可以在创建 `DataLoader` 时设置 `drop_last=True`:
```scala 3
// 如果数据集大小不能被批次大小整除,则丢弃最后一个不完整的批次
val train_loader_drop_last = DataLoader(dataset=dataset, batch_size=32, shuffle=true, drop_last=true)
println("\n--- DataLoader with drop_last=True ---")
for i <- Range(train_loader_drop_last.size):
val batch = train_loader_drop_last(i)
val features, labels = batch
println(s"Batch ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")
# 预期输出:只有 3 个大小为 32 的批次。最后的 9 个样本被丢弃。
```
### 打乱数据以获得更好的训练
在训练期间强烈建议设置 `shuffle=True`。它告诉 `DataLoader` 在为每个周期创建批次之前重新打乱数据集的索引。这确保模型每次都以不同的顺序查看数据,减少过度拟合数据呈现顺序的风险并提高模型健壮性。对于验证或测试,通常禁用打乱数据(`shuffle=False`),以确保不同运行之间评估指标的一致性。
### 使用 `num_workers` 并行加载数据
数据加载和预处理(应用变换)有时可能会成为瓶颈,特别是当变换很复杂或数据读取涉及大量 I/O 操作时。CPU 可能会花费大量时间准备下一个批次,而 GPU 则空闲等待数据。
`DataLoader` 允许你通过使用多个工作进程并行加载数据来缓解这个问题。你可以使用 `num_workers` 参数来指定工作进程的数量:
```scala 3
// 使用 4 个工作进程加载数据
// num_workers > 0 启用多进程数据加载
// 一个常见的起始点是 num_workers = 4 * num_gpus,但最优值取决于
// 系统(CPU 核心数、磁盘速度)和批次大小。通常需要通过实验确定。
val fast_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=true, num_workers=4)
// 迭代看起来相同,但数据加载发生在后台进程中
for i <- Range(fast_loader.size):
val batch = fast_loader(i)
val features, labels = batch
println(s"Batch ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import static org.bytedeco.pytorch.global.torch.*;
public class FastDataLoaderApp {
public static void main(String[] args) {
long numSamples = 105;
long batchSize = 32;
long numWorkers = 4;
DummyDataset dataset = new DummyDataset(numSamples);
DataLoaderOptions options = new DataLoaderOptions(batchSize);
options.workers().put(numWorkers);
RandomSampler sampler = new RandomSampler(numSamples);
JavaRandomDataLoader fastLoader = new JavaRandomDataLoader(dataset, sampler, options);
System.out.println("Dataset size: " + dataset.size());
System.out.println("DataLoader batch size: " + batchSize);
System.out.println("Number of workers: " + numWorkers);
int step = 0;
ExampleVectorIterator it = fastLoader.begin();
while (!it.equals(fastLoader.end())) {
ExampleVector batch = it.access();
Tensor features = stackExampleData(batch);
Tensor labels = stackExampleTarget(batch);
System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
++step,
features.sizes().get(0) + "x" + features.sizes().get(1),
labels.sizes().get(0));
it.increment();
}
}
private static Tensor stackExampleData(ExampleVector batch) {
TensorVector tensors = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tensors.push_back(batch.get(i).data());
}
return torch.stack(tensors);
}
private static Tensor stackExampleTarget(ExampleVector batch) {
TensorVector tensors = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tensors.push_back(batch.get(i).target());
}
return torch.stack(tensors);
}
}
```
当 `num_workers > 0` 时,`DataLoader` 会生成指定数量的 Python 进程。每个工作进程独立加载一个批次。这使得后续批次的数据加载和变换可以并行发生,而主进程则对当前批次执行模型训练步骤,通常通过更有效地利用 GPU 来显著提高速度。
请注意,增加 `num_workers` 也会增加 CPU 使用率和内存消耗,因为每个工作进程都会加载数据。将其设置过高有时会导致资源争用和收益递减,甚至减慢速度。它通常是根据你的具体硬件和数据集进行调整的参数。
并行加载 (如果 num_workers > 0)Dataset 对象(YourCustomDataset)DataLoader(批次大小, 打乱数据,工作进程数)封装工作进程 1分配索引工作进程 N分配索引批次(特征, 标签)获取并准备获取并准备训练循环(GPU/CPU)提供请求下一个
> `DataLoader` 的数据加载流程。`DataLoader` 封装了 `Dataset`,并且在 `num_workers > 0` 的情况下,使用工作进程获取并整理样本成批次,这些批次随后被训练循环使用。
### 使用 `pin_memory` 优化 GPU 传输
在 GPU 上训练时,由 `DataLoader` 加载的数据(位于标准 CPU 内存中)需要传输到 GPU 的内存中。这种传输需要时间。你通常可以通过在 `DataLoader` 中设置 `pin_memory=True` 来稍微加快速度。
```scala 3
val gpu_optimized_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=true,
num_workers=4,
pin_memory=true)
for i <- Range(gpu_optimized_loader.size):
val batch = gpu_optimized_loader(i)
val features, labels = batch
features = features.to('cuda') // 传输变得更快
labels = labels.to('cuda')
// ...其余训练步骤...
```
```java
package featurestore;
package featurestore;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
public class GPUOptimizedDataLoaderExample {
public static void main(String[] args) {
DummyDataset dataset = new DummyDataset(105);
DataLoaderOptions loaderOptions = new DataLoaderOptions();
loaderOptions.batch_size().put(32);
loaderOptions.enforce_ordering().put(false);
loaderOptions.workers().put(4);
torch.pinned_memory_or_default(new BoolOptional(true))
BoolOptional updatedDefault = torch.pinned_memory_or_default();
RandomSampler sampler = new RandomSampler(dataset.size().get());
JavaRandomDataLoader gpuOptimizedLoader = new JavaRandomDataLoader(dataset, sampler, loaderOptions);
boolean hasCUDA = torch.cuda_is_available();
Device cudaDevice = hasCUDA ? new Device(torch.DeviceType.CUDA) : null;
System.out.println("GPU可用状态: " + hasCUDA + " | 固定内存已启用: " + tensorOptions.has_pin_memory().get());
var batchIter = gpuOptimizedLoader.begin();
var iterEnd = gpuOptimizedLoader.end();
int batchIdx = 0;
while (!batchIter.equals(iterEnd)) {
batchIdx++;
var batch = batchIter.access();
Tensor features = batch.data();
Tensor labels = batch.target();
if (hasCUDA) {
features = features.to(cudaDevice, torch.ScalarType.Float);
labels = labels.to(cudaDevice, torch.ScalarType.Long);
System.out.printf("批次 %d: 特征/标签已迁移到CUDA | 特征形状: %s%n",
batchIdx, tensorShape(features.sizes().vec().get()));
}
features.close();
labels.close();
batch.close();
batchIter.increment();
}
gpuOptimizedLoader.close();
sampler.close();
dataset.features.close();
dataset.labels.close();
}
private static String tensorShape(SizeTPointer sizes) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < sizes.limit(); i++) {
sb.append(sizes.get(i)).append(i < sizes.limit()-1 ? "x" : "");
}
return sb.toString();
}
static class DummyDataset extends JavaDataset {
private final long numSamples;
private final Tensor features;
private final Tensor labels;
public DummyDataset(long numSamples) {
super();
this.numSamples = numSamples;
this.features = torch.randn(new long[]{numSamples, 10},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float)));
this.labels = torch.randint(0, 2, new long[]{numSamples},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long)));
}
@Override
public SizeTOptional size() {
return new SizeTOptional(new SizeTPointer(numSamples));
}
@Override
public Example get(long index) {
if (index < 0 || index >= numSamples) throw new IndexOutOfBoundsException();
Tensor feature = features.index_select(0, torch.tensor(new LongPointer(index)));
Tensor label = labels.index_select(0, torch.tensor(new LongPointer(index)));
return new Example(feature, label);
}
}
}
```
设置 `pin_memory=True` 指示 `DataLoader` 在 CPU 端将张量分配到“固定”(页面锁定)内存中。从固定 CPU 内存到 GPU 内存的传输通常比从标准可分页 CPU 内存的传输快。这在与 `num_workers > 0` 一起使用时最有效。请注意,使用固定内存会消耗更多的 CPU 内存。
总之,`DataLoader` 是 PyTorch 中一个核心的工具,它简化并优化了向模型提供数据的过程。通过处理批次化、打乱数据和并行加载,它使你能够专注于模型架构和训练逻辑,同时确保你的数据处理流程高效且可扩展。
# 自定义DataLoader用法
“尽管默认的`DataLoader`提供了方便的批处理和随机排列功能,但许多应用需要更精细地控制数据如何抽样和整理成批次。PyTorch通过自定义采样器和`collate`函数提供灵活性,让你能根据具体需求调整数据加载过程,例如处理不平衡数据集或使用可变大小的输入。”
### 使用采样器控制样本选择
`DataLoader`使用`sampler`对象来决定从`Dataset`中抽取索引的顺序。默认情况下,如果`shuffle=True`,它使用`RandomSampler`;如果`shuffle=False`,则使用`SequentialSampler`。但是,你可以通过`sampler`参数显式传入自己的采样器实例(请注意:如果你提供了`sampler`,则必须将`shuffle`设为`False`,因为随机排列是由采样器本身定义的)。
PyTorch在`torch.utils.data`中提供了几种内置采样器:
- `SequentialSampler`:按顺序采样元素,总是以相同的顺序。
- `RandomSampler`:随机采样元素。如果`replacement=True`,则进行有放回采样。
- `SubsetRandomSampler`:从给定索引列表中随机采样元素。它适用于创建验证集划分,而无需修改原始数据集。
- `WeightedRandomSampler`:根据给定概率(权重)从`[0,..,len(weights)-1]`中采样元素。这对于处理不平衡数据集特别有用,例如你想对少数类进行过采样或对多数类进行欠采样。
**示例:对不平衡数据使用`WeightedRandomSampler`**
设想一个分类数据集,其中类别“0”有900个样本,类别“1”有100个样本。简单的随机采样会导致批次严重偏向类别“0”。我们可以使用`WeightedRandomSampler`来提高类别“1”样本被选中的概率。
```scala 3
import torch
import torch.utils.data.{Dataset, DataLoader, WeightedRandomSampler}
val class_counts = torch.bincount(torch.tensor(targets))
val num_samples = targets.size
val sample_weights = torch.tensor([1.0 / class_counts(t) for t in targets])
val sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)
val dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
```
你也可以创建完全自定义的采样策略,通过继承`torch.utils.data.Sampler`并实现`__iter__`和`__len__`方法。
### 使用`collate_fn`自定义批次创建
一旦采样器为一个批次提供了索引列表,`DataLoader`会使用`dataset[index]`从`Dataset`中获取对应的样本。然后,它需要将这些单独的样本组装成一个批次。这个组装过程由`collate_fn`参数处理。
默认的`collate_fn`在许多标准情况下都能很好地工作。它会尝试:
- 将NumPy数组和Python数字转换为PyTorch张量。
- 保留数据结构(例如,如果你的`Dataset.__getitem__`返回一个字典,则整理后的批次将是一个字典,其中每个值是对应项目的批次)。
- 沿新维度(批次维度)堆叠张量。
但是,如果你的样本具有不同的大小(例如,不同长度的序列)或包含它不知道如何堆叠的数据类型,默认的`collate_fn`可能会失败或产生不理想的结果。
在这种情况下,你可以为`DataLoader`的`collate_fn`参数提供一个自定义函数。这个函数接收一个样本列表(其中每个样本是`Dataset.__getitem__`的输出),并负责以所需格式返回整理后的批次。
**示例:填充可变长度序列**
一个常见情况是涉及长度不同的序列(例如NLP中的句子)。默认的`collate`函数不能直接将它们堆叠成一个张量。自定义的`collate_fn`可以将每个批次中的序列填充到该批次中的最大长度。
```scala 3
import torch
import torch.utils.data.{Dataset, DataLoader}
import torch.nn.utils.rnn.{pad_sequence}
class VariableSequenceDataset extends Dataset[(torch.Tensor, Int)]:
def __init__(self, data):
// data是一个张量列表,例如 [torch.randn(5), torch.randn(8), ...]
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
// 为简单起见,假设每个项目也有一个标签(例如其长度)
val sequence = self.data(idx)
val label = len(sequence)
return sequence, label
// 自定义collate函数
def pad_collate(batch):
// batch是一个元组列表:[(序列1, 标签1), (序列2, 标签2), ...]
// 按序列长度对批次元素进行排序(可选,但通常为了RNN效率而进行)
// batch.sort(key=lambda x: len(x[0]), reverse=True) // 对于填充不是严格必需的
// 分离序列和标签
val sequences = batch.map(_._1)
val labels = batch.map(_._2)
// 将序列填充到批次中最长序列的长度
// `batch_first=True` 使输出形状变为 (batch_size, 最大序列长度, 特征)
val padded_sequences = pad_sequence(sequences, batch_first=true, padding_value=0.0)
// 堆叠标签(假设它们是简单的标量)
val labels = torch.tensor(labels)
return padded_sequences, labels
// 创建数据集和dataloader
val sequences = [torch.randn(torch.randint(5, 15, (1,)).item()) for _ in range(100)]
val dataset = VariableSequenceDataset(sequences)
// 创建dataloader
val dataloader = DataLoader(dataset, batch_size=4, collate_fn=pad_collate)
// 遍历dataloader
// for padded_batch, label_batch in dataloader:
// // padded_batch 形状:如果序列是一维的,则为 (4, 该批次中的最大长度, 1)
// // label_batch 形状:(4,)
// // 模型处理...
// pass
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
public class VariableSequenceDataset extends JavaDataset {
private final TensorVector data;
public VariableSequenceDataset(TensorVector data) {
this.data = data;
}
@Override
public Example get(long index) {
Tensor sequence = data.get(index);
Tensor label = torch.tensor(sequence.size(0));
return new Example(sequence, label);
}
@Override
public SizeTOptional size() {
return new SizeTOptional(data.size());
}
}
```
```java
package torch;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.Random;
import static org.bytedeco.pytorch.global.torch.*;
public class VariableSequenceApp {
public static void main(String[] args) {
int totalSamples = 100;
int batchSize = 4;
Random random = new Random();
TensorVector sequences = new TensorVector();
for (int i = 0; i < totalSamples; i++) {
long length = random.nextInt(10) + 5;
sequences.push_back(torch.randn(new long[]{length}));
}
VariableSequenceDataset dataset = new VariableSequenceDataset(sequences);
DataLoaderOptions options = new DataLoaderOptions(batchSize);
JavaRandomDataLoader dataLoader = new JavaRandomDataLoader(dataset, new RandomSampler(totalSamples), options);
ExampleVectorIterator it = dataLoader.begin();
int step = 0;
while (!it.equals(dataLoader.end())) {
ExampleVector batch = it.access();
TensorVector seqVec = new TensorVector();
TensorVector labelVec = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
seqVec.push_back(batch.get(i).data());
labelVec.push_back(batch.get(i).target());
}
Tensor paddedSequences = torch.pad_sequence(seqVec, true, 0.0, new BytePointer ("right"));
Tensor labels = torch.stack(labelVec);
System.out.printf("Batch %d: Padded Shape=%s, Labels Shape=%s%n",
++step,
paddedSequences.sizes().get(0) + "x" + paddedSequences.sizes().get(1),
labels.sizes().get(0));
it.increment();
}
}
}
```
这个自定义`collate_fn`使用`torch.nn.utils.rnn.pad_sequence`来处理填充,确保批次中的所有序列长度相同,使它们适合RNN等模型处理。
### 其他DataLoader自定义参数
除了`sampler`和`collate_fn`,其他参数也提供性能和行为调整:
- `num_workers` (整数,可选):指定用于数据加载的子进程数量。将其设置为正整数可启用多进程数据加载,这可以显著加快数据获取速度,尤其当数据加载涉及磁盘I/O或CPU上的复杂预处理时。一个常见的起始设置是将其设为可用CPU核心的数量。默认值:`0`(数据加载在主进程中进行)。
- `pin_memory` (布尔值,可选):如果为`True`,`DataLoader`在返回张量之前会将其复制到CUDA固定内存中。固定内存可以加快从CPU到GPU的数据传输。这仅在你使用GPU进行训练时才有效。默认值:`False`。
- `drop_last` (布尔值,可选):如果为`True`,当数据集大小不能被批次大小整除时,将丢弃最后一个不完整的批次。如果为`False`(默认值),则最后一个批次可能小于`batch_size`。
通过理解和使用采样器、自定义`collate`函数以及其他`DataLoader`参数,你可以对数据管道获得精确的控制,从而能高效处理各种数据类型和结构,解决数据集不平衡问题,并优化数据加载性能以加快模型训练。
# 动手实践:构建数据处理流程
构建一个完整的数据处理流程涉及从原始数据(为简化起见,将合成数据)开始,最终得到准备好输入模型的数据批次。这个过程需要创建一个自定义 `Dataset`,定义数据转换,并将所有内容封装在一个 `DataLoader` 中。
### 构建合成数据集
假设我们有一个数据集,包含特征向量和对应的二元分类标签(0 或 1)。在本次练习中,我们将使用 PyTorch 张量直接生成这些数据。这避免了文件输入/输出的复杂性,让我们能够专注于数据处理机制。
```scala 3
import torch
import torch.utils.data as data
from torchvision import transforms
val num_samples = 100
val num_features = 10
val features = torch.randn(num_samples, num_features)
val labels = torch.randint(0, 2, (num_samples,))
println(s"Shape of features: ${features.shape}")
println(s"Shape of labels: ${labels.shape}")
println(s"First 5 features:\n${features[:5]}")
println(s"First 5 labels:\n${labels[:5]}")
```
这使我们得到了两个张量:`features` 包含 100 个样本,每个样本有 10 个特征;`labels` 包含对应的 100 个标签。
### 创建自定义 `Dataset`
现在,我们需要使用 PyTorch 的 `Dataset` 类来组织这些数据。我们将创建一个自定义类,它继承自 `torch.utils.data.Dataset` 并实现两个重要方法:
1. `__len__(self)`: 返回数据集中样本的总数。
2. `__getitem__(self, idx)`: 返回指定索引 `idx` 处的样本(特征和标签)。
我们还将添加一个 `__init__` 方法来存储数据并可选地接受转换。
```scala 3
class SyntheticDataset extends data.Dataset:
"""一个用于合成特征和标签的自定义数据集。"""
def __init__(features, labels, transform=None):
"""
参数:
features (Tensor): 包含特征数据的张量。
labels (Tensor): 包含标签的张量。
transform (callable, optional): 可选的样本转换。
"""
// 基本检查,确保特征和标签具有相同的样本数量
assert features.shape[0] == labels.shape[0], \
"特征和标签的样本数量必须一致"
// 存储特征和标签
val features = features
val labels = labels
// 存储转换
val transform = transform
def __len__(self):
"""返回样本总数。"""
return features.shape[0]
def __getitem__(idx):
"""
根据给定索引获取特征向量和标签。
参数:
idx (int): 要获取的样本索引。
返回:
tuple: (特征, 标签),其中 feature 是特征向量,label 是对应的标签。
"""
// 获取原始特征和标签
val feature_sample = features(idx)
val label_sample = labels(idx)
// 创建一个样本字典(或元组)
val sample = Map('feature -> feature_sample, 'label -> label_sample)
// 如果存在转换,则应用转换
if transform:
sample = transform(sample)
// 返回可能已转换的样本
// 常见做法是分别返回特征和标签
return sample('feature), sample('label)
// 暂时实例化不带转换的数据集
val raw_dataset = SyntheticDataset(features, labels)
// 测试获取一个样本
val sample_idx = 0
val feature_sample, label_sample = raw_dataset(sample_idx)
println(f"\nSample {sample_idx} - Feature: {feature_sample}")
println(f"Sample {sample_idx} - Label: {label_sample}")
println(f"Dataset length: {raw_dataset.size()}") // 输出: 100
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
public class SyntheticDataset extends JavaDataset {
private final Tensor features;
private final Tensor labels;
public SyntheticDataset(Tensor features, Tensor labels) {
if (features.size(0) != labels.size(0)) {
throw new IllegalArgumentException("特征和标签的样本数量必须一致");
}
this.features = features;
this.labels = labels;
}
@Override
public Example get(long index) {
Tensor featureSample = features.select(0, index);
Tensor labelSample = labels.select(0, index);
return new Example(featureSample, labelSample);
}
@Override
public SizeTOptional size() {
return new SizeTOptional(features.size(0));
}
}
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import static org.bytedeco.pytorch.global.torch.*;
public class SyntheticApp {
public static void main(String[] args) {
long numSamples = 100;
long numFeatures = 10;
Tensor features = torch.randn(new long[]{numSamples, numFeatures});
Tensor labels = torch.randint(0, 2, new long[]{numSamples});
System.out.println("Shape of features: " + features.sizes().get(0) + "x" + features.sizes().get(1));
System.out.println("Shape of labels: " + labels.sizes().get(0));
SyntheticDataset rawDataset = new SyntheticDataset(features, labels);
long sampleIdx = 0;
Example sample = rawDataset.get(sampleIdx);
Tensor featureSample = sample.data();
Tensor labelSample = sample.target();
System.out.println("\nSample " + sampleIdx + " - Feature: " + featureSample);
System.out.println("Sample " + sampleIdx + " - Label: " + labelSample);
System.out.println("Dataset length: " + rawDataset.size().get());
}
}
```
此时,`raw_dataset` 包含我们的数据并知道如何提供单个样本。
### 定义数据转换
通常,原始数据不适合直接输入神经网络。我们可能需要规范化特征、转换数据类型或应用数据增强(特别是对于图像)。`torchvision.transforms` 提供了方便的工具。即使我们的数据不是图像,我们也可以定义自定义转换或使用对张量进行操作的现有转换。
让我们定义一个简单转换流程:
1. 将特征张量转换为 `torch.float32`(模型输入的良好做法)。
2. 将标签张量转换为 `torch.long`(损失函数如 `CrossEntropyLoss` 常需要的)。
3. 对特征应用规范化(减去均值,除以标准差)。我们将在此示例中从合成数据集中计算这些统计量。
由于 `torchvision.transforms` 主要为图像(PIL 图像或张量)设计,将其直接应用于像我们 `sample` 这样的字典需要一些封装。我们将为此创建自定义可调用类或 lambda 函数。
```scala 3
val feature_mean = features.mean(dim=0)
val feature_std = features.std(dim=0)
feature_std(feature_std == 0) = 1.0
class ToTensorAndType:
"""将特征转换为 FloatTensor,将标签转换为 LongTensor。"""
def __call__(self, sample):
feature, label = sample['feature'], sample['label']
return {'feature': feature.float(), 'label': label.long()}
class NormalizeFeatures:
"""规范化特征张量。"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, sample):
feature, label = sample['feature'], sample['label']
# 应用规范化: (张量 - 均值) / 标准差
normalized_feature = (feature - self.mean) / self.std
return {'feature': normalized_feature, 'label': label}
val data_transforms = transforms.Compose(
ToTensorAndType(),
NormalizeFeatures(mean=feature_mean, std=feature_std)
)
val transformed_dataset = SyntheticDataset(features, labels, transform=data_transforms)
val sample_idx = 0
val transformed_feature, transformed_label = transformed_dataset(sample_idx)
println(f"\n--- 转换后的样本 {sample_idx} ---")
println(f"原始特征:\n{features(sample_idx)}")
println(f"转换后特征:\n{transformed_feature}")
println(f"原始标签: {labels(sample_idx)} (dtype={labels.dtype})")
println(f"转换后标签: {transformed_label} (dtype={transformed_label.dtype})")
println(f"转换后特征均值: {transformed_feature.mean():.4f}")
```
```java
package torch;
import org.bytedeco.pytorch.*;
import java.util.function.Function;
public class FeatureLabelDataset extends JavaDataset {
private final Tensor features;
private final Tensor labels;
private final Function<Example, Example> transform;
public FeatureLabelDataset(Tensor features, Tensor labels, Function<Example, Example> transform) {
if (features.size(0) != labels.size(0)) {
throw new IllegalArgumentException("特征和标签的样本数量必须一致");
}
this.features = features;
this.labels = labels;
this.transform = transform;
}
@Override
public Example get(long index) {
Tensor featureSample = features.select(0, index);
Tensor labelSample = labels.select(0, index);
Example example = new Example(featureSample, labelSample);
return (transform != null) ? transform.apply(example) : example;
}
@Override
public SizeTOptional size() {
return new SizeTOptional(features.size(0));
}
}
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.function.Function;
import static org.bytedeco.pytorch.global.torch.*;
public class FeatureLabelApp {
public static void main(String[] args) {
long numSamples = 100;
long numFeatures = 10;
Tensor features = torch.randn(new long[]{numSamples, numFeatures});
Tensor labels = torch.randint(0, 2, new long[]{numSamples});
Tensor featureMean = features.mean(new long[]{0}, false, new ScalarTypeOptional(ScalarType.Float));
Tensor featureStd = features.std(new long[]{0}, false);
featureStd.masked_fill_(featureStd.eq(new Scalar(0.0)), new Scalar(1.0));
Function<Example, Example> toType = e ->
new Example(e.data().to(torch.kFloat()), e.target().to(torch.kInt()));
Function<Example, Example> normalize = e -> {
Tensor normData = e.data().sub(featureMean).div(featureStd);
return new Example(normData, e.target());
};
Function<Example, Example> pipeline = toType.andThen(normalize);
FeatureLabelDataset dataset = new FeatureLabelDataset(features, labels, pipeline);
long sampleIdx = 0;
Example result = dataset.get(sampleIdx);
System.out.println("--- 转换后的样本信息 ---");
System.out.println("特征形状: " + result.data().sizes().get(0));
System.out.println("特征数据类型: " + result.data().scalar_type());
System.out.println("标签数据类型: " + result.target().scalar_type());
System.out.printf("转换后特征均值: %.4f%n", result.data().mean().item_float());
}
}
```
注意到由于规范化,特征值已发生改变,并且特征和标签的数据类型现在分别是 `torch.float32` 和 `torch.int64` (LongTensor)。
### 使用 `DataLoader`
最后一步是使用 `DataLoader`。它接收我们的 `Dataset` 实例,并处理批处理、数据混洗以及可能的并行数据加载。
```scala 3
val batch_size = 16
val shuffle_data = true
val num_workers = 0
val data_loader = data.DataLoader(
transformed_dataset,
batch_size=batch_size,
shuffle=shuffle_data,
num_workers=num_workers
)
println(f"\n--- 迭代 DataLoader (batch_size={batch_size}) ---")
val feature_batch, label_batch = next(iter(data_loader))
val feature_batch, label_batch = next(iter(data_loader))
println(f"Type of feature_batch: {type(feature_batch)}")
println(f"Shape of feature_batch: {feature_batch.shape}")
println(f"Shape of label_batch: {label_batch.shape}")
println(f"Data type of feature_batch: {feature_batch.dtype}")
println(f"Data type of label_batch: {label_batch.dtype}")
```
```java
package torch;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.function.Function;
import static org.bytedeco.pytorch.global.torch.*;
public class FeatureLabelAppV2 {
public static void main(String[] args) {
long numSamples = 100;
long numFeatures = 10;
Tensor features = torch.randn(new long[]{numSamples, numFeatures});
Tensor labels = torch.randint(0, 2, new long[]{numSamples});
Tensor featureMean = features.mean(new long[]{0}, false, new ScalarTypeOptional(ScalarType.Float));
Tensor featureStd = features.std(new long[]{0}, false);
featureStd.masked_fill_(featureStd.eq(new Scalar(0.0)), new Scalar(1.0));
Function<Example, Example> pipeline = e -> {
Tensor castedData = e.data().to(torch.kFloat());
Tensor castedTarget = e.target().to(torch.kLong());
Tensor normData = castedData.sub(featureMean).div(featureStd);
return new Example(normData, castedTarget);
};
FeatureLabelDataset dataset = new FeatureLabelDataset(features, labels, pipeline);
long batchSize = 16;
DataLoaderOptions options = new DataLoaderOptions(batchSize);
int numWorkers = System.getProperty("os.name").toLowerCase().contains("win") ? 0 : 2;
options.workers().put(numWorkers);
RandomSampler sampler = new RandomSampler(numSamples);
JavaRandomDataLoader dataLoader = new JavaRandomDataLoader(dataset, sampler, options);
System.out.println("\n--- 迭代 DataLoader (batch_size=" + batchSize + ") ---");
long batchIdx = 0;
ExampleVectorIterator it = dataLoader.begin();
if (!it.equals(dataLoader.end())) {
ExampleVector batch = it.access();
Tensor featureBatch = stackData(batch);
Tensor labelBatch = stackTarget(batch);
System.out.println("Batch 1 特征形状: " + featureBatch.sizes().get(0) + "x" + featureBatch.sizes().get(1));
System.out.println("Batch 1 标签形状: " + labelBatch.sizes().get(0));
System.out.println("特征数据类型: " + featureBatch.scalar_type());
System.out.println("标签数据类型: " + labelBatch.scalar_type());
it.increment();
batchIdx++;
}
while (!it.equals(dataLoader.end())) {
ExampleVector batch = it.access();
Tensor featureBatch = stackData(batch);
Tensor labelBatch = stackTarget(batch);
if (batchIdx < 3) {
System.out.printf("Batch %d: Features shape=%dx%d%n",
batchIdx + 1, featureBatch.size(0), featureBatch.size(1));
}
it.increment();
batchIdx++;
}
System.out.println("总批次数: " + batchIdx);
}
public static Tensor stackData(ExampleVector batch) {
TensorVector tv = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tv.push_back(batch.get(i).data());
}
return torch.stack(tv);
}
public static Tensor stackTarget(ExampleVector batch) {
TensorVector tv = new TensorVector();
for (long i = 0; i < batch.size(); i++) {
tv.push_back(batch.get(i).target());
}
return torch.stack(tv);
}
}
```
`DataLoader` 产生批次,其中第一个维度对应于 `batch_size`。我们的特征批次形状为 `[16, 10]`,标签批次形状为 `[16]`。数据类型反映了我们应用的转换。
### 数据处理流程可视化
我们可以可视化刚刚创建的流程:
原始数据(特征, 标签张量)自定义数据集(SyntheticDataset)__init__数据转换(ToTensor, 规范化)__getitem__ 应用DataLoader(批处理, 混洗)输入数据集模型输入(批处理和已处理数据)迭代获取批次

> 此图表展示了从原始张量到自定义 `Dataset` 的演进过程,在数据获取时应用转换(`__getitem__`),最后使用 `DataLoader` 生成适合模型训练的混洗批次。
你现在已成功使用 PyTorch 的核心数据工具构建了数据处理流程。你创建了一个 `Dataset` 来封装你的数据,应用了必要的 `transforms`,并使用 `DataLoader` 高效地生成批次。这种结构化方法是 PyTorch 项目中处理数据的基础,确保你的模型接收正确格式的数据并促进高效训练。这个流程现在已准备好集成到我们将在下一章构建的训练循环中。