TCT模型训练

104 阅读2分钟

1.数据集

首先根据老师的建议,将每张图片的异常细胞进行裁剪,大小为640*640,一个细胞裁剪十张图片,每一张图片都包含该异常细胞,且坐标随机。

最终数据集中图片有十万多张,xml标签与图片一一对应,表明细胞坐标及类型。

2.yolov11

# 文件:CNN/train_yolo11_4090.py
# 用法:
#   pip install ultralytics
#   python train_yolo11_4090.py --model yolo11m --imgsz 640 --epochs 100 --device 0
# 备注:
#   --model 可选 yolo11n/s/m/l/x(4090建议从 yolo11m 或 yolo11l 开始)
#   --batch 默认 -1 自动找最大batch;OOM就改小点
#   已适度降低医疗场景不友好的强增强(mosaic/mixup),保留基础颜色与缩放

import os
import argparse
from ultralytics import YOLO

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data", type=str, default="abnormal_cells_4cls.yaml",
                    help="data.yaml 相对路径(相对于本脚本目录)")
    ap.add_argument("--model", type=str, default="yolo11m",
                    help="yolo11n/s/m/l/x(不要带.pt)")
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--batch", type=int, default=-1,  # -1 = 自动找到最大batch(Ultralytics支持)
                    help="-1 表示自动最大batch")
    ap.add_argument("--imgsz", type=int, default=640)
    ap.add_argument("--device", type=str, default="0")
    ap.add_argument("--workers", type=int, default=8)
    ap.add_argument("--project", type=str, default="runs_detect")
    ap.add_argument("--name", type=str, default="yolo11_4cls_4090")
    ap.add_argument("--resume", action="store_true")
    return ap.parse_args()

def main():
    args = parse_args()
    this_dir = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(this_dir, args.data)
    model_ckpt = f"{args.model}.pt"  # 例如 yolo11m.pt

    # 加载预训练模型
    model = YOLO(model_ckpt)

    # 4090 友好设置与轻度医学增强
    overrides = dict(
        data=data_path,
        epochs=args.epochs,
        imgsz=args.imgsz,
        batch=args.batch,        # -1 自动最大;或手动如 64/48/32 …
        device=args.device,
        workers=args.workers,
        project=args.project,
        name=args.name,
        exist_ok=True,
        resume=args.resume,

        amp=True,                # 混合精度,提升吞吐
        cos_lr=True,             # 余弦退火
        patience=30,             # 早停
        optimizer="auto",
        seed=42,
        cache="ram",             # 若内存够,缓存数据到内存加速
        # 数据增强(适度,避免过强几何扰动)
        hsv_h=0.01, hsv_s=0.4, hsv_v=0.4,
        degrees=0.0, translate=0.05, scale=0.6, shear=0.0, perspective=0.0,
        flipud=0.0, fliplr=0.5,  # 细胞上下翻转通常无意义,左右翻转保留
        mosaic=0.1, mixup=0.0, copy_paste=0.0,
    )

    results = model.train(**overrides)

    # 用最优权重再验证一次
    best = os.path.join(results.save_dir, "weights", "best.pt")
    print(f"验证最优权重: {best}")
    model_best = YOLO(best)
    model_best.val(data=data_path, imgsz=args.imgsz, batch=max(16, (args.batch if args.batch > 0 else 32)),
                   device=args.device)

    # 可选:导出 ONNX/TensorRT
    # model_best.export(format="onnx", dynamic=True)
    # model_best.export(format="engine")  # 需要TensorRT

if __name__ == "__main__":
    main()

image.png image.png

image.png

3.vig

  • 全称 Vision GNN(ViG,论文 “An Image is Worth Graph of Nodes”)
  • 思想:把一张图切成小块(patch/tokens),每个小块是图上的“节点”,节点之间按特征相近性建 k 近邻图(k-NN),再用图卷积做信息传递。这样既能学到“邻域关系”,又避免 ViT 全局注意力 O(N^2) 的高成本。

4.融合

image.png