1.数据集
首先根据老师的建议,将每张图片的异常细胞进行裁剪,大小为640*640,一个细胞裁剪十张图片,每一张图片都包含该异常细胞,且坐标随机。
最终数据集中图片有十万多张,xml标签与图片一一对应,表明细胞坐标及类型。
2.yolov11
# 文件:CNN/train_yolo11_4090.py
# 用法:
# pip install ultralytics
# python train_yolo11_4090.py --model yolo11m --imgsz 640 --epochs 100 --device 0
# 备注:
# --model 可选 yolo11n/s/m/l/x(4090建议从 yolo11m 或 yolo11l 开始)
# --batch 默认 -1 自动找最大batch;OOM就改小点
# 已适度降低医疗场景不友好的强增强(mosaic/mixup),保留基础颜色与缩放
import os
import argparse
from ultralytics import YOLO
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--data", type=str, default="abnormal_cells_4cls.yaml",
help="data.yaml 相对路径(相对于本脚本目录)")
ap.add_argument("--model", type=str, default="yolo11m",
help="yolo11n/s/m/l/x(不要带.pt)")
ap.add_argument("--epochs", type=int, default=100)
ap.add_argument("--batch", type=int, default=-1, # -1 = 自动找到最大batch(Ultralytics支持)
help="-1 表示自动最大batch")
ap.add_argument("--imgsz", type=int, default=640)
ap.add_argument("--device", type=str, default="0")
ap.add_argument("--workers", type=int, default=8)
ap.add_argument("--project", type=str, default="runs_detect")
ap.add_argument("--name", type=str, default="yolo11_4cls_4090")
ap.add_argument("--resume", action="store_true")
return ap.parse_args()
def main():
args = parse_args()
this_dir = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(this_dir, args.data)
model_ckpt = f"{args.model}.pt" # 例如 yolo11m.pt
# 加载预训练模型
model = YOLO(model_ckpt)
# 4090 友好设置与轻度医学增强
overrides = dict(
data=data_path,
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch, # -1 自动最大;或手动如 64/48/32 …
device=args.device,
workers=args.workers,
project=args.project,
name=args.name,
exist_ok=True,
resume=args.resume,
amp=True, # 混合精度,提升吞吐
cos_lr=True, # 余弦退火
patience=30, # 早停
optimizer="auto",
seed=42,
cache="ram", # 若内存够,缓存数据到内存加速
# 数据增强(适度,避免过强几何扰动)
hsv_h=0.01, hsv_s=0.4, hsv_v=0.4,
degrees=0.0, translate=0.05, scale=0.6, shear=0.0, perspective=0.0,
flipud=0.0, fliplr=0.5, # 细胞上下翻转通常无意义,左右翻转保留
mosaic=0.1, mixup=0.0, copy_paste=0.0,
)
results = model.train(**overrides)
# 用最优权重再验证一次
best = os.path.join(results.save_dir, "weights", "best.pt")
print(f"验证最优权重: {best}")
model_best = YOLO(best)
model_best.val(data=data_path, imgsz=args.imgsz, batch=max(16, (args.batch if args.batch > 0 else 32)),
device=args.device)
# 可选:导出 ONNX/TensorRT
# model_best.export(format="onnx", dynamic=True)
# model_best.export(format="engine") # 需要TensorRT
if __name__ == "__main__":
main()
3.vig
- 全称 Vision GNN(ViG,论文 “An Image is Worth Graph of Nodes”)
- 思想:把一张图切成小块(patch/tokens),每个小块是图上的“节点”,节点之间按特征相近性建 k 近邻图(k-NN),再用图卷积做信息传递。这样既能学到“邻域关系”,又避免 ViT 全局注意力 O(N^2) 的高成本。