说明
在计算机视觉领域,图像分割一直是一项极具挑战性的任务。传统方法往往需要大量标注数据、长时间模型训练,以及针对不同物体类别的反复微调。但今天,这一切将被彻底改变——INSID3 横空出世,它以一种全新的 In-Context(上下文学习) 范式,让你仅需提供一张参考图像和其掩码,就能在任意目标图像中精准分割出同款物体,完全无需任何训练!
更令人兴奋的是,通过 ONNX Runtime 的加持,我们可以将 INSID3 高效地部署 到生产环境中,实现跨平台、跨语言的高性能推理。本文将从零开始,带你了解 INSID3 的原理、优势,并手把手教你如何使用 Python + ONNX Runtime 快速部署 INSID3,让你轻松拥有一次标注、处处分割的“超能力”。
INSID3(In-context Segmentation with a Non‑trainable DINOv3 encoder)是一种无需训练的分割模型,它的核心思想非常简单:
利用一个冻结的大型视觉编码器(DINOv3)提取图像特征,然后通过巧妙的特征匹配和聚类机制,将参考物体的语义“传递”到目标图像上。
换句话说,你只需要在参考图上勾画或选择一个物体,INSID3 就能自动在其他图像中找到同样的物体,并生成精确的分割掩码。整个过程没有训练、没有梯度下降、不依赖特定类别,真正做到了开箱即用、万物皆可分割。
效果
模型信息
Model Properties
-------------------------
---------------------------------------------------------------
Inputs
-------------------------
name:imgs
tensor:Float[-1, 3, 1024, 1024]
---------------------------------------------------------------
Outputs
-------------------------
name:f_norm
tensor:Float[-1, 1024, 64, 64]
name:f_debias
tensor:Float[-1, 1024, 64, 64]
---------------------------------------------------------------
代码
class INSID3App:
def __init__(self, root, onnx_path="insid3_encoder_nobatch.onnx"):
self.root = root
self.root.title("INSID3 交互式分割")
self.onnx_path = onnx_path
self.ref_orig = None
self.tgt_orig = None
self.ref_disp = None
self.tgt_disp = None
self.mask_np = None
self.mask_pil = None
self.result = None
self.display_size = 500
self.points = []
self.cached_model = None
self._build_ui()
def _build_ui(self):
toolbar = tk.Frame(self.root)
toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)
for text, cmd in [("加载参考图", self.load_ref), ("加载目标图", self.load_tgt),
("清除多边形", self.clear_poly), ("生成掩码", self.gen_mask),
("缓存参考", self.cache_ref), ("推理", self.run_inference),
("保存结果", self.save_result)]:
tk.Button(toolbar, text=text, command=cmd).pack(side=tk.LEFT, padx=2)
self.info_var = tk.StringVar(value="选点: 0")
tk.Label(toolbar, textvariable=self.info_var).pack(side=tk.LEFT, padx=10)
cf = tk.Frame(self.root); cf.pack(fill=tk.BOTH, expand=True)
self.cv_ref = tk.Canvas(cf, width=self.display_size, height=self.display_size,
bg='gray', cursor="cross")
self.cv_ref.grid(row=0, column=0, padx=5)
tk.Label(cf, text="参考图 (左键加点 / 右键删点)").grid(row=1, column=0)
self.cv_tgt = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray')
self.cv_tgt.grid(row=0, column=1, padx=5)
tk.Label(cf, text="目标图").grid(row=1, column=1)
self.cv_res = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray')
self.cv_res.grid(row=0, column=2, padx=5)
tk.Label(cf, text="分割结果").grid(row=1, column=2)
self.cv_ref.bind("<Button-1>", self.add_point)
self.cv_ref.bind("<Button-3>", self.rm_last)
self.prog = ttk.Progressbar(self.root, mode='indeterminate')
self.prog.pack(fill=tk.X)
self.stvar = tk.StringVar(value="就绪")
tk.Label(self.root, textvariable=self.stvar, bd=1, relief=tk.SUNKEN, anchor=tk.W).pack(fill=tk.X)
def load_ref(self):
path = filedialog.askopenfilename(filetypes=[("Image", "*.jpg *.jpeg *.png *.bmp")])
if not path: return
self.ref_orig = Image.open(path).convert("RGB")
self.ref_disp, _ = letterbox_image(self.ref_orig, self.display_size)
self.points.clear(); self.mask_np = None; self.cached_model = None
self._redraw(); self.stvar.set(f"参考: {path}")
def _redraw(self):
if self.ref_disp is None: return
base = self.ref_disp.copy().convert("RGBA")
if self.mask_np is not None and self.mask_np.sum() > 0:
base = self._overlay(base, self.mask_np)
draw = ImageDraw.Draw(base)
for x,y in self.points:
r = 4; draw.ellipse((x-r, y-r, x+r, y+r), fill='green')
if len(self.points) > 1:
draw.line(self.points, fill='yellow', width=2)
self.ref_tk = ImageTk.PhotoImage(base)
self.cv_ref.delete("all"); self.cv_ref.create_image(0,0,anchor=tk.NW, image=self.ref_tk)
self.info_var.set(f"选点:{len(self.points)}")
def _overlay(self, base, mask_1024):
small = Image.fromarray(mask_1024).resize((self.display_size, self.display_size), Image.NEAREST)
ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0))
for y in range(self.display_size):
for x in range(self.display_size):
if small.getpixel((x,y)) > 128: ov.putpixel((x,y), (255,0,0,100))
return Image.alpha_composite(base, ov)
def load_tgt(self):
path = filedialog.askopenfilename(filetypes=[("Image", "*.jpg *.jpeg *.png *.bmp")])
if not path: return
self.tgt_orig = Image.open(path).convert("RGB")
self.tgt_disp, _ = letterbox_image(self.tgt_orig, self.display_size)
self.tgt_tk = ImageTk.PhotoImage(self.tgt_disp)
self.cv_tgt.delete("all"); self.cv_tgt.create_image(0,0,anchor=tk.NW, image=self.tgt_tk)
self.stvar.set(f"目标: {path}")
def add_point(self, event): self.points.append((event.x, event.y)); self._redraw()
def rm_last(self, event):
if self.points: self.points.pop(); self._redraw()
def clear_poly(self): self.points.clear(); self._redraw()
def gen_mask(self):
if len(self.points) < 3: messagebox.showwarning("警告", "至少需要3个顶点"); return
scale_x = 1024 / self.display_size
scale_y = 1024 / self.display_size
pts = [(int(x*scale_x), int(y*scale_y)) for x,y in self.points]
mask = np.zeros((1024,1024), dtype=np.uint8)
cv2.fillPoly(mask, [np.array(pts, dtype=np.int32)], 255)
self.mask_np = mask
self.mask_pil = Image.fromarray(mask, 'L')
self._redraw(); self.stvar.set("掩码已生成")
def cache_ref(self):
if self.ref_orig is None or self.mask_pil is None:
messagebox.showwarning("缺少数据", "请先加载参考图并生成掩码"); return
try:
self.cached_model = OnnxINSID3(self.onnx_path, device='cuda')
self.cached_model.encode_reference(self.ref_orig, self.mask_pil)
self.stvar.set("参考特征已缓存,可连续推理")
except Exception as e: messagebox.showerror("缓存失败", str(e))
def run_inference(self):
if self.ref_orig is None or self.mask_pil is None or self.tgt_orig is None:
messagebox.showwarning("缺少数据", "请加载参考图、掩码和目标图"); return
self.cv_res.delete("all"); self.prog.start(); self.stvar.set("推理中...")
threading.Thread(target=self._inf, daemon=True).start()
def _inf(self):
start = time.time()
try:
if self.cached_model and self.cached_model._cached_ref_norm is not None:
result_pil = self.cached_model.segment_with_cache(self.tgt_orig)
else:
model = OnnxINSID3(self.onnx_path, device='cuda')
model.set_reference(self.ref_orig, self.mask_pil)
model.set_target(self.tgt_orig)
result_pil = model.segment()
self.result = result_pil
elapsed = time.time() - start
self.root.after(0, self._show)
self.root.after(0, lambda e=elapsed: self.stvar.set(f"完成,耗时 {e:.2f}s"))
except Exception as e:
self.root.after(0, lambda e=e: messagebox.showerror("推理错误", str(e)))
finally:
self.root.after(0, self.prog.stop)
def _show(self):
if self.result is None: return
tgt_disp = self.tgt_orig.resize((self.display_size, self.display_size))
res_disp = self.result.resize((self.display_size, self.display_size))
base = tgt_disp.convert("RGBA")
ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0))
arr = np.array(res_disp) > 128
for y in range(self.display_size):
for x in range(self.display_size):
if arr[y,x]: ov.putpixel((x,y), (255,0,0,100))
base = Image.alpha_composite(base, ov)
self.res_tk = ImageTk.PhotoImage(base)
self.cv_res.delete("all"); self.cv_res.create_image(0,0,anchor=tk.NW, image=self.res_tk)
def save_result(self):
if self.result is None: messagebox.showwarning("无结果", "请先推理"); return
path = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG", "*.png")])
if path: self.result.save(path); messagebox.showinfo("保存成功", f"已保存至 {path}")
if __name__ == "__main__":
root = tk.Tk()
# app = INSID3App(root, onnx_path="insid3_encoder_nobatch.onnx")
app = INSID3App(root, onnx_path="insid3_encoder_nobatch_simplified.onnx")
root.mainloop()