Python OnnxRuntime 部署 INSID3

3 阅读4分钟

说明

官网地址:github.com/visinf/INSI…

在计算机视觉领域,图像分割一直是一项极具挑战性的任务。传统方法往往需要大量标注数据、长时间模型训练,以及针对不同物体类别的反复微调。但今天,这一切将被彻底改变——INSID3 横空出世,它以一种全新的 In-Context(上下文学习) 范式,让你仅需提供一张参考图像和其掩码,就能在任意目标图像中精准分割出同款物体,完全无需任何训练!

更令人兴奋的是,通过 ONNX Runtime 的加持,我们可以将 INSID3 高效地部署 到生产环境中,实现跨平台、跨语言的高性能推理。本文将从零开始,带你了解 INSID3 的原理、优势,并手把手教你如何使用 Python + ONNX Runtime 快速部署 INSID3,让你轻松拥有一次标注、处处分割的“超能力”。

INSID3(In-context Segmentation with a Non‑trainable DINOv3 encoder)是一种无需训练的分割模型,它的核心思想非常简单:

利用一个冻结的大型视觉编码器(DINOv3)提取图像特征,然后通过巧妙的特征匹配和聚类机制,将参考物体的语义“传递”到目标图像上。

换句话说,你只需要在参考图上勾画或选择一个物体,INSID3 就能自动在其他图像中找到同样的物体,并生成精确的分割掩码。整个过程没有训练、没有梯度下降、不依赖特定类别,真正做到了开箱即用、万物皆可分割。

效果

图片

模型信息

Model Properties
-------------------------
---------------------------------------------------------------

Inputs
-------------------------
name:imgs
tensor:Float[-1, 3, 1024, 1024]
---------------------------------------------------------------

Outputs
-------------------------
name:f_norm
tensor:Float[-1, 1024, 64, 64]
name:f_debias
tensor:Float[-1, 1024, 64, 64]
---------------------------------------------------------------

代码

class INSID3App:
    def __init__(self, root, onnx_path="insid3_encoder_nobatch.onnx"):
        self.root = root
        self.root.title("INSID3 交互式分割")
        self.onnx_path = onnx_path
        self.ref_orig = None
        self.tgt_orig = None
        self.ref_disp = None
        self.tgt_disp = None
        self.mask_np = None
        self.mask_pil = None
        self.result = None
        self.display_size = 500
        self.points = []
        self.cached_model = None
        self._build_ui()

    def _build_ui(self):
        toolbar = tk.Frame(self.root)
        toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
        for text, cmd in [("加载参考图", self.load_ref), ("加载目标图", self.load_tgt),
                          ("清除多边形", self.clear_poly), ("生成掩码", self.gen_mask),
                          ("缓存参考", self.cache_ref), ("推理", self.run_inference),
                          ("保存结果", self.save_result)]:
            tk.Button(toolbar, text=text, command=cmd).pack(side=tk.LEFT, padx=2)
        self.info_var = tk.StringVar(value="选点: 0")
        tk.Label(toolbar, textvariable=self.info_var).pack(side=tk.LEFT, padx=10)

        cf = tk.Frame(self.root); cf.pack(fill=tk.BOTH, expand=True)
        self.cv_ref = tk.Canvas(cf, width=self.display_size, height=self.display_size,
                                bg='gray', cursor="cross")
        self.cv_ref.grid(row=0, column=0, padx=5)
        tk.Label(cf, text="参考图 (左键加点 / 右键删点)").grid(row=1, column=0)
        self.cv_tgt = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray')
        self.cv_tgt.grid(row=0, column=1, padx=5)
        tk.Label(cf, text="目标图").grid(row=1, column=1)
        self.cv_res = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray')
        self.cv_res.grid(row=0, column=2, padx=5)
        tk.Label(cf, text="分割结果").grid(row=1, column=2)

        self.cv_ref.bind("<Button-1>", self.add_point)
        self.cv_ref.bind("<Button-3>", self.rm_last)
        self.prog = ttk.Progressbar(self.root, mode='indeterminate')
        self.prog.pack(fill=tk.X)
        self.stvar = tk.StringVar(value="就绪")
        tk.Label(self.root, textvariable=self.stvar, bd=1, relief=tk.SUNKEN, anchor=tk.W).pack(fill=tk.X)

    def load_ref(self):
        path = filedialog.askopenfilename(filetypes=[("Image""*.jpg *.jpeg *.png *.bmp")])
        if not path: return
        self.ref_orig = Image.open(path).convert("RGB")
        self.ref_disp, _ = letterbox_image(self.ref_orig, self.display_size)
        self.points.clear(); self.mask_np = None; self.cached_model = None
        self._redraw(); self.stvar.set(f"参考: {path}")

    def _redraw(self):
        if self.ref_disp is Nonereturn
        base = self.ref_disp.copy().convert("RGBA")
        if self.mask_np is not None and self.mask_np.sum() > 0:
            base = self._overlay(base, self.mask_np)
        draw = ImageDraw.Draw(base)
        for x,y in self.points:
            r = 4; draw.ellipse((x-r, y-r, x+r, y+r), fill='green')
        if len(self.points) > 1:
            draw.line(self.points, fill='yellow', width=2)
        self.ref_tk = ImageTk.PhotoImage(base)
        self.cv_ref.delete("all"); self.cv_ref.create_image(0,0,anchor=tk.NW, image=self.ref_tk)
        self.info_var.set(f"选点:{len(self.points)}")

    def _overlay(self, base, mask_1024):
        small = Image.fromarray(mask_1024).resize((self.display_size, self.display_size), Image.NEAREST)
        ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0))
        for y in range(self.display_size):
            for x in range(self.display_size):
                if small.getpixel((x,y)) > 128: ov.putpixel((x,y), (255,0,0,100))
        return Image.alpha_composite(base, ov)

    def load_tgt(self):
        path = filedialog.askopenfilename(filetypes=[("Image""*.jpg *.jpeg *.png *.bmp")])
        if not path: return
        self.tgt_orig = Image.open(path).convert("RGB")
        self.tgt_disp, _ = letterbox_image(self.tgt_orig, self.display_size)
        self.tgt_tk = ImageTk.PhotoImage(self.tgt_disp)
        self.cv_tgt.delete("all"); self.cv_tgt.create_image(0,0,anchor=tk.NW, image=self.tgt_tk)
        self.stvar.set(f"目标: {path}")

    def add_point(self, event): self.points.append((event.x, event.y)); self._redraw()
    def rm_last(self, event):
        if self.points: self.points.pop(); self._redraw()
    def clear_poly(self): self.points.clear(); self._redraw()

    def gen_mask(self):
        if len(self.points) < 3: messagebox.showwarning("警告""至少需要3个顶点"); return
        scale_x = 1024 / self.display_size
        scale_y = 1024 / self.display_size
        pts = [(int(x*scale_x), int(y*scale_y)) for x,y in self.points]
        mask = np.zeros((1024,1024), dtype=np.uint8)
        cv2.fillPoly(mask, [np.array(pts, dtype=np.int32)], 255)
        self.mask_np = mask
        self.mask_pil = Image.fromarray(mask, 'L')
        self._redraw(); self.stvar.set("掩码已生成")

    def cache_ref(self):
        if self.ref_orig is None or self.mask_pil is None:
            messagebox.showwarning("缺少数据""请先加载参考图并生成掩码"); return
        try:
            self.cached_model = OnnxINSID3(self.onnx_path, device='cuda')
            self.cached_model.encode_reference(self.ref_orig, self.mask_pil)
            self.stvar.set("参考特征已缓存,可连续推理")
        except Exception as e: messagebox.showerror("缓存失败", str(e))

    def run_inference(self):
        if self.ref_orig is None or self.mask_pil is None or self.tgt_orig is None:
            messagebox.showwarning("缺少数据""请加载参考图、掩码和目标图"); return
        self.cv_res.delete("all"); self.prog.start(); self.stvar.set("推理中...")
        threading.Thread(target=self._inf, daemon=True).start()

    def _inf(self):
        start = time.time()
        try:
            if self.cached_model and self.cached_model._cached_ref_norm is not None:
                result_pil = self.cached_model.segment_with_cache(self.tgt_orig)
            else:
                model = OnnxINSID3(self.onnx_path, device='cuda')
                model.set_reference(self.ref_orig, self.mask_pil)
                model.set_target(self.tgt_orig)
                result_pil = model.segment()
            self.result = result_pil
            elapsed = time.time() - start
            self.root.after(0, self._show)
            self.root.after(0, lambda e=elapsed: self.stvar.set(f"完成,耗时 {e:.2f}s"))
        except Exception as e:
            self.root.after(0, lambda e=e: messagebox.showerror("推理错误", str(e)))
        finally:
            self.root.after(0, self.prog.stop)

    def _show(self):
        if self.result is Nonereturn
        tgt_disp = self.tgt_orig.resize((self.display_size, self.display_size))
        res_disp = self.result.resize((self.display_size, self.display_size))
        base = tgt_disp.convert("RGBA")
        ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0))
        arr = np.array(res_disp) > 128
        for y in range(self.display_size):
            for x in range(self.display_size):
                if arr[y,x]: ov.putpixel((x,y), (255,0,0,100))
        base = Image.alpha_composite(base, ov)
        self.res_tk = ImageTk.PhotoImage(base)
        self.cv_res.delete("all"); self.cv_res.create_image(0,0,anchor=tk.NW, image=self.res_tk)

    def save_result(self):
        if self.result is None: messagebox.showwarning("无结果""请先推理"); return
        path = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG""*.png")])
        if path: self.result.save(path); messagebox.showinfo("保存成功", f"已保存至 {path}")

if __name__ == "__main__":
    root = tk.Tk()
    # app = INSID3App(root, onnx_path="insid3_encoder_nobatch.onnx")
    app = INSID3App(root, onnx_path="insid3_encoder_nobatch_simplified.onnx")

    root.mainloop()