数据加载与预处理

306 阅读2分钟

数据加载与预处理

1. Dataset与DataLoader使用

1.1 数据处理流程架构

graph LR
    A[原始数据] --> B(Dataset类)
    B --> C{DataLoader}
    C --> D[批量数据]
    style A fill:#9f9,stroke:#333
    style D fill:#f99,stroke:#333

1.2 内置Dataset使用示例

from torchvision import datasets
from torch.utils.data import DataLoader

# MNIST数据集加载
train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    transform=transforms.ToTensor()
)

# DataLoader配置
train_loader = DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

1.3 DataLoader核心参数解析

参数类型说明推荐值
batch_sizeint单批数据量32-256
shufflebool是否打乱数据True(训练集)
num_workersint加载线程数CPU核心数×2
pin_memorybool锁页内存加速GPU训练时启用
drop_lastbool丢弃不足批次大数据集启用

2. 自定义数据集与数据增强

2.1 自定义Dataset类模板

from torch.utils.data import Dataset
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, img_dir, label_file, transform=None):
        self.img_dir = Path(img_dir)
        self.labels = pd.read_csv(label_file)
        self.transform = transform

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img_path = self.img_dir / self.labels.iloc[idx, 0]
        image = Image.open(img_path).convert('RGB')
        label = self.labels.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

2.2 数据增强策略

2.2.1 常用图像变换组合
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
2.2.2 自定义变换示例(马赛克增强)
class MosaicAugmentation:
    def __init__(self, size=224):
        self.size = size
        
    def __call__(self, images):
        """接收4张图像生成马赛克"""
        mosaic = Image.new('RGB', (2*self.size, 2*self.size))
        # 左上角
        mosaic.paste(images[0].resize((self.size, self.size)), (0, 0))
        # 右上角
        mosaic.paste(images[1].resize((self.size, self.size)), (self.size, 0))
        # 左下角
        mosaic.paste(images[2].resize((self.size, self.size)), (0, self.size))
        # 右下角
        mosaic.paste(images[3].resize((self.size, self.size)), 
                    (self.size, self.size))
        return mosaic

3. 多线程加载与批处理优化

3.1 并行加载原理

graph TD
    A[主进程] --> B[Worker1]
    A --> C[Worker2]
    A --> D[Worker3]
    A --> E[Worker4]
    B --> F[数据队列]
    C --> F
    D --> F
    E --> F
    F --> G[模型训练]

3.2 性能优化实验

3.2.1 不同num_workers对比
import time

def benchmark_loader(loader):
    start = time.time()
    for _ in loader: pass
    return time.time() - start

workers = [0, 2, 4, 8, 16]
times = []
for w in workers:
    loader = DataLoader(dataset, num_workers=w, batch_size=64)
    times.append(benchmark_loader(loader))

plt.plot(workers, times)
plt.xlabel('Number of workers')
plt.ylabel('Loading time (s)')
3.2.2 最佳实践总结
  • num_workers:设置为可用CPU核心数的2倍
  • pin_memory:当使用GPU时设置为True
  • prefetch_factor:增大预取批次(默认2)
  • persistent_workers:减少进程频繁创建销毁

3.3 内存优化技巧

3.3.1 共享内存策略
# 使用共享内存加速多进程
shared_transform = transforms.Lambda(lambda x: x)
shared_dataset = CustomDataset(transform=shared_transform)

# 在多个Loader间共享
loader1 = DataLoader(shared_dataset, num_workers=4)
loader2 = DataLoader(shared_dataset, num_workers=4)
3.3.2 分页锁定内存
# 自动启用条件
loader = DataLoader(..., pin_memory=True)

# 手动管理
tensor = tensor.pin_memory()
tensor = tensor.to(device, non_blocking=True)

附录:预处理数学原理

标准化计算公式

对每个通道进行归一化: xnormalized(c)=x(c)μ(c)σ(c)x_{\text{normalized}}^{(c)} = \frac{x^{(c)} - \mu^{(c)}}{\sigma^{(c)}} 其中:

  • μ(c)\mu^{(c)} 为通道c的均值
  • σ(c)\sigma^{(c)} 为通道c的标准差

随机裁剪算法

���定输入尺寸(H,W)(H, W)和目标尺寸(h,w)(h, w)

  1. 生成随机位置(i,j)(i, j),满足: 0iHh0 \leq i \leq H - h 0jWw0 \leq j \leq W - w
  2. 裁剪区域: x=input[:,i:i+h,j:j+w]x = \text{input}[:, i:i+h, j:j+w]

高级技巧:自定义缓存机制

from torch.utils.data import Dataset

class CachedDataset(Dataset):
    def __init__(self, base_dataset, cache_size=1024):
        self.base = base_dataset
        self.cache = [None] * cache_size
        self.cache_idx = set()

    def __getitem__(self, idx):
        if idx in self.cache_idx:
            return self.cache[idx % len(self.cache)]
        else:
            data = self.base[idx]
            self.cache[idx % len(self.cache)] = data
            self.cache_idx.add(idx)
            return data

说明:本文所有代码均基于PyTorch 2.1实现,数据增强示例需配合OpenCV或PIL库使用。建议使用torchdata库实现更复杂的数据流水线。下一章将深入讲解模型训练全流程! 🚀