如何在Colab中使用检查点

284 阅读2分钟

为什么使用检查点

人工智能中进行模型微调和调用模型,一个代码能跑几个小时,身为一个懒得配环境的懒狗,都是直接使用colab已经配好的环境,那么问题来了,在跑代码过程中一旦出现断网的情况,可能直接导致程序停止运行,就算你只差一分钟就跑完了,也只能认命重新开始,这样整的话可能一天时间都不够跑完一个模型。因此我们引入检查点(checkpoints)这个概念。

什么是检查点

不设置检查点的程序,就是运行完直接产出一个结果。设置检查点的程序,就是在程序运行过程中,在多个节点保存程序当前的状态和中间结果,如果程序中断或出错,允许程序从上一个检查点开始,而不是从头开始。

怎么用检查点

  1. 保存检查点:在代码适当位置使用文件操作函数保存当前状态到一个文件
  2. 加载检查点:在程序开始时,检查是否存在检查点文件,如果存在则加载该检查点,并从检查点继续执行。
import os  
import pickle  
  
checkpoint_file = '/content/checkpoint.pkl'  
  
# 检查是否存在检查点  
if os.path.exists(checkpoint_file):  
    # 加载检查点  
    with open(checkpoint_file, 'rb'as f:  
        checkpoint_data = pickle.load(f)  
else:  
    checkpoint_data = None  
  
# ... 运行你的代码 ...  
  
# 保存检查点  
with open(checkpoint_file, 'wb'as f:  
    pickle.dump(checkpoint_data, f)  

实际案例

我是调用模型,判断5000多张图片是否有坑洼,在这里用来pytorch提供的内置检查点功能。

import torch
import pandas as pd
#保存检查点函数
def save_checkpoint(image_names, flags, checkpoint_file):
    torch.save({
        'image_names': image_names,
        'flags': flags
    }, checkpoint_file)
#加载检查点函数
def load_checkpoint(checkpoint_file):
    if os.path.exists(checkpoint_file):
        checkpoint = torch.load(checkpoint_file)
        return checkpoint['image_names'], checkpoint['flags']
    else:
        return [], []
import os

from ultralytics import YOLO
# 加载模型

model = YOLO('/content/drive/MyDrive/roaddefect3(高低+翻).pt') 

# 加载图片文件夹

image_folder = '/content/drive/MyDrive/testdata_V2'
checkpoint_file = '/content/drive/MyDrive/checkpoint.pth' 

# 加载检查点 (如果存在)

image_names, flags = load_checkpoint(checkpoint_file)
# 获取剩余的图像名称
remaining_image_names = [name for name in os.listdir(image_folder) if name.endswith('.jpg') and name not in image_names]

# 迭代处理图像

for image_name in remaining_image_names:

    image_path = os.path.join(image_folder, image_name)
    # 执行推理

    results = model(image_path)

    # 确定 flag 值

    flag = 1 if results[0] else 0

    # 存储图像名称和标志

    image_names.append(image_name)

    flags.append(flag)

    # 每处理100张图像,保存一个检查点

    if len(image_names) % 100 == 0:

        save_checkpoint(image_names, flags, checkpoint_file)


# 保存最终检查点

save_checkpoint(image_names, flags, checkpoint_file)

  

# 保存图片名称以及对应的Flag值

df = pd.DataFrame({'Image Name': image_names, 'Flag': flags})

# 将结果保存为excel文件

df.to_excel('/content/drive/MyDrive/output3.xlsx', index=False)