持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第2天,点击查看活动详情
1 Dataset类
Dataset 抽象类,所有数据集需要继承这个类。 提供一种方式去获取数据及其label、数据集长度。
子类需要重写__getitem__(idx)方法
__getitem__(idx)返回索引idx对应的img和label__len()__返回数据集长度
2 Dataloader类
Dataloader类 对数据进行打包,为后面的网络提供不同的数据形式
常用参数: dataset: Dataset类, 决定数据从哪读取以及如何读取 bathsize: 批大小 num_works: 是否多进程读取机制 shuffle: 每个epoch是否乱序 drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据
3 TensorBoards使用
安装:pip install tensorboard
from torch.utils.tensorboard import SummaryWriter
SummaryWriter类: 向指定文件夹写入事件文件,该文件可被tensorboard解析
writer=SummaryWriter("logs") #向logs文件夹写入事件文件
writer.add_images() #add_image的输入格式必须为numpy.array 或者其他,所以用numpy读取,image_array.shape=(512, 768, 3),因此设置dataformats='HWC'
writer.add_scalar() #添加名称、数据、step
witer.close() #关闭
add_images(self, tag, img_tensor, global_step=None)方法:通常用来观测训练结果add_scalar()方法:通常用来绘制loss
查看: 运行程序之后,在控制台输入tensorboard --logdir=logs --port=6067
4 Transforms使用
对图片进行变换 常用类
- ToTensor类型变换
- Normalize归一化
- Resize 改变尺寸
- 输入:PIL Image
- 输出:PIL Image
- Compose
- RandomCrop() 随机裁剪