简介
SAMURAI 是新型视觉目标追踪模型,它是 Segment Anything Model 2 (SAM 2) 的改进版本。虽然 SAM 2 在图像分割任务上表现出色,但在处理复杂场景下的视频目标追踪时存在不足。SAMURAI 通过引入时间运动信息和改进的记忆选择机制,克服了这些不足,实现了无需重新训练或微调即可进行高效、准确的实时目标追踪。
预览
环境准备
第一步,我们要去官网下载源码,并确保 Python、pytorch 都已安装完毕,笔者就是因为 pytorch 版本过老环境准备失败多次。(python>=3.10, as well as torch>=2.3.1 and torchvision>=0.18.1)
第二步,我们开始安装各项依赖并下载模型
cd sam2
pip install -e .
pip install -e ".[notebooks]"
pip install matplotlib==3.7 tikzplotlib jpeg4py opencv-python lmdb pandas scipy loguru
cd ..
cd checkpoints && \
./download_ckpts.sh && \
cd ..
第三步,安装完后,便可以进行推理运行。注意txt是框选的坐标,模型会自动跟踪框选坐标中的目标。
python scripts/demo.py --video_path <your_video.mp4> --txt_path <path_to_first_frame_bbox.txt>
源码分析
import argparse
import os
import os.path as osp
import numpy as np
import cv2
import torch
import gc
import sys
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor
color = [(255, 0, 0)]
def load_txt(gt_path):
with open(gt_path, 'r') as f:
gt = f.readlines()
prompts = {}
for fid, line in enumerate(gt):
x, y, w, h = map(float, line.split(','))
x, y, w, h = int(x), int(y), int(w), int(h)
prompts[fid] = ((x, y, x + w, y + h), 0)
return prompts
def determine_model_cfg(model_path):
if "large" in model_path:
return "configs/samurai/sam2.1_hiera_l.yaml"
elif "base_plus" in model_path:
return "configs/samurai/sam2.1_hiera_b+.yaml"
elif "small" in model_path:
return "configs/samurai/sam2.1_hiera_s.yaml"
elif "tiny" in model_path:
return "configs/samurai/sam2.1_hiera_t.yaml"
else:
raise ValueError("Unknown model size in path!")
def prepare_frames_or_path(video_path):
if video_path.endswith(".mp4") or osp.isdir(video_path):
return video_path
else:
raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")
def main(args):
model_cfg = determine_model_cfg(args.model_path)
predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
frames_or_path = prepare_frames_or_path(args.video_path)
prompts = load_txt(args.txt_path)
if args.save_to_video:
if osp.isdir(args.video_path):
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
height, width = loaded_frames[0].shape[:2]
else:
cap = cv2.VideoCapture(args.video_path)
loaded_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
loaded_frames.append(frame)
cap.release()
height, width = loaded_frames[0].shape[:2]
if len(loaded_frames) == 0:
raise ValueError("No frames were loaded from the video.")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
bbox, track_label = prompts[0]
_, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
mask_to_vis = {}
bbox_to_vis = {}
for obj_id, mask in zip(object_ids, masks):
mask = mask[0].cpu().numpy()
mask = mask > 0.0
non_zero_indices = np.argwhere(mask)
if len(non_zero_indices) == 0:
bbox = [0, 0, 0, 0]
else:
y_min, x_min = non_zero_indices.min(axis=0).tolist()
y_max, x_max = non_zero_indices.max(axis=0).tolist()
bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
bbox_to_vis[obj_id] = bbox
mask_to_vis[obj_id] = mask
if args.save_to_video:
img = loaded_frames[frame_idx]
for obj_id, mask in mask_to_vis.items():
mask_img = np.zeros((height, width, 3), np.uint8)
mask_img[mask] = color[(obj_id + 1) % len(color)]
img = cv2.addWeighted(img, 1, mask_img, 0.2, 0)
for obj_id, bbox in bbox_to_vis.items():
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2)
out.write(img)
if args.save_to_video:
out.release()
del predictor, state
gc.collect()
torch.clear_autocast_cache()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--video_path", required=True, help="Input video path or directory of frames.")
parser.add_argument("--txt_path", required=True, help="Path to ground truth text file.")
parser.add_argument("--model_path", default="sam2/checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.")
parser.add_argument("--video_output_path", default="demo.mp4", help="Path to save the output video.")
parser.add_argument("--save_to_video", default=True, help="Save results to a video.")
args = parser.parse_args()
main(args)
这段代码看起来是一个集视频处理、目标跟踪、掩码生成于一体的AI处理流水线。
1. 引入模块:召唤江湖中的“神器”
import argparse
import os
import os.path as osp
import numpy as np
import cv2
import torch
import gc
import sys
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor
这里是召唤“神器”的合集,风格各异:
argparse:专门用来让玩家(工程师)输入指令参数,类似“武林秘籍的目录”。os和os.path:文件系统的门派高手,帮你查遍本地文件路径。numpy:数学大拿,辅助各种矩阵操作。cv2:江湖人称“OpenCV”,视觉处理的扛把子。torch:深度学习界的内功心法。gc:清理垃圾的环保小分队。sys:负责扩展“门派地盘”(添加路径)。
最后,重点来了:build_sam2_video_predictor,这是本代码的大魔王,负责用“武士刀”(SAM)在视频中劈出目标的踪迹。
2. 读取坐标:目标的藏宝图
def load_txt(gt_path):
with open(gt_path, 'r') as f:
gt = f.readlines()
prompts = {}
for fid, line in enumerate(gt):
x, y, w, h = map(float, line.split(','))
x, y, w, h = int(x), int(y), int(w), int(h)
prompts[fid] = ((x, y, x + w, y + h), 0)
return prompts
这个函数是用来读取“藏宝图”的,也就是目标框的坐标信息。
过程翻译如下:
- 打开一个“藏宝图”(txt文件),一行一行读。
- 每行都像是“藏宝点”的坐标,格式是:
x, y, w, h,分别是左上角和宽高。 - 转换成整数后存进一个字典,
fid是帧号,值是框的坐标。
总结:这是在告诉武士“要在哪找目标”。
3. 配置模型:挑选合适的武士刀
def determine_model_cfg(model_path):
if "large" in model_path:
return "configs/samurai/sam2.1_hiera_l.yaml"
elif "base_plus" in model_path:
return "configs/samurai/sam2.1_hiera_b+.yaml"
elif "small" in model_path:
return "configs/samurai/sam2.1_hiera_s.yaml"
elif "tiny" in model_path:
return "configs/samurai/sam2.1_hiera_t.yaml"
else:
raise ValueError("Unknown model size in path!")
这里是选刀的过程,程序根据model_path里的关键词(如large、small)挑选对应的配置文件。
如果你的路径里没有这些关键词,直接抛个错误:“你这武士刀是哪来的?!”
4. 视频预处理:确认战场
def prepare_frames_or_path(video_path):
if video_path.endswith(".mp4") or osp.isdir(video_path):
return video_path
else:
raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")
这个函数检查战场环境:
- 如果是个
.mp4文件:战场是视频。 - 如果是个文件夹:战场是一堆图像帧。
- 如果都不是,直接拒绝开战:“这不是个合法战场!”
5. 主函数:武士大显身手
def main(args):
主函数是整个武士操作的核心,包含以下几步:
(1) 准备武器和战场
model_cfg = determine_model_cfg(args.model_path)
predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
frames_or_path = prepare_frames_or_path(args.video_path)
prompts = load_txt(args.txt_path)
这就像是开战前的准备工作:
- 挑选武器:根据
args.model_path,选择合适的模型配置文件(大刀还是小刀?)。 - 召唤武士:通过
build_sam2_video_predictor召唤模型,这个武士在战场中负责追踪敌人(目标)。 - 确定战场:视频或图像帧——到底要在哪开打?
prepare_frames_or_path帮你搞定。 - 解读藏宝图:从
txt_path读取目标框坐标,告诉武士“敌人初始位置在哪”。
(2) 准备输出:保存战斗结果
if args.save_to_video:
if osp.isdir(args.video_path):
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
height, width = loaded_frames[0].shape[:2]
else:
cap = cv2.VideoCapture(args.video_path)
loaded_frames = []
while True:
ret, frame = cap.read()
if not ret:
break
loaded_frames.append(frame)
cap.release()
height, width = loaded_frames[0].shape[:2]
if len(loaded_frames) == 0:
raise ValueError("No frames were loaded from the video.")
这段代码在做什么呢?给战斗录个视频,方便回头看看战绩:
- 如果战场是图像帧:
- 把文件夹里的
.jpg图像按顺序加载。 - 确认战场的大小(高度和宽度)。
- 把文件夹里的
- 如果战场是视频:
- 用
cv2.VideoCapture逐帧读取视频。 - 确保加载到的帧数不为空,否则直接扔个错误:“视频里啥都没有,这仗还打个毛!”。
- 用
然后根据读取的第一帧,获取战场的尺寸(height和width),为后续输出视频做准备。
(3) 开始战斗:用武士刀劈出结果
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))
这里是“战斗记录仪”的设置:
- 定义视频编码格式(
mp4v),确保输出结果是个.mp4视频。 - 创建一个
cv2.VideoWriter对象,用来逐帧保存战斗的“录像”。
(4) 初始化武士状态:准备开打
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
bbox, track_label = prompts[0]
_, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
这段代码是武士战斗的开场仪式:
- 省电模式:
torch.inference_mode()和torch.autocast让模型推理更高效,节省计算资源。 - 初始化战斗状态:
predictor.init_state告诉武士“战场在哪里”,是否需要把视频临时存放到CPU。 - 给敌人标记起点:从
prompts中取出第一个帧的目标框(bbox),并告诉武士“这是敌人的初始位置”。 - 生成初始掩码:调用
predictor.add_new_points_or_box,用目标框在第一帧生成目标的分割掩码。
(5) 逐帧追击敌人
接下来是战斗的核心逻辑——逐帧追踪目标,同时生成每帧的分割掩码和目标框:
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
mask_to_vis = {}
bbox_to_vis = {}
predictor.propagate_in_video(state)是武士的追踪大招,它会逐帧传播目标的状态,返回:frame_idx: 当前帧的索引。object_ids: 当前帧中所有被追踪的目标 ID。masks: 这些目标对应的分割掩码。
- 每一帧会初始化两个字典:
mask_to_vis: 用来保存可视化的掩码。bbox_to_vis: 用来保存可视化的目标框。
(6) 处理掩码和生成目标框
for obj_id, mask in zip(object_ids, masks):
mask = mask[0].cpu().numpy()
mask = mask > 0.0
non_zero_indices = np.argwhere(mask)
if len(non_zero_indices) == 0:
bbox = [0, 0, 0, 0]
else:
y_min, x_min = non_zero_indices.min(axis=0).tolist()
y_max, x_max = non_zero_indices.max(axis=0).tolist()
bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
bbox_to_vis[obj_id] = bbox
mask_to_vis[obj_id] = mask
这段代码是处理每帧的掩码和生成目标框的逻辑:
-
提取掩码:
mask[0].cpu().numpy():将掩码从 GPU 中取出并转换为 NumPy 格式。mask > 0.0:将掩码转化为二值化布尔数组,表示哪些像素属于目标。np.argwhere(mask):找出掩码中非零(即目标像素)的位置索引。
-
生成目标框:
- 如果掩码中没有任何目标像素(
len(non_zero_indices) == 0),目标框设置为 [0, 0, 0, 0],表示当前帧没有检测到目标。
- 如果掩码中没有任何目标像素(
总结
分析完源码,大家就可以应用到诸如监控视频回放系统中,但要注意目前该模型是不支持实时推理的。