Samurai:Zero Shot 目标跟踪神器

1,357 阅读8分钟

简介

SAMURAI 是新型视觉目标追踪模型,它是 Segment Anything Model 2 (SAM 2)  的改进版本。虽然 SAM 2 在图像分割任务上表现出色,但在处理复杂场景下的视频目标追踪时存在不足。SAMURAI 通过引入时间运动信息和改进的记忆选择机制,克服了这些不足,实现了无需重新训练或微调即可进行高效、准确的实时目标追踪。

预览

image.png

环境准备

yangchris11/samurai: Official repository of "SAMURAI: Adapting Segment Anything Model for Zero-Shot Visual Tracking with Motion-Aware Memory"

第一步,我们要去官网下载源码,并确保 Python、pytorch 都已安装完毕,笔者就是因为 pytorch 版本过老环境准备失败多次。(python>=3.10, as well as torch>=2.3.1 and torchvision>=0.18.1

第二步,我们开始安装各项依赖并下载模型

cd sam2
pip install -e .
pip install -e ".[notebooks]"
pip install matplotlib==3.7 tikzplotlib jpeg4py opencv-python lmdb pandas scipy loguru
cd ..
cd checkpoints && \
./download_ckpts.sh && \
cd ..

第三步,安装完后,便可以进行推理运行。注意txt是框选的坐标,模型会自动跟踪框选坐标中的目标。

python scripts/demo.py --video_path <your_video.mp4> --txt_path <path_to_first_frame_bbox.txt>

源码分析

import argparse
import os
import os.path as osp
import numpy as np
import cv2
import torch
import gc
import sys
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor

color = [(255, 0, 0)]

def load_txt(gt_path):
    with open(gt_path, 'r') as f:
        gt = f.readlines()
    prompts = {}
    for fid, line in enumerate(gt):
        x, y, w, h = map(float, line.split(','))
        x, y, w, h = int(x), int(y), int(w), int(h)
        prompts[fid] = ((x, y, x + w, y + h), 0)
    return prompts

def determine_model_cfg(model_path):
    if "large" in model_path:
        return "configs/samurai/sam2.1_hiera_l.yaml"
    elif "base_plus" in model_path:
        return "configs/samurai/sam2.1_hiera_b+.yaml"
    elif "small" in model_path:
        return "configs/samurai/sam2.1_hiera_s.yaml"
    elif "tiny" in model_path:
        return "configs/samurai/sam2.1_hiera_t.yaml"
    else:
        raise ValueError("Unknown model size in path!")

def prepare_frames_or_path(video_path):
    if video_path.endswith(".mp4") or osp.isdir(video_path):
        return video_path
    else:
        raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")

def main(args):
    model_cfg = determine_model_cfg(args.model_path)
    predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
    frames_or_path = prepare_frames_or_path(args.video_path)
    prompts = load_txt(args.txt_path)

    if args.save_to_video:
        if osp.isdir(args.video_path):
            frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
            loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
            height, width = loaded_frames[0].shape[:2]
        else:
            cap = cv2.VideoCapture(args.video_path)
            loaded_frames = []
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                loaded_frames.append(frame)
            cap.release()
            height, width = loaded_frames[0].shape[:2]

            if len(loaded_frames) == 0:
                raise ValueError("No frames were loaded from the video.")

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
        state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
        bbox, track_label = prompts[0]
        _, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)

        for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
            mask_to_vis = {}
            bbox_to_vis = {}

            for obj_id, mask in zip(object_ids, masks):
                mask = mask[0].cpu().numpy()
                mask = mask > 0.0
                non_zero_indices = np.argwhere(mask)
                if len(non_zero_indices) == 0:
                    bbox = [0, 0, 0, 0]
                else:
                    y_min, x_min = non_zero_indices.min(axis=0).tolist()
                    y_max, x_max = non_zero_indices.max(axis=0).tolist()
                    bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
                bbox_to_vis[obj_id] = bbox
                mask_to_vis[obj_id] = mask

            if args.save_to_video:
                img = loaded_frames[frame_idx]
                for obj_id, mask in mask_to_vis.items():
                    mask_img = np.zeros((height, width, 3), np.uint8)
                    mask_img[mask] = color[(obj_id + 1) % len(color)]
                    img = cv2.addWeighted(img, 1, mask_img, 0.2, 0)

                for obj_id, bbox in bbox_to_vis.items():
                    cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2)

                out.write(img)

        if args.save_to_video:
            out.release()

    del predictor, state
    gc.collect()
    torch.clear_autocast_cache()
    torch.cuda.empty_cache()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--video_path", required=True, help="Input video path or directory of frames.")
    parser.add_argument("--txt_path", required=True, help="Path to ground truth text file.")
    parser.add_argument("--model_path", default="sam2/checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.")
    parser.add_argument("--video_output_path", default="demo.mp4", help="Path to save the output video.")
    parser.add_argument("--save_to_video", default=True, help="Save results to a video.")
    args = parser.parse_args()
    main(args)

这段代码看起来是一个集视频处理、目标跟踪、掩码生成于一体的AI处理流水线。


1. 引入模块:召唤江湖中的“神器”

import argparse
import os
import os.path as osp
import numpy as np
import cv2
import torch
import gc
import sys
sys.path.append("./sam2")
from sam2.build_sam import build_sam2_video_predictor

这里是召唤“神器”的合集,风格各异:

  • argparse:专门用来让玩家(工程师)输入指令参数,类似“武林秘籍的目录”。
  • osos.path:文件系统的门派高手,帮你查遍本地文件路径。
  • numpy:数学大拿,辅助各种矩阵操作。
  • cv2:江湖人称“OpenCV”,视觉处理的扛把子。
  • torch:深度学习界的内功心法。
  • gc:清理垃圾的环保小分队。
  • sys:负责扩展“门派地盘”(添加路径)。

最后,重点来了:build_sam2_video_predictor,这是本代码的大魔王,负责用“武士刀”(SAM)在视频中劈出目标的踪迹。


2. 读取坐标:目标的藏宝图

def load_txt(gt_path):
    with open(gt_path, 'r') as f:
        gt = f.readlines()
    prompts = {}
    for fid, line in enumerate(gt):
        x, y, w, h = map(float, line.split(','))
        x, y, w, h = int(x), int(y), int(w), int(h)
        prompts[fid] = ((x, y, x + w, y + h), 0)
    return prompts

这个函数是用来读取“藏宝图”的,也就是目标框的坐标信息。
过程翻译如下:

  • 打开一个“藏宝图”(txt文件),一行一行读。
  • 每行都像是“藏宝点”的坐标,格式是:x, y, w, h,分别是左上角和宽高。
  • 转换成整数后存进一个字典,fid是帧号,值是框的坐标。

总结:这是在告诉武士“要在哪找目标”。


3. 配置模型:挑选合适的武士刀

def determine_model_cfg(model_path):
    if "large" in model_path:
        return "configs/samurai/sam2.1_hiera_l.yaml"
    elif "base_plus" in model_path:
        return "configs/samurai/sam2.1_hiera_b+.yaml"
    elif "small" in model_path:
        return "configs/samurai/sam2.1_hiera_s.yaml"
    elif "tiny" in model_path:
        return "configs/samurai/sam2.1_hiera_t.yaml"
    else:
        raise ValueError("Unknown model size in path!")

这里是选刀的过程,程序根据model_path里的关键词(如largesmall)挑选对应的配置文件。
如果你的路径里没有这些关键词,直接抛个错误:“你这武士刀是哪来的?!”


4. 视频预处理:确认战场

def prepare_frames_or_path(video_path):
    if video_path.endswith(".mp4") or osp.isdir(video_path):
        return video_path
    else:
        raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")

这个函数检查战场环境:

  • 如果是个.mp4文件:战场是视频。
  • 如果是个文件夹:战场是一堆图像帧。
  • 如果都不是,直接拒绝开战:“这不是个合法战场!”

5. 主函数:武士大显身手

def main(args):

主函数是整个武士操作的核心,包含以下几步:

(1) 准备武器和战场

model_cfg = determine_model_cfg(args.model_path)
predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
frames_or_path = prepare_frames_or_path(args.video_path)
prompts = load_txt(args.txt_path)

这就像是开战前的准备工作:

  • 挑选武器:根据args.model_path,选择合适的模型配置文件(大刀还是小刀?)。
  • 召唤武士:通过build_sam2_video_predictor召唤模型,这个武士在战场中负责追踪敌人(目标)。
  • 确定战场:视频或图像帧——到底要在哪开打?prepare_frames_or_path帮你搞定。
  • 解读藏宝图:从txt_path读取目标框坐标,告诉武士“敌人初始位置在哪”。

(2) 准备输出:保存战斗结果

if args.save_to_video:
    if osp.isdir(args.video_path):
        frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
        loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
        height, width = loaded_frames[0].shape[:2]
    else:
        cap = cv2.VideoCapture(args.video_path)
        loaded_frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            loaded_frames.append(frame)
        cap.release()
        height, width = loaded_frames[0].shape[:2]

        if len(loaded_frames) == 0:
            raise ValueError("No frames were loaded from the video.")

这段代码在做什么呢?给战斗录个视频,方便回头看看战绩:

  1. 如果战场是图像帧
    • 把文件夹里的.jpg图像按顺序加载。
    • 确认战场的大小(高度和宽度)。
  2. 如果战场是视频
    • cv2.VideoCapture逐帧读取视频。
    • 确保加载到的帧数不为空,否则直接扔个错误:“视频里啥都没有,这仗还打个毛!”。

然后根据读取的第一帧,获取战场的尺寸(heightwidth),为后续输出视频做准备。


(3) 开始战斗:用武士刀劈出结果

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))

这里是“战斗记录仪”的设置:

  • 定义视频编码格式(mp4v),确保输出结果是个.mp4视频。
  • 创建一个cv2.VideoWriter对象,用来逐帧保存战斗的“录像”。

(4) 初始化武士状态:准备开打
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
    state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
    bbox, track_label = prompts[0]
    _, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)

这段代码是武士战斗的开场仪式:

  1. 省电模式torch.inference_mode()torch.autocast让模型推理更高效,节省计算资源。
  2. 初始化战斗状态predictor.init_state告诉武士“战场在哪里”,是否需要把视频临时存放到CPU。
  3. 给敌人标记起点:从prompts中取出第一个帧的目标框(bbox),并告诉武士“这是敌人的初始位置”。
  4. 生成初始掩码:调用predictor.add_new_points_or_box,用目标框在第一帧生成目标的分割掩码。
(5) 逐帧追击敌人

接下来是战斗的核心逻辑——逐帧追踪目标,同时生成每帧的分割掩码和目标框:

for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
    mask_to_vis = {}
    bbox_to_vis = {}
  • predictor.propagate_in_video(state) 是武士的追踪大招,它会逐帧传播目标的状态,返回:
    • frame_idx: 当前帧的索引。
    • object_ids: 当前帧中所有被追踪的目标 ID。
    • masks: 这些目标对应的分割掩码。
  • 每一帧会初始化两个字典:
    • mask_to_vis: 用来保存可视化的掩码。
    • bbox_to_vis: 用来保存可视化的目标框。
(6) 处理掩码和生成目标框
    for obj_id, mask in zip(object_ids, masks):
        mask = mask[0].cpu().numpy()
        mask = mask > 0.0
        non_zero_indices = np.argwhere(mask)
        if len(non_zero_indices) == 0:
            bbox = [0, 0, 0, 0]
        else:
            y_min, x_min = non_zero_indices.min(axis=0).tolist()
            y_max, x_max = non_zero_indices.max(axis=0).tolist()
            bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
        bbox_to_vis[obj_id] = bbox
        mask_to_vis[obj_id] = mask

这段代码是处理每帧的掩码和生成目标框的逻辑:

  1. 提取掩码

    • mask[0].cpu().numpy():将掩码从 GPU 中取出并转换为 NumPy 格式。
    • mask > 0.0:将掩码转化为二值化布尔数组,表示哪些像素属于目标。
    • np.argwhere(mask):找出掩码中非零(即目标像素)的位置索引。
  2. 生成目标框

    • 如果掩码中没有任何目标像素(len(non_zero_indices) == 0),目标框设置为 [0, 0, 0, 0],表示当前帧没有检测到目标。

总结

分析完源码,大家就可以应用到诸如监控视频回放系统中,但要注意目前该模型是不支持实时推理的。