《pytorch》

208 阅读2分钟

dir() 和 help()

dir()函数可以进入到一个工具箱,并且打开,看看里边都有什么分隔区,都有哪些工具。

help()函数可以查看具体某个工具的使用方法

比如:看看torch里都有哪些工具 dir(torch)

dir(torch.cuda)

dir(torch.cuda.is_available())

这时就是一个可以用的函数了,于是:help(torch.cuda.is_available)查看具体的使用方法。

Dataset类

表示数据集,用于获取每一个数据和对应的标签,以及知道一共有多少数据。

使用

先在控制台看看它是什么,怎么用:

from torch.utils.data import Dataset
help(Dataset)

意思是Dataset是一个类,所有的数据集都是他的子类,所有子类都应该重写__getitem__ 方法,用于获取数据和标签。

获取数据及标签

首先把下载好的数据集放到项目里:

标签就是文件名ants。

在控制台:

from PIL import Image    
先得到图片的路径,由于在win里,要多加一个\转义
img_path="D:\\myPycharm\\hello\\dataset\\train\\ants\\0013035.jpg"
用这个路径得到图片img
img=Image.open(img_path)
就可以使用img的属性了:
img.size
(768, 512)

img.tile
[('jpeg', (0, 0, 768, 512), 0, ('RGB', ''))]
展示出这张图片
img.show()

所以要先拿到每一张图片的路径。怎么拿?先拿到所有图片路径的列表,然后根据索引idx,找到每一张图片的路径。这里需要import os

在控制台演示:

dir_path="dataset\\train\\ants"
import os
img_path_list=os.listdir(dir_path)

可以看到这个list里的每一项就是每张图片的地址。

python文件

具体在python文件里怎么写呢?

要想拿到每一张图片地址,先拿到所有图片所在的目录,然后根据目录地址拿到所有图片的路径列表,然后根据索引idx找到具体每一个图片。

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)

    def __getitem__(self,idx):
        img_name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)

root_dir="dataset\\train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label_dir)
bees_dataset=MyData(root_dir,bees_label_dir)

img,label=ants_dataset[0]
img.show()

train_dataset=ants_dataset+bees_dataset

TensorBoard的使用

1.安装

在终端:

pip install tensorboard

2.

from torch.utils.tensorboard import SummaryWriter

SummaryWriter是一个类,用于像一个目录写入事件文件。用的时候可以传一个目录。如:writer=SummaryWriter("logs")创建一个实例

主要会用到实例的三种方法:

writer.add_image()
writer.add_scalar()
writer.close()

add_scalar() 用于给summary 添加数据

需要三个参数:tag可以认为是标题,第二个可以认为是纵坐标,然后是横坐标

from torch.utils.tensorboard import SummaryWriter

writer=SummaryWriter("logs")

#writer.add_image()
#y=2x
for i in range(100):
    writer.add_scalar("y=2x",2*i,i)
writer.close()

运行后,项目会出现一个logs目录,里边会生成一个文件,怎么看呢?在终端运行:tensorboard --logdir=logs,然后点击链接会跳转到一个网页,

可以看到结果