数据加载与预处理
1. Dataset与DataLoader使用
1.1 数据处理流程架构
graph LR
A[原始数据] --> B(Dataset类)
B --> C{DataLoader}
C --> D[批量数据]
style A fill:#9f9,stroke:#333
style D fill:#f99,stroke:#333
1.2 内置Dataset使用示例
from torchvision import datasets
from torch.utils.data import DataLoader
# MNIST数据集加载
train_data = datasets.MNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True
)
test_data = datasets.MNIST(
root='data',
train=False,
transform=transforms.ToTensor()
)
# DataLoader配置
train_loader = DataLoader(
dataset=train_data,
batch_size=64,
shuffle=True,
num_workers=4,
pin_memory=True
)
1.3 DataLoader核心参数解析
| 参数 | 类型 | 说明 | 推荐值 |
|---|---|---|---|
batch_size | int | 单批数据量 | 32-256 |
shuffle | bool | 是否打乱数据 | True(训练集) |
num_workers | int | 加载线程数 | CPU核心数×2 |
pin_memory | bool | 锁页内存加速 | GPU训练时启用 |
drop_last | bool | 丢弃不足批次 | 大数据集启用 |
2. 自定义数据集与数据增强
2.1 自定义Dataset类模板
from torch.utils.data import Dataset
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, img_dir, label_file, transform=None):
self.img_dir = Path(img_dir)
self.labels = pd.read_csv(label_file)
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_path = self.img_dir / self.labels.iloc[idx, 0]
image = Image.open(img_path).convert('RGB')
label = self.labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
return image, label
2.2 数据增强策略
2.2.1 常用图像变换组合
from torchvision import transforms
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
2.2.2 自定义变换示例(马赛克增强)
class MosaicAugmentation:
def __init__(self, size=224):
self.size = size
def __call__(self, images):
"""接收4张图像生成马赛克"""
mosaic = Image.new('RGB', (2*self.size, 2*self.size))
# 左上角
mosaic.paste(images[0].resize((self.size, self.size)), (0, 0))
# 右上角
mosaic.paste(images[1].resize((self.size, self.size)), (self.size, 0))
# 左下角
mosaic.paste(images[2].resize((self.size, self.size)), (0, self.size))
# 右下角
mosaic.paste(images[3].resize((self.size, self.size)),
(self.size, self.size))
return mosaic
3. 多线程加载与批处理优化
3.1 并行加载原理
graph TD
A[主进程] --> B[Worker1]
A --> C[Worker2]
A --> D[Worker3]
A --> E[Worker4]
B --> F[数据队列]
C --> F
D --> F
E --> F
F --> G[模型训练]
3.2 性能优化实验
3.2.1 不同num_workers对比
import time
def benchmark_loader(loader):
start = time.time()
for _ in loader: pass
return time.time() - start
workers = [0, 2, 4, 8, 16]
times = []
for w in workers:
loader = DataLoader(dataset, num_workers=w, batch_size=64)
times.append(benchmark_loader(loader))
plt.plot(workers, times)
plt.xlabel('Number of workers')
plt.ylabel('Loading time (s)')
3.2.2 最佳实践总结
- num_workers:设置为可用CPU核心数的2倍
- pin_memory:当使用GPU时设置为True
- prefetch_factor:增大预取批次(默认2)
- persistent_workers:减少进程频繁创建销毁
3.3 内存优化技巧
3.3.1 共享内存策略
# 使用共享内存加速多进程
shared_transform = transforms.Lambda(lambda x: x)
shared_dataset = CustomDataset(transform=shared_transform)
# 在多个Loader间共享
loader1 = DataLoader(shared_dataset, num_workers=4)
loader2 = DataLoader(shared_dataset, num_workers=4)
3.3.2 分页锁定内存
# 自动启用条件
loader = DataLoader(..., pin_memory=True)
# 手动管理
tensor = tensor.pin_memory()
tensor = tensor.to(device, non_blocking=True)
附录:预处理数学原理
标准化计算公式
对每个通道进行归一化: 其中:
- 为通道c的均值
- 为通道c的标准差
随机裁剪算法
���定输入尺寸和目标尺寸:
- 生成随机位置,满足:
- 裁剪区域:
高级技巧:自定义缓存机制
from torch.utils.data import Dataset
class CachedDataset(Dataset):
def __init__(self, base_dataset, cache_size=1024):
self.base = base_dataset
self.cache = [None] * cache_size
self.cache_idx = set()
def __getitem__(self, idx):
if idx in self.cache_idx:
return self.cache[idx % len(self.cache)]
else:
data = self.base[idx]
self.cache[idx % len(self.cache)] = data
self.cache_idx.add(idx)
return data
说明:本文所有代码均基于PyTorch 2.1实现,数据增强示例需配合OpenCV或PIL库使用。建议使用torchdata库实现更复杂的数据流水线。下一章将深入讲解模型训练全流程! 🚀