如何在PyTorch中使用自定义图像数据集(附代码)

281 阅读7分钟

把这个项目带入生活

许多初学者在尝试使用PyTorch的自定义数据集时可能会遇到一些困难。在之前探讨了如何 策划一个自定义的图像数据集(通过网络刮取)后,本文将作为一个指南,介绍如何加载和标记自定义数据集以用于PyTorch。

创建一个自定义数据集

本节借用了关于策划数据集的文章中的代码。这里的目标是为一个模型策划一个自定义数据集,该模型将区分男士运动鞋/训练鞋和男士靴子。

为了简洁起见,我将不详细介绍代码的作用,而是提供一个快速的总结,因为我相信你一定读过前一篇文章。如果你没有,不用担心:这里又有链接。你也可以简单地运行这些代码块,你就可以为下一节做好准备了:

#  article dependencies
import cv2
import numpy as np
import os
import requests
from bs4 import BeautifulSoup
from urllib.request import urlopen
from urllib.request import Request
import time
from torch.utils.data import Dataset
import torch
from torchvision import transforms
from tqdm import tqdm

WebScraper类

下面的类包含了一些方法,这些方法将帮助我们通过使用beautifulsoup库来解析html,使用感兴趣的标签和属性来提取图片的src链接,最后从网页上下载/抓取感兴趣的图片来策划一个自定义的数据集。这些方法被相应地命名:

class WebScraper():
    def __init__(self, headers, tag: str, attribute: dict,
                src_attribute: str, filepath: str, count=0):
      self.headers = headers
      self.tag = tag
      self.attribute = attribute
      self.src_attribute = src_attribute
      self.filepath = filepath
      self.count = count
      self.bs = []
      self.interest = []

    def __str__(self):
      display = f"""      CLASS ATTRIBUTES
      headers: headers used so as to mimic requests coming from web browsers.
      tag: html tags intended for scraping.
      attribute: attributes of the html tags of interest.
      filepath: path ending with filenames to use when scraping images.
      count: numerical suffix to differentiate files in the same folder.
      bs: a list of each page's beautifulsoup elements.
      interest: a list of each page's image links."""
      return display

    def __repr__(self):
      display = f"""      CLASS ATTRIBUTES
      headers: {self.headers}
      tag: {self.tag}
      attribute: {self.attribute}
      filepath: {self.filepath}
      count: {self.count}
      bs: {self.bs}
      interest: {self.interest}"""
      return display

    def parse_html(self, url):
      """
      This method requests the webpage from the server and
      returns a beautifulsoup element
      """
      try:
        request = Request(url, headers=self.headers)
        html = urlopen(request)
        bs = BeautifulSoup(html.read(), 'html.parser')
        self.bs.append(bs)
      except Exception as e:
        print(f'problem with webpage\n{e}')
      pass

    def extract_src(self):
      """
      This method extracts tags of interest from the webpage's
      html
      """
      #  extracting tag of interest
      interest = self.bs[-1].find_all(self.tag, attrs=self.attribute)
      interest = [listing[self.src_attribute] for listing in interest]
      self.interest.append(interest)
      pass
    
    def scrape_images(self):
      """
      This method grabs images located in the src links and
      saves them as required
      """
      for link in tqdm(self.interest[-1]):
        try:
          with open(f'{self.filepath}_{self.count}.jpg', 'wb') as f:
            response = requests.get(link)
            image = response.content
            f.write(image)
            self.count+=1
            time.sleep(0.4)
        except Exception as e:
          print(f'problem with image\n{e}')
          time.sleep(0.4)
      pass

刮削功能

为了使用我们的网络搜刮器迭代多个页面,我们需要将其封装在一个函数中,使其能够做到这一点。下面的函数就是为此而写的,因为它包含了格式化为f-string的感兴趣的url,这将使url中包含的页面引用被迭代。

def my_scraper(scraper, page_range: list):
    """
    This function wraps around the web scraper class allowing it to scrape
    multiple pages. The argument page_range takes both a list of two elements
    to define a range of pages or a list of one element to define a single page.
    """
    if len(page_range) > 1:
      for i in range(page_range[0], page_range[1] + 1):
        scraper.parse_html(url=f'https://www.jumia.com.ng/mlp-fashion-deals/mens-athletic-shoes/?page={i}#catalog-listing')
        scraper.extract_src()
        scraper.scrape_images()
        print(f'\npage {i} done.')
      print('All Done!')
    else:
      scraper.parse_html(url=f'https://www.jumia.com.ng/mlp-fashion-deals/mens-athletic-shoes/?page={page_range[0]}#catalog-listing')
      scraper.extract_src()
      scraper.scrape_images()
      print('\nAll Done!')
    pass

创建目录

由于我们的目标是策划一个男鞋的数据集,所以我们需要创建目录来达到这个效果。为了整洁起见,我们在根目录下创建一个名为 "鞋子 "的父目录,这个目录包含两个子目录,分别名为 "运动鞋 "和 "靴子",它们将存放相应的图片。

#  create directories to hold images
os.mkdir('shoes')
os.mkdir('shoes/athletic')
os.mkdir('shoes/boots')

搜集图片

把这个项目带入生活

首先,我们需要为我们的网络搜刮器定义一个适当的头。头部有助于掩盖搜刮器,因为它模拟了来自实际网络浏览器的请求。之后,我们可以使用我们定义的头、我们想要提取图片的标签(img)、感兴趣的标签属性(class: img)、保存图片链接的属性(data-src)、以文件名结尾的感兴趣的文件路径,以及文件名中包含的计数前缀的起点,来实例化一个运动鞋图片的搜刮器。然后,我们可以将运动型的搜刮器传递给my_scraper函数,因为它已经包含了与运动鞋有关的URL。

headers = {'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11',
          'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
          'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
          'Accept-Encoding': 'none',
          'Accept-Language': 'en-US,en;q=0.8',
          'Connection': 'keep-alive'}
#  scrape athletic shoe images
athletic_scraper = WebScraper(headers=headers, tag='img', attribute = {'class':'img'},
                              src_attribute='data-src', filepath='shoes/athletic/atl', count=0)
                        
my_scraper(scraper=athletic_scraper, page_range=[1, 3])

为了搜刮靴子的图片,复制下面注释中的两个urls并替换my_scraper函数中的当前urls。靴子搜刮器的实例化方式与运动鞋搜刮器相同,并提供给my_scraper函数,以便搜刮靴子图片:

#  replace the urls in the my scraper function with the urls below
#  first url:
#  f'https://www.jumia.com.ng/mlp-fashion-deals/mens-boots/?page={i}#catalog-listing'
#  second url:
#  f'https://www.jumia.com.ng/mlp-fashion-deals/mens-boots/?page={page_range[0]}#catalog-listing'
#  rerun my_scraper function code cell

#  scrape boot images
boot_scraper = WebScraper(headers=headers, tag='img', attribute = {'class':'img'},
                          src_attribute='data-src', filepath='shoes/boots/boot', count=0)
                        
my_scraper(scraper=boot_scraper, page_range=[1, 3])

当所有这些代码单元按顺序运行后,在当前工作目录中应该创建一个名为 "shoes "的父目录。这个父目录应该包含两个子目录,分别名为 "运动员 "和 "靴子",它们将保存属于这两个类别的图像。

加载和标记图像

现在我们已经有了我们的自定义数据集,我们需要对其组成的图像产生数组表示(加载),标记数组,然后将其转换为张量,以便在PyTorch中使用。归档这将需要我们定义一个类来完成所有这些过程。下面定义的类完成了前两个步骤,它读取灰度图像,将其调整为100×100像素,然后按需要进行标注(运动鞋=[1, 0] ,靴子=[0, 1] )。*注意:从我的角度来看,我的工作目录是根目录,所以我在下面的Python类中相应地定义了文件路径,你应该根据自己的工作目录来定义文件路径:

#  defining class to load and label data
class LoadShoeData():
    """
    This class loads in data from each directory in numpy array format then saves
    loaded dataset
    """
    def __init__(self):
        self.athletic = 'shoes/athletic'
        self.boots = 'shoes/boots'
        self.labels = {self.athletic: np.eye(2, 2)[0], self.boots: np.eye(2, 2)[1]}
        self.img_size = 100
        self.dataset = []
        self.athletic_count = 0
        self.boots_count = 0

    def create_dataset(self):
        """
        This method reads images as grayscale from directories,
        resizes them and labels them as required.
        """

        #  reading from directory
        for key in self.labels:
          print(key)

          #  looping through all files in the directory
          for img_file in tqdm(os.listdir(key)):
            try:
              #  deriving image path
              path = os.path.join(key, img_file)

              #  reading image
              image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
              image = cv2.resize(image, (self.img_size, self.img_size))

              #  appending image and class label to list
              self.dataset.append([image, self.labels[key]])

              #  incrementing counter
              if key == self.athletic:
                self.athletic_count+=1
              elif key == self.boots:
                self.boots_count+=1

            except Exception as e:
              pass

        #  shuffling array of images
        np.random.shuffle(self.dataset)

        #  printing to screen
        print(f'\nathletic shoe images: {self.athletic_count}')
        print(f'boot images: {self.boots_count}')
        print(f'total: {self.athletic_count + self.boots_count}')
        print('All done!')
        return np.array(self.dataset, dtype='object')
#  load data
data = LoadShoeData()

dataset = data.create_dataset()

运行上面的代码单元应该返回一个包含自定义数据集中所有图像的NumPy数组。该数组的每个元素都是一个自己的数组,其中包含一个图像和它的标签。

创建一个PyTorch数据集

在生成了自定义数据集中所有图像和标签的数组表示后,现在是时候创建一个PyTorch数据集了。要做到这一点,我们需要定义一个继承自PyTorch数据集类的类,如下图所示:

#  extending Dataset class
class ShoeDataset(Dataset):
    def __init__(self, custom_dataset, transforms=None):
        self.custom_dataset = custom_dataset
        self.transforms = transforms

    def __len__(self):
        return len(self.custom_dataset)
    
    def __getitem__(self, idx):
        #  extracting image from index and scaling
        image = self.custom_dataset[idx][0]
        #  extracting label from index
        label = torch.tensor(self.custom_dataset[idx][1])
        #  applying transforms if transforms are supplied
        if self.transforms:
          image = self.transforms(image)
        return (image, label)

基本上,定义了两个重要的方法__len__()__getitem__()__len__() 方法返回自定义数据集的长度,而__getitem__() 方法通过索引从自定义数据集中抓取图像及其标签,如果有的话,应用转换,并返回一个元组,然后可以被PyTorch使用:

#  creating an instance of the dataset class
dataset = ShoeDataset(dataset, transforms=transforms.ToTensor())

当上面的代码单元运行时,数据集对象就变成了PyTorch的数据集,现在可以用于构建深度学习模型。

结束语

在这篇文章中,我们看了一下在PyTorch中使用自定义数据集的情况。我们通过网络搜刮策划了一个自定义数据集,对其进行了加载和标记,并从中创建了一个PyTorch数据集。

在这篇文章中,Python类的知识被发挥得淋漓尽致。大多数定义为类的过程都可以用普通函数来完成(除了PyTorch数据集类),但作为个人偏好,我选择了我所做的。在你的编程旅程中,你可以自由地尝试复制这段代码,做最适合你的事情。