Pytorch基础之数据集加载(上)

1,042 阅读2分钟
参与拿奖:本文已参与「新人创作礼」活动,一起开启掘金创作之路

1.学习准备

1.1 需要安装的包

opencv-python
torch

1.2 数据准备

import os
# 获取数据指定目录下的所有文件名
path=os.getcwd()
filename=os.listdir(path+"\data\01 data")
# 打开一张图片
from torch.utils.data import Dataset
from PIL import Image #读取图片的库
import torch

img_path=path+"\data\01 data\train\ants\1030023514_aad5c608f9.jpg"
img=Image.open(img_path) #打开图片
img.show() #展示图片

image.png

使用类继承DataSet,定义相关函数

import os
from torch.utils.data import Dataset
from PIL import Image #读取图片的库
import torch



# 使用类进行封装
class myData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)
        
        
    def __getitem__(self,idx):
        img_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label
    
    def __len__(self):
        return len(self.img_path)
        

打印查看结果

if __name__=="__main__":
    import os
    root_dir=os.getcwd()+"\data\01 data\train\" #训练集
    label_dir=['bees','ants']
    path=os.path.join(root_dir,label_dir[0])
    print("bees path ",path)
    path=os.path.join(root_dir,label_dir[1])
    print("ants path ",path)
    img_path=os.listdir(path)
    idx=0
    img_name=img_path[idx]
    print(img_name)
    Image.open(path+"\"+img_path[0]).show() #查看图片
    
    # 使用myData
    data=myData(root_dir,label_dir[0]) #获取数据 
    print(len(data))#打印长度
    img,label=data[0]
    img.show()
    print(label)
    
    # 获取训练集
    d_train_ant=myData(root_dir,label_dir[0])
    d_train_bee=myData(root_dir,label_dir[1])
    
    # 打印长度
    print(len(d_train_ant),len(d_train_bee))

image.png

2.使用Tensorboard进行图像变换

2.1 导入相关包

from torch.utils.tensorboard import SummaryWriter
import os
writer=SummaryWriter("logs")

# writer.add_image()
# 显示函数曲线
for i in range(100):
    writer.add_scalar('y=2x',i*2,i)

writer.close()
执行上述程序后会生成一个logs文件夹。
在终端运行命令:tensorboard --logdir=logs (可以·加--port=5555 用来设置不同端口)。
运行后打开生成的链接即可。

image.png

image.png

注意:执行上述命令时,必须在logs目录文件夹外(等号右边必须为logs目录的存放位置)。

2.2 图片转换

import os
path=os.getcwd()
img_path=path+"\data\01 data\train\ants\1030023514_aad5c608f9.jpg"
img_PIL=Image.open(img_path)
print("img type",type(img_PIL))

# 类型转换
import numpy as np
img_array=np.array(img_PIL) #转换为数组类型

writer=SummaryWriter("logs")

writer.add_image("test",img_array,1,dataformats="HWC")

writer.close()

3.transformers工具

3.1 多种转换操作


from torchvision import transforms

import os
path=os.getcwd()
img_path=path+"\data\01 data\train\ants\1030023514_aad5c608f9.jpg"
img_PIL=Image.open(img_path) #读取图片


writer=SummaryWriter("logs") #定义存放目录

# 转换为tensor
tensor_trans=transforms.ToTensor() #实例化
tensor_img=tensor_trans(img_PIL) #将图片转换为tensor数据类型
writer.add_image("Tensor_img",tensor_img) #添加到tensorboard

# 标准化
trans_norm=transforms.Normalize([0.5 for i in range(3)],[0.5 for i in range(3)])
img_norm=trans_norm(tensor_img)
writer.add_image("Normalize",img_norm)

# 重新定义大小
trans_size=transforms.Resize((512,512))
img_resize=trans_size(tensor_img)
writer.add_image("Resize",img_resize,0)


# Compose
trans_size1=transforms.Resize(512)
trans_compose=transforms.Compose([trans_size1,tensor_trans])
img_resize1=trans_compose(img_PIL)
writer.add_image("Compose",img_resize1,1)

# RandomCrop
trans_random=transforms.RandomCrop(200)
trans_compose1=transforms.Compose([trans_random,tensor_trans])
for i in range(10):
    img_crop=trans_compose1(img_PIL)
    writer.add_image("RandomCrop",img_crop,i)
    

writer.close()


tensor_img.shape,type(tensor_img)

image.png

3.2 使用cv2查看图片

import cv2  #需要安装opencv-python
cv_img=cv2.imread(img_path) #读取图片
cv_img.shape,type(cv_img) 

image.png

3.3 关于python函数中使用__call__关键字

class Person:
    def __call__(self,name):
        print("__call__"+" hello "+name)
    def hello(self,name):
        print("hello "+name)
        
person=Person()
person("QinHsiu") #使用call创建的函数直接调用
person.hello("QinHsiu") #使用定义的函数,需要添加.符号

image.png 可知__call__会在调用类的时候直接运行,而普通函数需要显示调用。

参考资料

[1] b站课程链接

[2] 手敲代码链接