yolo中标签计算及展示锚框位置

93 阅读1分钟

yolov5的说明文档中阅读时发现有一行关于训练模型的命令

python train.py --data coco.yaml --epochs 300 --weights '' --cfg yolov5n.yaml --batch-size 128

image.png

就试着在本机运行了一下,运行过程中会下载coco训练数据集,就好奇数据集中到底标记了啥,就写了一些代码

import json
import os

import cv2
import yaml

user_directory = os.path.expanduser('~')
path = rf"{user_directory}\workspace\datasets\coco"
instances_path = rf"{path}\annotations\instances_val2017.json"


def get_label_info():
    _dict = {}
    with open(instances_path, 'r') as f:
        data = json.load(f)
        categories = data.get('categories')
        for categ in categories:
            _dict[categ.get('id')] = (categ.get('supercategory'), categ.get('name'))
    return _dict


def get_label_dict():
    _dict = {}
    file_url = r"./data/coco.yaml" # https://github.com/ultralytics/yolov5/blob/master/data/coco.yaml
    with open(file_url, 'r', encoding='utf8') as f:
        data = yaml.safe_load(f)
        names = data.get('names')
        return names


categories_dict = get_label_dict()


def show(img_path, position_list, color=(0, 255, 0), thickness=1):
    image = cv2.imread(img_path)
    p_h, p_w, t = image.shape
    for po in position_list:
        t, x, y, w, h = po
        l_x = int((x - w / 2) * p_w)
        l_y = int((y - h / 2) * p_h)
        r_x = int((x + w / 2) * p_w)
        r_y = int((y + h / 2) * p_h)
        start_point = (l_x, l_y)
        end_point = (r_x, r_y)
        cv2.rectangle(image, start_point, end_point, color, thickness)

        categ = categories_dict.get(t)
        text = str(t) if categ is None else str(categ)
        org = (int(x * p_w), int(y * p_h))
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.3
        font_color = color
        cv2.putText(image, text, org, font, font_scale, font_color, thickness)

    cv2.imshow(os.path.basename(img_path), image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


def get_position(file_path):
    position_list = []
    with open(file_path, "r") as f:
        text = f.read()
        lines = text.split("\n")
        for line in lines:
            if not line:
                continue
            t, x, y, w, h = line.split(" ")
            position = (int(t), float(x), float(y), float(w), float(h))
            position_list.append(position)
    return position_list


if __name__ == "__main__":
    # label_path = rf"{path}\labels\val2017\000000000139.txt"
    # img_path = rf"{path}\images\val2017\000000000139.jpg"
    # show(img_path, get_position(label_path))
    folder_path = rf"{path}\images\val2017"
    files_and_folders = os.listdir(folder_path)
    for item in files_and_folders:
        img_path = os.path.join(folder_path, item)
        file_name = os.path.basename(img_path)
        if os.path.isfile(img_path):
            label_path = rf"{path}\labels\val2017{file_name.split('.')[0]}.txt"
            show(img_path, get_position(label_path))