代码主要是参考bubbliiing的github YOLOv3的代码:github.com/bubbliiiing…
对于源代码的解读
训练部分
callbacks.py文件
- 主要可以对loss的可视化部分的代码进行学习,方便我们观察指标的变化,提前终止训练或者更换参数后继续训练
class LossHistory():
# 初始化部分
def __init__(self, log_dir, model, input_shape):
# 获取时间字符串
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
# 日志的路径
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
# 训练loss列表
self.losses = []
# 验证loss列表
self.val_loss = []
os.makedirs(self.log_dir)
# SummaryWriter实例化,传入的是log的路径
self.writer = SummaryWriter(self.log_dir)
try:
# 生成一个假的输入,将其加入到SummaryWriter的add_graph中(why)
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
self.writer.add_graph(model, dummy_input)
except:
pass
def append_loss(self, epoch, loss, val_loss):
# 若不存在log路径,直接进行构建
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
# 将传入的loss加入我们的训练损失列表,传入的val_loss加入我们的验证损失列表
self.losses.append(loss)
self.val_loss.append(val_loss)
# 打开epoch_loss、epoch_val_loss文件,对数据进行存储
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
f.write(str(loss))
f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
f.write(str(val_loss))
f.write("\n")
# 将我们所需要的数据保存在文件里面供可视化使用。 这里是Scalar类型,所以使用writer.add_scalar()
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)
# 进行loss的绘制
self.loss_plot()
def loss_plot(self):
# 计算迭代次数
iters = range(len(self.losses))
# 开启画布
plt.figure()
# 绘制训练和验证的loss
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
try:
if len(self.losses) < 25:
num = 5
else:
num = 15
# 进行图像平滑,scipy.signal.savgol_filter进行滤波
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
except:
pass
# 绘制网格
plt.grid(True)
# x轴标签
plt.xlabel('Epoch')
# y轴标签
plt.ylabel('Loss')
# 图例
plt.legend(loc="upper right")
# 保存成图片
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
plt.cla()
plt.close("all")