YOLOv3的源代码精度理解(八) LossHistory

203 阅读1分钟

代码主要是参考bubbliiing的github YOLOv3的代码:github.com/bubbliiiing…

对于源代码的解读

训练部分

callbacks.py文件

  • 主要可以对loss的可视化部分的代码进行学习,方便我们观察指标的变化,提前终止训练或者更换参数后继续训练
class LossHistory():
    # 初始化部分
    def __init__(self, log_dir, model, input_shape):
        # 获取时间字符串
        time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
        # 日志的路径
        self.log_dir    = os.path.join(log_dir, "loss_" + str(time_str))
        # 训练loss列表
        self.losses     = []
        # 验证loss列表
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        
        # SummaryWriter实例化,传入的是log的路径
        self.writer     = SummaryWriter(self.log_dir)
        try:
            # 生成一个假的输入,将其加入到SummaryWriter的add_graph中(why)
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass

    def append_loss(self, epoch, loss, val_loss):
        # 若不存在log路径,直接进行构建
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        # 将传入的loss加入我们的训练损失列表,传入的val_loss加入我们的验证损失列表
        self.losses.append(loss)
        self.val_loss.append(val_loss)

        # 打开epoch_loss、epoch_val_loss文件,对数据进行存储
        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")
        
        # 将我们所需要的数据保存在文件里面供可视化使用。 这里是Scalar类型,所以使用writer.add_scalar()
        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        
        # 进行loss的绘制
        self.loss_plot()

    def loss_plot(self):
        # 计算迭代次数
        iters = range(len(self.losses))
        
        # 开启画布
        plt.figure()
        # 绘制训练和验证的loss
        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
        try:
            if len(self.losses) < 25:
                num = 5
            else:
                num = 15
            
            # 进行图像平滑,scipy.signal.savgol_filter进行滤波
            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
        except:
            pass
        
        # 绘制网格
        plt.grid(True)
        # x轴标签
        plt.xlabel('Epoch')
        # y轴标签
        plt.ylabel('Loss')
        # 图例
        plt.legend(loc="upper right")
        # 保存成图片
        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")