YOLOv3的源代码精度理解(十) 对辅助文件进行解读

555 阅读19分钟

生成训练验证测试txt文件

  • voc_annotation.py
#--------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#

# 这个是我们生成的文件的类型,指定不同值,我们会生成不同的文件
annotation_mode     = 0
#-------------------------------------------------------------------#
#   必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
#   与训练和预测所用的classes_path一致即可
#   如果生成的2007_train.txt里面没有目标信息
#   那么就是因为classes没有设定正确
#   仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#

# 指定类名称文件的位置
classes_path        = 'model_data/voc_classes.txt'
#--------------------------------------------------------------------------------------------------------------------------------#
#   trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
#   train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1  
#   仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#

# trainval是训练+验证的总合集
# train是训练的部分
trainval_percent    = 0.9
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#

# 指定数据的根文件夹
VOCdevkit_path  = 'VOCdevkit'

# 在下面循环的时候使用
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]

# 从上面的class_path中读取文件信息,获得class_name的列表
classes, _ = get_classes(classes_path)

def convert_annotation(year, image_id, list_file):
    # 将xml中的数据转化成x1,y1,x2,y2,class_num多组的格式
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    
    # xml解析数据
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        
        # 只是注意一下这个地方我们将不在我们class_list中的和标记为difficult的去掉
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
        
if __name__ == "__main__":
    random.seed(0)
    # 在上面的模式选择0或者1的时候
    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        
        # 获取xml文件的文件路径
        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        
        # 选择我们要保存的文件的路径
        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        
        # 将xml文件夹下的所有文件都列出来,形成一个列表
        temp_xml        = os.listdir(xmlfilepath)
        total_xml       = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)
		
        # 按照上面我们我们定义的比例进行数据划分
        num     = len(total_xml)  
        list    = range(num)
        # trainval 占整个数据的90%,train占trainval的90%
        tv      = int(num*trainval_percent)  
        tr      = int(tv*train_percent)
        # 我们将随机抽取数据集的90%当成是trainval的数据集,这个地方是将下标放到一个列表中
        trainval= random.sample(list,tv)
        # 我们将随机抽取trainval数据集的90%当成是train的数据集
        train   = random.sample(trainval,tr)  
        
        print("train and val size",tv)
        print("train size",tr)
        # 打开四个文件,将按比例分好的数据放到文件中
        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
        
        # 我们对下标列表进行循环
        for i in list:
            # 将文件名去除.xml写入文件
            name=total_xml[i][:-4]+'\n'
            
            # 在trainval文件中的
            if i in trainval: 
                # 写入trainval文件
                ftrainval.write(name)
                # 在train中的写入train文件
                if i in train:  
                    ftrain.write(name) 
                # 不在的写入val文件
                else:  
                    fval.write(name)  
           # 不在trainval中的10%的文件写入test
            else:  
                ftest.write(name)  
        
        # 关闭文件
        ftrainval.close()  
        ftrain.close()  
        fval.close()  
        ftest.close()
        print("Generate txt in ImageSets done.")
	
    # 当时模式0或者2的时候
    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        # 我们在更目录下创建两个文件(year一般包括两个2007 or 2012,我们这里只有2007)
        for year, image_set in VOCdevkit_sets:
            # 从我们上面构建好的train.txt和val.txt的获取image的id列表
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            # 打开文件
            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
            
            # 循环上面的image_id的列表
            for image_id in image_ids:
                # 写入文件的路径信息
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))
				# 将xml中的数据转化成x1,y1,x2,y2,class_num多组的格式
                convert_annotation(year, image_id, list_file)
                # 记录一条数据加上一个换行符
                list_file.write('\n')
            # 关闭文件
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")

分文件就不进行展示了,只展示2007_train.txt的效果

image.png

查看网络的层次结构和参数数量

  • summary.py
import torch
from torchsummary import summary

from nets.yolo import YoloBody

if __name__ == "__main__":
    # 需要使用device来指定网络在GPU还是CPU运行
    device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    m       = YoloBody([[6, 7, 8], [3, 4, 5], [0, 1, 2]], 80).to(device)
    summary(m, input_size=(3, 416, 416))

网络的结构如下

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 32, 416, 416]             864
       BatchNorm2d-2         [-1, 32, 416, 416]              64
         LeakyReLU-3         [-1, 32, 416, 416]               0
            Conv2d-4         [-1, 64, 208, 208]          18,432
       BatchNorm2d-5         [-1, 64, 208, 208]             128
         LeakyReLU-6         [-1, 64, 208, 208]               0
            Conv2d-7         [-1, 32, 208, 208]           2,048
       BatchNorm2d-8         [-1, 32, 208, 208]              64
         LeakyReLU-9         [-1, 32, 208, 208]               0
           Conv2d-10         [-1, 64, 208, 208]          18,432
      BatchNorm2d-11         [-1, 64, 208, 208]             128
        LeakyReLU-12         [-1, 64, 208, 208]               0
       BasicBlock-13         [-1, 64, 208, 208]               0
           Conv2d-14        [-1, 128, 104, 104]          73,728
      BatchNorm2d-15        [-1, 128, 104, 104]             256
        LeakyReLU-16        [-1, 128, 104, 104]               0
           Conv2d-17         [-1, 64, 104, 104]           8,192
      BatchNorm2d-18         [-1, 64, 104, 104]             128
        LeakyReLU-19         [-1, 64, 104, 104]               0
           Conv2d-20        [-1, 128, 104, 104]          73,728
      BatchNorm2d-21        [-1, 128, 104, 104]             256
        LeakyReLU-22        [-1, 128, 104, 104]               0
       BasicBlock-23        [-1, 128, 104, 104]               0
           Conv2d-24         [-1, 64, 104, 104]           8,192
      BatchNorm2d-25         [-1, 64, 104, 104]             128
        LeakyReLU-26         [-1, 64, 104, 104]               0
           Conv2d-27        [-1, 128, 104, 104]          73,728
      BatchNorm2d-28        [-1, 128, 104, 104]             256
        LeakyReLU-29        [-1, 128, 104, 104]               0
       BasicBlock-30        [-1, 128, 104, 104]               0
           Conv2d-31          [-1, 256, 52, 52]         294,912
      BatchNorm2d-32          [-1, 256, 52, 52]             512
        LeakyReLU-33          [-1, 256, 52, 52]               0
           Conv2d-34          [-1, 128, 52, 52]          32,768
      BatchNorm2d-35          [-1, 128, 52, 52]             256
        LeakyReLU-36          [-1, 128, 52, 52]               0
           Conv2d-37          [-1, 256, 52, 52]         294,912
      BatchNorm2d-38          [-1, 256, 52, 52]             512
        LeakyReLU-39          [-1, 256, 52, 52]               0
       BasicBlock-40          [-1, 256, 52, 52]               0
           Conv2d-41          [-1, 128, 52, 52]          32,768
      BatchNorm2d-42          [-1, 128, 52, 52]             256
        LeakyReLU-43          [-1, 128, 52, 52]               0
           Conv2d-44          [-1, 256, 52, 52]         294,912
      BatchNorm2d-45          [-1, 256, 52, 52]             512
        LeakyReLU-46          [-1, 256, 52, 52]               0
       BasicBlock-47          [-1, 256, 52, 52]               0
           Conv2d-48          [-1, 128, 52, 52]          32,768
      BatchNorm2d-49          [-1, 128, 52, 52]             256
        LeakyReLU-50          [-1, 128, 52, 52]               0
           Conv2d-51          [-1, 256, 52, 52]         294,912
      BatchNorm2d-52          [-1, 256, 52, 52]             512
        LeakyReLU-53          [-1, 256, 52, 52]               0
       BasicBlock-54          [-1, 256, 52, 52]               0
           Conv2d-55          [-1, 128, 52, 52]          32,768
      BatchNorm2d-56          [-1, 128, 52, 52]             256
        LeakyReLU-57          [-1, 128, 52, 52]               0
           Conv2d-58          [-1, 256, 52, 52]         294,912
      BatchNorm2d-59          [-1, 256, 52, 52]             512
        LeakyReLU-60          [-1, 256, 52, 52]               0
       BasicBlock-61          [-1, 256, 52, 52]               0
           Conv2d-62          [-1, 128, 52, 52]          32,768
      BatchNorm2d-63          [-1, 128, 52, 52]             256
        LeakyReLU-64          [-1, 128, 52, 52]               0
           Conv2d-65          [-1, 256, 52, 52]         294,912
      BatchNorm2d-66          [-1, 256, 52, 52]             512
        LeakyReLU-67          [-1, 256, 52, 52]               0
       BasicBlock-68          [-1, 256, 52, 52]               0
           Conv2d-69          [-1, 128, 52, 52]          32,768
      BatchNorm2d-70          [-1, 128, 52, 52]             256
        LeakyReLU-71          [-1, 128, 52, 52]               0
           Conv2d-72          [-1, 256, 52, 52]         294,912
      BatchNorm2d-73          [-1, 256, 52, 52]             512
        LeakyReLU-74          [-1, 256, 52, 52]               0
       BasicBlock-75          [-1, 256, 52, 52]               0
           Conv2d-76          [-1, 128, 52, 52]          32,768
      BatchNorm2d-77          [-1, 128, 52, 52]             256
        LeakyReLU-78          [-1, 128, 52, 52]               0
           Conv2d-79          [-1, 256, 52, 52]         294,912
      BatchNorm2d-80          [-1, 256, 52, 52]             512
        LeakyReLU-81          [-1, 256, 52, 52]               0
       BasicBlock-82          [-1, 256, 52, 52]               0
           Conv2d-83          [-1, 128, 52, 52]          32,768
      BatchNorm2d-84          [-1, 128, 52, 52]             256
        LeakyReLU-85          [-1, 128, 52, 52]               0
           Conv2d-86          [-1, 256, 52, 52]         294,912
      BatchNorm2d-87          [-1, 256, 52, 52]             512
        LeakyReLU-88          [-1, 256, 52, 52]               0
       BasicBlock-89          [-1, 256, 52, 52]               0
           Conv2d-90          [-1, 512, 26, 26]       1,179,648
      BatchNorm2d-91          [-1, 512, 26, 26]           1,024
        LeakyReLU-92          [-1, 512, 26, 26]               0
           Conv2d-93          [-1, 256, 26, 26]         131,072
      BatchNorm2d-94          [-1, 256, 26, 26]             512
        LeakyReLU-95          [-1, 256, 26, 26]               0
           Conv2d-96          [-1, 512, 26, 26]       1,179,648
      BatchNorm2d-97          [-1, 512, 26, 26]           1,024
        LeakyReLU-98          [-1, 512, 26, 26]               0
       BasicBlock-99          [-1, 512, 26, 26]               0
          Conv2d-100          [-1, 256, 26, 26]         131,072
     BatchNorm2d-101          [-1, 256, 26, 26]             512
       LeakyReLU-102          [-1, 256, 26, 26]               0
          Conv2d-103          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-104          [-1, 512, 26, 26]           1,024
       LeakyReLU-105          [-1, 512, 26, 26]               0
      BasicBlock-106          [-1, 512, 26, 26]               0
          Conv2d-107          [-1, 256, 26, 26]         131,072
     BatchNorm2d-108          [-1, 256, 26, 26]             512
       LeakyReLU-109          [-1, 256, 26, 26]               0
          Conv2d-110          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-111          [-1, 512, 26, 26]           1,024
       LeakyReLU-112          [-1, 512, 26, 26]               0
      BasicBlock-113          [-1, 512, 26, 26]               0
          Conv2d-114          [-1, 256, 26, 26]         131,072
     BatchNorm2d-115          [-1, 256, 26, 26]             512
       LeakyReLU-116          [-1, 256, 26, 26]               0
          Conv2d-117          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-118          [-1, 512, 26, 26]           1,024
       LeakyReLU-119          [-1, 512, 26, 26]               0
      BasicBlock-120          [-1, 512, 26, 26]               0
          Conv2d-121          [-1, 256, 26, 26]         131,072
     BatchNorm2d-122          [-1, 256, 26, 26]             512
       LeakyReLU-123          [-1, 256, 26, 26]               0
          Conv2d-124          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-125          [-1, 512, 26, 26]           1,024
       LeakyReLU-126          [-1, 512, 26, 26]               0
      BasicBlock-127          [-1, 512, 26, 26]               0
          Conv2d-128          [-1, 256, 26, 26]         131,072
     BatchNorm2d-129          [-1, 256, 26, 26]             512
       LeakyReLU-130          [-1, 256, 26, 26]               0
          Conv2d-131          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-132          [-1, 512, 26, 26]           1,024
       LeakyReLU-133          [-1, 512, 26, 26]               0
      BasicBlock-134          [-1, 512, 26, 26]               0
          Conv2d-135          [-1, 256, 26, 26]         131,072
     BatchNorm2d-136          [-1, 256, 26, 26]             512
       LeakyReLU-137          [-1, 256, 26, 26]               0
          Conv2d-138          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-139          [-1, 512, 26, 26]           1,024
       LeakyReLU-140          [-1, 512, 26, 26]               0
      BasicBlock-141          [-1, 512, 26, 26]               0
          Conv2d-142          [-1, 256, 26, 26]         131,072
     BatchNorm2d-143          [-1, 256, 26, 26]             512
       LeakyReLU-144          [-1, 256, 26, 26]               0
          Conv2d-145          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-146          [-1, 512, 26, 26]           1,024
       LeakyReLU-147          [-1, 512, 26, 26]               0
      BasicBlock-148          [-1, 512, 26, 26]               0
          Conv2d-149         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-150         [-1, 1024, 13, 13]           2,048
       LeakyReLU-151         [-1, 1024, 13, 13]               0
          Conv2d-152          [-1, 512, 13, 13]         524,288
     BatchNorm2d-153          [-1, 512, 13, 13]           1,024
       LeakyReLU-154          [-1, 512, 13, 13]               0
          Conv2d-155         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-156         [-1, 1024, 13, 13]           2,048
       LeakyReLU-157         [-1, 1024, 13, 13]               0
      BasicBlock-158         [-1, 1024, 13, 13]               0
          Conv2d-159          [-1, 512, 13, 13]         524,288
     BatchNorm2d-160          [-1, 512, 13, 13]           1,024
       LeakyReLU-161          [-1, 512, 13, 13]               0
          Conv2d-162         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-163         [-1, 1024, 13, 13]           2,048
       LeakyReLU-164         [-1, 1024, 13, 13]               0
      BasicBlock-165         [-1, 1024, 13, 13]               0
          Conv2d-166          [-1, 512, 13, 13]         524,288
     BatchNorm2d-167          [-1, 512, 13, 13]           1,024
       LeakyReLU-168          [-1, 512, 13, 13]               0
          Conv2d-169         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-170         [-1, 1024, 13, 13]           2,048
       LeakyReLU-171         [-1, 1024, 13, 13]               0
      BasicBlock-172         [-1, 1024, 13, 13]               0
          Conv2d-173          [-1, 512, 13, 13]         524,288
     BatchNorm2d-174          [-1, 512, 13, 13]           1,024
       LeakyReLU-175          [-1, 512, 13, 13]               0
          Conv2d-176         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-177         [-1, 1024, 13, 13]           2,048
       LeakyReLU-178         [-1, 1024, 13, 13]               0
      BasicBlock-179         [-1, 1024, 13, 13]               0
         DarkNet-180  [[-1, 256, 52, 52], [-1, 512, 26, 26], [-1, 1024, 13, 13]]               0
          Conv2d-181          [-1, 512, 13, 13]         524,288
     BatchNorm2d-182          [-1, 512, 13, 13]           1,024
       LeakyReLU-183          [-1, 512, 13, 13]               0
          Conv2d-184         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-185         [-1, 1024, 13, 13]           2,048
       LeakyReLU-186         [-1, 1024, 13, 13]               0
          Conv2d-187          [-1, 512, 13, 13]         524,288
     BatchNorm2d-188          [-1, 512, 13, 13]           1,024
       LeakyReLU-189          [-1, 512, 13, 13]               0
          Conv2d-190         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-191         [-1, 1024, 13, 13]           2,048
       LeakyReLU-192         [-1, 1024, 13, 13]               0
          Conv2d-193          [-1, 512, 13, 13]         524,288
     BatchNorm2d-194          [-1, 512, 13, 13]           1,024
       LeakyReLU-195          [-1, 512, 13, 13]               0
          Conv2d-196         [-1, 1024, 13, 13]       4,718,592
     BatchNorm2d-197         [-1, 1024, 13, 13]           2,048
       LeakyReLU-198         [-1, 1024, 13, 13]               0
          Conv2d-199          [-1, 255, 13, 13]         261,375
          Conv2d-200          [-1, 256, 13, 13]         131,072
     BatchNorm2d-201          [-1, 256, 13, 13]             512
       LeakyReLU-202          [-1, 256, 13, 13]               0
        Upsample-203          [-1, 256, 26, 26]               0
          Conv2d-204          [-1, 256, 26, 26]         196,608
     BatchNorm2d-205          [-1, 256, 26, 26]             512
       LeakyReLU-206          [-1, 256, 26, 26]               0
          Conv2d-207          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-208          [-1, 512, 26, 26]           1,024
       LeakyReLU-209          [-1, 512, 26, 26]               0
          Conv2d-210          [-1, 256, 26, 26]         131,072
     BatchNorm2d-211          [-1, 256, 26, 26]             512
       LeakyReLU-212          [-1, 256, 26, 26]               0
          Conv2d-213          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-214          [-1, 512, 26, 26]           1,024
       LeakyReLU-215          [-1, 512, 26, 26]               0
          Conv2d-216          [-1, 256, 26, 26]         131,072
     BatchNorm2d-217          [-1, 256, 26, 26]             512
       LeakyReLU-218          [-1, 256, 26, 26]               0
          Conv2d-219          [-1, 512, 26, 26]       1,179,648
     BatchNorm2d-220          [-1, 512, 26, 26]           1,024
       LeakyReLU-221          [-1, 512, 26, 26]               0
          Conv2d-222          [-1, 255, 26, 26]         130,815
          Conv2d-223          [-1, 128, 26, 26]          32,768
     BatchNorm2d-224          [-1, 128, 26, 26]             256
       LeakyReLU-225          [-1, 128, 26, 26]               0
        Upsample-226          [-1, 128, 52, 52]               0
          Conv2d-227          [-1, 128, 52, 52]          49,152
     BatchNorm2d-228          [-1, 128, 52, 52]             256
       LeakyReLU-229          [-1, 128, 52, 52]               0
          Conv2d-230          [-1, 256, 52, 52]         294,912
     BatchNorm2d-231          [-1, 256, 52, 52]             512
       LeakyReLU-232          [-1, 256, 52, 52]               0
          Conv2d-233          [-1, 128, 52, 52]          32,768
     BatchNorm2d-234          [-1, 128, 52, 52]             256
       LeakyReLU-235          [-1, 128, 52, 52]               0
          Conv2d-236          [-1, 256, 52, 52]         294,912
     BatchNorm2d-237          [-1, 256, 52, 52]             512
       LeakyReLU-238          [-1, 256, 52, 52]               0
          Conv2d-239          [-1, 128, 52, 52]          32,768
     BatchNorm2d-240          [-1, 128, 52, 52]             256
       LeakyReLU-241          [-1, 128, 52, 52]               0
          Conv2d-242          [-1, 256, 52, 52]         294,912
     BatchNorm2d-243          [-1, 256, 52, 52]             512
       LeakyReLU-244          [-1, 256, 52, 52]               0
          Conv2d-245          [-1, 255, 52, 52]          65,535
================================================================
Total params: 61,949,149
Trainable params: 61,949,149
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.98
Forward/backward pass size (MB): 998.13
Params size (MB): 236.32
Estimated Total Size (MB): 1236.43
----------------------------------------------------------------

我们发现参数总量在6200万左右,评估总的内存消耗1.23G,还是比较复杂的网络。

聚类得到框的大小

  • 在yolov2中,其中之一的改进方法就是得到的框的方式是使用聚类的方法
  • kmeans_for_anchors.py
def cas_iou(box, cluster):
    # 计算iou,已经分析了很多次了,不在赘述
    x = np.minimum(cluster[:, 0], box[0])
    y = np.minimum(cluster[:, 1], box[1])

    intersection = x * y
    area1 = box[0] * box[1]

    area2 = cluster[:,0] * cluster[:,1]
    iou = intersection / (area1 + area2 - intersection)

    return iou

def avg_iou(box, cluster):
    return np.mean([np.max(cas_iou(box[i], cluster)) for i in range(box.shape[0])])

def kmeans(box, k):
    #-------------------------------------------------------------#
    #   取出一共有多少框
    #-------------------------------------------------------------#
    row = box.shape[0]
    
    #-------------------------------------------------------------#
    #   每个框各个点的位置
    #-------------------------------------------------------------#
    distance = np.empty((row, k))
    
    #-------------------------------------------------------------#
    #   最后的聚类位置
    #-------------------------------------------------------------#
    last_clu = np.zeros((row, ))

    np.random.seed()

    #-------------------------------------------------------------#
    #   随机选9个当聚类中心
    #-------------------------------------------------------------#
    cluster = box[np.random.choice(row, k, replace = False)]

    iter = 0
    while True:
        #-------------------------------------------------------------#
        #   计算当前框和先验框的宽高比例
        #-------------------------------------------------------------#
        for i in range(row):
            # 我们这个部分使用的是 1 - iou的方式作为距离的度量
            distance[i] = 1 - cas_iou(box[i], cluster)
        
        #-------------------------------------------------------------#
        #   取出最小点
        #-------------------------------------------------------------#
        
        # 当收敛稳定的时候,可以提前跳出循环
        near = np.argmin(distance, axis=1)
		
        if (last_clu == near).all():
            break
        
        #-------------------------------------------------------------#
        #   求每一个类的中位点
        #-------------------------------------------------------------#
        
        # 我们要得到每个类的聚类的中心
        for j in range(k):
            cluster[j] = np.median(box[near == j],axis=0)
		
        # 
        last_clu = near
        if iter % 5 == 0:
            print('iter: {:d}. avg_iou:{:.2f}'.format(iter, avg_iou(box, cluster)))
        iter += 1
	# 返回的是聚类中心,和最小的距离
    return cluster, near

def load_data(path):
    data = []
    #-------------------------------------------------------------#
    #   对于每一个xml都寻找box
    #-------------------------------------------------------------#
    
    # 对数据进行循环得到宽度和左上角和右下角的坐标
    for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
        tree    = ET.parse(xml_file)
        height  = int(tree.findtext('./size/height'))
        width   = int(tree.findtext('./size/width'))
        if height<=0 or width<=0:
            continue
        
        #-------------------------------------------------------------#
        #   对于每一个目标都获得它的宽高
        #-------------------------------------------------------------#
        for obj in tree.iter('object'):
            # 对坐标进行归一化分析
            xmin = int(float(obj.findtext('bndbox/xmin'))) / width
            ymin = int(float(obj.findtext('bndbox/ymin'))) / height
            xmax = int(float(obj.findtext('bndbox/xmax'))) / width
            ymax = int(float(obj.findtext('bndbox/ymax'))) / height

            
            xmin = np.float64(xmin)
            ymin = np.float64(ymin)
            xmax = np.float64(xmax)
            ymax = np.float64(ymax)
            # 得到归一化后的框的宽高
            data.append([xmax - xmin, ymax - ymin])
    return np.array(data)

if __name__ == '__main__':
    np.random.seed(0)
    #-------------------------------------------------------------#
    #   运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
    #   会生成yolo_anchors.txt
    #-------------------------------------------------------------#
    
    # 我们输入的图片的大小和框的个数信息
    input_shape = [416, 416]
    anchors_num = 9
    #-------------------------------------------------------------#
    #   载入数据集,可以使用VOC的xml
    #-------------------------------------------------------------#
    
    # 我们的框xml文件的路径
    path        = 'VOCdevkit/VOC2007/Annotations'
    
    #-------------------------------------------------------------#
    #   载入所有的xml
    #   存储格式为转化为比例后的width,height
    #-------------------------------------------------------------#
    print('Load xmls.')
    # 加载数据
    data = load_data(path)
    print('Load xmls done.')
    
    #-------------------------------------------------------------#
    #   使用k聚类算法
    #-------------------------------------------------------------#
    print('K-means boxes.')
    # 我们使用kmeans对我们上面的归一化的进行聚类
    cluster, near   = kmeans(data, anchors_num)
    print('K-means boxes done.')
    
    # 将归一化的宽高信息进行还原,和输入图片的指定大小(416,416)一致,data的还原
    data            = data * np.array([input_shape[1], input_shape[0]])
    # 聚类中心的归一化转化现实坐标
    cluster         = cluster * np.array([input_shape[1], input_shape[0]])

    #-------------------------------------------------------------#
    #   绘图
    #-------------------------------------------------------------#
    
    # 将我们的框以长宽为横纵坐标,进行点的绘制,
    for j in range(anchors_num):
        # 将所有和当前框聚到一类的框的绘制
        plt.scatter(data[near == j][:,0], data[near == j][:,1])
        # 中心点的绘制
        plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
    # 保存和图片的展示
    plt.savefig("kmeans_for_anchors.jpg")
    plt.show()
    print('Save kmeans_for_anchors.jpg in root dir.')
	
    # 对聚类中心按照面积进行排序
    cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])]
    print('avg_ratio:{:.2f}'.format(avg_iou(data, cluster)))
    print(cluster)

    # 将上面排好序的聚类中心写入文件
    f = open("yolo_anchors.txt", 'w')
    row = np.shape(cluster)[0]
    for i in range(row):
        if i == 0:
            x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
        else:
            x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
        f.write(x_y)
    f.close()

运行完,我们能得到的如下的图像

image.png