
既然你已经了解如何使用 `torch.nn` 构建模型以及如何用 Autograd 计算梯度,下一步就是有效地为这些模型提供数据。处理大型数据集、进行必要的预处理以及在不耗尽内存的情况下分批加载数据,是深度学习工作中常见的问题。
本章将讲解 PyTorch 管理数据流程的方案:即 `torch.utils.data` 模块。你将学习如何:
- 使用 `Dataset` 类来组织数据。
- 使用预设数据集,例如 `torchvision` 中提供的。
- 使用 `torchvision.transforms` 进行数据转换和增强。
- 使用 `DataLoader` 类高效地分批加载数据、打乱数据,并可能并行加载。
学完本章后,你将能够为你的 PyTorch 项目构建高效的数据流程。
训练深度学习模型需要处理大量数据。虽然使用 `torch.nn` 构建模型和使用 Autograd 计算梯度是基本步骤,但一个实际问题随之而来:如何在训练期间高效地将数据输入这些模型?
如果尝试手动处理数据加载,会遇到以下挑战:
1. **内存限制**: 现代数据集,特别是在计算机视觉或自然语言处理等方面,可能非常庞大,常常超出可用内存(RAM),更不用说 GPU 上的显存(VRAM)。一次性将整个数据集加载到内存中通常是不可行的。想象一下,尝试将整个 ImageNet 数据集(超过 1400 万张图像,数百 GB)直接加载到计算机的 RAM 中——对于大多数系统来说,这根本无法容纳。
2. **I/O 瓶颈**: 从磁盘读取数据的速度比 CPU 或 GPU 上的计算慢几个数量级。如果模型需要数据时你逐个加载数据样本,你速度极快的 GPU 将大部分时间处于空闲状态,等待下一批数据到来。这种顺序磁盘读取会成为一个主要瓶颈,极大地减缓训练过程。
3. **低效的预处理**: 数据很少以神经网络所需的精确格式存在。它通常需要预处理步骤,例如归一化、调整大小、数据类型转换或数据增强(随机修改样本以提高模型泛化能力)。与主要训练过程同步地逐个样本执行这些转换,会增加进一步的延迟。
4. **洗牌需求**: 为确保模型泛化能力并防止与数据顺序相关的偏差,标准做法是在每个训练周期前对数据集进行洗牌。实现高效的洗牌,特别是对于无法完全放入内存的数据集,会增加复杂性。
5. **批处理**: 神经网络通常在数据的小批量上进行训练,而不是单个样本。分批处理数据可以获得更稳定的梯度估计,并更好地使用 GPU 的并行处理能力。手动创建这些批次,确保它们的格式正确,以及处理最后一个可能较小的批次,都需要仔细编写代码。
6. **并行处理**: 为克服 I/O 瓶颈,高效的数据加载管道通常使用多个工作进程并行加载和预处理数据,在 GPU 忙于处理当前批次时准备未来的批次。正确实现这种并行,管理进程,并确保数据完整性是一项复杂的工程任务。
为每个项目从头解决所有这些问题会非常耗时且容易出错。你每次都相当于在重建一个重要的基础设施部分。
简化版朴素加载PyTorch DataLoader 方法大型数据集(磁盘)加载 + 处理样本 (CPU)慢速读取将样本移至 GPUI/O 和 CPU 限制GPU 常空闲在样本上训练 (GPU)等待下一批减缓训练大型数据集(磁盘)DataLoader(并行工作器,批处理, 洗牌)已准备的批次(RAM)预取将批次移至 GPU在批次上训练(GPU 忙碌)请求下一批

> 比较了导致瓶颈的朴素顺序数据加载方法与 PyTorch 数据工具提供的并行批处理方法。
认识到这些常见且重要的挑战,PyTorch 提供了 `torch.utils.data` 模块。这个模块提供专用工具,专门用于构建高效、灵活和并行的数据加载管道。它封装了洗牌、批处理、内存管理和并行加载的复杂性,让你能专注于定义数据集结构和所需的转换。
通过使用 PyTorch 的 `Dataset` 和 `DataLoader` 类(我们将在后续章节中介绍),你将获得:
- **效率**: 优化数据获取和预处理,通常在 CPU 核心间并行执行,确保 GPU 获得充足数据。
- **内存管理**: 通过仅在需要时将必要的批次加载到内存中来处理大型数据集。
- **灵活性**: 轻松集成自定义数据源和复杂的预处理/数据增强步骤。
- **简洁性**: 用于与数据集交互和创建数据迭代器的标准化 API。
这些工具是使用 PyTorch 构建实际深度学习应用的基本组成部分。让我们从 `Dataset` 类开始,了解它们如何工作。
高效加载和处理数据对训练深度学习模型非常重要。PyTorch 提供了一种通过其 `torch.utils.data.Dataset` 抽象类来处理数据集的标准化方式。可以把 `Dataset` 看作一个约定:它定义了访问数据的标准接口,无论数据是存在内存中、磁盘上,还是需要即时生成。
其核心是,`torch.utils.data.Dataset` 是一个表示数据集的抽象类。你在 PyTorch 中创建的任何自定义数据集都应该继承自这个类。为什么要使用这种结构?它确保了不同的数据集,无论是内置的还是自定义的,都能向其他 PyTorch 组件提供一致的 API,最值得一提的是 `DataLoader`,我们稍后会讲到。这种标准化简化了在相同训练代码中替换数据集或使用不同数据源的过程。
要创建自己的自定义数据集,你需要继承 `torch.utils.data.Dataset` 并重写两个必要的方法:
1. `__len__(self)`: 这个方法应该返回数据集中样本的总数。`DataLoader` 使用它来确定数据集的大小。
2. `__getitem__(self, idx)`: 这个方法负责根据给定索引 `idx` 从数据集中加载并返回一个样本。这是实际数据加载逻辑所在的地方(例如,读取图像文件、从 CSV 获取一行数据、访问列表中的元素)。`DataLoader` 会重复调用此方法来构建批次。
让我们用一个简单的例子来说明这一点。假设你的特征和对应的标签存储在 Python 列表或 NumPy 数组中。
```scala 3
import torch.*
import torch.utils.data.Dataset
import numpy as np
class SimpleCustomDataset extends Dataset:
"""一个带有特征和标签的简单数据集示例。"""
def __init__(features, labels):
"""
参数:
features (列表或 np.array): 特征的列表或数组。
labels (列表或 np.array): 标签的列表或数组。
"""
assert len(features) == len(labels), "特征和标签的长度必须相同。"
self.features = features
self.labels = labels
def __len__(self):
"""返回样本总数。"""
return len(self.features)
def __getitem__(self, idx):
"""
生成一个数据样本。
参数:
idx (int): 元素的索引。
返回:
tuple: 给定索引对应的 (特征, 标签)。
"""
//获取给定索引的特征和标签
feature = self.features[idx]
label = self.labels[idx]
//通常,你会在Dataloader中将数据转换为 PyTorch 张量
// 我们假设特征/标签可能还不是张量
sample = (torch.tensor(feature, dtype=torch.float32),
torch.tensor(label, dtype=torch.long)) // 假设是分类标签
return sample
// --- 示例用法 ---
// 样本数据(请替换为你的实际数据)
val num_samples = 100
val num_features = 10
val features_data = np.random.randn(num_samples, num_features)
val labels_data = np.random.randint(0, 5, size=num_samples) // 示例:5 个类别
// 创建自定义数据集实例
val my_dataset = SimpleCustomDataset(features_data, labels_data)
// 访问数据集属性和元素
println(s"数据集大小: ${len(my_dataset)}")
// 获取第一个样本
val first_sample = my_dataset[0]
val feature_sample = first_sample._1
val label_sample = first_sample._2
println(s"\n第一个样本特征:\n$feature_sample")
println(s"第一个样本形状: ${feature_sample.shape}")
println(s"第一个样本标签: $label_sample")
// 获取第十个样本
val tenth_sample = my_dataset[9]
val tenth_feature_sample = tenth_sample._1
val tenth_label_sample = tenth_sample._2
println(s"\n第十个样本特征:\n$tenth_feature_sample")
println(s"第十个样本形状: ${tenth_feature_sample.shape}")
println(s"第十个样本标签: $tenth_label_sample")
```
```java
/**
* 自定义数据集类,继承JavaDataset(对应Python的torch.utils.data.Dataset)
* 实现特征和标签的存储、长度获取、按索引取样本逻辑
*/
public class SimpleCustomDataset extends JavaDataset {
// 存储特征和标签的原始数据(浮点型特征,长整型标签)
private float[][] features
private long[] labels
private long numSamples
private long numFeatures
/**
* 构造方法(对应Python的__init__)
* @param features 特征数组 [样本数, 特征数]
* @param labels 标签数组 [样本数]
*/
public SimpleCustomDataset(float[][] features, long[] labels) {
super()
// 基本检查:特征和标签长度必须相同
if (features.length != labels.length) {
throw new IllegalArgumentException("特征和标签的长度必须相同。")
}
this.features = features
this.labels = labels
this.numSamples = features.length
this.numFeatures = features.length > 0 ? features[0].length : 0
}
/**
* 返回样本总数(对应Python的__len__)
*/
@Override
public SizeTOptional size() {
return new SizeTOptional(new SizeTPointer(numSamples))
}
/**
* 按索引获取样本(对应Python的__getitem__)
* @param index 样本索引(size_t类型)
* @return Example 包含特征张量和标签张量的样本
*/
@Override
public Example get(long index) {
// 边界检查:索引不能超出样本范围
if (index < 0 || index >= numSamples) {
throw new IndexOutOfBoundsException("索引超出数据集范围: " + index)
}
// 1. 获取指定索引的特征和标签
float[] featureData = features[(int)index]
long labelData = labels[(int)index]
// 2. 将特征转换为PyTorch Float32张量(对应torch.float32)
Tensor featureTensor = torch.from_blob(
new FloatPointer(featureData), // 特征数据指针
new Long[](numFeatures), // 特征形状 [numFeatures]
new TensorOptions().dtype(new ScalarTypeOptional(ScalarType.Float)) // 数据类型
)
// 3. 将标签转换为PyTorch Long张量(对应torch.long)
Tensor labelTensor = torch.tensor(
new LongPointer(labelData), // 标签数据指针
new TensorOptions().dtype(new ScalarTypeOptional(ScalarType.Long)) // 数据类型
)
// 4. 封装为Example返回(对应Python的(特征, 标签)元组)
return new Example(featureTensor, labelTensor)
}
// ========== 示例用法 ==========
public static void main(String[] args) {
// 1. 生成示例数据(对应原代码的np.random)
int numSamples = 100
int numFeatures = 10
Random random = new Random()
// 生成特征数据:numSamples x numFeatures 的随机浮点数
float[][] featuresData = new float[numSamples][numFeatures]
for (int i = 0
for (int j = 0
// 模拟np.random.randn(正态分布随机数)
featuresData[i][j] = (float) random.nextGaussian()
}
}
// 生成标签数据:0-4的随机整数(5个类别)
long[] labelsData = new long[numSamples]
for (int i = 0
labelsData[i] = random.nextInt(5)
}
// 2. 创建自定义数据集实例
SimpleCustomDataset myDataset = new SimpleCustomDataset(featuresData, labelsData)
// 3. 打印数据集大小
System.out.println("数据集大小: " + myDataset.size().get())
// 4. 获取第一个样本
Example firstSample = myDataset.get(0)
Tensor firstFeature = firstSample.data()
Tensor firstLabel = firstSample.target()
System.out.println("\n第一个样本特征:\n" + tensorToFloatString(firstFeature))
System.out.println("第一个样本形状: " + tensorShapeToString(firstFeature.sizes()))
System.out.println("第一个样本标签: " + firstLabel.item().toLong())
// 5. 获取第十个样本(索引9)
Example tenthSample = myDataset.get(9)
Tensor tenthFeature = tenthSample.data()
Tensor tenthLabel = tenthSample.target()
System.out.println("\n第十个样本特征:\n" + tensorToFloatString(tenthFeature))
System.out.println("第十个样本形状: " + tensorShapeToString(tenthFeature.sizes().vec().get()))
System.out.println("第十个样本标签: " + tenthLabel.item().toLong())
// 6. 释放张量资源
firstFeature.close()
firstLabel.close()
tenthFeature.close()
tenthLabel.close()
}
/**
* 辅助方法:将Float32张量转换为易读的字符串
*/
private static String tensorToFloatString(Tensor tensor) {
float[] data = new float[(int) tensor.numel()]
tensor.data().get(torch.tensor(data))
StringBuilder sb = new StringBuilder("[")
for (int i = 0
sb.append(String.format("%.4f", data[i]))
if (i < data.length - 1) {
sb.append(", ")
}
}
sb.append("]")
return sb.toString()
}
/**
* 辅助方法:将张量形状(SizeTPointer)转换为字符串
*/
private static String tensorShapeToString(SizeTPointer sizes) {
StringBuilder sb = new StringBuilder("[")
for (int i = 0
sb.append(sizes.get(i))
if (i < sizes.limit() - 1) {
sb.append(", ")
}
}
sb.append("]")
return sb.toString()
}
}
```
```java
import org.bytedeco.pytorch.*
import org.bytedeco.javacpp.*
import org.bytedeco.javacpp.annotation.*
import static org.bytedeco.pytorch.global.torch.*
/**
* 实现 SimpleCustomDataset
* 继承 ChunkDataset 以获得与 LibTorch DataLoader 兼容的能力
*/
public class SimpleCustomDataset extends ChunkDataset {
static { Loader.load()
private float[][] features
private int[] labels
private long totalSamples
/**
* 指针构造函数(JavaCPP 内部要求)
*/
public SimpleCustomDataset(Pointer p) {
super(p)
}
/**
* Java 层主构造函数
*/
public SimpleCustomDataset(float[][] features, int[] labels) {
// 调用父类 Pointer 构造,这里暂时传 null,因为我们要重写逻辑
super((Pointer)null)
if (features.length != labels.length) {
throw new IllegalArgumentException("特征和标签长度不匹配")
}
this.features = features
this.labels = labels
this.totalSamples = features.length
}
/**
* 核心方法:重写 get_batch
* 对应 Python 的 __getitem__,但在 C++ 体系中是以 Batch 为单位请求的
*/
@Override
public @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size) {
// 如果没有数据了,返回空的 Optional (结束信号)
if (totalSamples == 0) return new ExampleVectorOptional()
// 创建一个 ExampleVector (std::vector<torch::data::Example<>>)
ExampleVector examples = new ExampleVector()
// 这里的逻辑通常需要一个计数器(stateful)
// 简单起见,这里演示单次请求返回 batch_size 个样本
for (int i = 0
// 获取数据并转为 Tensor
Tensor f = tensor(features[i % (int)totalSamples])
Tensor l = tensor(labels[i % (int)totalSamples])
// 封装为 Example 并存入 Vector
examples.push_back(new Example(f, l))
}
return new ExampleVectorOptional(examples)
}
/**
* 重写 size()
*/
@Override
public @ByVal SizeTOptional size() {
return new SizeTOptional(totalSamples)
}
/**
* 重写 reset()
* 在每个 epoch 开始时被调用
*/
@Override
public void reset() {
// 重置索引指针等逻辑
}
public static void main(String[] args) {
// 示例:在训练循环中使用
SimpleCustomDataset myDataset = new SimpleCustomDataset(features, labels)
// 模拟一次获取数据
ExampleVectorOptional batchOpt = myDataset.get_batch(32)
if (batchOpt.has_value()) {
ExampleVector batch = batchOpt.value()
// 处理 batch...
System.out.println("获取到样本数: " + batch.size())
}
}
}
```
在此示例中:
- `__init__` 方法存储在实例化时传入的特征和标签数据。
- `__len__` 简单地返回特征列表的长度(这与标签列表的长度相同)。
- `__getitem__` 接受一个索引 `idx`,获取对应的特征和标签,将它们转换为 PyTorch 张量,并以元组形式返回。这种转换为张量的操作在 `__getitem__` 中很常见。
自定义 `Dataset` 的真正作用体现在处理那些不能直接在内存中获取的数据时。例如,你的图像文件路径和标签可能存储在一个 CSV 文件中。
```scala 3
import torch.*
import torch.utils.data.Dataset
from PIL import Image // Python 图像库,用于图像加载
import pandas as pd
import os
class ImageFilelistDataset extends Dataset:
"""用于从 CSV 文件加载图像路径和标签的数据集。"""
def __init__(self, csv_file, root_dir, transform=None):
"""
参数:
csv_file (字符串): 包含标注的 CSV 文件路径。
假设列有:'image_path', 'label'
root_dir (字符串): 包含所有图像的目录。
transform (可调用, 可选): 可选的数据变换,用于对样本进行处理。
应用于样本。
"""
val annotations = pd.read_csv(csv_file)
val root_dir = root_dir
val transform = transform // 我们稍后会讨论数据变换
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
// 从 CSV 获取相对于 root_dir 的图像路径
val img_rel_path = annotations.iloc[idx, 0] // 假设第一列是路径
val img_full_path = os.path.join(root_dir, img_rel_path)
//使用 PIL 加载图像
try:
image = Image.open(img_full_path).convert('RGB')
except FileNotFoundError:
println(f"错误:未在 {img_full_path} 找到图像")
// 适当处理错误,例如返回 None 或抛出异常
// 为简单起见,这里我们将返回 None,并依赖 DataLoader 的 collate_fn
// 来处理它(或稍后过滤)。一个更好的方法
// 可能是事先清理 CSV 文件。
return None, None
// 从 CSV 获取标签
val label = annotations.iloc[idx, 1] // 假设第二列是标签
val label_tensor = torch.tensor(int(label), dtype=torch.long)
// 如果有,应用数据变换
if transform then
image = transform(image) // 数据变换通常会将 PIL 图像转换为张量
// 如果没有提供将图像转换为张量的数据变换,则手动转换
if not isinstance(image, torch.Tensor):
image = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) / 255.0
return image, label
// --- 示例用法(需要实际图像和 CSV)---
// 假设你拥有:
// 1. 文件夹 'data/images/' 包含图像文件(例如,cat1.jpg, dog1.png)
// 2. CSV 文件 'data/annotations.csv',内容如下:
// image_path,label
// images/cat1.jpg,0
// images/dog1.png,1
// ...
// 3. 确保 CSV 文件中没有空行或额外空格
// 4. 图像路径列中的值应与实际图像文件匹配(不区分大小写)
// 访问方式类似:
// print(f"图像数据集大小: {len(image_dataset)}")
// if len(image_dataset) > 0:
// img, lbl = image_dataset[0]
// if img is not None:
// println(f"第一个图像形状: {img.shape}") // 形状取决于数据变换
// println(f"第一个图像标签: {lbl}")
```
```java
import org.bytedeco.javacpp.LongPointer
import org.bytedeco.javacpp.SizeTPointer
import org.bytedeco.pytorch.*
import org.bytedeco.pytorch.global.torch
import org.apache.commons.io.FilenameUtils
import com.opencsv.CSVReader
import com.opencsv.exceptions.CsvException
import javax.imageio.ImageIO
import java.awt.image.BufferedImage
import java.io.File
import java.io.FileReader
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Paths
import java.util.ArrayList
import java.util.List
import java.util.function.Function
/**
* 从CSV加载图像路径和标签的自定义数据集,继承JavaDataset
* 对应Python的ImageFilelistDataset,适配Java图像加载和CSV解析逻辑
*/
public class ImageFilelistDataset extends JavaDataset {
// 存储CSV标注:每行[图像相对路径, 标签]
private List<String[]> annotations
// 图像根目录
private String rootDir
// 图像变换函数(可选,对应Python的transform)
private Function<BufferedImage, Tensor> transform
/**
* 构造方法(对应Python的__init__)
* @param csvFile CSV文件路径(列:image_path, label)
* @param rootDir 图像根目录
* @param transform 图像变换函数(可选)
*/
public ImageFilelistDataset(String csvFile, String rootDir, Function<BufferedImage, Tensor> transform) {
super()
this.rootDir = rootDir
this.transform = transform
// 读取CSV标注文件
try (CSVReader csvReader = new CSVReader(new FileReader(csvFile))) {
// 读取所有行(跳过表头)
annotations = csvReader.readAll()
if (!annotations.isEmpty()) {
annotations.remove(0)
}
if (annotations.isEmpty()) {
throw new IllegalArgumentException("CSV文件无有效标注数据:" + csvFile)
}
} catch (IOException | CsvException e) {
throw new RuntimeException("读取CSV文件失败:" + csvFile, e)
}
}
/**
* 返回数据集大小(对应Python的__len__)
*/
@Override
public SizeTOptional size() {
return new SizeTOptional(new SizeTPointer(annotations.size()))
}
/**
* 按索引获取图像样本(对应Python的__getitem__)
* @param index 样本索引
* @return Example 包含图像张量和标签张量的样本(失败时返回空Example)
*/
@Override
public Example get(long index) {
// 1. 边界检查
if (index < 0 || index >= annotations.size()) {
throw new IndexOutOfBoundsException("索引超出数据集范围: " + index)
}
// 2. 获取CSV中的图像路径和标签
String[] row = annotations.get((int) index)
if (row.length < 2) {
System.err.println("错误:CSV行数据不完整,索引=" + index)
return new Example(torch.empty(), torch.empty())
}
String imgRelPath = row[0].trim()
String labelStr = row[1].trim()
String imgFullPath = Paths.get(rootDir, imgRelPath).toString()
// 3. 加载图像
BufferedImage image = null
try {
image = ImageIO.read(new File(imgFullPath))
if (image == null) {
throw new IOException("图像无法解码:" + imgFullPath)
}
// 转换为RGB格式(确保3通道)
image = convertToRGB(image)
} catch (IOException e) {
System.err.println("错误:未在 " + imgFullPath + " 找到图像或加载失败")
return new Example(torch.empty(), torch.empty())
}
// 4. 处理标签(转换为Long张量)
long label
try {
label = Long.parseLong(labelStr)
} catch (NumberFormatException e) {
System.err.println("错误:标签格式无效,索引=" + index + ",标签值=" + labelStr)
return new Example(torch.empty(), torch.empty())
}
Tensor labelTensor = torch.tensor(
new LongPointer(label),
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long))
)
// 5. 应用图像变换(可选)
Tensor imageTensor
if (transform != null) {
// 使用自定义变换处理图像
imageTensor = transform.apply(image)
} else {
// 无变换时,手动转换为张量(HWC -> CHW,归一化到0-1)
imageTensor = convertImageToTensor(image)
}
// 6. 返回封装的Example(图像张量 + 标签张量)
return new Example(imageTensor, labelTensor)
}
/**
* 辅助方法:将BufferedImage转换为RGB格式(确保3通道)
*/
private BufferedImage convertToRGB(BufferedImage image) {
if (image.getType() == BufferedImage.TYPE_3BYTE_BGR || image.getType() == BufferedImage.TYPE_INT_RGB) {
return image
}
// 创建RGB格式的新图像
BufferedImage rgbImage = new BufferedImage(
image.getWidth(), image.getHeight(),
BufferedImage.TYPE_INT_RGB
)
rgbImage.getGraphics().drawImage(image, 0, 0, null)
return rgbImage
}
/**
* 辅助方法:将BufferedImage转换为PyTorch张量(CHW格式,归一化到0-1)
* 对应Python:torch.tensor(np.array(image), dtype=torch.float32).permute(2,0,1)/255.0
*/
private Tensor convertImageToTensor(BufferedImage image) {
int width = image.getWidth()
int height = image.getHeight()
float[] pixelData = new float[3 * height * width]
// 遍历图像像素,提取RGB值并归一化
int idx = 0
for (int c = 0
for (int h = 0
for (int w = 0
int pixel = image.getRGB(w, h)
float value
switch (c) {
case 0: value = ((pixel >> 16) & 0xFF) / 255.0f
case 1: value = ((pixel >> 8) & 0xFF) / 255.0f
case 2: value = (pixel & 0xFF) / 255.0f
default: value = 0.0f
}
pixelData[idx++] = value
}
}
}
// 创建PyTorch张量(形状:3 x H x W,Float32类型)
return torch.from_blob(
new org.bytedeco.javacpp.FloatPointer(pixelData),
new long[]{3, height, width},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float))
)
}
// ========== 示例用法 ==========
public static void main(String[] args) {
// 1. 配置路径(替换为你的实际路径)
String csvFile = "data/annotations.csv"
String rootDir = "data/"
// 2. 可选:定义图像变换(示例:调整大小后转张量)
Function<BufferedImage, Tensor> resizeTransform = img -> {
// 这里可添加调整大小、裁剪等变换逻辑,示例直接调用基础转换
return new ImageFilelistDataset(null, null, null).convertImageToTensor(img)
}
// 3. 创建数据集实例
ImageFilelistDataset imageDataset = new ImageFilelistDataset(csvFile, rootDir, resizeTransform)
// 4. 打印数据集大小
System.out.println("图像数据集大小: " + imageDataset.size().get())
// 5. 读取第一个样本
Example firstSample = imageDataset.get(0)
Tensor firstImage = firstSample.data()
Tensor firstLabel = firstSample.target()
if (!firstImage.is_empty()) {
System.out.println("\n第一个样本图像形状: " + tensorShapeToString(firstImage.sizes().vec().get()))
System.out.println("第一个样本标签: " + firstLabel.item().toLong())
} else {
System.out.println("\n第一个样本加载失败")
}
// 6. 释放张量资源
firstImage.close()
firstLabel.close()
}
/**
* 辅助方法:将张量形状转换为易读字符串
*/
private static String tensorShapeToString(SizeTPointer sizes) {
StringBuilder sb = new StringBuilder("[")
for (int i = 0
sb.append(sizes.get(i))
if (i < sizes.limit() - 1) {
sb.append(", ")
}
}
sb.append("]")
return sb.toString()
}
}
```
在这个 `ImageFilelistDataset` 示例中:
- `__init__` 使用 pandas 读取 CSV 文件,并存储文件路径和根目录。它还接受一个可选的 `transform` 参数(我们很快会看到它的用法)。
- `__len__` 返回 CSV 文件中的行数。
- `__getitem__` 构建完整的图像路径,使用 PIL 加载图像,获取标签,应用任何指定的数据变换,确保图像是一个张量,并返回图像张量和标签张量。
请注意,`Dataset` 本身只定义了 *如何* 获取单个项目。它不会一次性将整个数据集加载到内存中(除非你的 `__init__` 明确这样做,但这对于大型数据集通常是避免的)。它也不处理批处理、打乱或并行加载。`DataLoader` 便是为此而生,它直接建立在 `Dataset` 提供的结构之上。通过实现 `__len__` 和 `__getitem__`,你为 `DataLoader` 高效访问数据样本提供了必要的结构。
尽管创建自定义 `Dataset` 类为您的特定数据提供了最大的灵活性,但许多深度学习任务,特别是在研究和基准测试中,使用标准化数据集。手动准备这些数据集涉及下载、解压、组织文件和编写解析逻辑,这可能既耗时又容易出错。
幸运的是,PyTorch 提供了配套库,可以简化常见领域的数据处理过程。对于计算机视觉,`torchvision` 包是一个不可或缺的工具。它不仅包含流行的数据集,还包含预训练模型和常用的图像转换函数。本节主要介绍如何访问和使用 `torchvision.datasets` 提供的数据集。
`torchvision.datasets` 模块提供了对许多广泛使用的计算机视觉数据集的便捷访问,例如 MNIST、Fashion MNIST、CIFAR 10/100、ImageNet、COCO 等。使用这些数据集很简单。通常,您会从 `torchvision.datasets` 导入特定的数据集类并实例化它。
让我们看一个使用 CIFAR 10 数据集的例子,它包含 60,000 张 32x32 彩色图像,分为 10 个类别。
```scala 3
import torchvision
import torchvision.transforms as transforms
// 定义一个简单的转换,将图像转换为 PyTorch 张量
val transform = transforms.Compose([transforms.ToTensor()])
// 加载训练数据集
// root: 数据将被存储/查找的目录
// train=True: 指定训练集
// download=True: 如果本地未找到数据,则下载
// transform: 将定义的转换应用于每张图像
val train_dataset = torchvision.datasets.CIFAR10(root='./data',
train=true,
download=true,
transform=transform)
// 加载测试数据集
// root: 数据将被存储/查找的目录
// train=False: 指定测试集
// download=True: 如果本地未找到数据,则下载
// transform: 将定义的转换应用于每张图像
val test_dataset = torchvision.datasets.CIFAR10(root='./data',
train=false,
download=true,
transform=transform)
// 打印数据集大小
println(f"CIFAR-10 training dataset size: {len(train_dataset)}")
println(f"CIFAR-10 test dataset size: {len(test_dataset)}")
// 访问单个数据点(图像、标签)
val img = train_dataset[0]._1
println(f"Image shape: {img.shape}") // 通常输出:torch.Size([3, 32, 32])
println(f"Label: {train_dataset[0]._2}") // 输出:表示类别的整数
```
当您首次运行此代码时,`torchvision` 会检查指定的 `root` 目录(在本例中为 `./data`)。如果 CIFAR 10 数据不存在,设置 `download=True` 会指示 `torchvision` 自动将数据集下载并解压到该目录中。后续运行将发现数据已存在于本地并跳过下载。
注意 `transform` 参数。您可以在此处指定数据预处理步骤,这些步骤在数据样本加载后但在 `__getitem__` 返回之前应用于每个样本。我们使用了 `transforms.ToTensor()`,它将 PIL 图像格式(`torchvision` 数据集常用)转换为 PyTorch 张量。数据转换将在下一节中进行更详细的介绍。
重要的是,`torchvision.datasets` 返回的对象(如上文的 `train_dataset` 和 `test_dataset`)是继承自 `torch.utils.data.Dataset` 的类实例。这意味着它们实现了必需的 `__len__` 和 `__getitem__` 方法,使其与 PyTorch 的 `DataLoader` 完全兼容。
- `len(train_dataset)` 返回数据集中样本的总数。
- `train_dataset[i]` 返回第 i*i* 个样本,通常是一个元组 `(data, target)`,其中 `data` 是预处理后的输入(例如,图像张量),`target` 是对应的标签或标注。
以下是 CIFAR-10 训练集中类别分布的简单可视化:

> CIFAR-10 数据集是平衡的,每个类别恰好有 5,000 张训练图像。
尽管 `torchvision` 最为完善,但其他领域也存在类似的库:
- **`torchaudio`**: 为音频处理任务提供数据集(如 SpeechCommands、LJSpeech 等)、模型和转换功能。
- **`torchtext`**: 为自然语言处理提供数据集(如 IMDb 情感分析、WikiText 语言建模)、分词器和词汇工具。注意:`torchtext` 经历了重大的 API 变更,因此请查阅其文档以了解当前的使用模式。
使用这些库遵循相似的原则:导入所需的数据集类,实例化它(通常带有下载和预处理选项),然后将生成的 `Dataset` 对象与 `DataLoader` 一起使用。
依靠这些内置数据集可以显著加快开发和实验速度,使您能够专注于模型架构和训练,而不是数据获取和准备,尤其是在使用标准基准时。请记住,这些数据集对象直接与本章后面讨论的 `DataLoader` 集成,从而实现高效的批处理和洗牌。
原始数据,如图像或文本,很少直接以完美适合神经网络输入的格式出现。模型通常需要特定大小和分布的数值张量。此外,为了提高模型的泛化能力并防止过拟合,通常的做法是通过对现有数据应用随机修改来人工扩充训练数据集。这就是数据变换的作用。
PyTorch,特别是通过用于计算机视觉任务的 `torchvision` 库,提供了一个便捷的模块 `torchvision.transforms`,它包含多种常用操作,这些操作可以链式组合以创建数据处理流程。这些变换主要有两个作用:
1. **预处理:** 标准化数据格式、比例和大小。
2. **数据增强:** 对训练数据应用随机改动以增加其多样性。
让我们看一些基础变换。
这些变换通常应用于所有数据集划分(训练集、验证集和测试集),以确保一致性。
- **`transforms.ToTensor()`**:这通常是对使用PIL(Python图像库)或NumPy等库加载的图像数据最先应用的变换之一。它将PIL图像或NumPy数组(格式为高 x 宽 x 通道)转换为PyTorch `FloatTensor`(格式为通道 x 高 x 宽)。重要的是,它还将像素值从 [0, 255] 范围缩放到 [0.0, 1.0]。这种转换为张量和标准化范围对于模型输入是必需的。
- **`transforms.Resize(size)`**:将输入图像调整到给定 `size`。如果 `size` 是整数,图像的较短边将匹配此数字,同时保持宽高比。如果 `size` 是像 `(h, w)` 这样的序列,它会将图像调整为精确的高度 `h` 和宽度 `w`。这很重要,因为许多神经网络需要固定大小的输入。
- **`transforms.CenterCrop(size)`**:将图像的中心部分裁剪到给定 `size`。这通常在调整大小后使用,以确保最终图像尺寸精确,同时聚焦于中心区域。
- **`transforms.Normalize(mean, std)`**:使用为每个通道提供的均值和标准差对张量图像进行标准化。应用的操作是: 输出=(输入−均值)/标准差输出=(输入−均值)/标准差 标准化有助于稳定训练,并通过确保输入特征具有相似的比例(通常围绕零居中)来促进更快的收敛。`mean` 和 `std` 通常是值序列,每个输入通道对应一个值(例如,RGB图像有3个值)。来自ImageNet等大型数据集的预计算值通常用作默认值:`mean=[0.485, 0.456, 0.406]` 和 `std=[0.229, 0.224, 0.225]`。
这些变换引入随机性,并且通常*仅*应用于训练数据集。这有助于模型学习对输入中的微小变化保持不变,从而降低过拟合倾向。
- **`transforms.RandomHorizontalFlip(p=0.5)`**:以给定概率 `p`(默认为0.5,表示50%的机会)随机水平翻转图像。
- **`transforms.RandomRotation(degrees)`**:通过从 `(-degrees, +degrees)` 中均匀选择的随机角度旋转图像,或者如果 `degrees` 是序列 `(min, max)`,则在特定范围内旋转。
- **`transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)`**:随机改变图像的亮度、对比度、饱和度和色调。你可以指定每个属性的抖动范围。例如,`brightness=0.2` 意味着随机选择一个介于 `[max(0, 1 - 0.2), 1 + 0.2]` 之间的亮度因子。
- **`transforms.RandomResizedCrop(size)`**:裁剪图像的随机部分并将其调整到所需的 `size`。这是一种常用增强技术,特别适用于训练Inception网络等图像分类模型。
你很少只应用一个变换。PyTorch 通过 `transforms.Compose` 方便地将多个变换链式组合起来。它接收一个变换对象列表并按顺序应用它们。
下面是为训练数据创建处理流程的示例,包括调整大小、增强、转换为张量和标准化:
```scala 3
import torchvision.transforms as transforms
// 训练数据的变换流程示例
val train_transform = transforms.Compose(
transforms.Resize(256), // 将较短边调整为256
transforms.RandomCrop(224), // 随机裁剪224x224的区域
transforms.RandomHorizontalFlip(), // 随机水平翻转
transforms.ToTensor(), // 将PIL图像转换为张量(0-1范围)
transforms.Normalize(mean=Seq(0.485, 0.456, 0.406), // 使用ImageNet统计数据进行标准化
std=Seq(0.229, 0.224, 0.225))
)
// 验证/测试数据的变换流程示例(无增强)
val test_transform = transforms.Compose(
transforms.Resize(256), // 将较短边调整为256
transforms.CenterCrop(224), // 中心裁剪到224x224
transforms.ToTensor(), // 将PIL图像转换为张量(0-1范围)
transforms.Normalize(mean=Seq(0.485, 0.456, 0.406), // 使用ImageNet统计数据进行标准化
std=Seq(0.229, 0.224, 0.225))
)
println("训练变换:")
println(train_transform)
println("\n测试变换:")
println(test_transform)
```
正如上一节关于 `Dataset` 对象所述,这些组合变换通常在实例化 `Dataset` 时作为参数(通常命名为 `transform` 或 `target_transform`)传入。对于 `torchvision.datasets` 中的内置数据集,这直接简单:
```scala 3
// 假设您已安装 torchvision
import torchvision.datasets as datasets
import java.nio.file.Path
// torchvision's ImageFolder 的使用示例
val train_data_path = Path("path/to/your/train_images")
val test_data_path = Path("path/to/your/test_images")
val train_dataset = datasets.ImageFolder(root=train_data_path, transform=train_transform)
val test_dataset = datasets.ImageFolder(root=test_data_path, transform=test_transform)
// 当您从 train_dataset 访问一个项时,train_transform 会被应用
val sample_image = train_dataset(0)._1 // sample_image 现在是一个经过变换的张量
val sample_label = train_dataset(0)._2 // sample_label 是一个整数标签
```
对于自定义 `Dataset` 类,您通常会在 `__init__` 方法中接受变换对象,并在返回样本之前在 `__getitem__` 方法中应用它。
```scala 3
import torch.utils.data.Dataset
import PIL.Image
class CustomImageDataset extends Dataset[(Image, Int)]:
def __init__(self, image_paths, labels, transform=None):
val image_paths = image_paths
val labels = labels
val transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
val image_path = image_paths(idx)
val label = labels(idx)
val image = Image.open(image_path).convert("RGB") // 加载图像
if transform then
image = transform(image) // 应用变换
return image, label
//使用方法
val custom_train_dataset = CustomImageDataset(train_paths, train_labels, transform=train_transform)
val custom_test_dataset = CustomImageDataset(test_paths, test_labels, transform=test_transform)
```
```java
package featurestore
import org.bytedeco.javacpp.LongPointer
import org.bytedeco.javacpp.SizeTPointer
import org.bytedeco.pytorch.*
import org.bytedeco.pytorch.global.torch
import javax.imageio.ImageIO
import java.awt.image.BufferedImage
import java.io.File
import java.io.IOException
import java.util.List
import java.util.function.Function
/**
* 简化版自定义图像数据集,继承JavaDataset
* 对应Scala的CustomImageDataset[(Image, Int)],极简实现核心逻辑
*/
public class CustomImageDataset extends JavaDataset {
// 图像路径列表、标签列表
private final List<String> imagePaths
private final List<Integer> labels
// 图像变换函数(可选)
private final Function<BufferedImage, Tensor> transform
/**
* 构造方法(对应Scala的__init__)
* @param imagePaths 图像路径列表
* @param labels 标签列表(Int类型)
* @param transform 图像变换函数(可选)
*/
public CustomImageDataset(List<String> imagePaths, List<Integer> labels, Function<BufferedImage, Tensor> transform) {
super()
// 基础校验:路径和标签长度一致
if (imagePaths.size() != labels.size()) {
throw new IllegalArgumentException("图像路径列表和标签列表长度必须相同")
}
this.imagePaths = imagePaths
this.labels = labels
this.transform = transform
}
/**
* 返回数据集大小(对应Scala的__len__)
*/
@Override
public SizeTOptional size() {
return new SizeTOptional(new SizeTPointer(imagePaths.size()))
}
/**
* 按索引获取样本(对应Scala的__getitem__)
* @param index 样本索引(size_t类型)
* @return Example 封装图像张量和标签张量的样本
*/
@Override
public Example get(long index) {
// 1. 边界校验
int idx = (int) index
if (idx < 0 || idx >= imagePaths.size()) {
throw new IndexOutOfBoundsException("索引超出范围: " + index)
}
// 2. 获取路径和标签
String imagePath = imagePaths.get(idx)
int label = labels.get(idx)
// 3. 加载图像并转换为RGB
BufferedImage image
try {
image = ImageIO.read(new File(imagePath))
if (image == null) {
throw new IOException("图像文件无法解码: " + imagePath)
}
// 确保RGB格式(对应Scala的convert("RGB"))
image = convertToRGB(image)
} catch (IOException e) {
throw new RuntimeException("加载图像失败: " + imagePath, e)
}
// 4. 应用变换(对应Scala的if transform then image = transform(image))
Tensor imageTensor
if (transform != null) {
imageTensor = transform.apply(image)
} else {
// 无变换时,默认转换为CHW格式的Float32张量(归一化到0-1)
imageTensor = convertImageToTensor(image)
}
// 5. 转换标签为Long张量(对应Scala的Int标签)
Tensor labelTensor = torch.tensor(
new LongPointer(label),
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Long))
)
// 6. 返回封装的样本(对应Scala的(image, label)元组)
return new Example(imageTensor, labelTensor)
}
// ------------------- 辅助方法 -------------------
/**
* 将BufferedImage转换为RGB格式(确保3通道)
*/
private BufferedImage convertToRGB(BufferedImage image) {
if (image.getType() == BufferedImage.TYPE_INT_RGB || image.getType() == BufferedImage.TYPE_3BYTE_BGR) {
return image
}
BufferedImage rgbImage = new BufferedImage(
image.getWidth(), image.getHeight(),
BufferedImage.TYPE_INT_RGB
)
rgbImage.getGraphics().drawImage(image, 0, 0, null)
return rgbImage
}
/**
* 无自定义变换时,将图像转为PyTorch标准张量(CHW + 归一化)
*/
private Tensor convertImageToTensor(BufferedImage image) {
int h = image.getHeight()
int w = image.getWidth()
float[] data = new float[3 * h * w]
int idx = 0
// 遍历像素,提取RGB值并归一化到[0,1]
for (int c = 0
for (int y = 0
for (int x = 0
int pixel = image.getRGB(x, y)
float val = switch (c) {
case 0 -> ((pixel >> 16) & 0xFF) / 255.0f
case 1 -> ((pixel >> 8) & 0xFF) / 255.0f
case 2 -> (pixel & 0xFF) / 255.0f
default -> 0.0f
}
data[idx++] = val
}
}
}
// 创建Float32张量,形状为[3, H, W]
return torch.from_blob(
new org.bytedeco.javacpp.FloatPointer(data),
new long[]{3, h, w},
new TensorOptions().dtype(new ScalarTypeOptional(torch.ScalarType.Float))
)
}
// ------------------- 示例使用方法 -------------------
public static void main(String[] args) {
// 1. 模拟训练/测试数据(替换为你的实际路径和标签)
List<String> trainPaths = List.of("data/train/cat1.jpg", "data/train/dog1.jpg")
List<Integer> trainLabels = List.of(0, 1)
List<String> testPaths = List.of("data/test/cat2.jpg", "data/test/dog2.jpg")
List<Integer> testLabels = List.of(0, 1)
// 2. 定义训练/测试变换(示例:空变换,可自定义Resize/Normalize等)
Function<BufferedImage, Tensor> trainTransform = null
Function<BufferedImage, Tensor> testTransform = null
// 3. 创建数据集实例(对应Scala的custom_train_dataset/custom_test_dataset)
CustomImageDataset customTrainDataset = new CustomImageDataset(trainPaths, trainLabels, trainTransform)
CustomImageDataset customTestDataset = new CustomImageDataset(testPaths, testLabels, testTransform)
// 4. 验证数据集
System.out.println("训练数据集大小: " + customTrainDataset.size().get())
System.out.println("测试数据集大小: " + customTestDataset.size().get())
// 5. 获取第一个训练样本
Example firstTrainSample = customTrainDataset.get(0)
System.out.println("第一个训练样本图像形状: " + tensorShape(firstTrainSample.data().sizes().vec().get()))
System.out.println("第一个训练样本标签: " + firstTrainSample.target().item().toLong())
// 释放资源
firstTrainSample.data().close()
firstTrainSample.target().close()
}
/**
* 辅助方法:打印张量形状
*/
private static String tensorShape(SizeTPointer sizes) {
StringBuilder sb = new StringBuilder("[")
for (int i = 0
sb.append(sizes.get(i))
if (i < sizes.limit() - 1) sb.append(", ")
}
sb.append("]")
return sb.toString()
}
}
```
```java
import org.bytedeco.pytorch.*
import org.bytedeco.javacpp.*
import org.bytedeco.javacpp.annotation.*
import java.io.File
import java.util.List
import static org.bytedeco.pytorch.global.torch.*
/**
* 转译自 Python CustomImageDataset
* 继承 ChunkDataset 以适配 LibTorch 的数据加载管线
*/
public class CustomImageDataset extends ChunkDataset {
public interface ImageTransform {
Tensor apply(Tensor input)
}
static { Loader.load()
private List<String> imagePaths
private int[] labels
private long currentIndex = 0
private long totalSamples
// 假设 transform 是一个简单的缩放/归一化接口
private ImageTransform transform
public CustomImageDataset(Pointer p) { super(p)
public CustomImageDataset(List<String> imagePaths, int[] labels, ImageTransform transform) {
super((Pointer)null)
this.imagePaths = imagePaths
this.labels = labels
this.transform = transform
this.totalSamples = imagePaths.size()
}
/**
* 实现 __getitem__ 逻辑的批处理版本
*/
@Override
public @ByVal ExampleVectorOptional get_batch(@Cast("size_t") long batch_size) {
// 1. 检查是否读完(每个 Epoch 结束信号)
if (currentIndex >= totalSamples) {
return new ExampleVectorOptional()
}
ExampleVector examples = new ExampleVector()
long limit = Math.min(currentIndex + batch_size, totalSamples)
for (long i = currentIndex
// 2. 加载图像 (对应 PIL.Image.open)
// 这里建议使用 OpenCV 或 Java 的 ImageIO 读取
Tensor imgTensor = loadImageAsTensor(imagePaths.get((int)i))
// 3. 应用 Transform
if (transform != null) {
imgTensor = transform.apply(imgTensor)
}
// 4. 包装 Label
Tensor labelTensor = tensor(labels[(int)i], new TensorOptions().dtype(new ScalarTypeOptional(kInt())))
examples.push_back(new Example(imgTensor, labelTensor))
}
currentIndex = limit
return new ExampleVectorOptional(examples)
}
@Override
public void reset() {
// 每个 Epoch 开始时重置指针
this.currentIndex = 0
}
@Override
public @ByVal SizeTOptional size() {
return new SizeTOptional(totalSamples)
}
/**
* 辅助方法:图像转 Tensor (简化版)
* 实际开发中可以使用 opencv_imgcodecs.imread 并转换颜色空间
*/
private Tensor loadImageAsTensor(String path) {
// 伪代码:加载图片并归一化到 [0, 1] 范围的 [C, H, W] Tensor
// return OpenCVUtils.toTensor(path)
return zeros(new long[]{3, 224, 224},new TensorOptions().dtype(new ScalarTypeOptional(kFloat())))
}
}
```
通过定义适当的变换并将其集成到您的 `Dataset` 中,您可以确保输入到模型的数据格式正确,并且对于训练数据而言,得到充分增强。这为下一步使用 `DataLoader` 有效地按批加载这些已处理的数据做好了准备。