Pytorch入门进行迁移学习实现自行车分类识别:获取数据集与准备数据

121 阅读3分钟

前言

迁移学习是一种机器学习方法,它利用已经训练好的模型在新任务上进行训练,从而提高模型的性能和泛化能力。在本文中,我们将使用PyTorch实现一个基于预训练模型的迁移学习模型,用于单车分类识别。

项目概述

我们的目标是创建一个能够识别不同类型自行车的图像分类模型。为实现这一目标,我们首先需要获取一个包含大量自行车图片的数据集。由于公开可用的数据集可能不完全满足特定需求,我们决定使用爬虫技术从互联网上抓取自行车图片。

数据集爬取

使用Python的requests库构建网络爬虫,从网页中提取图片。

def get_images_from_baidu(keyword, page_num, save_dir):

    header = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}

    url = 'https://image.baidu.com/search/acjson?'

    for pn in range(0, 30 * page_num, 30):

        param = {'tn': 'resultjson_com',

                 'logid': '7603311155072595725',

                 'ipn': 'rj',

                 'ct': 201326592,

                 'is': '',

                 'fp': 'result',

                 'queryWord': keyword,

                 'cl': 2,

                 'lm': -1,

                 'ie': 'utf-8',

                 'oe': 'utf-8',

                 'adpicid': '',

                 'st': -1,

                 'z': '',

                 'ic': '',

                 'hd': '',

                 'latest': '',

                 'copyright': '',

                 'word': keyword,

                 's': '',

                 'se': '',

                 'tab': '',

                 'width': '',

                 'height': '',

                 'face': 0,

                 'istype': 2,

                 'qc': '',

                 'nc': '1',

                 'fr': '',

                 'expermode': '',

                 'force': '',

                 'cg': '',    # 这个参数没公开,但是不可少

                 'pn': pn,    # 显示:30-60-90

                 'rn': '30',  # 每页显示 30 条

                 'gsm': '1e',

                 '1618827096642': ''

        request = requests.get(url=url, headers=header, params=param)

        if request.status_code == 200:

            print('Request success.')

        request.encoding = 'utf-8'

        html = request.text

        image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)

        if not os.path.exists(save_dir):

            os.makedirs(save_dir)

        for image_url in image_url_list:

            image_data = requests.get(url=image_url, headers=header).content

            with open(os.path.join(save_dir, f'{n:06}.jpg'), 'wb') as fp:

                fp.write(image_data)

我们将单车分成5类:hello(哈罗单车), meituan(美团单车), qingju(青桔单车),zijiadanche(自用单车),qita(其它单车)

5.jpg 9.jpg 78.jpg 45.jpg 6.jpg

数据清理

获取到图片后,需要对数据进行清洗,去除不相关或质量不高的图片进行过滤。

对数据进行标注

针对于不同的类别,标注数据集。

train_name_file = open("data/biycle/train.txt", "w")

test_name_file = open("data/biycle/test.txt", "w")

train_label_file = open("data/biycle/train_label.txt", "w")

test_label_file = open("data/biycle/test_label.txt", "w")

    path = root_path + name 

    file_names = get_filenames_in_folder(path)

    print(len(file_names))

    for path in file_names:

            test_name_file.write(path + '\n')    

            test_label_file.write(str(i) + '\n')

            train_name_file.write(path + '\n')

            train_label_file.write(str(i) + '\n')

编写数据加载模块

class CustomImageDataset(Dataset):

    def __init__(self, data_path, model, transform=None, target_transform=None):

        self.data_path = data_path

        self.model = model

        self.img_labels = []

        self.image_lists =[]

        self.transform = transform

        self.target_transform = target_transform

        self.obtain_label_image()

        return len(self.img_labels)

    def __getitem__(self, idx):

        #print(self.image_lists[idx])

        image = cv2.imread(self.image_lists[idx])

        image =cv2.resize(image, (32,32))

        label = self.img_labels[idx]

        if self.transform:

            image = self.transform(image)

        if self.target_transform:

            label = self.target_transform(label)

        return image, label

    def obtain_label_image(self):

        if(self.model == "train"):

            folder_path = self.data_path + 'train.txt'

            with open(folder_path, 'r') as file:

                # 逐行读取文件内容

                for line in file:

                    self.image_lists.append(line.strip())

            file_path = self.data_path + 'train_label.txt'  # 替换为实际文件路径

            with open(file_path, 'r') as file:

                # 逐行读取文件内容

                for line in file:

                    # 处理每一行的数据,例如打印或存储

                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符

        if (self.model == "test"):

            folder_path = self.data_path + 'test.txt'

            with open(folder_path, 'r') as file:

                # 逐行读取文件内容

                for line in file:

                    self.image_lists.append(line.strip())

            file_path = self.data_path + 'test_label.txt'  # 替换为实际文件路径

            with open(file_path, 'r') as file:

                # 逐行读取文件内容

                for line in file:

                    # 处理每一行的数据,例如打印或存储

                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符

总结

当我们没有数据集的时候,使用爬虫技术获取数据集,我们能快速获取自己的数据集,构建自己的数据加载模块,使得我们能够胜任不同类型的数据加载。

关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。

本文使用 文章同步助手 同步