在 yolov5的说明文档中阅读时发现有一行关于训练模型的命令
python train.py --data coco.yaml --epochs 300 --weights '' --cfg yolov5n.yaml --batch-size 128
就试着在本机运行了一下,运行过程中会下载coco训练数据集,就好奇数据集中到底标记了啥,就写了一些代码
import json
import os
import cv2
import yaml
user_directory = os.path.expanduser('~')
path = rf"{user_directory}\workspace\datasets\coco"
instances_path = rf"{path}\annotations\instances_val2017.json"
def get_label_info():
_dict = {}
with open(instances_path, 'r') as f:
data = json.load(f)
categories = data.get('categories')
for categ in categories:
_dict[categ.get('id')] = (categ.get('supercategory'), categ.get('name'))
return _dict
def get_label_dict():
_dict = {}
file_url = r"./data/coco.yaml" # https://github.com/ultralytics/yolov5/blob/master/data/coco.yaml
with open(file_url, 'r', encoding='utf8') as f:
data = yaml.safe_load(f)
names = data.get('names')
return names
categories_dict = get_label_dict()
def show(img_path, position_list, color=(0, 255, 0), thickness=1):
image = cv2.imread(img_path)
p_h, p_w, t = image.shape
for po in position_list:
t, x, y, w, h = po
l_x = int((x - w / 2) * p_w)
l_y = int((y - h / 2) * p_h)
r_x = int((x + w / 2) * p_w)
r_y = int((y + h / 2) * p_h)
start_point = (l_x, l_y)
end_point = (r_x, r_y)
cv2.rectangle(image, start_point, end_point, color, thickness)
categ = categories_dict.get(t)
text = str(t) if categ is None else str(categ)
org = (int(x * p_w), int(y * p_h))
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.3
font_color = color
cv2.putText(image, text, org, font, font_scale, font_color, thickness)
cv2.imshow(os.path.basename(img_path), image)
cv2.waitKey(0)
cv2.destroyAllWindows()
def get_position(file_path):
position_list = []
with open(file_path, "r") as f:
text = f.read()
lines = text.split("\n")
for line in lines:
if not line:
continue
t, x, y, w, h = line.split(" ")
position = (int(t), float(x), float(y), float(w), float(h))
position_list.append(position)
return position_list
if __name__ == "__main__":
# label_path = rf"{path}\labels\val2017\000000000139.txt"
# img_path = rf"{path}\images\val2017\000000000139.jpg"
# show(img_path, get_position(label_path))
folder_path = rf"{path}\images\val2017"
files_and_folders = os.listdir(folder_path)
for item in files_and_folders:
img_path = os.path.join(folder_path, item)
file_name = os.path.basename(img_path)
if os.path.isfile(img_path):
label_path = rf"{path}\labels\val2017{file_name.split('.')[0]}.txt"
show(img_path, get_position(label_path))