为什么使用检查点
人工智能中进行模型微调和调用模型,一个代码能跑几个小时,身为一个懒得配环境的懒狗,都是直接使用colab已经配好的环境,那么问题来了,在跑代码过程中一旦出现断网的情况,可能直接导致程序停止运行,就算你只差一分钟就跑完了,也只能认命重新开始,这样整的话可能一天时间都不够跑完一个模型。因此我们引入检查点(checkpoints)这个概念。
什么是检查点
不设置检查点的程序,就是运行完直接产出一个结果。设置检查点的程序,就是在程序运行过程中,在多个节点保存程序当前的状态和中间结果,如果程序中断或出错,允许程序从上一个检查点开始,而不是从头开始。
怎么用检查点
- 保存检查点:在代码适当位置使用文件操作函数保存当前状态到一个文件
- 加载检查点:在程序开始时,检查是否存在检查点文件,如果存在则加载该检查点,并从检查点继续执行。
import os
import pickle
checkpoint_file = '/content/checkpoint.pkl'
# 检查是否存在检查点
if os.path.exists(checkpoint_file):
# 加载检查点
with open(checkpoint_file, 'rb') as f:
checkpoint_data = pickle.load(f)
else:
checkpoint_data = None
# ... 运行你的代码 ...
# 保存检查点
with open(checkpoint_file, 'wb') as f:
pickle.dump(checkpoint_data, f)
实际案例
我是调用模型,判断5000多张图片是否有坑洼,在这里用来pytorch提供的内置检查点功能。
import torch
import pandas as pd
#保存检查点函数
def save_checkpoint(image_names, flags, checkpoint_file):
torch.save({
'image_names': image_names,
'flags': flags
}, checkpoint_file)
#加载检查点函数
def load_checkpoint(checkpoint_file):
if os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file)
return checkpoint['image_names'], checkpoint['flags']
else:
return [], []
import os
from ultralytics import YOLO
# 加载模型
model = YOLO('/content/drive/MyDrive/roaddefect3(高低+翻).pt')
# 加载图片文件夹
image_folder = '/content/drive/MyDrive/testdata_V2'
checkpoint_file = '/content/drive/MyDrive/checkpoint.pth'
# 加载检查点 (如果存在)
image_names, flags = load_checkpoint(checkpoint_file)
# 获取剩余的图像名称
remaining_image_names = [name for name in os.listdir(image_folder) if name.endswith('.jpg') and name not in image_names]
# 迭代处理图像
for image_name in remaining_image_names:
image_path = os.path.join(image_folder, image_name)
# 执行推理
results = model(image_path)
# 确定 flag 值
flag = 1 if results[0] else 0
# 存储图像名称和标志
image_names.append(image_name)
flags.append(flag)
# 每处理100张图像,保存一个检查点
if len(image_names) % 100 == 0:
save_checkpoint(image_names, flags, checkpoint_file)
# 保存最终检查点
save_checkpoint(image_names, flags, checkpoint_file)
# 保存图片名称以及对应的Flag值
df = pd.DataFrame({'Image Name': image_names, 'Flag': flags})
# 将结果保存为excel文件
df.to_excel('/content/drive/MyDrive/output3.xlsx', index=False)