【Java深度学习】PyTorch On Java 系列课程 第五章 11 :数据集高级 Dataloader[PyTorch Java课程

0 阅读29分钟

SCR-20260220-ufky.png

# 使用 `torch.utils.data.DataLoader`



尽管 `Dataset` 类提供了一种清晰的方式来抽象对单个数据样本的访问,但对大型数据集逐个样本进行迭代对于训练深度学习模型通常效率不高。训练通常受益于分批处理数据。在这种情况下,`torch.utils.data.DataLoader` 在高效数据处理中发挥着核心作用。

`DataLoader` 封装了一个 `Dataset`(无论是内置的还是你自定义的实现)并提供了对其的迭代接口。它的主要职责是:

1. **分批处理**:将单个样本分组为小批量。
2. **打乱数据**:在每个周期随机打乱数据,以防止模型学习到样本的顺序并提高泛化能力。
3. **并行数据加载**:使用多个子进程并发加载数据,这可能会显著加快数据处理流程。

### 基本用法和迭代

创建 `DataLoader` 很简单。你主要需要提供 `Dataset` 实例并指定所需的 `batch_size`。

```scala 3
import torch
import torch.utils.data.Dataset
import torch.utils.data.DataLoader

// 假设 'YourCustomDataset' 已如前所示定义
// 或者使用内置数据集,例如 datasets.MNIST
// 为了演示,我们创建一个简单的虚拟数据集:
class DummyDataset extends Dataset:
    def __init__(self, num_samples=100):
        val num_samples = num_samples
        val features = torch.randn(num_samples, 10) // 示例:100 个样本,10 个特征
        val labels = torch.randint(0, 2, (num_samples,)) // 示例:100 个二元标签

    def __len__(self):
        return num_samples

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

// 实例化数据集
val dataset = DummyDataset(num_samples=105)

// 实例化 DataLoader
// batch_size: 每个批次的样本数量
// shuffle: 设置为 True 以在每个周期打乱数据(对训练很重要)
val train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=true)

// 迭代 DataLoader
println(s"Dataset size: ${len(dataset)}")
println(s"DataLoader batch size: ${train_loader.batch_size}")

for epoch <- Range(1): // 一个周期的示例
    println(s"\n--- Epoch ${epoch+1} ---")
    for i, batch in enumerate(train_loader):
        // DataLoader 产生批次。每个 'batch' 通常是元组或列表
        // 包含特征和标签的张量。
        val features, labels = batch
        println(s"Batch ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")
        // 在这里你通常会执行训练步骤:
        model.train()
        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
```

```java
package featurestore;

import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.JavaRandomDataLoader;
import org.bytedeco.pytorch.DataLoaderOptions;
import org.bytedeco.pytorch.global.torch;

import java.util.Iterator;

/**
 * 虚拟数据集(DummyDataset),继承JavaDataset
 * 模拟随机特征和二元标签,对应Python的DummyDataset
 */
public class DummyDataset extends JavaDataset {
    private final long numSamples; // 样本总数
    private final Tensor features; // 特征张量 [numSamples, 10]
    private final Tensor labels;   // 标签张量 [numSamples](二元标签 0/1)

    /**
     * 构造方法(对应Python的__init__)
     * @param numSamples 样本数量,默认100
     */
    public DummyDataset(long numSamples) {
        super();
        this.numSamples = numSamples;

        // 生成随机特征:numSamples x 10 的正态分布张量(对应torch.randn)
        this.features = torch.randn(
                new long[]{numSamples, 10},
                new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float))
        );

        // 生成二元标签:0-1的随机整数(对应torch.randint(0,2, (numSamples,)))
        this.labels = torch.randint(
                0, 2,
                new long[]{numSamples},
                new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long))
        );
    }

    // 无参构造,默认100个样本
    public DummyDataset() {
        this(100);
    }

    /**
     * 返回样本总数(对应Python的__len__)
     */
    @Override
    public SizeTOptional size() {
        return new SizeTOptional(new SizeTPointer(numSamples));
    }

    /**
     * 按索引获取样本(对应Python的__getitem__)
     * @param index 样本索引
     * @return Example 包含特征和标签张量的样本
     */
    @Override
    public Example get(long index) {
        // 边界检查
        if (index < 0 || index >= numSamples) {
            throw new IndexOutOfBoundsException("索引超出范围: " + index);
        }

        // 提取单个样本的特征和标签(对应features[idx], labels[idx])
        Tensor feature = features.index_select(0, torch.tensor(new LongPointer(index)));
        Tensor label = labels.index_select(0, torch.tensor(new LongPointer(index)));

        return new Example(feature, label);
    }

    // ------------------- 训练迭代示例 -------------------
    public static void main(String[] args) {
        // 1. 实例化虚拟数据集(105个样本,对应Python的num_samples=105)
        DummyDataset dataset = new DummyDataset(105);

        // 2. 配置DataLoader参数(对应Python的DataLoader)
        DataLoaderOptions loaderOptions = new DataLoaderOptions();
        loaderOptions.batch_size().put(32); // 批次大小32
        loaderOptions.enforce_ordering().put(false); // 允许乱序加载(对应shuffle=true) // 打乱数据(训练必备)
        
        RandomSampler sampler = new RandomSampler(dataset.size().get());

        // 3. 实例化DataLoader
        JavaRandomDataLoader trainLoader = new JavaRandomDataLoader(dataset,sampler, loaderOptions);

        // 打印数据集和DataLoader信息
        System.out.println("Dataset size: " + dataset.size().get());
        System.out.println("DataLoader batch size: " + loaderOptions.batch_size().get());

        // 4. 模拟1个训练周期(对应Python的for epoch <- Range(1))
        int numEpochs = 1;
        for (int epoch = 0; epoch < numEpochs; epoch++) {
            System.out.println("\n--- Epoch " + (epoch + 1) + " ---");

            var beginBatch = trainLoader.begin();
            var endBatch = trainLoader.end();
            while(beginBatch != endBatch){
                var batch = beginBatch.access();
                
            
                
                beginBatch.increment();
            }
            // 迭代DataLoader获取批次数据
//            Iterator<IValue[]> batchIterator = trainLoader.iterator();
//            int batchIdx = 0;
//            while (batchIterator.hasNext()) {
//                batchIdx++;
//                IValue[] batch = batchIterator.next();
//
//                // 提取批次的特征和标签张量(对应Python的features, labels = batch)
//                Tensor batchFeatures = batch[0].toTensor();
//                Tensor batchLabels = batch[1].toTensor();
//
//                // 打印批次信息(特征形状、标签形状)
//                System.out.printf(
//                        "Batch %d: Features shape=%s, Labels shape=%s%n",
//                        batchIdx,
//                        tensorShapeToString(batchFeatures.sizes()),
//                        tensorShapeToString(batchLabels.sizes())
//                );

                // ------------------- 模拟训练步骤(核心逻辑) -------------------
                // 注:这里仅演示流程,需替换为你实际的model/optimizer/criterion
                /*
                // 1. 模型设为训练模式
                model.train();

                // 2. 梯度归零
                optimizer.zero_grad();

                // 3. 前向传播
                Tensor outputs = model.forward(IValue.from(batchFeatures)).toTensor();

                // 4. 计算损失
                Tensor loss = criterion.forward(outputs, batchLabels).toTensor();

                // 5. 反向传播
                loss.backward(torch.tensor(new float[]{1.0f}, new long[]{}));

                // 6. 更新权重
                optimizer.step();
                */

                // 释放当前批次张量资源
//                batchFeatures.close();
//                batchLabels.close();
            }
        }

        // 释放数据集张量资源
//        dataset.features.close();
//        dataset.labels.close();
//        trainLoader.close();
    }

    /**
     * 辅助方法:将张量形状(SizeTPointer)转为易读字符串
     */
    private static String tensorShapeToString(SizeTPointer sizes) {
        StringBuilder sb = new StringBuilder("[");
        for (int i = 0; i < sizes.limit(); i++) {
            sb.append(sizes.get(i));
            if (i < sizes.limit() - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }
}


```

```java
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

public class DummyDataset extends JavaDataset {
    private final Tensor features;
    private final Tensor labels;
    private final long numSamples;

    public DummyDataset(long numSamples) {
        this.numSamples = numSamples;
        // 相当于 torch.randn(num_samples, 10)
        this.features = torch.randn(new long[]{numSamples, 10});
        // 相当于 torch.randint(0, 2, (num_samples,))
        this.labels = torch.randint(0, 2, new long[]{numSamples});
    }

    // 对应 Python 的 __getitem__
    @Override
    public Example get(long index) {
        // 使用 index() 获取切片
//        Tensor feature = features.select(0, index);
//        Tensor label = labels.select(0, index);
        Tensor feature = features.index(new TensorIndexVector(new TensorIndex(index)));
        Tensor label = labels.index(new TensorIndexVector(new TensorIndex(index)));
        return new Example(feature, label);
    }

    // 对应 Python 的 __len__
    @Override
    public SizeTOptional size() {
        return new SizeTOptional(numSamples);
    }
}

package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

public class DummyDatasetApp {
    public static void main(String[] args) {
        long numSamples = 105;
        long batchSize = 32;

        // 1. 初始化 (沿用之前的 DummyDataset)
        DummyDataset dataset = new DummyDataset(numSamples);
        DataLoaderOptions options = new DataLoaderOptions(batchSize);
        RandomSampler sampler = new RandomSampler(numSamples);
        JavaRandomDataLoader trainLoader = new JavaRandomDataLoader(dataset, sampler, options);

        // 2. 打印基本信息
        System.out.println("Dataset size: " + dataset.size());
        System.out.println("DataLoader batch size: " + batchSize);

        // 3. 模拟 Epoch 迭代
        for (int epoch = 0; epoch < 1; epoch++) {
            System.out.println("\n--- Epoch " + (epoch + 1) + " ---");

            int step = 0;
            // 获取初始迭代器
            ExampleVectorIterator it = trainLoader.begin();

            // 循环遍历批次
            while (!it.equals(trainLoader.end())) {
                // 获取当前批次 (ExampleVector 包含多个 Example)
                ExampleVector batch = it.access();
                // 聚合数据:将 List<Example> 转换为 Batch Tensor [B, 10]
                Tensor features = stackExampleData(batch);
                Tensor labels = stackExampleTarget(batch);

                System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
                        ++step,
                        features.sizes().get(0) + "x" + features.sizes().get(1),
                        labels.sizes().get(0));

                /* 
                // 训练步骤占位:
                model.train();
                optimizer.zero_grad();
                Tensor outputs = model.forward(features);
                Tensor loss = criterion.forward(outputs, labels);
                loss.backward();
                optimizer.step();
                */

                it.increment(); // 移动到下一批次
            }
        }
    }

    // 将批次内的 Tensor 堆叠在一起 (对应 Python DataLoader 的默认 Collate 功能)
    private static Tensor stackExampleData(ExampleVector batch) {
        TensorVector tensors = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensors.push_back(batch.get(i).data());
        }
        return torch.stack(tensors);
    }

    private static Tensor stackExampleTarget(ExampleVector batch) {
        TensorVector tensors = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensors.push_back(batch.get(i).target());
        }
        return torch.stack(tensors);
    }
}



```

```java
import org.bytedeco.pytorch.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.pytorch.global.torch.*;

public class DummyDataset extends ChunkDataset implements AutoCloseable {
    static { Loader.load(); }

    private Tensor features;
    private Tensor labels;
    private long numSamples;
    private long currentIndex = 0;

    public DummyDataset(Pointer p) { super(p); }

    public DummyDataset(long numSamples) {
        super((Pointer)null);
        this.numSamples = numSamples;
        // 对应 torch.randn(num_samples, 10)
        this.features = randn(new long[]{numSamples, 10}, new TensorOptions().dtype(new ScalarTypeOptional(kFloat())));
        // 对应 torch.randint(0, 2, (num_samples,))
        this.labels = randint(0, 2, new long[]{numSamples}, new TensorOptions().dtype(new ScalarTypeOptional(kLong())));
    }

    @Override
    public @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size) {
        if (currentIndex >= numSamples) {
            return new ExampleVectorOptional(); // 返回 empty 表示 epoch 结束
        }

        long actualBatchSize = Math.min(batch_size, numSamples - currentIndex);

        // 使用 narrow 获取切片: narrow(维度, 起始, 长度)
        Tensor batchFeatures = features.narrow(0, currentIndex, actualBatchSize);
        Tensor batchLabels = labels.narrow(0, currentIndex, actualBatchSize);

        // 更新索引
        currentIndex += actualBatchSize;

        // 注意:由于 ChunkDataset 预期返回的是 vector<Example>
        // 在这里我们把这一个 batch 的 Tensor 包装成一个 ExampleVector
        ExampleVector vec = new ExampleVector();
        vec.push_back(new Example(batchFeatures, batchLabels));

        return new ExampleVectorOptional(vec);
    }

    @Override
    public void reset() {
        this.currentIndex = 0;
    }

    @Override
    public @ByVal SizeTOptional size() {
        return new SizeTOptional(numSamples);
    }
}


```


```java

    public static void main(String[] args) {
        // 1. 实例化数据集
        long numSamples = 105;
        DummyDataset dataset = new DummyDataset(numSamples);

        // 2. 实例化 DataLoader 配置
        // 注意:ChunkDataset 已经是 Stateful 的,所以使用 StatefulDataLoader
        long batchSize = 32;
        DataLoaderOptions options = new DataLoaderOptions(batchSize);

        // 在 JavaCPP 中,ChunkDataset 对应的加载器通常通过以下方式构造
        // 假设已经映射了对应的 ChunkRandomDataLoader
        ChunkRandomDataLoader trainLoader = new ChunkRandomDataLoader(dataset, options);

        System.out.println("Dataset size: " + dataset.size().get());

        // 3. 迭代 DataLoader (对应 Python 的 for i, batch in enumerate(train_loader))
        for (int epoch = 0; epoch < 1; epoch++) {
            System.out.println("\n--- Epoch " + (epoch + 1) + " ---");

            int batchIdx = 0;
            // 获取迭代器
            ExampleIterator iter = trainLoader.begin();
            ExampleIterator end = trainLoader.end();

            while (!iter.equals(end)) {
                // 获取当前 batch (这里 batch 是一个 Example 对象)
                Example batch = iter.access();

                Tensor features = batch.data();
                Tensor labels = batch.target();

                System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
                        (++batchIdx),
                        java.util.Arrays.toString(features.sizes().get()),
                        java.util.Arrays.toString(labels.sizes().get()));

                // --- 训练步骤示例 ---
                // model.train();
                // optimizer.zero_grad();
                // Tensor outputs = model.forward(features);
                // Tensor loss = criterion.forward(outputs, labels);
                // loss.backward();
                // optimizer.step();

                iter.increment(); // 移动到下一个 batch
            }

            trainLoader.retainReference()
//            trainLoader.reset(); // 每个 epoch 结束手动或自动 reset
        }
    }
```

```java

public class DummyMapDataset extends ChunkMapDataset {
    private long numSamples;
    Tensor features = randn(new long[]{32, 10},  new TensorOptions().dtype(new ScalarTypeOptional(kFloat())));
    Tensor labels = randint(0, 2, new long[]{32}, new TensorOptions().dtype(new ScalarTypeOptional(kLong())));
   
    public DummyMapDataset(long numSamples) {
        super((Pointer) null);
        this.numSamples = numSamples;
        // 注意:在实际 LibTorch 中,ChunkDataset 通常需要初始化底层 C++ 指针
        // 如果 ChunkMapDataset 没有提供默认 allocate,这里可能需要调用 native allocate
    }

    // 实现 ChunkDataset 要求的核心方法:获取特定的数据块
    // 注意:ChunkDataset 并不是简单的 getitem(idx),而是 get_batch(batch_index)
    @Override
    public ExampleOptional get_batch_example(@Cast("size_t") long batch_index) {
        // 1. 创建 Tensor 数据(模拟数据加载)

        // 2. 封装成 Example 对象
        Example example = new Example(features, labels);
        return new ExampleOptional(example);
    }

    @Override
    public @Cast("size_t") long size() {
        return numSamples;
    }
}

```

运行此代码将展示 `DataLoader` 如何产生数据批次。注意每个批次打印的形状反映了 `batch_size`(除了可能的最后一个批次)。

### 批次策略和 `drop_last`

默认情况下,如果 `Dataset` 中的总样本数不能被 `batch_size` 完全整除,最后一个批次将包含剩余样本,因此会更小。

在我们有 105 个样本、批次大小为 32 的示例中:

- 批次 1: 32 个样本
- 批次 2: 32 个样本
- 批次 3: 32 个样本
- 批次 4: 9 个样本 (105 - 3*32 = 9)

有时,拥有可变批次大小,尤其是非常小的最后一个批次,可能会影响某些训练动态或特定层的要求(例如训练期间的 BatchNorm 层,尽管 PyTorch 处理得相当好)。如果你希望所有批次都具有精确的 `batch_size`,并丢弃较小的最后一个批次,你可以在创建 `DataLoader` 时设置 `drop_last=True`:

```scala 3
// 如果数据集大小不能被批次大小整除,则丢弃最后一个不完整的批次
val train_loader_drop_last = DataLoader(dataset=dataset, batch_size=32, shuffle=true, drop_last=true)

println("\n--- DataLoader with drop_last=True ---")
for i <- Range(train_loader_drop_last.size):
    val batch = train_loader_drop_last(i)
    val features, labels = batch
    println(s"Batch  ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")

# 预期输出:只有 3 个大小为 32 的批次。最后的 9 个样本被丢弃。
```

### 打乱数据以获得更好的训练

在训练期间强烈建议设置 `shuffle=True`。它告诉 `DataLoader` 在为每个周期创建批次之前重新打乱数据集的索引。这确保模型每次都以不同的顺序查看数据,减少过度拟合数据呈现顺序的风险并提高模型健壮性。对于验证或测试,通常禁用打乱数据(`shuffle=False`),以确保不同运行之间评估指标的一致性。

### 使用 `num_workers` 并行加载数据

数据加载和预处理(应用变换)有时可能会成为瓶颈,特别是当变换很复杂或数据读取涉及大量 I/O 操作时。CPU 可能会花费大量时间准备下一个批次,而 GPU 则空闲等待数据。

`DataLoader` 允许你通过使用多个工作进程并行加载数据来缓解这个问题。你可以使用 `num_workers` 参数来指定工作进程的数量:

```scala 3
// 使用 4 个工作进程加载数据
// num_workers > 0 启用多进程数据加载
// 一个常见的起始点是 num_workers = 4 * num_gpus,但最优值取决于
// 系统(CPU 核心数、磁盘速度)和批次大小。通常需要通过实验确定。
val fast_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=true, num_workers=4)

// 迭代看起来相同,但数据加载发生在后台进程中
for i <- Range(fast_loader.size):
    val batch = fast_loader(i)
    val features, labels = batch
    println(s"Batch  ${i+1}: Features shape=${features.shape}, Labels shape=${labels.shape}")
```

```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

import static org.bytedeco.pytorch.global.torch.*;

public class FastDataLoaderApp {
    public static void main(String[] args) {
        long numSamples = 105;
        long batchSize = 32;
        long numWorkers = 4; // 对应 num_workers=4

        // 1. 初始化数据集
        DummyDataset dataset = new DummyDataset(numSamples);

        // 2. 配置 DataLoader 选项,设置并行工作线程数
        DataLoaderOptions options = new DataLoaderOptions(batchSize);
        options.workers().put(numWorkers); // 启用多线程加载

        // 3. 实例化 DataLoader (带随机打乱)
        RandomSampler sampler = new RandomSampler(numSamples);
        JavaRandomDataLoader fastLoader = new JavaRandomDataLoader(dataset, sampler, options);
        System.out.println("Dataset size: " + dataset.size());
        System.out.println("DataLoader batch size: " + batchSize);
        System.out.println("Number of workers: " + numWorkers);
        // 4. 迭代数据
        int step = 0;
        ExampleVectorIterator it = fastLoader.begin();

        while (!it.equals(fastLoader.end())) {
            // 数据在后台线程被预取到 ExampleVector 中
            ExampleVector batch = it.access(); // 获取当前批次数据

            // 聚合批次数据
            Tensor features = stackExampleData(batch);
            Tensor labels = stackExampleTarget(batch);

            System.out.printf("Batch %d: Features shape=%s, Labels shape=%s%n",
                    ++step,
                    features.sizes().get(0) + "x" + features.sizes().get(1),
                    labels.sizes().get(0));

            it.increment();
        }
    }

    private static Tensor stackExampleData(ExampleVector batch) {
        TensorVector tensors = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensors.push_back(batch.get(i).data());
        }
        return torch.stack(tensors);
    }

    private static Tensor stackExampleTarget(ExampleVector batch) {
        TensorVector tensors = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tensors.push_back(batch.get(i).target());
        }
        return torch.stack(tensors);
    }
}


```

当 `num_workers > 0` 时,`DataLoader` 会生成指定数量的 Python 进程。每个工作进程独立加载一个批次。这使得后续批次的数据加载和变换可以并行发生,而主进程则对当前批次执行模型训练步骤,通常通过更有效地利用 GPU 来显著提高速度。

请注意,增加 `num_workers` 也会增加 CPU 使用率和内存消耗,因为每个工作进程都会加载数据。将其设置过高有时会导致资源争用和收益递减,甚至减慢速度。它通常是根据你的具体硬件和数据集进行调整的参数。

并行加载 (如果 num_workers > 0)Dataset 对象(YourCustomDataset)DataLoader(批次大小, 打乱数据,工作进程数)封装工作进程 1分配索引工作进程 N分配索引批次(特征, 标签)获取并准备获取并准备训练循环(GPU/CPU)提供请求下一个

> `DataLoader` 的数据加载流程。`DataLoader` 封装了 `Dataset`,并且在 `num_workers > 0` 的情况下,使用工作进程获取并整理样本成批次,这些批次随后被训练循环使用。

### 使用 `pin_memory` 优化 GPU 传输

在 GPU 上训练时,由 `DataLoader` 加载的数据(位于标准 CPU 内存中)需要传输到 GPU 的内存中。这种传输需要时间。你通常可以通过在 `DataLoader` 中设置 `pin_memory=True` 来稍微加快速度。

```scala 3
// 启用固定内存以加快 CPU 到 GPU 的传输
val gpu_optimized_loader = DataLoader(dataset=dataset,
                                  batch_size=32,
                                  shuffle=true,
                                  num_workers=4,
                                  pin_memory=true)

// 在训练循环内部(假设你有 GPU)
for i <- Range(gpu_optimized_loader.size):
    val batch = gpu_optimized_loader(i)
    val features, labels = batch
    features = features.to('cuda') // 传输变得更快
    labels = labels.to('cuda')
    // ...其余训练步骤...
```

```java
package featurestore;

package featurestore;

import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

/**
 * 核心逻辑:GPU优化的DataLoader + CUDA张量迁移
 * 完全对齐Python代码的pin_memory/num_workers/shuffle配置 + GPU传输逻辑
 */
public class GPUOptimizedDataLoaderExample {
    public static void main(String[] args) {
        // 1. 初始化虚拟数据集(复用你之前的DummyDataset)
        DummyDataset dataset = new DummyDataset(105);

        // 2. 配置GPU优化的DataLoader参数(对应Python的gpu_optimized_loader)
        DataLoaderOptions loaderOptions = new DataLoaderOptions();
        loaderOptions.batch_size().put(32);          // batch_size=32
        loaderOptions.enforce_ordering().put(false); // shuffle=true(打乱数据)
        loaderOptions.workers().put(4);          // num_workers=4(多线程加载)
//        loaderOptions.pin_memory().put(true);        // pin_memory=true(固定内存加速CPU→GPU)
        torch.pinned_memory_or_default(new BoolOptional(true))
        BoolOptional updatedDefault = torch.pinned_memory_or_default();
        // 随机采样器(配合shuffle=true)
        RandomSampler sampler = new RandomSampler(dataset.size().get());

        // 3. 实例化GPU优化的DataLoader(对应gpu_optimized_loader)
        JavaRandomDataLoader gpuOptimizedLoader = new JavaRandomDataLoader(dataset, sampler, loaderOptions);

        // 4. 判断GPU是否可用(对应Python的"假设你有GPU")
        boolean hasCUDA = torch.cuda_is_available();
        Device cudaDevice = hasCUDA ? new Device(torch.DeviceType.CUDA) : null;
        System.out.println("GPU可用状态: " + hasCUDA + " | 固定内存已启用: " + tensorOptions.has_pin_memory().get());

        // 5. 训练循环:迭代DataLoader并迁移张量到CUDA(对应Python的for循环)
        var batchIter = gpuOptimizedLoader.begin();
        var iterEnd = gpuOptimizedLoader.end();
        int batchIdx = 0;

        // 对应Python的 for i <- Range(gpu_optimized_loader.size)
        while (!batchIter.equals(iterEnd)) {
            batchIdx++;
            // 对应Python的 val batch = gpu_optimized_loader(i)
            var batch = batchIter.access();

            // 对应Python的 val features, labels = batch
            Tensor features = batch.data();
            Tensor labels = batch.target();

            // 对应Python的 features.to('cuda') / labels.to('cuda')
            if (hasCUDA) {
                features = features.to(cudaDevice, torch.ScalarType.Float); // 张量迁移到GPU,传输更快
                labels = labels.to(cudaDevice, torch.ScalarType.Long);
                System.out.printf("批次 %d: 特征/标签已迁移到CUDA | 特征形状: %s%n",
                        batchIdx, tensorShape(features.sizes().vec().get()));
            }

            // ------------------- 其余训练步骤(与你之前一致) -------------------
            /*
            model.train();
            optimizer.zero_grad();
            Tensor outputs = model.forward(IValue.from(features)).toTensor();
            Tensor loss = criterion.forward(outputs, labels).toTensor();
            loss.backward(torch.tensor(new float[]{1.0f}, new long[]{}));
            optimizer.step();
            */

            // 释放资源
            features.close();
            labels.close();
            batch.close();
            batchIter.increment();
        }

        // 释放DataLoader相关资源
        gpuOptimizedLoader.close();
        sampler.close();
        dataset.features.close();
        dataset.labels.close();
    }

    // 辅助方法:打印张量形状
    private static String tensorShape(SizeTPointer sizes) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < sizes.limit(); i++) {
            sb.append(sizes.get(i)).append(i < sizes.limit()-1 ? "x" : "");
        }
        return sb.toString();
    }

    // 复用你之前定义的DummyDataset(继承JavaDataset)
    static class DummyDataset extends JavaDataset {
        private final long numSamples;
        private final Tensor features;
        private final Tensor labels;

        public DummyDataset(long numSamples) {
            super();
            this.numSamples = numSamples;
            this.features = torch.randn(new long[]{numSamples, 10},
                    new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float)));
            this.labels = torch.randint(0, 2, new long[]{numSamples},
                    new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long)));
        }

        @Override
        public SizeTOptional size() {
            return new SizeTOptional(new SizeTPointer(numSamples));
        }

        @Override
        public Example get(long index) {
            if (index < 0 || index >= numSamples) throw new IndexOutOfBoundsException();
            Tensor feature = features.index_select(0, torch.tensor(new LongPointer(index)));
            Tensor label = labels.index_select(0, torch.tensor(new LongPointer(index)));
            return new Example(feature, label);
        }
    }
}


```

设置 `pin_memory=True` 指示 `DataLoader` 在 CPU 端将张量分配到“固定”(页面锁定)内存中。从固定 CPU 内存到 GPU 内存的传输通常比从标准可分页 CPU 内存的传输快。这在与 `num_workers > 0` 一起使用时最有效。请注意,使用固定内存会消耗更多的 CPU 内存。

总之,`DataLoader` 是 PyTorch 中一个核心的工具,它简化并优化了向模型提供数据的过程。通过处理批次化、打乱数据和并行加载,它使你能够专注于模型架构和训练逻辑,同时确保你的数据处理流程高效且可扩展。



# 自定义DataLoader用法



“尽管默认的`DataLoader`提供了方便的批处理和随机排列功能,但许多应用需要更精细地控制数据如何抽样和整理成批次。PyTorch通过自定义采样器和`collate`函数提供灵活性,让你能根据具体需求调整数据加载过程,例如处理不平衡数据集或使用可变大小的输入。”

### 使用采样器控制样本选择

`DataLoader`使用`sampler`对象来决定从`Dataset`中抽取索引的顺序。默认情况下,如果`shuffle=True`,它使用`RandomSampler`;如果`shuffle=False`,则使用`SequentialSampler`。但是,你可以通过`sampler`参数显式传入自己的采样器实例(请注意:如果你提供了`sampler`,则必须将`shuffle`设为`False`,因为随机排列是由采样器本身定义的)。

PyTorch在`torch.utils.data`中提供了几种内置采样器:

- `SequentialSampler`:按顺序采样元素,总是以相同的顺序。
- `RandomSampler`:随机采样元素。如果`replacement=True`,则进行有放回采样。
- `SubsetRandomSampler`:从给定索引列表中随机采样元素。它适用于创建验证集划分,而无需修改原始数据集。
- `WeightedRandomSampler`:根据给定概率(权重)从`[0,..,len(weights)-1]`中采样元素。这对于处理不平衡数据集特别有用,例如你想对少数类进行过采样或对多数类进行欠采样。

**示例:对不平衡数据使用`WeightedRandomSampler`**

设想一个分类数据集,其中类别“0”有900个样本,类别“1”有100个样本。简单的随机采样会导致批次严重偏向类别“0”。我们可以使用`WeightedRandomSampler`来提高类别“1”样本被选中的概率。

```scala 3
import torch
import torch.utils.data.{Dataset, DataLoader, WeightedRandomSampler}

// 假设“dataset”是你的torch.utils.data.Dataset实例
// 假设“targets”是一个列表或张量,包含每个样本的类别标签
// e.g., targets = [0, 0, 1, 0, ..., 1, 0]

// 为每个样本计算权重
val class_counts = torch.bincount(torch.tensor(targets)) // 各类别的计数:例如 [900, 100]
val num_samples = targets.size // 总样本数:1000

// 每个样本的权重是 1 / (其所属类别的样本数)
val sample_weights = torch.tensor([1.0 / class_counts(t) for t in targets])

// 创建采样器
val sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

// 使用自定义采样器创建DataLoader
// 注意:使用采样器时,shuffle必须为False
val dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

// 现在,从这个dataloader中抽取的批次将随着时间推移,在类别表示上更加平衡。
// for batch_features, batch_labels in dataloader:
//     // 训练步骤...
//     pass
```

你也可以创建完全自定义的采样策略,通过继承`torch.utils.data.Sampler`并实现`__iter__`和`__len__`方法。

### 使用`collate_fn`自定义批次创建

一旦采样器为一个批次提供了索引列表,`DataLoader`会使用`dataset[index]`从`Dataset`中获取对应的样本。然后,它需要将这些单独的样本组装成一个批次。这个组装过程由`collate_fn`参数处理。

默认的`collate_fn`在许多标准情况下都能很好地工作。它会尝试:

- 将NumPy数组和Python数字转换为PyTorch张量。
- 保留数据结构(例如,如果你的`Dataset.__getitem__`返回一个字典,则整理后的批次将是一个字典,其中每个值是对应项目的批次)。
- 沿新维度(批次维度)堆叠张量。

但是,如果你的样本具有不同的大小(例如,不同长度的序列)或包含它不知道如何堆叠的数据类型,默认的`collate_fn`可能会失败或产生不理想的结果。

在这种情况下,你可以为`DataLoader`的`collate_fn`参数提供一个自定义函数。这个函数接收一个样本列表(其中每个样本是`Dataset.__getitem__`的输出),并负责以所需格式返回整理后的批次。

**示例:填充可变长度序列**

一个常见情况是涉及长度不同的序列(例如NLP中的句子)。默认的`collate`函数不能直接将它们堆叠成一个张量。自定义的`collate_fn`可以将每个批次中的序列填充到该批次中的最大长度。

```scala 3
import torch
import torch.utils.data.{Dataset, DataLoader}
import torch.nn.utils.rnn.{pad_sequence}

// 返回可变长度张量的示例Dataset
class VariableSequenceDataset extends Dataset[(torch.Tensor, Int)]:
    def __init__(self, data):
        // data是一个张量列表,例如 [torch.randn(5), torch.randn(8), ...]
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        // 为简单起见,假设每个项目也有一个标签(例如其长度)
        val sequence = self.data(idx)
        val label = len(sequence)
        return sequence, label

// 自定义collate函数
def pad_collate(batch):
    // batch是一个元组列表:[(序列1, 标签1), (序列2, 标签2), ...]
    // 按序列长度对批次元素进行排序(可选,但通常为了RNN效率而进行)
    // batch.sort(key=lambda x: len(x[0]), reverse=True) // 对于填充不是严格必需的

    // 分离序列和标签
    val sequences = batch.map(_._1)
    val labels = batch.map(_._2)

    // 将序列填充到批次中最长序列的长度
    // `batch_first=True` 使输出形状变为 (batch_size, 最大序列长度, 特征)
    val padded_sequences = pad_sequence(sequences, batch_first=true, padding_value=0.0)

    // 堆叠标签(假设它们是简单的标量)
    val labels = torch.tensor(labels)

    return padded_sequences, labels

// 创建数据集和dataloader
val sequences = [torch.randn(torch.randint(5, 15, (1,)).item()) for _ in range(100)]
val dataset = VariableSequenceDataset(sequences)

// 创建dataloader
val dataloader = DataLoader(dataset, batch_size=4, collate_fn=pad_collate)

// 遍历dataloader
// for padded_batch, label_batch in dataloader:
//     // padded_batch 形状:如果序列是一维的,则为 (4, 该批次中的最大长度, 1)
//     // label_batch 形状:(4,)
//     // 模型处理...
//     pass
```

```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

public class VariableSequenceDataset extends JavaDataset {
    private final TensorVector data;

    public VariableSequenceDataset(TensorVector data) {
        this.data = data;
    }

    @Override
    public Example get(long index) {
        Tensor sequence = data.get(index);
        // 将长度作为标签 (Scalar Tensor)
        Tensor label = torch.tensor(sequence.size(0));
        return new Example(sequence, label);
    }

    @Override
    public SizeTOptional size() {
        return new SizeTOptional(data.size());
    }
}


```

```java
package torch;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.Random;

import static org.bytedeco.pytorch.global.torch.*;

public class VariableSequenceApp {
    public static void main(String[] args) {
        int totalSamples = 100;
        int batchSize = 4;
        Random random = new Random();

        // 1. 生成模拟数据:不同长度的一维张量列表
        TensorVector sequences = new TensorVector();
        for (int i = 0; i < totalSamples; i++) {
            long length = random.nextInt(10) + 5; // 长度在 5 到 15 之间
            sequences.push_back(torch.randn(new long[]{length}));
        }

        // 2. 初始化数据集和加载器
        VariableSequenceDataset dataset = new VariableSequenceDataset(sequences);
        DataLoaderOptions options = new DataLoaderOptions(batchSize);
        JavaRandomDataLoader dataLoader = new JavaRandomDataLoader(dataset, new RandomSampler(totalSamples), options);

        // 3. 迭代并执行手动 Collate (Padding)
        ExampleVectorIterator it = dataLoader.begin();
        int step = 0;

        while (!it.equals(dataLoader.end())) {
            ExampleVector batch = it.access(); // 获取当前批次数据

            // 分离序列和标签
            TensorVector seqVec = new TensorVector();
            TensorVector labelVec = new TensorVector();
            for (long i = 0; i < batch.size(); i++) {
                seqVec.push_back(batch.get(i).data());
                labelVec.push_back(batch.get(i).target());
            }

            // 执行填充逻辑:对应 Python 的 pad_sequence
            // batch_first=true, padding_value=0.0
            Tensor paddedSequences = torch.pad_sequence(seqVec, true, 0.0, new BytePointer ("right"));
//            Tensor paddedSequences = torch.pad_sequence(seqVec);
            Tensor labels = torch.stack(labelVec);

            System.out.printf("Batch %d: Padded Shape=%s, Labels Shape=%s%n",
                    ++step,
                    paddedSequences.sizes().get(0) + "x" + paddedSequences.sizes().get(1),
                    labels.sizes().get(0));

            it.increment();
        }
    }
}


```

这个自定义`collate_fn`使用`torch.nn.utils.rnn.pad_sequence`来处理填充,确保批次中的所有序列长度相同,使它们适合RNN等模型处理。

### 其他DataLoader自定义参数

除了`sampler`和`collate_fn`,其他参数也提供性能和行为调整:

- `num_workers` (整数,可选):指定用于数据加载的子进程数量。将其设置为正整数可启用多进程数据加载,这可以显著加快数据获取速度,尤其当数据加载涉及磁盘I/O或CPU上的复杂预处理时。一个常见的起始设置是将其设为可用CPU核心的数量。默认值:`0`(数据加载在主进程中进行)。
- `pin_memory` (布尔值,可选):如果为`True`,`DataLoader`在返回张量之前会将其复制到CUDA固定内存中。固定内存可以加快从CPU到GPU的数据传输。这仅在你使用GPU进行训练时才有效。默认值:`False`。
- `drop_last` (布尔值,可选):如果为`True`,当数据集大小不能被批次大小整除时,将丢弃最后一个不完整的批次。如果为`False`(默认值),则最后一个批次可能小于`batch_size`。

通过理解和使用采样器、自定义`collate`函数以及其他`DataLoader`参数,你可以对数据管道获得精确的控制,从而能高效处理各种数据类型和结构,解决数据集不平衡问题,并优化数据加载性能以加快模型训练。



# 动手实践:构建数据处理流程



构建一个完整的数据处理流程涉及从原始数据(为简化起见,将合成数据)开始,最终得到准备好输入模型的数据批次。这个过程需要创建一个自定义 `Dataset`,定义数据转换,并将所有内容封装在一个 `DataLoader` 中。

### 构建合成数据集

假设我们有一个数据集,包含特征向量和对应的二元分类标签(01)。在本次练习中,我们将使用 PyTorch 张量直接生成这些数据。这避免了文件输入/输出的复杂性,让我们能够专注于数据处理机制。

```scala 3
import torch
import torch.utils.data as data
from torchvision import transforms

// 生成合成数据
val num_samples = 100
val num_features = 10

// 创建随机特征向量(例如,传感器读数)
val features = torch.randn(num_samples, num_features)

// 创建随机二元标签(0 或 1)
val labels = torch.randint(0, 2, (num_samples,))

// 打印数据形状和前5个样本
println(s"Shape of features: ${features.shape}") // 输出: torch.Size([100, 10])
println(s"Shape of labels: ${labels.shape}")   // 输出: torch.Size([100])
println(s"First 5 features:\n${features[:5]}")
println(s"First 5 labels:\n${labels[:5]}")
```

这使我们得到了两个张量:`features` 包含 100 个样本,每个样本有 10 个特征;`labels` 包含对应的 100 个标签。

### 创建自定义 `Dataset`

现在,我们需要使用 PyTorch 的 `Dataset` 类来组织这些数据。我们将创建一个自定义类,它继承自 `torch.utils.data.Dataset` 并实现两个重要方法:

1. `__len__(self)`: 返回数据集中样本的总数。
2. `__getitem__(self, idx)`: 返回指定索引 `idx` 处的样本(特征和标签)。

我们还将添加一个 `__init__` 方法来存储数据并可选地接受转换。

```scala 3
class SyntheticDataset extends data.Dataset:
    """一个用于合成特征和标签的自定义数据集。"""

    def __init__(features, labels, transform=None):
        """
        参数:
            features (Tensor): 包含特征数据的张量。
            labels (Tensor): 包含标签的张量。
            transform (callable, optional): 可选的样本转换。
        """
        // 基本检查,确保特征和标签具有相同的样本数量
        assert features.shape[0] == labels.shape[0], \
            "特征和标签的样本数量必须一致"

        // 存储特征和标签
        val features = features
        val labels = labels
        // 存储转换
        val transform = transform

    def __len__(self):
        """返回样本总数。"""
        return features.shape[0]

    def __getitem__(idx):
        """
        根据给定索引获取特征向量和标签。

        参数:
            idx (int): 要获取的样本索引。

        返回:
            tuple: (特征, 标签),其中 feature 是特征向量,label 是对应的标签。
        """
        // 获取原始特征和标签
        val feature_sample = features(idx)
        val label_sample = labels(idx)

        // 创建一个样本字典(或元组)
        val sample = Map('feature -> feature_sample, 'label -> label_sample)

        // 如果存在转换,则应用转换
        if transform:
            sample = transform(sample)

        // 返回可能已转换的样本
        // 常见做法是分别返回特征和标签
        return sample('feature), sample('label)

// 暂时实例化不带转换的数据集
val raw_dataset = SyntheticDataset(features, labels)

// 测试获取一个样本
val sample_idx = 0
val feature_sample, label_sample = raw_dataset(sample_idx)
println(f"\nSample {sample_idx} - Feature: {feature_sample}")
println(f"Sample {sample_idx} - Label: {label_sample}")
println(f"Dataset length: {raw_dataset.size()}") // 输出: 100
```

```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;

public class SyntheticDataset extends JavaDataset {
    private final Tensor features;
    private final Tensor labels;

    public SyntheticDataset(Tensor features, Tensor labels) {
        // 基本检查:确保特征和标签的样本数量一致
        if (features.size(0) != labels.size(0)) {
            throw new IllegalArgumentException("特征和标签的样本数量必须一致");
        }
        this.features = features;
        this.labels = labels;
    }

    @Override
    public Example get(long index) {
        // 获取指定索引的特征张量和标签张量
        // select(dim, index) 用于获取特定维度的切片
        Tensor featureSample = features.select(0, index);
        Tensor labelSample = labels.select(0, index);

        // 返回 Example 对象,包含特征数据和目标标签
        return new Example(featureSample, labelSample);
    }

    @Override
    public SizeTOptional size() {
        // 返回数据集的总样本数
        return new SizeTOptional(features.size(0));
    }
}


```

```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import static org.bytedeco.pytorch.global.torch.*;

public class SyntheticApp {
    public static void main(String[] args) {
        long numSamples = 100;
        long numFeatures = 10;

        // 1. 创建随机特征向量 (100, 10)
        Tensor features = torch.randn(new long[]{numSamples, numFeatures});

        // 2. 创建随机二元标签 (100,),范围 [0, 2)
        Tensor labels = torch.randint(0, 2, new long[]{numSamples});

        System.out.println("Shape of features: " + features.sizes().get(0) + "x" + features.sizes().get(1));
        System.out.println("Shape of labels: " + labels.sizes().get(0));

        // 3. 实例化自定义数据集
        SyntheticDataset rawDataset = new SyntheticDataset(features, labels);

        // 4. 测试获取一个样本
        long sampleIdx = 0;
        Example sample = rawDataset.get(sampleIdx);

        Tensor featureSample = sample.data();
        Tensor labelSample = sample.target();

        System.out.println("\nSample " + sampleIdx + " - Feature: " + featureSample);
        System.out.println("Sample " + sampleIdx + " - Label: " + labelSample);
        System.out.println("Dataset length: " + rawDataset.size().get());
    }
}


```
此时,`raw_dataset` 包含我们的数据并知道如何提供单个样本。

### 定义数据转换

通常,原始数据不适合直接输入神经网络。我们可能需要规范化特征、转换数据类型或应用数据增强(特别是对于图像)。`torchvision.transforms` 提供了方便的工具。即使我们的数据不是图像,我们也可以定义自定义转换或使用对张量进行操作的现有转换。

让我们定义一个简单转换流程:

1. 将特征张量转换为 `torch.float32`(模型输入的良好做法)。
2. 将标签张量转换为 `torch.long`(损失函数如 `CrossEntropyLoss` 常需要的)。
3. 对特征应用规范化(减去均值,除以标准差)。我们将在此示例中从合成数据集中计算这些统计量。

由于 `torchvision.transforms` 主要为图像(PIL 图像或张量)设计,将其直接应用于像我们 `sample` 这样的字典需要一些封装。我们将为此创建自定义可调用类或 lambda 函数。

```scala 3
// 计算用于规范化的均值和标准差(跨数据集)
val feature_mean = features.mean(dim=0)
val feature_std = features.std(dim=0)

// 如果任何特征的标准差为零,则避免除以零
feature_std(feature_std == 0) = 1.0

// 为字典样本格式定义自定义转换类/函数
class ToTensorAndType:
    """将特征转换为 FloatTensor,将标签转换为 LongTensor。"""
    def __call__(self, sample):
        feature, label = sample['feature'], sample['label']
        return {'feature': feature.float(), 'label': label.long()}

class NormalizeFeatures:
    """规范化特征张量。"""
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        feature, label = sample['feature'], sample['label']
        # 应用规范化: (张量 - 均值) / 标准差
        normalized_feature = (feature - self.mean) / self.std
        return {'feature': normalized_feature, 'label': label}

// 组合转换
val data_transforms = transforms.Compose(
    ToTensorAndType(),
    NormalizeFeatures(mean=feature_mean, std=feature_std)
)

// 实例化带转换的数据集
val transformed_dataset = SyntheticDataset(features, labels, transform=data_transforms)

// 测试获取一个转换后的样本
val sample_idx = 0
val transformed_feature, transformed_label = transformed_dataset(sample_idx)

println(f"\n--- 转换后的样本 {sample_idx} ---")
println(f"原始特征:\n{features(sample_idx)}")
println(f"转换后特征:\n{transformed_feature}")
println(f"原始标签: {labels(sample_idx)} (dtype={labels.dtype})")
println(f"转换后标签: {transformed_label} (dtype={transformed_label.dtype})")

// 验证规范化(对于第一个样本的特征,均值应接近 0,标准差应接近 1)
println(f"转换后特征均值: {transformed_feature.mean():.4f}") // 应用数据集范围的规范化后应接近 0
```


```java
package torch;

import org.bytedeco.pytorch.*;
import java.util.function.Function;

public class FeatureLabelDataset extends JavaDataset {
    private final Tensor features;
    private final Tensor labels;
    private final Function<Example, Example> transform;

    public FeatureLabelDataset(Tensor features, Tensor labels, Function<Example, Example> transform) {
        if (features.size(0) != labels.size(0)) {
            throw new IllegalArgumentException("特征和标签的样本数量必须一致");
        }
        this.features = features;
        this.labels = labels;
        this.transform = transform;
    }

    @Override
    public Example get(long index) {
        // 使用 select(0, index) 获取第 index 个样本
        Tensor featureSample = features.select(0, index);
        Tensor labelSample = labels.select(0, index);
        Example example = new Example(featureSample, labelSample);

        return (transform != null) ? transform.apply(example) : example;
    }

    @Override
    public SizeTOptional size() {
        return new SizeTOptional(features.size(0));
    }
}


```


```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.function.Function;
import static org.bytedeco.pytorch.global.torch.*;

public class FeatureLabelApp {
    public static void main(String[] args) {
        long numSamples = 100;
        long numFeatures = 10;

        // 1. 生成原始数据
        Tensor features = torch.randn(new long[]{numSamples, numFeatures});
        Tensor labels = torch.randint(0, 2, new long[]{numSamples});

        // 2. 计算均值和标准差用于规范化
        Tensor featureMean = features.mean(new long[]{0}, false, new ScalarTypeOptional(ScalarType.Float));
        Tensor featureStd = features.std(new long[]{0}, false);

        // 处理标准差为 0 的情况
        featureStd.masked_fill_(featureStd.eq(new Scalar(0.0)), new Scalar(1.0));

        // 3. 定义转换流水线 (Transforms Pipeline)
        // 转换类型为 Float 和 Long
        Function<Example, Example> toType = e ->
                new Example(e.data().to(torch.kFloat()), e.target().to(torch.kInt()));

        // 规范化特征: (x - mean) / std
        Function<Example, Example> normalize = e -> {
            Tensor normData = e.data().sub(featureMean).div(featureStd);
            return new Example(normData, e.target());
        };

        // 组合转换
        Function<Example, Example> pipeline = toType.andThen(normalize);

        // 4. 实例化重命名后的数据集
        FeatureLabelDataset dataset = new FeatureLabelDataset(features, labels, pipeline);

        // 5. 验证转换结果
        long sampleIdx = 0;
        Example result = dataset.get(sampleIdx);

        System.out.println("--- 转换后的样本信息 ---");
        System.out.println("特征形状: " + result.data().sizes().get(0));
        System.out.println("特征数据类型: " + result.data().scalar_type());
        System.out.println("标签数据类型: " + result.target().scalar_type());
        System.out.printf("转换后特征均值: %.4f%n", result.data().mean().item_float());
    }
}



```

注意到由于规范化,特征值已发生改变,并且特征和标签的数据类型现在分别是 `torch.float32` 和 `torch.int64` (LongTensor)。

### 使用 `DataLoader`

最后一步是使用 `DataLoader`。它接收我们的 `Dataset` 实例,并处理批处理、数据混洗以及可能的并行数据加载。

```scala 3
// 创建 DataLoader
val batch_size = 16 // 以 16 个样本的批次处理数据
val shuffle_data = true // 在每个 epoch 开始时打乱数据
val num_workers = 0 // 用于数据加载的子进程数量。0 表示数据加载在主进程中进行。

// 在非 Windows 平台上,通常可以将 num_workers 设置为 > 0 进行并行加载
// import os
// if os.name != 'nt': // 检查是否非 Windows
//     num_workers = 2

// 创建 DataLoader
val data_loader = data.DataLoader(
    transformed_dataset,
    batch_size=batch_size,
    shuffle=shuffle_data,
    num_workers=num_workers
)
// 迭代 DataLoader 以获取批次数据
println(f"\n--- 迭代 DataLoader (batch_size={batch_size}) ---")

// 获取一个批次
val feature_batch, label_batch = next(iter(data_loader))

// 获取一个批次
val feature_batch, label_batch = next(iter(data_loader))

// 打印批次信息
println(f"Type of feature_batch: {type(feature_batch)}")
println(f"Shape of feature_batch: {feature_batch.shape}") // 输出: torch.Size([16, 10])
println(f"Shape of label_batch: {label_batch.shape}")   // 输出: torch.Size([16])
println(f"Data type of feature_batch: {feature_batch.dtype}") // 输出: torch.float32
println(f"Data type of label_batch: {label_batch.dtype}")   // 输出: torch.int64

// 你可以像这样遍历所有批次(例如,在一个训练 epoch 中)
// println("\n遍历几个批次:")
// for i, (batch_features, batch_labels) in enumerate(data_loader):
//     if i >= 3: // 显示前 3 个批次
//         break
//     println(f"Batch {i+1}: Features shape={batch_features.shape}, Labels shape={batch_labels.shape}")
//     // 在真实的训练循环中,你会将 batch_features 输入到模型
//     // 并将 batch_labels 作为目标进行损失计算和反向传播
```

```java
package torch;

import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import java.util.function.Function;
import static org.bytedeco.pytorch.global.torch.*;

public class FeatureLabelAppV2 {
    public static void main(String[] args) {
        // --- 前文代码回顾(数据集准备) ---
        long numSamples = 100;
        long numFeatures = 10;
        Tensor features = torch.randn(new long[]{numSamples, numFeatures});
        Tensor labels = torch.randint(0, 2, new long[]{numSamples});
        // 2. 计算均值和标准差用于规范化
        Tensor featureMean = features.mean(new long[]{0}, false, new ScalarTypeOptional(ScalarType.Float));
        Tensor featureStd = features.std(new long[]{0}, false);

        // 处理标准差为 0 的情况
        featureStd.masked_fill_(featureStd.eq(new Scalar(0.0)), new Scalar(1.0));

        Function<Example, Example> pipeline = e -> {
            Tensor castedData = e.data().to(torch.kFloat());
            Tensor castedTarget = e.target().to(torch.kLong());
            Tensor normData = castedData.sub(featureMean).div(featureStd);
            return new Example(normData, castedTarget);
        };

        FeatureLabelDataset dataset = new FeatureLabelDataset(features, labels, pipeline);

        // --- 新增:DataLoader 配置 ---
        long batchSize = 16;
        // 注意:JavaDataset 在底层通常与 SequentialSampler 或 RandomSampler 配合使用
        // 这里通过 options 设置批处理大小
        DataLoaderOptions options = new DataLoaderOptions(batchSize);

        // 检测 OS 设置线程数 (macOS 可设置为 > 0)
        int numWorkers = System.getProperty("os.name").toLowerCase().contains("win") ? 0 : 2;
        options.workers().put(numWorkers);

        // 创建 DataLoader
        // 注意:shuffle=true 通常需要通过传递不同类型的 Sampler 实现,这里演示默认加载
//        DataLoader dataLoader = torch.make_data_loader(dataset, options);
        RandomSampler sampler = new RandomSampler(numSamples);
        JavaRandomDataLoader dataLoader = new JavaRandomDataLoader(dataset, sampler, options);
        System.out.println("\n--- 迭代 DataLoader (batch_size=" + batchSize + ") ---");

        // 获取迭代器
        long batchIdx = 0;
        ExampleVectorIterator it = dataLoader.begin();

        if (!it.equals(dataLoader.end())) {
            // 获取第一个批次 (next 操作)
            ExampleVector batch = it.access();
//            Tensor featureBatch = batch.data();
//            Tensor labelBatch = batch.target();

            Tensor featureBatch = stackData(batch);
            Tensor labelBatch = stackTarget(batch);
            System.out.println("Batch 1 特征形状: " + featureBatch.sizes().get(0) + "x" + featureBatch.sizes().get(1));
            System.out.println("Batch 1 标签形状: " + labelBatch.sizes().get(0));
            System.out.println("特征数据类型: " + featureBatch.scalar_type());
            System.out.println("标签数据类型: " + labelBatch.scalar_type());

            it.increment();
            batchIdx++;
        }

        // 循环迭代剩余批次
        while (!it.equals(dataLoader.end())) {
            ExampleVector batch = it.access();
            Tensor featureBatch = stackData(batch);
            Tensor labelBatch = stackTarget(batch);
            if (batchIdx < 3) {
                System.out.printf("Batch %d: Features shape=%dx%d%n",
                        batchIdx + 1, featureBatch.size(0), featureBatch.size(1));
            }
            it.increment();
            batchIdx++;
        }

        System.out.println("总批次数: " + batchIdx);
    }

    // 将 Vector 中的特征张量堆叠为一个批次张量
    public static Tensor stackData(ExampleVector batch) {
        TensorVector tv = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tv.push_back(batch.get(i).data());
        }
        return torch.stack(tv);
    }

    // 将 Vector 中的标签张量堆叠为一个批次张量
    public static Tensor stackTarget(ExampleVector batch) {
        TensorVector tv = new TensorVector();
        for (long i = 0; i < batch.size(); i++) {
            tv.push_back(batch.get(i).target());
        }
        return torch.stack(tv);
    }
}



```

`DataLoader` 产生批次,其中第一个维度对应于 `batch_size`。我们的特征批次形状为 `[16, 10]`,标签批次形状为 `[16]`。数据类型反映了我们应用的转换。

### 数据处理流程可视化

我们可以可视化刚刚创建的流程:

原始数据(特征, 标签张量)自定义数据集(SyntheticDataset)__init__数据转换(ToTensor, 规范化)__getitem__ 应用DataLoader(批处理, 混洗)输入数据集模型输入(批处理和已处理数据)迭代获取批次

![image-20251014151217356](C:\Users\hai71\AppData\Roaming\Typora\typora-user-images\image-20251014151217356.png)

> 此图表展示了从原始张量到自定义 `Dataset` 的演进过程,在数据获取时应用转换(`__getitem__`),最后使用 `DataLoader` 生成适合模型训练的混洗批次。

你现在已成功使用 PyTorch 的核心数据工具构建了数据处理流程。你创建了一个 `Dataset` 来封装你的数据,应用了必要的 `transforms`,并使用 `DataLoader` 高效地生成批次。这种结构化方法是 PyTorch 项目中处理数据的基础,确保你的模型接收正确格式的数据并促进高效训练。这个流程现在已准备好集成到我们将在下一章构建的训练循环中。