生成训练验证测试txt文件
- voc_annotation.py
#--------------------------------------------------------------------------------------------------------------------------------#
# annotation_mode用于指定该文件运行时计算的内容
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
# 这个是我们生成的文件的类型,指定不同值,我们会生成不同的文件
annotation_mode = 0
#-------------------------------------------------------------------#
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
# 与训练和预测所用的classes_path一致即可
# 如果生成的2007_train.txt里面没有目标信息
# 那么就是因为classes没有设定正确
# 仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
# 指定类名称文件的位置
classes_path = 'model_data/voc_classes.txt'
#--------------------------------------------------------------------------------------------------------------------------------#
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
# 仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
# trainval是训练+验证的总合集
# train是训练的部分
trainval_percent = 0.9
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
# 指定数据的根文件夹
VOCdevkit_path = 'VOCdevkit'
# 在下面循环的时候使用
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
# 从上面的class_path中读取文件信息,获得class_name的列表
classes, _ = get_classes(classes_path)
def convert_annotation(year, image_id, list_file):
# 将xml中的数据转化成x1,y1,x2,y2,class_num多组的格式
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
# xml解析数据
tree=ET.parse(in_file)
root = tree.getroot()
for obj in root.iter('object'):
difficult = 0
if obj.find('difficult')!=None:
difficult = obj.find('difficult').text
cls = obj.find('name').text
# 只是注意一下这个地方我们将不在我们class_list中的和标记为difficult的去掉
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
if __name__ == "__main__":
random.seed(0)
# 在上面的模式选择0或者1的时候
if annotation_mode == 0 or annotation_mode == 1:
print("Generate txt in ImageSets.")
# 获取xml文件的文件路径
xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
# 选择我们要保存的文件的路径
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
# 将xml文件夹下的所有文件都列出来,形成一个列表
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
# 按照上面我们我们定义的比例进行数据划分
num = len(total_xml)
list = range(num)
# trainval 占整个数据的90%,train占trainval的90%
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
# 我们将随机抽取数据集的90%当成是trainval的数据集,这个地方是将下标放到一个列表中
trainval= random.sample(list,tv)
# 我们将随机抽取trainval数据集的90%当成是train的数据集
train = random.sample(trainval,tr)
print("train and val size",tv)
print("train size",tr)
# 打开四个文件,将按比例分好的数据放到文件中
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
# 我们对下标列表进行循环
for i in list:
# 将文件名去除.xml写入文件
name=total_xml[i][:-4]+'\n'
# 在trainval文件中的
if i in trainval:
# 写入trainval文件
ftrainval.write(name)
# 在train中的写入train文件
if i in train:
ftrain.write(name)
# 不在的写入val文件
else:
fval.write(name)
# 不在trainval中的10%的文件写入test
else:
ftest.write(name)
# 关闭文件
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
# 当时模式0或者2的时候
if annotation_mode == 0 or annotation_mode == 2:
print("Generate 2007_train.txt and 2007_val.txt for train.")
# 我们在更目录下创建两个文件(year一般包括两个2007 or 2012,我们这里只有2007)
for year, image_set in VOCdevkit_sets:
# 从我们上面构建好的train.txt和val.txt的获取image的id列表
image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
# 打开文件
list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
# 循环上面的image_id的列表
for image_id in image_ids:
# 写入文件的路径信息
list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
# 将xml中的数据转化成x1,y1,x2,y2,class_num多组的格式
convert_annotation(year, image_id, list_file)
# 记录一条数据加上一个换行符
list_file.write('\n')
# 关闭文件
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")
分文件就不进行展示了,只展示2007_train.txt的效果
查看网络的层次结构和参数数量
- summary.py
import torch
from torchsummary import summary
from nets.yolo import YoloBody
if __name__ == "__main__":
# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
m = YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 80).to(device)
summary(m, input_size=(3, 416, 416))
网络的结构如下
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 32, 416, 416] 864
BatchNorm2d-2 [-1, 32, 416, 416] 64
LeakyReLU-3 [-1, 32, 416, 416] 0
Conv2d-4 [-1, 64, 208, 208] 18,432
BatchNorm2d-5 [-1, 64, 208, 208] 128
LeakyReLU-6 [-1, 64, 208, 208] 0
Conv2d-7 [-1, 32, 208, 208] 2,048
BatchNorm2d-8 [-1, 32, 208, 208] 64
LeakyReLU-9 [-1, 32, 208, 208] 0
Conv2d-10 [-1, 64, 208, 208] 18,432
BatchNorm2d-11 [-1, 64, 208, 208] 128
LeakyReLU-12 [-1, 64, 208, 208] 0
BasicBlock-13 [-1, 64, 208, 208] 0
Conv2d-14 [-1, 128, 104, 104] 73,728
BatchNorm2d-15 [-1, 128, 104, 104] 256
LeakyReLU-16 [-1, 128, 104, 104] 0
Conv2d-17 [-1, 64, 104, 104] 8,192
BatchNorm2d-18 [-1, 64, 104, 104] 128
LeakyReLU-19 [-1, 64, 104, 104] 0
Conv2d-20 [-1, 128, 104, 104] 73,728
BatchNorm2d-21 [-1, 128, 104, 104] 256
LeakyReLU-22 [-1, 128, 104, 104] 0
BasicBlock-23 [-1, 128, 104, 104] 0
Conv2d-24 [-1, 64, 104, 104] 8,192
BatchNorm2d-25 [-1, 64, 104, 104] 128
LeakyReLU-26 [-1, 64, 104, 104] 0
Conv2d-27 [-1, 128, 104, 104] 73,728
BatchNorm2d-28 [-1, 128, 104, 104] 256
LeakyReLU-29 [-1, 128, 104, 104] 0
BasicBlock-30 [-1, 128, 104, 104] 0
Conv2d-31 [-1, 256, 52, 52] 294,912
BatchNorm2d-32 [-1, 256, 52, 52] 512
LeakyReLU-33 [-1, 256, 52, 52] 0
Conv2d-34 [-1, 128, 52, 52] 32,768
BatchNorm2d-35 [-1, 128, 52, 52] 256
LeakyReLU-36 [-1, 128, 52, 52] 0
Conv2d-37 [-1, 256, 52, 52] 294,912
BatchNorm2d-38 [-1, 256, 52, 52] 512
LeakyReLU-39 [-1, 256, 52, 52] 0
BasicBlock-40 [-1, 256, 52, 52] 0
Conv2d-41 [-1, 128, 52, 52] 32,768
BatchNorm2d-42 [-1, 128, 52, 52] 256
LeakyReLU-43 [-1, 128, 52, 52] 0
Conv2d-44 [-1, 256, 52, 52] 294,912
BatchNorm2d-45 [-1, 256, 52, 52] 512
LeakyReLU-46 [-1, 256, 52, 52] 0
BasicBlock-47 [-1, 256, 52, 52] 0
Conv2d-48 [-1, 128, 52, 52] 32,768
BatchNorm2d-49 [-1, 128, 52, 52] 256
LeakyReLU-50 [-1, 128, 52, 52] 0
Conv2d-51 [-1, 256, 52, 52] 294,912
BatchNorm2d-52 [-1, 256, 52, 52] 512
LeakyReLU-53 [-1, 256, 52, 52] 0
BasicBlock-54 [-1, 256, 52, 52] 0
Conv2d-55 [-1, 128, 52, 52] 32,768
BatchNorm2d-56 [-1, 128, 52, 52] 256
LeakyReLU-57 [-1, 128, 52, 52] 0
Conv2d-58 [-1, 256, 52, 52] 294,912
BatchNorm2d-59 [-1, 256, 52, 52] 512
LeakyReLU-60 [-1, 256, 52, 52] 0
BasicBlock-61 [-1, 256, 52, 52] 0
Conv2d-62 [-1, 128, 52, 52] 32,768
BatchNorm2d-63 [-1, 128, 52, 52] 256
LeakyReLU-64 [-1, 128, 52, 52] 0
Conv2d-65 [-1, 256, 52, 52] 294,912
BatchNorm2d-66 [-1, 256, 52, 52] 512
LeakyReLU-67 [-1, 256, 52, 52] 0
BasicBlock-68 [-1, 256, 52, 52] 0
Conv2d-69 [-1, 128, 52, 52] 32,768
BatchNorm2d-70 [-1, 128, 52, 52] 256
LeakyReLU-71 [-1, 128, 52, 52] 0
Conv2d-72 [-1, 256, 52, 52] 294,912
BatchNorm2d-73 [-1, 256, 52, 52] 512
LeakyReLU-74 [-1, 256, 52, 52] 0
BasicBlock-75 [-1, 256, 52, 52] 0
Conv2d-76 [-1, 128, 52, 52] 32,768
BatchNorm2d-77 [-1, 128, 52, 52] 256
LeakyReLU-78 [-1, 128, 52, 52] 0
Conv2d-79 [-1, 256, 52, 52] 294,912
BatchNorm2d-80 [-1, 256, 52, 52] 512
LeakyReLU-81 [-1, 256, 52, 52] 0
BasicBlock-82 [-1, 256, 52, 52] 0
Conv2d-83 [-1, 128, 52, 52] 32,768
BatchNorm2d-84 [-1, 128, 52, 52] 256
LeakyReLU-85 [-1, 128, 52, 52] 0
Conv2d-86 [-1, 256, 52, 52] 294,912
BatchNorm2d-87 [-1, 256, 52, 52] 512
LeakyReLU-88 [-1, 256, 52, 52] 0
BasicBlock-89 [-1, 256, 52, 52] 0
Conv2d-90 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-91 [-1, 512, 26, 26] 1,024
LeakyReLU-92 [-1, 512, 26, 26] 0
Conv2d-93 [-1, 256, 26, 26] 131,072
BatchNorm2d-94 [-1, 256, 26, 26] 512
LeakyReLU-95 [-1, 256, 26, 26] 0
Conv2d-96 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-97 [-1, 512, 26, 26] 1,024
LeakyReLU-98 [-1, 512, 26, 26] 0
BasicBlock-99 [-1, 512, 26, 26] 0
Conv2d-100 [-1, 256, 26, 26] 131,072
BatchNorm2d-101 [-1, 256, 26, 26] 512
LeakyReLU-102 [-1, 256, 26, 26] 0
Conv2d-103 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-104 [-1, 512, 26, 26] 1,024
LeakyReLU-105 [-1, 512, 26, 26] 0
BasicBlock-106 [-1, 512, 26, 26] 0
Conv2d-107 [-1, 256, 26, 26] 131,072
BatchNorm2d-108 [-1, 256, 26, 26] 512
LeakyReLU-109 [-1, 256, 26, 26] 0
Conv2d-110 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-111 [-1, 512, 26, 26] 1,024
LeakyReLU-112 [-1, 512, 26, 26] 0
BasicBlock-113 [-1, 512, 26, 26] 0
Conv2d-114 [-1, 256, 26, 26] 131,072
BatchNorm2d-115 [-1, 256, 26, 26] 512
LeakyReLU-116 [-1, 256, 26, 26] 0
Conv2d-117 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-118 [-1, 512, 26, 26] 1,024
LeakyReLU-119 [-1, 512, 26, 26] 0
BasicBlock-120 [-1, 512, 26, 26] 0
Conv2d-121 [-1, 256, 26, 26] 131,072
BatchNorm2d-122 [-1, 256, 26, 26] 512
LeakyReLU-123 [-1, 256, 26, 26] 0
Conv2d-124 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-125 [-1, 512, 26, 26] 1,024
LeakyReLU-126 [-1, 512, 26, 26] 0
BasicBlock-127 [-1, 512, 26, 26] 0
Conv2d-128 [-1, 256, 26, 26] 131,072
BatchNorm2d-129 [-1, 256, 26, 26] 512
LeakyReLU-130 [-1, 256, 26, 26] 0
Conv2d-131 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-132 [-1, 512, 26, 26] 1,024
LeakyReLU-133 [-1, 512, 26, 26] 0
BasicBlock-134 [-1, 512, 26, 26] 0
Conv2d-135 [-1, 256, 26, 26] 131,072
BatchNorm2d-136 [-1, 256, 26, 26] 512
LeakyReLU-137 [-1, 256, 26, 26] 0
Conv2d-138 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-139 [-1, 512, 26, 26] 1,024
LeakyReLU-140 [-1, 512, 26, 26] 0
BasicBlock-141 [-1, 512, 26, 26] 0
Conv2d-142 [-1, 256, 26, 26] 131,072
BatchNorm2d-143 [-1, 256, 26, 26] 512
LeakyReLU-144 [-1, 256, 26, 26] 0
Conv2d-145 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-146 [-1, 512, 26, 26] 1,024
LeakyReLU-147 [-1, 512, 26, 26] 0
BasicBlock-148 [-1, 512, 26, 26] 0
Conv2d-149 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-150 [-1, 1024, 13, 13] 2,048
LeakyReLU-151 [-1, 1024, 13, 13] 0
Conv2d-152 [-1, 512, 13, 13] 524,288
BatchNorm2d-153 [-1, 512, 13, 13] 1,024
LeakyReLU-154 [-1, 512, 13, 13] 0
Conv2d-155 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-156 [-1, 1024, 13, 13] 2,048
LeakyReLU-157 [-1, 1024, 13, 13] 0
BasicBlock-158 [-1, 1024, 13, 13] 0
Conv2d-159 [-1, 512, 13, 13] 524,288
BatchNorm2d-160 [-1, 512, 13, 13] 1,024
LeakyReLU-161 [-1, 512, 13, 13] 0
Conv2d-162 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-163 [-1, 1024, 13, 13] 2,048
LeakyReLU-164 [-1, 1024, 13, 13] 0
BasicBlock-165 [-1, 1024, 13, 13] 0
Conv2d-166 [-1, 512, 13, 13] 524,288
BatchNorm2d-167 [-1, 512, 13, 13] 1,024
LeakyReLU-168 [-1, 512, 13, 13] 0
Conv2d-169 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-170 [-1, 1024, 13, 13] 2,048
LeakyReLU-171 [-1, 1024, 13, 13] 0
BasicBlock-172 [-1, 1024, 13, 13] 0
Conv2d-173 [-1, 512, 13, 13] 524,288
BatchNorm2d-174 [-1, 512, 13, 13] 1,024
LeakyReLU-175 [-1, 512, 13, 13] 0
Conv2d-176 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-177 [-1, 1024, 13, 13] 2,048
LeakyReLU-178 [-1, 1024, 13, 13] 0
BasicBlock-179 [-1, 1024, 13, 13] 0
DarkNet-180 [[-1, 256, 52, 52], [-1, 512, 26, 26], [-1, 1024, 13, 13]] 0
Conv2d-181 [-1, 512, 13, 13] 524,288
BatchNorm2d-182 [-1, 512, 13, 13] 1,024
LeakyReLU-183 [-1, 512, 13, 13] 0
Conv2d-184 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-185 [-1, 1024, 13, 13] 2,048
LeakyReLU-186 [-1, 1024, 13, 13] 0
Conv2d-187 [-1, 512, 13, 13] 524,288
BatchNorm2d-188 [-1, 512, 13, 13] 1,024
LeakyReLU-189 [-1, 512, 13, 13] 0
Conv2d-190 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-191 [-1, 1024, 13, 13] 2,048
LeakyReLU-192 [-1, 1024, 13, 13] 0
Conv2d-193 [-1, 512, 13, 13] 524,288
BatchNorm2d-194 [-1, 512, 13, 13] 1,024
LeakyReLU-195 [-1, 512, 13, 13] 0
Conv2d-196 [-1, 1024, 13, 13] 4,718,592
BatchNorm2d-197 [-1, 1024, 13, 13] 2,048
LeakyReLU-198 [-1, 1024, 13, 13] 0
Conv2d-199 [-1, 255, 13, 13] 261,375
Conv2d-200 [-1, 256, 13, 13] 131,072
BatchNorm2d-201 [-1, 256, 13, 13] 512
LeakyReLU-202 [-1, 256, 13, 13] 0
Upsample-203 [-1, 256, 26, 26] 0
Conv2d-204 [-1, 256, 26, 26] 196,608
BatchNorm2d-205 [-1, 256, 26, 26] 512
LeakyReLU-206 [-1, 256, 26, 26] 0
Conv2d-207 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-208 [-1, 512, 26, 26] 1,024
LeakyReLU-209 [-1, 512, 26, 26] 0
Conv2d-210 [-1, 256, 26, 26] 131,072
BatchNorm2d-211 [-1, 256, 26, 26] 512
LeakyReLU-212 [-1, 256, 26, 26] 0
Conv2d-213 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-214 [-1, 512, 26, 26] 1,024
LeakyReLU-215 [-1, 512, 26, 26] 0
Conv2d-216 [-1, 256, 26, 26] 131,072
BatchNorm2d-217 [-1, 256, 26, 26] 512
LeakyReLU-218 [-1, 256, 26, 26] 0
Conv2d-219 [-1, 512, 26, 26] 1,179,648
BatchNorm2d-220 [-1, 512, 26, 26] 1,024
LeakyReLU-221 [-1, 512, 26, 26] 0
Conv2d-222 [-1, 255, 26, 26] 130,815
Conv2d-223 [-1, 128, 26, 26] 32,768
BatchNorm2d-224 [-1, 128, 26, 26] 256
LeakyReLU-225 [-1, 128, 26, 26] 0
Upsample-226 [-1, 128, 52, 52] 0
Conv2d-227 [-1, 128, 52, 52] 49,152
BatchNorm2d-228 [-1, 128, 52, 52] 256
LeakyReLU-229 [-1, 128, 52, 52] 0
Conv2d-230 [-1, 256, 52, 52] 294,912
BatchNorm2d-231 [-1, 256, 52, 52] 512
LeakyReLU-232 [-1, 256, 52, 52] 0
Conv2d-233 [-1, 128, 52, 52] 32,768
BatchNorm2d-234 [-1, 128, 52, 52] 256
LeakyReLU-235 [-1, 128, 52, 52] 0
Conv2d-236 [-1, 256, 52, 52] 294,912
BatchNorm2d-237 [-1, 256, 52, 52] 512
LeakyReLU-238 [-1, 256, 52, 52] 0
Conv2d-239 [-1, 128, 52, 52] 32,768
BatchNorm2d-240 [-1, 128, 52, 52] 256
LeakyReLU-241 [-1, 128, 52, 52] 0
Conv2d-242 [-1, 256, 52, 52] 294,912
BatchNorm2d-243 [-1, 256, 52, 52] 512
LeakyReLU-244 [-1, 256, 52, 52] 0
Conv2d-245 [-1, 255, 52, 52] 65,535
================================================================
Total params: 61,949,149
Trainable params: 61,949,149
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.98
Forward/backward pass size (MB): 998.13
Params size (MB): 236.32
Estimated Total Size (MB): 1236.43
----------------------------------------------------------------
我们发现参数总量在6200万左右,评估总的内存消耗1.23G,还是比较复杂的网络。
聚类得到框的大小
- 在yolov2中,其中之一的改进方法就是得到的框的方式是使用聚类的方法
- kmeans_for_anchors.py
def cas_iou(box, cluster):
# 计算iou,已经分析了很多次了,不在赘述
x = np.minimum(cluster[:, 0], box[0])
y = np.minimum(cluster[:, 1], box[1])
intersection = x * y
area1 = box[0] * box[1]
area2 = cluster[:,0] * cluster[:,1]
iou = intersection / (area1 + area2 - intersection)
return iou
def avg_iou(box, cluster):
return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0])])
def kmeans(box, k):
#-------------------------------------------------------------#
# 取出一共有多少框
#-------------------------------------------------------------#
row = box.shape[0]
#-------------------------------------------------------------#
# 每个框各个点的位置
#-------------------------------------------------------------#
distance = np.empty((row, k))
#-------------------------------------------------------------#
# 最后的聚类位置
#-------------------------------------------------------------#
last_clu = np.zeros((row, ))
np.random.seed()
#-------------------------------------------------------------#
# 随机选9个当聚类中心
#-------------------------------------------------------------#
cluster = box[np.random.choice(row, k, replace = False)]
iter = 0
while True:
#-------------------------------------------------------------#
# 计算当前框和先验框的宽高比例
#-------------------------------------------------------------#
for i in range(row):
# 我们这个部分使用的是 1 - iou的方式作为距离的度量
distance[i] = 1 - cas_iou(box[i], cluster)
#-------------------------------------------------------------#
# 取出最小点
#-------------------------------------------------------------#
# 当收敛稳定的时候,可以提前跳出循环
near = np.argmin(distance, axis=1)
if (last_clu == near).all():
break
#-------------------------------------------------------------#
# 求每一个类的中位点
#-------------------------------------------------------------#
# 我们要得到每个类的聚类的中心
for j in range(k):
cluster[j] = np.median(box[near == j],axis=0)
#
last_clu = near
if iter % 5 == 0:
print('iter: {:d}. avg_iou:{:.2f}'.format(iter, avg_iou(box, cluster)))
iter += 1
# 返回的是聚类中心,和最小的距离
return cluster, near
def load_data(path):
data = []
#-------------------------------------------------------------#
# 对于每一个xml都寻找box
#-------------------------------------------------------------#
# 对数据进行循环得到宽度和左上角和右下角的坐标
for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
tree = ET.parse(xml_file)
height = int(tree.findtext('./size/height'))
width = int(tree.findtext('./size/width'))
if height<=0 or width<=0:
continue
#-------------------------------------------------------------#
# 对于每一个目标都获得它的宽高
#-------------------------------------------------------------#
for obj in tree.iter('object'):
# 对坐标进行归一化分析
xmin = int(float(obj.findtext('bndbox/xmin'))) / width
ymin = int(float(obj.findtext('bndbox/ymin'))) / height
xmax = int(float(obj.findtext('bndbox/xmax'))) / width
ymax = int(float(obj.findtext('bndbox/ymax'))) / height
xmin = np.float64(xmin)
ymin = np.float64(ymin)
xmax = np.float64(xmax)
ymax = np.float64(ymax)
# 得到归一化后的框的宽高
data.append([xmax - xmin, ymax - ymin])
return np.array(data)
if __name__ == '__main__':
np.random.seed(0)
#-------------------------------------------------------------#
# 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
# 会生成yolo_anchors.txt
#-------------------------------------------------------------#
# 我们输入的图片的大小和框的个数信息
input_shape = [416, 416]
anchors_num = 9
#-------------------------------------------------------------#
# 载入数据集,可以使用VOC的xml
#-------------------------------------------------------------#
# 我们的框xml文件的路径
path = 'VOCdevkit/VOC2007/Annotations'
#-------------------------------------------------------------#
# 载入所有的xml
# 存储格式为转化为比例后的width,height
#-------------------------------------------------------------#
print('Load xmls.')
# 加载数据
data = load_data(path)
print('Load xmls done.')
#-------------------------------------------------------------#
# 使用k聚类算法
#-------------------------------------------------------------#
print('K-means boxes.')
# 我们使用kmeans对我们上面的归一化的进行聚类
cluster, near = kmeans(data, anchors_num)
print('K-means boxes done.')
# 将归一化的宽高信息进行还原,和输入图片的指定大小(416,416)一致,data的还原
data = data * np.array([input_shape[1], input_shape[0]])
# 聚类中心的归一化转化现实坐标
cluster = cluster * np.array([input_shape[1], input_shape[0]])
#-------------------------------------------------------------#
# 绘图
#-------------------------------------------------------------#
# 将我们的框以长宽为横纵坐标,进行点的绘制,
for j in range(anchors_num):
# 将所有和当前框聚到一类的框的绘制
plt.scatter(data[near == j][:,0], data[near == j][:,1])
# 中心点的绘制
plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
# 保存和图片的展示
plt.savefig("kmeans_for_anchors.jpg")
plt.show()
print('Save kmeans_for_anchors.jpg in root dir.')
# 对聚类中心按照面积进行排序
cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])]
print('avg_ratio:{:.2f}'.format(avg_iou(data, cluster)))
print(cluster)
# 将上面排好序的聚类中心写入文件
f = open("yolo_anchors.txt", 'w')
row = np.shape(cluster)[0]
for i in range(row):
if i == 0:
x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
else:
x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
f.write(x_y)
f.close()
运行完,我们能得到的如下的图像