图像分类数据集|线性神经网络|动手学深度学习

112 阅读2分钟

1. 减少batch_size(如减少到1)是否会影响读取性能?

batch_sizes = [max(1, 2 ** x) for x in range(int(math.log(2048, 2)) + 1)]

dataloader_workers = [min(12, max(1, 2 ** x)) for x in range(int(math.log(12, 2)) + 2)]

for w in dataloader_workers[::-1]:
    for b in batch_sizes[::-1]:
        train_iter = data.DataLoader(mnist_train, batch_size=b, shuffle=True, num_workers=w)
        timer = d2l.Timer()
        for X, y in train_iter:
            continue
        print(f'{w} workers, {b} batches: {timer.stop():.4f} sec')
12 workers, 2048 batches: 0.5365 sec
12 workers, 1024 batches: 0.4758 sec
12 workers, 512 batches: 0.4671 sec
12 workers, 256 batches: 0.4725 sec
12 workers, 128 batches: 0.5044 sec
12 workers, 64 batches: 0.5802 sec
12 workers, 32 batches: 0.8488 sec
12 workers, 16 batches: 1.4198 sec
12 workers, 8 batches: 2.5134 sec
12 workers, 4 batches: 4.8664 sec
12 workers, 2 batches: 9.6087 sec
12 workers, 1 batches: 19.0353 sec
8 workers, 2048 batches: 0.4897 sec
8 workers, 1024 batches: 0.4825 sec
8 workers, 512 batches: 0.4926 sec
8 workers, 256 batches: 0.5208 sec
8 workers, 128 batches: 0.5546 sec
8 workers, 64 batches: 0.5693 sec
8 workers, 32 batches: 0.9099 sec
8 workers, 16 batches: 1.3812 sec
8 workers, 8 batches: 2.5243 sec
8 workers, 4 batches: 4.8211 sec
8 workers, 2 batches: 9.5129 sec
8 workers, 1 batches: 18.8237 sec
4 workers, 2048 batches: 0.4750 sec
4 workers, 1024 batches: 0.4612 sec
4 workers, 512 batches: 0.4992 sec
4 workers, 256 batches: 0.5076 sec
4 workers, 128 batches: 0.6466 sec
4 workers, 64 batches: 0.7256 sec
4 workers, 32 batches: 0.9570 sec
4 workers, 16 batches: 1.3206 sec
4 workers, 8 batches: 2.4388 sec
4 workers, 4 batches: 4.5724 sec
4 workers, 2 batches: 9.0541 sec
4 workers, 1 batches: 17.9737 sec
2 workers, 2048 batches: 0.7912 sec
2 workers, 1024 batches: 0.7947 sec
2 workers, 512 batches: 0.8041 sec
2 workers, 256 batches: 0.8216 sec
2 workers, 128 batches: 1.0138 sec
2 workers, 64 batches: 1.1117 sec
2 workers, 32 batches: 1.2738 sec
2 workers, 16 batches: 1.7465 sec
2 workers, 8 batches: 2.7226 sec
2 workers, 4 batches: 4.9756 sec
2 workers, 2 batches: 8.6975 sec
2 workers, 1 batches: 17.0822 sec
1 workers, 2048 batches: 1.4968 sec
1 workers, 1024 batches: 1.4890 sec
1 workers, 512 batches: 1.4888 sec
1 workers, 256 batches: 1.5537 sec
1 workers, 128 batches: 1.6774 sec
1 workers, 64 batches: 2.0337 sec
1 workers, 32 batches: 2.3070 sec
1 workers, 16 batches: 3.1514 sec
1 workers, 8 batches: 5.0449 sec
1 workers, 4 batches: 7.6472 sec
1 workers, 2 batches: 13.5933 sec
1 workers, 1 batches: 26.5768 sec

2. 数据迭代器的性能非常重要。当前的实现足够快吗?探索各种选择来改进它。

从👆的实验结果来看,worker = 4, batches = 1024 的时候性能最佳

3. 查阅框架的在线API文档。还有哪些其他数据集可用?

torchversion.datasets

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively. You can also create your own datasets using the provided base classes.

Image classification

Caltech101(root[, target_type, transform, ...])Caltech 101 Dataset.
Caltech256(root[, transform, ...])Caltech 256 Dataset.
CelebA(root[, split, target_type, ...])Large-scale CelebFaces Attributes (CelebA) Dataset Dataset.
CIFAR10(root[, train, transform, ...])CIFAR10 Dataset.
CIFAR100(root[, train, transform, ...])CIFAR100 Dataset.
Country211(root[, split, transform, ...])The Country211 Data Set from OpenAI.
DTD(root[, split, partition, transform, ...])Describable Textures Dataset (DTD).
EMNIST(root, split, **kwargs)EMNIST Dataset.
EuroSAT(root[, transform, target_transform, ...])RGB version of the EuroSAT Dataset.
FakeData([size, image_size, num_classes, ...])A fake dataset that returns randomly generated images and returns them as PIL images
FashionMNIST(root[, train, transform, ...])Fashion-MNIST Dataset.

...