【大模型微调-AI修图提示词生成】

0 阅读6分钟

1.背景

当前的大模型修图能力足以媲美人工修图,但目前用户存在的一个痛点就是无法去有效的设置修图提示词,导致无法让AI修图让自己满意的图。本文章通过收集“图片+文本(修图提示词)”对大模型进行微调,微调后的大模型可以根据用户上传的写真照输出对应的修图提示词,用户再根据输出的提示词去豆包或Nano(香蕉)进行修图。之前一直是在做nlp单模态的LoRA微调,该应用需要联系图片和文本就需要采用到多模态的微调,对于想了解多模态微调的同学也可以参考下面一篇论文,先了解一下多模态模型的实现原理。图片+文本就好比赋予模型眼睛和嘴巴,可以实现很多有趣的应用,欢迎和大家一起讨论

项目地址: github.com/23jisuper/w…

效果展示:www.bilibili.com/video/BV1N8…

论文: proceedings.mlr.press/v139/radfor…

在这里插入图片描述


2.wxy_AI_Drawing项目介绍

(1).项目结构
wxy_AI_Drawing/
├── images/                 # 项目图片资源
│   └── xy.png
├── model/
│   └── xy_model/           # 本地模型目录
├── config.py               # 配置 dataclass
├── modeling.py             # 模型加载 + LoRA 注入
├── collator.py             # 多模态数据 collator
├── messages.py             # 训练样本与 messages 构造
├── io_utils.py             # JSONL / CSV 工具
├── train.py                # 训练入口
├── infer.py                # 推理入口
├── convert_data.py         # CSV → JSONL 转换
├── requirements.txt
└── README.md

(1). config.py —— 配置中心

config.py 使用 Python 的 dataclass 将所有配置集中管理,避免在代码中散落魔法数字。

配置类职责关键字段
DataConfig数据与提示词data_roottrain_jsonlsystem_messagedefault_instruction
ModelConfig模型与预处理model_id_or_pathmin_pixelsmax_pixelstrust_remote_code
LoRAConfigLoRA 超参数r(秩)、alpha(缩放因子)、dropout
TrainConfig训练超参数num_train_epochslearning_rategradient_checkpointing

设计要点

  • system_messagedefault_instruction 明确规定了「图像描述助手」的角色和输出风格,使模型输出更贴合「AI 绘画提示词」场景。
  • min_pixels / max_pixels 控制图像分辨率,影响显存与细节的平衡:256×28²384×28² 是常见折中配置。

(2). messages.py —— 消息与样本构造

TrainSample:训练样本的不可变结构,包含 image(相对路径)、instructionresponse

build_prompt_messages:构造「用户输入」的 messages,格式为:

[system] + [user: image + instruction]

build_full_messages:在 prompt 基础上追加 assistant 的 response,用于训练:

[system] + [user: image + instruction] + [assistant: response]

关键点:Qwen3.5 的 chat 格式要求 content[{"type": "image", "image": path}, {"type": "text", "text": "..."}] 这种混合结构,messages.py 负责把 TrainSample 转成这种格式,供后续 process_vision_info 和 collator 使用。


(3). io_utils.py —— 数据读写与格式转换
函数作用
load_jsonl一次性加载 JSONL 为 list[dict]
iter_jsonl流式读取 JSONL,节省内存
write_jsonlIterable[dict] 写入 JSONL
csv_to_jsonl将 CSV/TSV 转为训练用 JSONL

csv_to_jsonl` 的设计

  • 使用 csv.Sniffer 自动识别分隔符(逗号、制表符、分号)。
  • 默认列名:image_pathtext,可配置。
  • 会校验图片是否存在,缺失时抛出 FileNotFoundError

这样可以把「图片路径 + 描述」的 CSV 快速转为项目统一使用的 JSONL 格式。


(4). collator.py —— 多模态数据 collator

QwenVLDataCollator 是训练时的核心数据组装模块,负责把一个 batch 的 TrainSample 转成模型可接受的输入。

流程概览

  1. 构造 messages:对每个样本调用 build_prompt_messagesbuild_full_messages
  2. 视觉信息:通过 qwen_vl_utils.process_vision_info 处理 messages 中的图片/视频,得到 image_inputsvideo_inputs
  3. 文本编码:用 processor.apply_chat_template 对 prompt 和 full 文本分别编码。
  4. 统一调用 processor:对 textimagesvideos 做 batch 处理。
  5. Labels 与 mask:用 prompt_inputs 的 attention 长度得到 prompt 长度,将 prompt 部分的 labels 置为 -100,只对 assistant 部分计算 loss。

关键点:只对 assistant 的 token 计算 loss,避免模型学习 prompt 和 system 的重复,是标准的 SFT 做法。


(5). modeling.py —— 模型加载与 LoRA 注入

pick_dtype:根据 CUDA 可用性和 bf16 支持,选择 bfloat16 / float16 / float32

load_qwen3_5_vl

  • 根据 ModelConfig 加载 AutoProcessorQwen3_5ForConditionalGeneration
  • 使用 device_map="auto" 自动分配设备。
  • 关闭 use_cache 以配合梯度检查点。

apply_lora

  • 使用 PEFT 的 LoraConfig,对 q_projk_projv_projo_projgate_projup_projdown_proj 注入 LoRA。
  • 调用 enable_input_require_grads() 以支持梯度检查点。

(6). train.py —— 训练入口

训练流程

  1. 解析命令行参数,构建 DataConfigModelConfigLoRAConfigTrainConfig
  2. 加载模型与 processor,冻结 vision encodermodel.model.visual),仅微调语言部分。
  3. 注入 LoRA。
  4. 加载 JSONL,构建 datasets.Dataset,使用 QwenVLDataCollator
  5. 调用 HuggingFace Trainer 训练。
  6. 保存 LoRA 到 output_dir/adapter,同时保存 processor。

冻结 vision encoder 的原因:视觉编码器通常已在大规模数据上预训练,微调时只更新语言模型部分,能减少显存、提升稳定性,并降低过拟合风险。


(7). infer.py —— 推理入口

generate_caption:封装单次推理逻辑:

  1. 构造 messages(含 image + text)。
  2. 调用 process_vision_infoprocessor 得到输入。
  3. 使用 model.generate 生成。
  4. 解码并只返回新生成部分。

main:支持两种模式:

  • 仅 base 模型--model 指向本地或远程模型,processor 也从同一路径加载。
  • base + LoRA:先加载 base,再 PeftModel.from_pretrained 加载 adapter,processor 从 adapter 目录加载(训练时已保存)。

(8). convert_data.py —— 数据转换工具

命令行封装 csv_to_jsonl,用于将 CSV/TSV 转为 JSONL 训练数据,便于快速接入自有数据。


3.多模态微调

自己也是第一次微调多模态模型,记录一下过程中自以为需要重要理解的点,如有错误欢迎指正

(1).输入维度对比
维度纯 NLP 微调多模态微调(本项目)
输入类型纯文本 token文本 + 图像(像素)
输入表示离散 token embedding文本 embedding + 图像 patch embedding
序列长度主要由文本长度决定文本 + 图像 token 数

(2).模型结构差异

纯 NLP 模型
输入 → Token Embedding → Transformer 层 → 输出

多模态模型(如 Qwen3.5-VL)
输入 → [Vision Encoder] → 图像 patch embedding
→ [Text Tokenizer] → 文本 embedding
→ 多模态融合(图像与文本拼接为统一序列)
→ Transformer 层 → 输出 多模态模型需要额外处理视觉编码和图像-文本对齐,结构更复杂。


(3).数据格式差异

纯 NLP
每条样本通常是 {"input": "...", "output": "..."} 或类似格式,全部为文本。

多模态
每条样本必须包含图像路径和对应文本,例如:

{"image": "images/1.png", "instruction": "...", "response": "..."}

训练时需要:

  • 加载图片。
  • process_vision_info 解析 messages 中的图像。
  • 用 processor 将图像和文本一起编码为 batch。

(4).小结

多模态微调在输入、模型结构、数据格式、训练流程上都比纯 NLP 更复杂,但能实现「看图说话」类能力。本项目通过冻结 vision、LoRA、梯度检查点等手段,在保持显存可控的前提下,完成多模态图像描述微调,适合作为多模态微调的入门示例。最后,希望有共同爱好大模型的可以一起提供思路开发有价值的模型应用产品(自己前端知识比较弱😳)