bf16权重合并lora出现无法忽略的精度损失

310 阅读2分钟

最近在微调 Qwen VL 模型,使用 peft 库的 lora 进行微调。为了更高的推理效率,就把 lora 合并到了基底模型。但合并过后的模型输出效果非常差。

不应该啊,没合并前还好好的。从原理上 lora 合并与不合并的输出是等价的。难道是精度问题?

思考

基于 float16 或 bfloat16 的半精度训练现在已经非常成熟。像是 Qwen 这样的大模型基本是使用 bfloat16 精度存储的模型权重。

要说 lora 合并带来的损失,我能想到的就只有精度上的损失了。接下来用个小一点的模型实验一番。

先说结论:基底模型权重是 float32 才能忽略 lora 合并的损失

实验

使用以下代码进行实验:

import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.float32,
)

input = torch.tensor([[1, 2, 3, 4, 5]])

# lora model
config = LoraConfig(
    r=8,
    init_lora_weights=False,
)
model_new = get_peft_model(
    model,
    config,
    adapter_name="adapter1",
)
# adapter1 output
model_new.set_adapter("adapter1")
output_adapter1 = model_new(input).logits

model_new.merge_and_unload(safe_merge=True, adapter_names=["adapter1"])
output_base = model_new.get_base_model()(input).logits

diff = output_adapter1 - output_base
print(torch.sum(torch.abs(diff)))

代码用 lora 输出与模型合并输出进行对比,输出结果为 0.2223。鉴于输出维度为 [1, 5, 50272],这个误差几乎能够忽略不计。

现在将基底模型设置为 bfloat16,再进行测试。

···

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.bfloat16,
)

···

结果猛然上升到 2416.,误差上升了万倍。

如果是 float16 呢?毕竟 bfloat16 是用动态换精度的格式,换用 float16 应该会好些。

···

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.float16,
)

···

其结果为 302.7500。仍然很高。

如果,对于 bfloat16 的基底模型,先转换为 float32 再合并 lora,结果会不会有所改善?

···

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.bfloat16,
)

···

model = model.float()
model_new.merge_and_unload(safe_merge=True, adapter_names=["adapter1"])
model = model.bfloat16()

···

输出结果 2416.。没有改善。

碎碎念

网上只能搜到用 QLoRA 微调后合并权重会出现精度损失的信息,提出的解决方法是把基底模型上采样到 float32 再进行合并。QLoRA 毕竟是把基底模型量化到了 Int4,就是没想到 bfloat16 也会出现问题。

半精度训练或推理能减少显存占用、充分利用上 tensor core 加速,获得种种好处的同时模型仍然能够保持较高准确度,不使用半精度不太现实。为了更好的效果,要么全量微调,要么只能不合并 lora 去推理了。