告别繁琐训练!用 C# + ONNX Runtime 轻松部署 INSID3,开启零样本图像分割新时代

0 阅读12分钟

说明

官网地址:github.com/visinf/INSI…

在计算机视觉领域,图像分割一直是一项极具挑战性的任务。传统方法往往需要大量标注数据、长时间模型训练,以及针对不同物体类别的反复微调。但今天,这一切将被彻底改变——INSID3 横空出世,它以一种全新的 In-Context(上下文学习) 范式,让你仅需提供一张参考图像和其掩码,就能在任意目标图像中精准分割出同款物体,完全无需任何训练!

更令人兴奋的是,通过 ONNX Runtime 的加持,我们可以将 INSID3 高效地部署 到生产环境中,实现跨平台、跨语言的高性能推理。本文将从零开始,带你了解 INSID3 的原理、优势,并手把手教你如何使用 Python + ONNX Runtime 快速部署 INSID3,让你轻松拥有一次标注、处处分割的“超能力”。

INSID3(In-context Segmentation with a Non‑trainable DINOv3 encoder)是一种无需训练的分割模型,它的核心思想非常简单:

利用一个冻结的大型视觉编码器(DINOv3)提取图像特征,然后通过巧妙的特征匹配和聚类机制,将参考物体的语义“传递”到目标图像上。

换句话说,你只需要在参考图上勾画或选择一个物体,INSID3 就能自动在其他图像中找到同样的物体,并生成精确的分割掩码。整个过程没有训练、没有梯度下降、不依赖特定类别,真正做到了开箱即用、万物皆可分割。

效果

C# CPU版本速度有点慢,需要提速可以上GPU!
图片
图片

模型信息

Model Properties
-------------------------
---------------------------------------------------------------

Inputs
-------------------------
name:imgs
tensor:Float[-1, 3, 1024, 1024]
---------------------------------------------------------------

Outputs
-------------------------
name:f_norm
tensor:Float[-1, 1024, 64, 64]
name:f_debias
tensor:Float[-1, 1024, 64, 64]
---------------------------------------------------------------

项目

图片

代码

using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using System.Windows.Forms;

namespace Onnx_Demo
{
    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        string fileFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";
        string image_path = "";
        string startupPath;
        string model_path;
        Mat refImage;
        Mat tgtImage;
        Mat refMask;
        List<System.Drawing.Point> polyPoints = new List<System.Drawing.Point>();
        SessionOptions options;
        InferenceSession onnx_session;

        bool log = false;

        private Rectangle imageRect = Rectangle.Empty;

        // INSID3 参数
        private const int ModelSize = 1024;
        private const int FeatC = 1024, FeatH = 64, FeatW = 64;
        private readonly float[] mean = { 0.485f, 0.456f, 0.406f };
        private readonly float[] std = { 0.229f, 0.224f, 0.225f };

        public float CandidateSigma = 0.5f;
        public float PixelSimSigma = 0.5f;
        public float ClusterScoreRatio = 0.6f;

        private void Form1_Load(object sender, EventArgs e)
        {
            startupPath = Application.StartupPath;
            model_path = "model/insid3_encoder.onnx";

            options = new SessionOptions();
            options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
            options.AppendExecutionProvider_CPU(0);
            options.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
            options.IntraOpNumThreads = Environment.ProcessorCount;
            options.EnableMemoryPattern = true;

            onnx_session = new InferenceSession(model_path, options);

            pictureBox1.MouseDown += PictureBox1_MouseDown;
            pictureBox1.Paint += PictureBox1_Paint;
        }

        private void PictureBox1_MouseDown(object sender, MouseEventArgs e)
        {
            if (refImage == null || imageRect == Rectangle.Empty) return;
            if (!imageRect.Contains(e.Location)) return;

            if (e.Button == MouseButtons.Left)
            {
                polyPoints.Add(e.Location);
                pictureBox1.Invalidate();
            }
            else if (e.Button == MouseButtons.Right)
            {
                if (polyPoints.Count == 0) return;
                float minDist = float.MaxValue; int idx = -1;
                for (int i = 0; i < polyPoints.Count; i++)
                {
                    float d = (polyPoints[i].X - e.X) * (polyPoints[i].X - e.X) + (polyPoints[i].Y - e.Y) * (polyPoints[i].Y - e.Y);
                    if (d < minDist) { minDist = d; idx = i; }
                }
                if (idx >= 0) { polyPoints.RemoveAt(idx); pictureBox1.Invalidate(); }
            }
        }

        private void PictureBox1_Paint(object sender, PaintEventArgs e)
        {
            // 绘制 mask 叠加层(只绘制有效图像区域,忽略 letterbox 黑边)
            if (refMask != null && !refMask.Empty() && imageRect != Rectangle.Empty)
            {
                // 计算 letterbox 参数(与推理时一致)
                float scale = ModelSize / (float)Math.Max(refImage.Width, refImage.Height);
                int nw = (int)(refImage.Width * scale);
                int nh = (int)(refImage.Height * scale);
                int dx = (ModelSize - nw) / 2;
                int dy = (ModelSize - nh) / 2;

                // 从完整 1024x1024 mask 中裁剪出图像区域
                using (Mat roi = new Mat(refMask, new Rect(dx, dy, nw, nh)))
                using (Mat resized = new Mat())
                {
                    Cv2.Resize(roi, resized, new OpenCvSharp.Size(imageRect.Width, imageRect.Height), 0, 0, InterpolationFlags.Nearest);
                    Cv2.Threshold(resized, resized, 128, 255, ThresholdTypes.Binary);
                    OpenCvSharp.Point[][] contours; HierarchyIndex[] hier;
                    Cv2.FindContours(resized, out contours, out hier, RetrievalModes.External, ContourApproximationModes.ApproxSimple);
                    using (Pen p = new Pen(Color.Red, 2))
                        foreach (var c in contours)
                            if (c.Length >= 2)
                            {
                                var pts = c.Select(pt => new System.Drawing.Point(imageRect.X + pt.X, imageRect.Y + pt.Y)).ToArray();
                                e.Graphics.DrawPolygon(p, pts);
                            }
                }
            }

            // 用户多边形(控件坐标)
            if (polyPoints.Count == 0) return;
            using (Pen p = new Pen(Color.Green, 3))
            using (Brush b = new SolidBrush(Color.Lime))
            {
                foreach (var pt in polyPoints) e.Graphics.FillEllipse(b, pt.X - 4, pt.Y - 4, 8, 8);
                if (polyPoints.Count > 1) e.Graphics.DrawLines(p, polyPoints.ToArray());
            }
        }

        private void button1_Click(object sender, EventArgs e)
        {
            using (OpenFileDialog ofd = new OpenFileDialog()) { ofd.Filter = fileFilter; if (ofd.ShowDialog() != DialogResult.OK) return; image_path = ofd.FileName; }
            refImage = Cv2.ImRead(image_path);

            // 计算保持比例的显示区域
            float scale = Math.Min((float)pictureBox1.Width / refImage.Width, (float)pictureBox1.Height / refImage.Height);
            int dispW = (int)(refImage.Width * scale);
            int dispH = (int)(refImage.Height * scale);
            imageRect = new Rectangle(
                (pictureBox1.Width - dispW) / 2,
                (pictureBox1.Height - dispH) / 2,
                dispW,
                dispH);

            // 绘制不拉伸的图片
            Bitmap bmp = new Bitmap(pictureBox1.Width, pictureBox1.Height);
            using (Graphics g = Graphics.FromImage(bmp))
            {
                g.Clear(SystemColors.Control);
                using (Mat disp = new Mat())
                {
                    Cv2.Resize(refImage, disp, new OpenCvSharp.Size(dispW, dispH));
                    using (var ms = disp.ToMemoryStream())
                    using (var img = System.Drawing.Image.FromStream(ms))
                    {
                        g.DrawImage(img, imageRect);
                    }
                }
            }
            pictureBox1.Image = bmp;

            polyPoints.Clear(); refMask = null; textBox1.Text = "";
        }

        private void button2_Click(object sender, EventArgs e)
        {
            using (OpenFileDialog ofd = new OpenFileDialog())
            {
                ofd.Filter = fileFilter;
                if (ofd.ShowDialog() != DialogResult.OK) return;
                tgtImage = Cv2.ImRead(ofd.FileName);
                pictureBox2.Image = new Bitmap(tgtImage.ToMemoryStream());
            }
            textBox1.Text = "目标图像已加载";
        }

        private void button4_Click(object sender, EventArgs e)
        {
            if (refImage == null) { MessageBox.Show("请先加载参考图像"); return; }
            using (OpenFileDialog ofd = new OpenFileDialog())
            {
                ofd.Filter = "Mask|*.png;*.jpg;*.bmp";
                if (ofd.ShowDialog() != DialogResult.OK) return;
                using (Mat m = Cv2.ImRead(ofd.FileName, ImreadModes.Grayscale))
                {
                    // 外部 mask 必须与 letterbox 对齐,不能直接 resize
                    // 这里按原图尺寸处理,再转为 letterbox 后的 mask
                    refMask = LetterboxMask(m, refImage.Width, refImage.Height);
                }
            }
            pictureBox1.Invalidate();
        }

        private void button5_Click(object sender, EventArgs e)
        {
            if (refImage == null || polyPoints.Count < 3) { MessageBox.Show("至少需要3个顶点"); return; }
            if (imageRect == Rectangle.Empty) return;

            float displayScale = (float)imageRect.Width / refImage.Width; // 显示缩放比例

            // 1. 映射到原始图像坐标
            var ptsOnImage = polyPoints.Select(p => new PointF(
                (p.X - imageRect.X) / displayScale,
                (p.Y - imageRect.Y) / displayScale)).ToList();

            // 2. 计算 letterbox 参数(与推理完全一致)
            float modelScale = ModelSize / (float)Math.Max(refImage.Width, refImage.Height);
            int nw = (int)(refImage.Width * modelScale);
            int nh = (int)(refImage.Height * modelScale);
            int dx = (ModelSize - nw) / 2;
            int dy = (ModelSize - nh) / 2;

            // 3. 映射到 1024x1024 空间(加上填充偏移)
            var modelPts = ptsOnImage.Select(p => new OpenCvSharp.Point(
                (int)(p.X * modelScale + dx),
                (int)(p.Y * modelScale + dy))).ToList();

            // 4. 生成掩码
            refMask = new Mat(ModelSize, ModelSize, MatType.CV_8UC1, Scalar.All(0));
            Cv2.FillPoly(refMask, new List<List<OpenCvSharp.Point>> { modelPts }, Scalar.All(255));
            pictureBox1.Invalidate();
        }

        private Mat LetterboxMask(Mat mask, int imgW, int imgH)
        {
            float scale = ModelSize / (float)Math.Max(imgW, imgH);
            int nw = (int)(imgW * scale);
            int nh = (int)(imgH * scale);
            int dx = (ModelSize - nw) / 2;
            int dy = (ModelSize - nh) / 2;

            using (Mat resized = new Mat())
            using (Mat canvas = new Mat(ModelSize, ModelSize, MatType.CV_8UC1, Scalar.All(0)))
            {
                Cv2.Resize(mask, resized, new OpenCvSharp.Size(nw, nh));
                resized.CopyTo(canvas[new Rect(dx, dy, nw, nh)]);
                return canvas.Clone();
            }
        }

        private async void button6_Click(object sender, EventArgs e)
        {
            if (refImage == null || refMask == null || tgtImage == null) { MessageBox.Show("请加载参考图、掩码、目标图"); return; }
            button6.Enabled = false;
            textBox1.Clear();
            textBox1.Text = "开始推理……\n";
            Application.DoEvents();

            Cv2.ImWrite("refImage.jpg", refImage);
            Cv2.ImWrite("refMask.jpg", refMask);

            Mat resultMask = null;
            TimeSpan preT = default, infT = default, postT = default, totalT = default;

            await Task.Run(() =>
            {
                var swT = System.Diagnostics.Stopwatch.StartNew();
                var sw = System.Diagnostics.Stopwatch.StartNew();

                float[] refIn = LetterboxAndNormalize(refImage, out _, out _, out _, out _, out int refOw, out int refOh);
                float[] tgtIn = LetterboxAndNormalize(tgtImage, out int dxT, out int dyT, out int nwT, out int nhT, out int tgtOw, out int tgtOh);

                PrintTensorStats("refIn_normalized", refIn, new[] { 3, ModelSize, ModelSize });
                PrintTensorStats("tgtIn_normalized", tgtIn, new[] { 3, ModelSize, ModelSize });
                SaveArrayToFile("refIn", refIn);
                SaveArrayToFile("tgtIn", tgtIn);

                float[] combined = new float[2 * 3 * ModelSize * ModelSize];
                Buffer.BlockCopy(refIn, 0, combined, 0, refIn.Length * sizeof(float));
                Buffer.BlockCopy(tgtIn, 0, combined, 3 * ModelSize * ModelSize * sizeof(float), tgtIn.Length * sizeof(float));
                preT = sw.Elapsed;

                sw.Restart();
                var inputTensor = new DenseTensor<float>(combined, new[] { 2, 3, ModelSize, ModelSize });
                var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor("imgs", inputTensor) };

                using (var results = onnx_session.Run(inputs))
                {
                    var fNorm = results[0].AsTensor<float>();
                    var fDebias = results[1].AsTensor<float>();

                    PrintTensorStats("fNorm_output", fNorm);
                    PrintTensorStats("fDebias_output", fDebias);
                    SaveArrayToFile("fNorm", fNorm.ToArray());
                    SaveArrayToFile("fDebias", fDebias.ToArray());

                    SafeLog($"[Shape] fNorm: [{string.Join(",", fNorm.Dimensions.ToArray())}]");
                    SafeLog($"[Shape] fDebias: [{string.Join(",", fDebias.Dimensions.ToArray())}]");
                    SafeLog($"fNorm_output mean={fNorm.ToArray().Average():F4}, std={Math.Sqrt(fNorm.ToArray().Select(v => v * v).Average()):F4}");
                    SafeLog($"fDebias_output mean={fDebias.ToArray().Average():F4}, std={Math.Sqrt(fDebias.ToArray().Select(v => v * v).Average()):F4}");

                    infT = sw.Elapsed;
                    sw.Restart();

                    float[][] refNormNC = ToFlatFeaturesNC(fNorm, 0);
                    float[][] refDebNC = ToFlatFeaturesNC(fDebias, 0);
                    float[][] tgtNormNC = ToFlatFeaturesNC(fNorm, 1);
                    float[][] tgtDebNC = ToFlatFeaturesNC(fDebias, 1);

                    PrintTensorStats("refDebNC_flat", refDebNC);
                    PrintTensorStats("tgtDebNC_flat", tgtDebNC);

                    float[] refMask64 = ResizeMaskFlat(refMask, FeatH, FeatW);
                    PrintTensorStats("refMask_64x64", refMask64, new[] { FeatH, FeatW });

                    float[] refProto = ComputePrototypeFlat(refDebNC, refMask64);
                    PrintTensorStats("refProto", refProto, new[] { FeatC });
                    SaveArrayToFile("refProto", refProto);

                    int N = FeatH * FeatW;
                    float[] simFwd = new float[N];
                    for (int i = 0; i < N; i++)
                    {
                        float dot = 0; for (int c = 0; c < FeatC; c++) dot += tgtDebNC[i][c] * refProto[c];
                        simFwd[i] = dot;
                    }
                    PrintTensorStats("simFwd", simFwd, new[] { FeatH, FeatW });
                    SaveArrayToFile("simFwd", simFwd);

                    bool[] candidate = LocateCandidatesFlat(refDebNC, refMask64, tgtDebNC, simFwd);
                    PrintTensorStats("candidate_mask", candidate.Select(b => b ? 1f : 0f).ToArray(), new[] { FeatH, FeatW });

                    int[] clusterLabels = ClusterFeatures(tgtNormNC, 0.6f);
                    PrintTensorStats("clusterLabels", clusterLabels.Select(l => (float)l).ToArray(), new[] { FeatH, FeatW });

                    bool[] final64 = SeedAndAggregateFlat(candidate, clusterLabels, tgtNormNC, tgtDebNC, refProto, simFwd);
                    PrintTensorStats("final64_mask", final64.Select(b => b ? 1f : 0f).ToArray(), new[] { FeatH, FeatW });
                    SaveArrayToFile("final64", final64.Select(b => b ? 1f : 0f).ToArray());

                    //bool[] final64 = SeedAndAggregateFlat(candidate, clusterLabels, tgtNormNC, tgtDebNC, refProto, simFwd);
                    //final64 = CleanMask64(final64); // 🔑 新增:形态学去噪
                    //PrintTensorStats("final64_mask", final64.Select(b => b ? 1f : 0f).ToArray(), new[] { FeatH, FeatW });

                    float[] mask1024 = BilinearUpsampleFlat(final64, FeatH, FeatW, ModelSize, ModelSize);
                    PrintTensorStats("mask1024_upsampled", mask1024, new[] { ModelSize, ModelSize });
                    SaveArrayToFile("mask1024", mask1024);

                    resultMask = RecoverOriginalMask(mask1024, dxT, dyT, nwT, nhT, tgtOw, tgtOh);

                    Cv2.ImWrite("resultMask.jpg", resultMask);

                    byte[] maskBytes;
                    resultMask.GetArray(out maskBytes);
                    float[] maskFloats = Array.ConvertAll(maskBytes, b => b / 255f);
                    PrintTensorStats("resultMask_final", maskFloats, new[] { resultMask.Height, resultMask.Width });
                }
                postT = sw.Elapsed; totalT = swT.Elapsed;
            });

            if (resultMask != null)
            {
                using (Mat ov = OverlayMask(tgtImage, resultMask))
                using (var ms = ov.ToMemoryStream()) { pictureBox2.Image = new Bitmap(ms); }
                resultMask.Dispose();
            }
            textBox1.Text = $"\n=== 耗时: 前处理:{preT.TotalMilliseconds:F0}ms | 推理:{infT.TotalMilliseconds:F0}ms | 后处理:{postT.TotalMilliseconds:F0}ms | 总:{totalT.TotalMilliseconds:F0}ms ===\n";
            button6.Enabled = true;
        }

        private void button3_Click(object sender, EventArgs e)
        {
            if (pictureBox2.Image == null) return;
            using (SaveFileDialog s = new SaveFileDialog()) { s.Filter = "PNG|*.png"; if (s.ShowDialog() == DialogResult.OK) pictureBox2.Image.Save(s.FileName); }
        }

        private float[] LetterboxAndNormalize(Mat src, out int dx, out int dy, out int nw, out int nh, out int ow, out int oh)
        {
            ow = src.Width; oh = src.Height;
            float scale = ModelSize / (float)Math.Max(ow, oh);
            nw = (int)(ow * scale); nh = (int)(oh * scale);
            dx = (ModelSize - nw) / 2; dy = (ModelSize - nh) / 2;

            using (Mat resized = new Mat())
            using (Mat canvas = new Mat(ModelSize, ModelSize, MatType.CV_8UC3, Scalar.All(0)))
            using (Mat rgb = new Mat())
            {
                Cv2.Resize(src, resized, new OpenCvSharp.Size(nw, nh));
                resized.CopyTo(canvas[new Rect(dx, dy, nw, nh)]);
                Cv2.CvtColor(canvas, rgb, ColorConversionCodes.BGR2RGB);

                float[] outF = new float[3 * ModelSize * ModelSize];
                for (int c = 0; c < 3; c++)
                {
                    int chOffset = c * ModelSize * ModelSize;
                    for (int y = 0; y < ModelSize; y++)
                    {
                        int rowOffset = y * ModelSize;
                        for (int x = 0; x < ModelSize; x++)
                        {
                            float val = rgb.At<Vec3b>(y, x)[c] / 255.0f;
                            outF[chOffset + rowOffset + x] = (val - mean[c]) / std[c];
                        }
                    }
                }
                return outF;
            }
        }

        private float[][] ToFlatFeaturesNC(Tensor<float> t, int b)
        {
            var dims = t.Dimensions.ToArray();
            if (dims.Length != 4 || dims[0] < 2 || dims[1] != FeatC || dims[2] != FeatH || dims[3] != FeatW)
            {
                SafeLog($"Tensor维度异常: [{string.Join(",", dims)}],预期 [N,1024,64,64]");
            }

            int C = dims[1], H = dims[2], W = dims[3];
            int N = H * W;
            float[][] f = new float[N][];
            float[] flat = t.ToArray();

            int batchOffset = b * C * H * W;
            for (int i = 0; i < N; i++)
            {
                int y = i / W, x = i % W;
                f[i] = new float[C];
                int spatialOffset = y * W + x;
                for (int c = 0; c < C; c++)
                {
                    f[i][c] = flat[batchOffset + c * H * W + spatialOffset];
                }
            }
            return f;
        }

        private float[] ResizeMaskFlat(Mat m, int h, int w)
        {
            using (Mat r = new Mat())
            {
                Cv2.Resize(m, r, new OpenCvSharp.Size(w, h), 0, 0, InterpolationFlags.Nearest);
                float[] res = new float[h * w];
                for (int y = 0; y < h; y++) for (int x = 0; x < w; x++) res[y * w + x] = r.At<byte>(y, x) > 128 ? 1f : 0f;
                return res;
            }
        }

        private float[] ComputePrototypeFlat(float[][] debNC, float[] mask)
        {
            int N = mask.Length;
            float[] p = new float[FeatC]; int cnt = 0;
            for (int i = 0; i < N; i++)
            {
                if (mask[i] > 0) { for (int c = 0; c < FeatC; c++) p[c] += debNC[i][c]; cnt++; }
            }
            if (cnt == 0) return p;
            float inv = 1f / cnt; for (int c = 0; c < FeatC; c++) p[c] *= inv;
            float norm = (float)Math.Sqrt(p.Sum(v => v * v));
            if (norm > 0) for (int c = 0; c < FeatC; c++) p[c] /= norm;
            return p;
        }

        private bool[] LocateCandidatesFlat(float[][] refDeb, float[] refMask, float[][] tgtDeb, float[] simFwd)
        {
            int N = FeatH * FeatW;

            //自适应阈值替代固定分位数
            float mean = 0f; foreach (var v in simFwd) mean += v; mean /= N;
            float var = 0f; foreach (var v in simFwd) var += (v - mean) * (v - mean); var /= N;
            float std = (float)Math.Sqrt(var);

            // 仅保留显著高于背景均值 + 0.5倍标准差的区域 (约 top 25%)
            float thr = mean + CandidateSigma * std;

            bool[] fwd = new bool[N];
            for (int i = 0; i < N; i++) fwd[i] = simFwd[i] > thr;

            var fgIdx = new List<int>();
            for (int i = 0; i < N; i++) if (refMask[i] > 0) fgIdx.Add(i);

            int[] votes = new int[N];
            for (int my = 0; my < N; my++)
            {
                int best = -1float mx = float.MinValue;
                foreach (int ry in fgIdx)
                {
                    float dot = 0; for (int c = 0; c < FeatC; c++) dot += refDeb[ry][c] * tgtDeb[my][c];
                    if (dot > mx) { mx = dot; best = ry; }
                }
                if (best >= 0) votes[my] = 1;
            }

            bool[] cand = new bool[N];
            for (int i = 0; i < N; i++) cand[i] = fwd[i] && votes[i] >= 1;
            return cand;
        }

        private bool[] SeedAndAggregateFlat(bool[] candidate, int[] labels,
            float[][] tgtNormNC, float[][] tgtDebNC, float[] refProto, float[] simFwd)
        {
            int N = FeatH * FeatW;
            bool[] matched = new bool[N]; int mCnt = 0;
            for (int i = 0; i < N; i++) if (candidate[i] && labels[i] >= 0) { matched[i] = true; mCnt++; }
            if (mCnt == 0) return new bool[N];

            var lblToK = new Dictionary<int, int>(); int K = 0;
            for (int i = 0; i < N; i++) if (matched[i] && !lblToK.ContainsKey(labels[i])) lblToK[labels[i]] = K++;
            if (K == 0) return new bool[N];

            int[] mCounts = new int[K]float[][] debSum = new float[K][];
            for (int k = 0; k < K; k++) debSum[k] = new float[FeatC];
            for (int i = 0; i < N; i++) if (matched[i])
                {
                    int k = lblToK[labels[i]]; mCounts[k]++;
                    for (int c = 0; c < FeatC; c++) debSum[k][c] += tgtDebNC[i][c];
                }
            float[][] debProtos = new float[K][];
            for (int k = 0; k < K; k++)
            {
                debProtos[k] = new float[FeatC];
                if (mCounts[k] > 0) { float inv = 1f / mCounts[k]; for (int c = 0; c < FeatC; c++) debProtos[k][c] = debSum[k][c] * inv; }
            }

            int[] vCounts = new int[K];
            for (int i = 0; i < N; i++) if (labels[i] >= 0 && lblToK.TryGetValue(labels[i], out int k)) vCounts[k]++;

            float[] crossSim = new float[K];
            for (int k = 0; k < K; k++) { float d = 0; for (int c = 0; c < FeatC; c++) d += debProtos[k][c] * refProto[c]; crossSim[k] = d; }
            int seedK = 0float mx = float.MinValue;
            for (int k = 0; k < K; k++) if (crossSim[k] > mx) { mx = crossSim[k]; seedK = k; }

            float[][] normSum = new float[K][]; for (int k = 0; k < K; k++) normSum[k] = new float[FeatC];
            for (int i = 0; i < N; i++) if (labels[i] >= 0 && lblToK.TryGetValue(labels[i], out int k))
                    for (int c = 0; c < FeatC; c++) normSum[k][c] += tgtNormNC[i][c];
            float[][] normProtos = new float[K][];
            for (int k = 0; k < K; k++)
            {
                normProtos[k] = new float[FeatC];
                if (vCounts[k] > 0) { float inv = 1f / vCounts[k]; for (int c = 0; c < FeatC; c++) normProtos[k][c] = normSum[k][c] * inv; }
                float norm = (float)Math.Sqrt(normProtos[k].Sum(v => v * v)) + 1e-8f;
                for (int c = 0; c < FeatC; c++) normProtos[k][c] /= norm;
            }

            float[] intraSim = new float[K]float[] seedP = normProtos[seedK];
            for (int k = 0; k < K; k++) { float d = 0; for (int c = 0; c < FeatC; c++) d += seedP[c] * normProtos[k][c]; intraSim[k] = d; }

            float[] crossSum = new float[K];
            for (int i = 0; i < N; i++) if (labels[i] >= 0 && lblToK.TryGetValue(labels[i], out int k)) crossSum[k] += simFwd[i];
            float[] crossScore = new float[K];
            for (int k = 0; k < K; k++) crossScore[k] = vCounts[k] > 0 ? crossSum[k] / vCounts[k] : 0f;

            float[] areaW = new float[K];
            for (int k = 0; k < K; k++) areaW[k] = vCounts[k] > 0 ? (float)mCounts[k] / vCounts[k] : 0f;
            areaW[seedK] = 1.0f;

            float[] combined = new float[K];
            for (int k = 0; k < K; k++) combined[k] = crossScore[k] * intraSim[k] * areaW[k];

            SafeLog($"[Seed] K={K} | seedK={seedK} | comb={combined[seedK]:F3}");

            bool[] final = new bool[N];

            //针对高位分布的 simFwd 收紧阈值
            float sum = 0, sumSq = 0;
            foreach (var v in simFwd) { sum += v; sumSq += v * v; }
            float meanSim = sum / N;
            float stdSim = (float)Math.Sqrt(sumSq / N - meanSim * meanSim);

            // 像素阈值:均值 + 0.7倍标准差 (硬切 ~0.29 以上,精准拦截中等背景)
            float pixelThresh = meanSim + PixelSimSigma * stdSim;
            // 簇阈值:最高得分的 (只保留强相关主簇)
            float bestComb = float.MinValue; for (int k = 0; k < K; k++) if (combined[k] > bestComb) bestComb = combined[k];
            float clusterThresh = bestComb * ClusterScoreRatio;

            for (int i = 0; i < N; i++)
            {
                if (labels[i] >= 0 && lblToK.TryGetValue(labels[i], out int k))
                {
                    if (combined[k] > clusterThresh && simFwd[i] > pixelThresh)
                        final[i] = true;
                }
            }

            for (int i = 0; i < N; i++)
            {
                if (labels[i] >= 0 && lblToK.TryGetValue(labels[i], out int k))
                {
                    if (combined[k] > clusterThresh && simFwd[i] > pixelThresh)
                        final[i] = true;
                }
            }

            // 连通域过滤,切除远处孤立小块
            // 找 seedK 簇中任意一个点作为种子(用于连通域搜索)
            //int seedPoint = -1;
            //for (int i = 0; i < N; i++)
            //{
            //    if (labels[i] == seedK && final[i]) { seedPoint = i; break; }
            //}
            //if (seedPoint >= 0)
            //{
            //    final = KeepLargestConnectedComponent(final, seedPoint);
            //    SafeLog($"[Connectivity] Kept component with seed@{seedPoint}");
            //}

            return final;
        }


        // 连通域过滤:只保留包含种子点的最大连通区域,切除远处孤立小块
        private bool[] KeepLargestConnectedComponent(bool[] mask, int seedIdx)
        {
            int H = FeatH, W = FeatW, N = H * W;
            if (!mask[seedIdx]) return new bool[N]; // 种子点本身被过滤则直接返回空

            // BFS 找种子点所在的连通域
            bool[] visited = new bool[N];
            Queue<int> queue = new Queue<int>();
            queue.Enqueue(seedIdx);
            visited[seedIdx] = true;
            List<int> component = new List<int>();

            int[] dx = { -1, 1, 0, 0 }, dy = { 0, 0, -1, 1 }; // 4邻域

            while (queue.Count > 0)
            {
                int cur = queue.Dequeue();
                component.Add(cur);
                int cy = cur / W, cx = cur % W;

                for (int d = 0; d < 4; d++)
                {
                    int ny = cy + dy[d], nx = cx + dx[d];
                    if (ny < 0 || ny >= H || nx < 0 || nx >= W) continue;
                    int ni = ny * W + nx;
                    if (mask[ni] && !visited[ni])
                    {
                        visited[ni] = true;
                        queue.Enqueue(ni);
                    }
                }
            }

            // 生成新掩码:仅保留该连通域
            bool[] result = new bool[N];
            foreach (int idx in component) result[idx] = true;
            return result;
        }

        private int[] ClusterFeatures(float[][] featsNC, float tau)
        {
            int N = featsNC.Length;
            const int K = 6;
            int[] labels = new int[N];
            float[][] centroids = new float[K][];

            float[][] normFeats = new float[N][];
            for (int i = 0; i < N; i++)
            {
                float s = 0;
                for (int c = 0; c < FeatC; c++) s += featsNC[i][c] * featsNC[i][c];
                float inv = 1f / (float)Math.Sqrt(s) + 1e-8f;
                normFeats[i] = new float[FeatC];
                for (int c = 0; c < FeatC; c++) normFeats[i][c] = featsNC[i][c] * inv;
            }

            System.Random rand = new System.Random(42);
            centroids[0] = normFeats[rand.Next(N)];
            for (int k = 1; k < K; k++)
            {
                float[] dists = new float[N];
                float sum = 0;
                for (int i = 0; i < N; i++)
                {
                    float minD = float.MaxValue;
                    for (int c = 0; c < k; c++)
                    {
                        float d = 0;
                        var ci = centroids[c]; var fi = normFeats[i];
                        for (int j = 0; j < FeatC; j++) { float v = fi[j] - ci[j]; d += v * v; }
                        if (d < minD) minD = d;
                    }
                    dists[i] = minD; sum += minD;
                }
                float r = (float)rand.NextDouble() * sum;
                float cum = 0;
                for (int i = 0; i < N; i++) { cum += dists[i]; if (cum >= r) { centroids[k] = normFeats[i]; break; } }
            }

            for (int iter = 0; iter < 10; iter++)
            {
                for (int i = 0; i < N; i++)
                {
                    int bestK = 0float minD = float.MaxValue;
                    for (int k = 0; k < K; k++)
                    {
                        float d = 0;
                        var ck = centroids[k]; var fi = normFeats[i];
                        for (int c = 0; c < FeatC; c++) { float v = fi[c] - ck[c]; d += v * v; }
                        if (d < minD) { minD = d; bestK = k; }
                    }
                    labels[i] = bestK;
                }

                int[] counts = new int[K];
                float[][] sums = new float[K][];
                for (int k = 0; k < K; k++) sums[k] = new float[FeatC];
                for (int i = 0; i < N; i++)
                {
                    int k = labels[i]; counts[k]++;
                    for (int c = 0; c < FeatC; c++) sums[k][c] += normFeats[i][c];
                }
                for (int k = 0; k < K; k++)
                {
                    if (counts[k] == 0) continue;
                    float inv = 1f / counts[k];
                    for (int c = 0; c < FeatC; c++) centroids[k][c] = sums[k][c] * inv;
                }
            }

            int[] cnt = new int[K];
            for (int i = 0; i < N; i++) if (labels[i] >= 0) cnt[labels[i]]++;
            int curLbl = 0;
            int[] map = new int[K];
            for (int i = 0; i < K; i++) map[i] = -1;
            for (int i = 0; i < N; i++)
            {
                int l = labels[i];
                if (cnt[l] < 30) { labels[i] = -1; continue; }
                if (map[l] == -1) map[l] = curLbl++;
                labels[i] = map[l];
            }
            return labels;
        }

        private float[] BilinearUpsampleFlat(bool[] mask, int sh, int sw, int dh, int dw)
        {
            float[] res = new float[dh * dw];
            for (int ty = 0; ty < dh; ty++) for (int tx = 0; tx < dw; tx++)
                {
                    float sx = (tx + 0.5f) * sw / dw - 0.5f, sy = (ty + 0.5f) * sh / dh - 0.5f;
                    int x0 = (int)Math.Floor(sx), x1 = Math.Min(x0 + 1, sw - 1);
                    int y0 = (int)Math.Floor(sy), y1 = Math.Min(y0 + 1, sh - 1);
                    x0 = Math.Max(x0, 0); y0 = Math.Max(y0, 0);
                    float dx = sx - x0, dy = sy - y0;
                    float v00 = mask[y0 * sw + x0] ? 1f : 0f, v10 = mask[y0 * sw + x1] ? 1f : 0f;
                    float v01 = mask[y1 * sw + x0] ? 1f : 0f, v11 = mask[y1 * sw + x1] ? 1f : 0f;
                    res[ty * dw + tx] = (1 - dx) * (1 - dy) * v00 + dx * (1 - dy) * v10 + (1 - dx) * dy * v01 + dx * dy * v11;
                }
            return res;
        }

        private Mat RecoverOriginalMask(float[] m1024, int dx, int dy, int nw, int nh, int ow, int oh)
        {
            using (Mat full = new Mat(ModelSize, ModelSize, MatType.CV_8UC1))
            {
                for (int y = 0; y < ModelSize; y++) for (int x = 0; x < ModelSize; x++)
                        full.At<byte>(y, x) = m1024[y * ModelSize + x] > 0.5f ? (byte)255 : (byte)0;
                using (Mat cr = new Mat(full, new Rect(dx, dy, nw, nh)))
                using (Mat outM = new Mat()) { Cv2.Resize(cr, outM, new OpenCvSharp.Size(ow, oh), 0, 0, InterpolationFlags.Nearest); return outM.Clone(); }
            }
        }

        private Mat OverlayMask(Mat src, Mat mask)
        {
            Mat ov = src.Clone();
            for (int y = 0; y < ov.Height; y++) for (int x = 0; x < ov.Width; x++)
                    if (mask.At<byte>(y, x) > 128) { var c = ov.At<Vec3b>(y, x); c[2] = (byte)Math.Min(255, c[2] + 100); ov.Set(y, x, c); }
            return ov;
        }

        private float Percentile(float[] arr, float p)
        {
            var s = (float[])arr.Clone(); Array.Sort(s); int idx = (int)Math.Round(p * (s.Length - 1));
            return s[Math.Max(0, Math.Min(idx, s.Length - 1))];
        }

        // 形态学清理:去除 64x64 掩码中的孤立噪声点与毛刺
        private bool[] CleanMask64(bool[] mask64)
        {
            using (Mat m = new Mat(FeatH, FeatW, MatType.CV_8UC1))
            {
                for (int i = 0; i < mask64.Length; i++)
                    m.At<byte>(i / FeatW, i % FeatW) = mask64[i] ? (byte)255 : (byte)0;

                using (var kernel = Cv2.GetStructuringElement(MorphShapes.Ellipse, new OpenCvSharp.Size(3, 3)))
                {
                    // Open: 腐蚀+膨胀 → 消除 <3x3 的孤立噪点
                    Cv2.MorphologyEx(m, m, MorphTypes.Open, kernel);
                    // Close: 膨胀+腐蚀 → 平滑边缘,填充目标内部微小空洞
                    Cv2.MorphologyEx(m, m, MorphTypes.Close, kernel);
                }

                bool[] cleaned = new bool[mask64.Length];
                for (int i = 0; i < mask64.Length; i++)
                    cleaned[i] = m.At<byte>(i / FeatW, i % FeatW) > 128;
                return cleaned;
            }
        }

        private void PrintTensorStats(string label, float[] data, int[] shape = null)
        {
            if (!log) return;

            if (data == null || data.Length == 0) { SafeLog($"{label}: NULL\n"); return; }
            float sum = 0, min = data[0], max = data[0], sumSq = 0;
            for (int i = 0; i < data.Length; i++)
            {
                float v = data[i]; sum += v; sumSq += v * v;
                if (v < min) min = v; if (v > max) max = v;
            }
            float mean = sum / data.Length;
            float std = (float)Math.Sqrt(sumSq / data.Length - mean * mean);
            string preview = string.Join(", ", data.Take(10).Select(v => v.ToString("F4")));
            string shapeStr = shape != null ? $"[{string.Join(",", shape)}]" : $"[{data.Length}]";
            SafeLog($"{label} {shapeStr}: mean={mean:F4}, std={std:F4}, min={min:F4}, max={max:F4}, preview=[{preview}...]\n");
        }

        private void PrintTensorStats(string label, float[][] data2D)
        {
            if (!log) return;

            if (data2D == null || data2D.Length == 0) { SafeLog($"{label}: NULL\n"); return; }
            var flat = data2D.SelectMany(row => row).ToArray();
            PrintTensorStats(label, flat, new[] { data2D.Length, data2D[0]?.Length ?? 0 });
        }

        private void PrintTensorStats(string label, Tensor<float> tensor, int batchIdx = 0)
        {
            if (!log) return;

            var shape = tensor.Dimensions.ToArray();
            var flat = tensor.ToArray();
            if (shape.Length == 4)
            {
                int C = shape[1], H = shape[2], W = shape[3];
                float[] batchData = new float[C * H * W];
                int offset = batchIdx * C * H * W;
                Array.Copy(flat, offset, batchData, 0, batchData.Length);
                PrintTensorStats(label, batchData, new[] { C, H, W });
            }
            else
            {
                PrintTensorStats(label, flat, shape);
            }
        }

        private void SaveArrayToFile(string label, float[] data, string folder = "debug_output")
        {
            if (!log) return;

            Directory.CreateDirectory(folder);
            string path = Path.Combine(folder, $"{label}_{DateTime.Now:HHmmssfff}.txt");
            using (var sw = new StreamWriter(path))
            {
                sw.WriteLine($"# {label} length={data.Length}");
                foreach (var v in data) sw.WriteLine(v.ToString("F8"));
            }
            SafeLog($"[DEBUG] Saved {label} to {path}\n");
        }

        private void SafeLog(string msg)
        {
            if (!log) return;

            if (this.InvokeRequired)
            {
                this.Invoke(new Action<string>(m => textBox1.AppendText(m + Environment.NewLine)), msg);
            }
            else
            {
                textBox1.AppendText(msg + Environment.NewLine);
            }
        }
    }
}