pytorch加载数据:Dataset类 & Dataloader类

217 阅读1分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情

1 Dataset类

Dataset 抽象类,所有数据集需要继承这个类。 提供一种方式去获取数据及其label、数据集长度。

子类需要重写__getitem__(idx)方法

  • __getitem__(idx) 返回索引idx对应的img和label
  • __len()__ 返回数据集长度

2 Dataloader类

Dataloader类 对数据进行打包,为后面的网络提供不同的数据形式

常用参数: dataset: Dataset类, 决定数据从哪读取以及如何读取 bathsize: 批大小 num_works: 是否多进程读取机制 shuffle: 每个epoch是否乱序 drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据

3 TensorBoards使用

安装:pip install tensorboard

from torch.utils.tensorboard import SummaryWriter

SummaryWriter类: 向指定文件夹写入事件文件,该文件可被tensorboard解析

writer=SummaryWriter("logs") #向logs文件夹写入事件文件
writer.add_images() #add_image的输入格式必须为numpy.array 或者其他,所以用numpy读取,image_array.shape=(512, 768, 3),因此设置dataformats='HWC'
writer.add_scalar() #添加名称、数据、step
witer.close() #关闭
  • add_images(self, tag, img_tensor, global_step=None)方法:通常用来观测训练结果
  • add_scalar()方法:通常用来绘制loss

查看: 运行程序之后,在控制台输入tensorboard --logdir=logs --port=6067

4 Transforms使用

对图片进行变换 常用类

  • ToTensor类型变换
  • Normalize归一化
  • Resize 改变尺寸
    • 输入:PIL Image
    • 输出:PIL Image
  • Compose
  • RandomCrop() 随机裁剪