问题:
在训练模型时,查看loss变化情况一般会使用TensorBoard工具。有时候远程服务器训练,使用TensorBoard不是很方便,仅仅就想查看一下loss的变化情况该如何解决
思路: 日志中本身有记录loss的变化情况,按照时间顺序读取loss,并画图显示出来即可。 关键点在于如何提取loss,可使用正则表达式来提取。同样方式可以提取 其他变量。
方案:
比如dbnet的训练日志有如下片段
[INFO] [2022-03-28 19:48:12,280] step: 468000, epoch: 313, loss: 0.651640, lr: 0.002897
[INFO] [2022-03-28 19:48:12,285] bce_loss: 0.074000
[INFO] [2022-03-28 19:48:12,285] thresh_loss: 0.063169
[INFO] [2022-03-28 19:48:12,286] l1_loss: 0.021847
[INFO] [2022-03-28 19:52:10,484] step: 468450, epoch: 313, loss: 1.067670, lr: 0.002897
[INFO] [2022-03-28 19:52:10,485] bce_loss: 0.124831
[INFO] [2022-03-28 19:52:10,485] thresh_loss: 0.095843
[INFO] [2022-03-28 19:52:10,486] l1_loss: 0.034767
[INFO] [2022-03-28 19:52:29,176] Training epoch 314
分别提取loss、bce_loss、thresh_loss、l1_loss的正则表达式为:
re.findall('.*loss: (.*), lr.*', row)
re.findall('.*bce_loss: (.*).*', row)
re.findall('.*thresh_loss: (.*).*', row)
re.findall('.*l1_loss: (.*).*', row)
代码:
import re
log = []
with open('output.log', 'r') as f:
log_after = f.readlines()
log.extend(log_after)
_x = []
total_loss = []
bce_loss = []
thresh_loss = []
l1_loss = []
for row in log:
if 'INFO' in row:
tmp_total_loss = re.findall('.*loss: (.*), lr.*', row)
if tmp_total_loss != []:
total_loss.append(float(tmp_total_loss[0]))
tmp_iters = re.findall('.*iters: (.*), eta.*', row)
if tmp_iters != []:
_x.append(int(tmp_iters[0]))
if 'bce_loss: ' in row:
tmp_bce_loss = re.findall('.*bce_loss: (.*).*', row)
if tmp_bce_loss != []:
bce_loss.append(float(tmp_bce_loss[0].split(',')[0]))
if 'thresh_loss: ' in row:
tmp_thresh_loss = re.findall('.*thresh_loss: (.*).*', row)
if tmp_thresh_loss != []:
thresh_loss.append(float(tmp_thresh_loss[0].split(',')[0]))
elif 'l1_loss: ' in row:
tmp_l1_loss = re.findall('.*l1_loss: (.*).*', row)
if tmp_l1_loss != []:
l1_loss.append(float(tmp_l1_loss[0].split(',')[0]))
# 可以改变画图的起点,开始阶段损失波动比较大
_cut = 10
_x = _x[_cut:]
total_loss = total_loss[_cut:]
bce_loss = bce_loss[_cut:]
thresh_loss = thresh_loss[_cut:]
l1_loss = l1_loss[_cut:]
print(len(total_loss))
print(len(bce_loss))
print(len(thresh_loss))
print(len(l1_loss))
import matplotlib.pyplot as plt
plt.figure()
plt.figure(figsize=(20,10))
_x = [x for x in range(1, len(total_loss)+1)]
y = [x for x in range(1, len(bce_loss)+1)]
z= [x for x in range(1, len(thresh_loss)+1)]
w= [x for x in range(1, len(l1_loss)+1)]
plt.plot(w, l1_loss, color='m')
plt.plot(z, thresh_loss, color='g')
# plt.plot(z, thresh_loss, color='coral')
plt.plot(y, bce_loss, color='dodgerblue')
# plt.plot(_x, total_loss, color='red')
# plt.show()
plt.savefig('train_loss.png', bbox_inches='tight')
# plt.savefig('total_loss.png', bbox_inches='tight')
思考: 不能太过依赖工具本身,也需要了解下工具的实现思路,一些简单的功能可以自己实现,不要要依赖太多的封装包。