YOLOv3的源代码精度理解(一) 预测

345 阅读6分钟

代码主要是参考bubbliiing的github YOLOv3的代码:github.com/bubbliiiing…

前景知识部分

image.png

计算机需要下面四个参数就能确定猫的位置了

image.png

为了防止失真的问题,我们使用的是填充技术

划分成不同的大小的特征图13 * 13 、26 * 26、52 * 52

物体的中心点落在了那个网格中,就由哪个网格负责预测这个物体

image.png

总结:

image.png

环境配置的部分

主要安装两个库:torch = 1.2.0torchvision = 0.4.0
这两个库的安装使用官网的方式并没有安装成功,所以使用了download.pytorch.org/whl/torch_s… 这个网站提供的whl文件,进行离线版安装

修正理解:对于我们安装的wheel文件,我们可以放在任意的部分,在我们激活环境之后我们可以直接使用pip 进行安装,即可安装到我们虚拟环境中,而不用将其放在script中。

对于源代码的解读

预测部分

predict.py文件

其实在这个地方预测的过程中,我们可以针对图片视频摄像头FPS或者文件夹(批量图片)进行测试
使用mode进行区分:
predict 是图片 我们需要输入图片的路径

  • ./img/street.jpg video 是视频或者摄像头
  • video_path = 0的时候,是摄像头
  • video_path = "真实路径的时候" 是对视频进行检测
  • 测试中一个比较好玩的点是
    • 原始视频的时长是16秒
    • 我们在其中有个地方是设置fps, 我分别设置了1,10,25,我发现设置1的时候,保存的时长是4分40秒左右,当时设置时10的时候,保存的时长是50秒左右,当我们使用25的是时候,保存时长是20秒左右
    • 但是我们的预测的花费的时间确是差不多的
    • 总结:我们在进行预测的过程其实是针对视频中的每一帧都当成是一张图片,在每一张图片中进行一次算法预测,所以预测的时间比较长,最终保留视频的时长确实是和fps的大小是有关系的,就是我们每隔多少的帧数进行抽取一张图片保留为最终的视频,隔的位置越远,最终视频的时长就越短
    • 我们使用fps = capture.get(cv2.CAP_PROP_FPS)video_fps = fps获取视频的fps然后在保存视频的时候使用这个fps,最终我们得到的视频的时长和我们原始的视频的时长基本一样,验证了猜想。 dir_predict 是针对批量化的图片进行处理
  • 输入的是图片文件夹的路径 ./img/ fps 就是计算单位时间能够检测的图片的数量

代码的核心部分

图片

yolo.detect_image(image, crop = crop)

视频

yolo.detect_image(frame)

#-----------------------------------------------------------------------#
#   predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
#   整合到了一个py文件中,通过指定mode进行模式的修改。
#-----------------------------------------------------------------------#
import time

import cv2
import numpy as np
from PIL import Image

from yolo import YOLO

if __name__ == "__main__":
    yolo = YOLO()
    #----------------------------------------------------------------------------------------------------------#
    #   mode用于指定测试的模式:
    #   'predict'表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
    #   'video'表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
    #   'fps'表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
    #   'dir_predict'表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
    #----------------------------------------------------------------------------------------------------------#
    mode = "predict"
    #-------------------------------------------------------------------------#
    #   crop指定了是否在单张图片预测后对目标进行截取
    #   crop仅在mode='predict'时有效
    #-------------------------------------------------------------------------#
    crop            = False
    #----------------------------------------------------------------------------------------------------------#
    #   video_path用于指定视频的路径,当video_path=0时表示检测摄像头
    #   想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
    #   video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
    #   想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
    #   video_fps用于保存的视频的fps
    #   video_path、video_save_path和video_fps仅在mode='video'时有效
    #   保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
    #----------------------------------------------------------------------------------------------------------#
    
    # ctrl + c 并不能终止退出,需要强杀。
    video_path      = 0
    video_save_path = ""
    video_fps       = 25.0
    
    #-------------------------------------------------------------------------#
    #   test_interval用于指定测量fps的时候,图片检测的次数
    #   理论上test_interval越大,fps越准确。
    #-------------------------------------------------------------------------#
    test_interval   = 100
    
    #-------------------------------------------------------------------------#
    #   dir_origin_path指定了用于检测的图片的文件夹路径
    #   dir_save_path指定了检测完图片的保存路径
    #   dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
    #-------------------------------------------------------------------------#
    dir_origin_path = "img/"
    dir_save_path   = "img_out/"

    if mode == "predict":
        '''
        1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 
        2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
        3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
        在原图上利用矩阵的方式进行截取。
        4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
        比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
        '''
        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                # 核心:将图片送入网络进行预测,得到最终预测的图片
                r_image = yolo.detect_image(image, crop = crop)
                r_image.show()

    elif mode == "video":
        # 检测的是视频,我们就能根据video_path判断是检测摄像头还是MP4的视频
        capture = cv2.VideoCapture(video_path)
        
        # 保存视频
        if video_save_path!="":
            # 输出的配置参数
            fourcc  = cv2.VideoWriter_fourcc(*'XVID')
            size    = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            
            # 我们构建一个输出对象
            out     = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        # ref bool,看是否读取视频成功;frame 读取的帧,图片数据
        ref, frame = capture.read()
        if not ref:
            raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

        fps = 0.0
        # 需要对视频流循环获取帧数据
        while(True):
            t1 = time.time()
            # 读取某一帧
            ref, frame = capture.read()
            
            # 当检测不到东西,直接跳出循环
            if not ref:
                break
            
            # 格式转变,BGRtoRGB
            frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
            
            # 转变成Image
            frame = Image.fromarray(np.uint8(frame))
            
            # 进行检测
            frame = np.array(yolo.detect_image(frame))
            
            # RGBtoBGR满足opencv显示格式
            frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
            
            fps  = ( fps + (1./(time.time()-t1)) ) / 2
            print("fps= %.2f"%(fps))
            
            # 在原有的帧上进行文本的绘制
            frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            cv2.imshow("video",frame)
            c= cv2.waitKey(1) & 0xff 
            if video_save_path!="":
                # 对于每一帧图片我们都写入文件中
                out.write(frame)
            
            # 27代表ctrl + c 当我们按下ctrl + c的时候 释放video对象
            if c==27:
                capture.release()
                break

        print("Video Detection Done!")
        # 当视频检测结束的时候,才会释放capture
        capture.release()
        # 保存视频
        if video_save_path!="":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        # 最终销毁窗口对象即可
        cv2.destroyAllWindows()
        
    elif mode == "fps":
        # 就是进行test_interval多次的循环,计算单位时间能够检测的图片的数量
        img = Image.open('img/street.jpg')
        tact_time = yolo.get_FPS(img, test_interval)
        print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')

    elif mode == "dir_predict":
        import os

        from tqdm import tqdm
        # 我们获取图片文件夹下的所有文件名称
        img_names = os.listdir(dir_origin_path)
        for img_name in tqdm(img_names):
        
            # 图片检测类型在一下指定的格式范围内,否则直接跳过
            if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                image_path  = os.path.join(dir_origin_path, img_name)
                image       = Image.open(image_path)
                # 检测,和上面一样
                r_image     = yolo.detect_image(image)
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                # 指定保存的图片的类型为png
                r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)

    else:
        # 检测类型应该在指定的范围内,否则引发exception
        raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")