Qwen2.5思维链微调代码实操 + 多卡Lora微调完整代码

495 阅读15分钟

作者:情感机器团队-陈少宏

邮箱:shaohon_chen@115lab.club

最近对于Scaling Law的讨论异常火热。包括ilya大神自己都下场演讲关于大模型数据规模碰壁的问题(参考:机器之心官网发文)。直觉上,现在大模型思维的过程更像是人对一件事情直觉的反应,而不是多步思考和迭代思考的过程。正如下图ilya的PPT中的一张图,10层神经网络可以干人在0.1秒干的事情。而现在大模型上十亿的参数也可能只是解决人经过一分钟思考的回答。像OpenAI o1或者强化对齐可能是通往AGI的方法之一。刚好趁这个机会尝试一下一直没有进行的思维链微调。下面简单介绍一下思维链技术,并且使用通义千问进行CoT数据微调并且简单测试一下。

ilya-ppt.png

网上关于思维链微调的实操比较少,甚至对于Qwen的指令微调高质量的文章都不多,许多细节都描述的不清楚,希望这篇文章能够进一步帮助到读者微调Qwen时能够关注到一些细节。

这里感谢魔乐社区赞助了昇腾NPU进行微调。尝试了下国产卡做微调的效果还是非常不错!本篇教程专门做了openMind Library的适配,兼容昇腾NPU

友情链接:

1 思维链技术介绍

思维链技术(Chain of Thought,也简称为CoT),最早由Json Wei等人在Chain-of-Thought Prompting Elicits Reasoning in Large Language Models文章提出。简单来说就是通过提示词让模型能够将一个复杂的问题分步思考。比如举个文章中提到的例子(见下图),一个数学问题是:

食堂有 23 个苹果。如果他们用掉了 20 个来做午餐,又买了 6 个,现在他们有多少个苹果?

对于一个人类,他的思考步骤是:

  1. 食堂有23个苹果,用了20个,所以是23-20=3

  2. 又买了6个,所以是3+6=9

  3. 共有9个苹果

当然这个思维过程还能猜的更碎。比如上面的过程中第一个实际上蕴涵了“因为食堂有23个苹果,用了20个”、“所以是23-20=3”两个步骤。对于进行了“指令微调”的模型来说,更倾向于简短的回答入,比如直接回答“他现在有XX个苹果”,而且对于一个需要多步计算的数学题往往是错误的。CoT技术的主要目标就是通过提示词让模型一步一步来,像上面的思考步骤那样要求模型不仅回答问题,同时还将问题的生成过程写出来。

cot-paper.png

Json Wei的这篇文章的工作是在提示词上做的(文中分了few-shot和zero-shot两种方式,简单来说就是给样例和不给样例),用学术些的话来说就是“上下文学习”。这篇文章的实验部分证明了CoT确实能有效提升LLM的推理能力,尤其是数学任务。当然很多人一下就想到了,我能否用微调的方式直接将这种“一步步思考”的能力直接微调到模型中呢?实际上Json Wei大神也很快想到了,所以在紧接着下一篇Scaling Instruction-Finetuned Language Models、Google的FLAN​数据集改进版FLAN PaLM中直接引入了CoT数据集,这篇Json Wei虽然不是一作,但也是参与者之一。下图展示了FLAN PaLM所使用微调数据集,和对CoT部分的消融实验,很明显CoT使得模型能够具备更强的zero-shot能力。

(建议读一下原文,CoT部分在第四章节,我个人觉得文章中的消融实验不足以直接证明CoT数据对于模型的性能提升,但是目前看来大家在构建finetune数据时还是倾向于将CoT数据构建进去,起码能够显著提升模型使用CoT提示词的能力)。

flanpalm.png

flanpalm-cot.png

2 Qwen2.5思维链微调实操

2.1 CoT数据集准备

本文使用的是早年收集的一个CoT英文数据集,从质量判断是利用GPT4+提示词生成的,因为时间太久了实在是忘了出处无法附上链接,在这里跟原作者抱歉了orz。现在网上已经有许多高质量的CoT数据集,包括一些数据集专门区分了思考中间过程与返回结果,以及一些将思维步骤使用Special Token进行分隔的。

本文考虑到仅仅是练手项目,选用了最基础的数据集,并且使用经过指令微调的Qwen2.5-7B将英文数据集进一步翻译成了中文+整理成了标准的Instruction-Input-Output样式,以方便对Qwen模型的微调和练手学习。注意这个数据集仅仅是练手,LLM翻译的质量并不高,很容易被要翻译的内容带偏了。下图展示了部分数据集和样式:

{
    "instruction": "这家宠物店一共有多少只啮齿动物?宠物店里有6个笼子装着啮齿动物。其中3个笼子里各有10只仓鼠,另外3个笼子里各有5只豚鼠。那么这家宠物店一共养了多少只啮齿动物呢?",
    "input": "",
    "output": "让我们一步步来思考。\n\n有3个笼子,每个笼子里有10只豚鼠。\n这意味着总共有30只豚鼠。\n\n还有3个笼子,每个笼子里有5只豚兔。\n这意味着总共有15只豚兔。\n\n现在我们可以把30和15加起来得到答案:\n\n总共有45只啮齿动物。\n因此,答案(阿拉伯数字)是45。"
}

可以使用如下链接直接下载测试数据集

2.2 环境安装

2.2.1 昇腾NPU + openMind Library环境安装

使用昇腾NPU的话推荐在魔乐社区中找模型,里面能找到完成NPU适配的模型。魔乐社区使用的是openMind Library工具包,这个包支持在Nvidia GPU和Ascend NPU上运行,使用起来和transfomers接口一致。如果说做NPU迁移的话非常推荐使用。

魔乐社区的模型分为MindSpore支持和Pytorch-NPU支持,这里主要看本地装什么环境,考虑到新手学习的话推荐使用Pytorch-NPU,和Pytorch逻辑基本一致。

modelers.png

2.2.2 驱动安装&验证

首先得确定有NPU卡和NPU相关驱动,驱动是8.0.RC3.beta1,具体可以参考软件安装-CANN商用版8.0.RC3开发文档-昇腾社区

安装好后的验证方法是运行下面的命令,该命令作用与nvidia-smi类似,这里是查看NPU的状态和性能

npu-smi info

可以看到如下信息的话就表示驱动已经安装完成了,左侧是安装成功后运行代码后的结果,右侧是每一部分的含义。

npu-info.png

2.2.3 openMind环境搭建

openMind环境安装比较简单,这边列出所需用到的全部安装命令:

# 下载PyTorch安装包
wget https://download.pytorch.org/whl/cpu/torch-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 下载torch_npu插件包
wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc3-pytorch2.4.0/torch_npu-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 安装命令
pip3 install torch-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
pip3 install torch_npu-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
# 安装openMind Library
pip install openmind[pt]
pip install transformers accelerate datasets peft   # 部分场景会用到hf几个包,干脆全装了
# 安装SwanLab
pip install swanlab

2.2.4 Nvidia GPU + Transformers环境安装

这个流程比较简单,首先也是得确保Nvidia驱动存在,验证命令:

nvida-smi

如果没显示同样需要先安装cuda环境,这里贴上CUDA官方安装链接

网上有大量cuda安装安装教程,这里笔者就不赘述了。同样放出transformers环境安装的全部命令:​

pip install torch
pip install transformers accelerate datasets peft
# 安装SwanLab
pip install swanlab

2.2.5 关于提示词模版构建(大坑)

这里需要强调一下,在使用Qwen2.5的Instruct模型微调时,为了保障效果建议严格按照模型自身的Instruct的提示词模版构建。HF Transformers在4.3几的版本开始支持Chat Templates。Qwen2.5关于Instruct和Chat的提示词模版被直接写到了tokenziers的设置保存中,这导致了很多人在原始代码中找不到instruct提示词格式的构造。很多教程在教微调的时候还用的是Qwen1的老提示词模版或者自己构建的提示词模版,这会严重影响使用已经微调的模型做进一步微调时的效果。建议针对模型微调时一定要仔细检查提示词模版的实现部分。尽量使用模型已经定义好的格式和结构。

可以在Qwen的HF项目中找到提示词模版,点击HF Qwen查看chat_template设置。chat_template默认使用的是一种前端模版语言jinja,并不好看懂,笔者把Qwen2.5的提示词模版格式化后粘贴在下文:

{%- if tools %}
    {{- '<|im_start|>system\n' }}
    {%- if messages[0]['role'] == 'system' %}
        {{- messages[0]['content'] }}
    {%- else %}
        {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
    {%- endif %}
    {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
    {%- for tool in tools %}
        {{- "\n" }}
        {{- tool | tojson }}
    {%- endfor %}
    {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
    {%- if messages[0]['role'] == 'system' %}
        {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
    {%- else %}
        {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
    {%- endif %}
{%- endif %}
{%- for message in messages %}
    {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
        {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
    {%- elif message.role == "assistant" %}
        {{- '<|im_start|>' + message.role }}
        {%- if message.content %}
            {{- '\n' + message.content }}
        {%- endif %}
        {%- for tool_call in message.tool_calls %}
            {%- if tool_call.function is defined %}
                {%- set tool_call = tool_call.function %}
            {%- endif %}
            {{- '\n<tool_call>\n{"name": "' }}
            {{- tool_call.name }}
            {{- '", "arguments": ' }}
            {{- tool_call.arguments | tojson }}
            {{- '}\n</tool_call>' }}
        {%- endfor %}
        {{- '<|im_end|>\n' }}
    {%- elif message.role == "tool" %}
        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
            {{- '<|im_start|>user' }}
        {%- endif %}
        {{- '\n<tool_response>\n' }}
        {{- message.content }}
        {{- '\n</tool_response>' }}
        {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
            {{- '<|im_end|>\n' }}
        {%- endif %}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- endif %}

可以看到超级长,因为定义了好几种情况,包括是否有system prompt。以及针对function tools怎么处理等等等等。如果读不懂(我感觉大多数搞deep learning的除了做LLM Finetune也很小有机会去学一个前端语言)我建议用大模型给你逐行解释下,这里附上jinja的官方文档

这里笔者简单提供我所使用的Qwen2.5简化版python模版(下脚本),去除了Function Calling和多轮对话的部分。并且只包含对Instruct和Inputs的处理部分,以及Assitants的生成头。这分为带inputs的版本和不带inputs的版本。我自己专门测试了使用此模版构造的提示词长度上和使用Qwen带chat_template的tokenziers完全一致。你只需要将outputs部分增加一个\n<|im_end|>\n即可直接拼接成finetune LLM模型的targets部分。

PROMPT_DICT = {
    "prompt_no_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n""",
    "prompt_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n""",
}

如果你直接偷懒使用chat_template来tokenizer仅带outputs部分的数据。你会发现由于Qwen的chat template处理机制,实际上生成的outputs部分会默认带上system prompts。导致最后训练阶段会出现奇怪的内容。Qwen的tokenizers针对未增加system角色的对话输入会自动加上如下提示词:

system:You are Qwen, created by Alibaba Cloud. You are a helpful assistant.

更神奇的是,这个system prompt居然是个英文的。Qwen可是个中文模型。。。这个system prompt的出现会影响后续的模型微调效果。

2.3 可视化工具配置(SwanLab使用教程)

swanlab.png

SwanLab可以将微调的许多关键参数自动记录下来并且能够再现可视化查看训练。能够在线或者离线保存+查看训练日志。SwanLab(有可能是唯一的)同时支持记录NVIDIA GPU和华为N腾NPU设备的日志记录工具。最新版本已经支持对NPU的内存使用、功率、温度等进行记录。甚至还有黑夜模式,方便苦逼研究生大晚上搞科研。:)

ascend-swanlab.png

关于SwanLab的使用方法可以参考SwanLab官方文档-快速开始

对于Huggingface Transformers或者支持华为昇腾NPU的openMind Library,可以使用SwanLab Integration轻松完成实验数据记录:

...
from swanlab.integration.huggingface import SwanLabCallback
swanlab_call = SwanLabCallback( #
    "Ascend_finetune_v2",
    experiment_name=os.path.basename(os.path.normpath(training_args.output_dir)),
    config=asdict(data_args)
    | asdict(model_args)
    | asdict(training_args)
    | asdict(lora_config),
    public=True,
)
trainer = openmind.Trainer( # 使用hf transformers的话则是把openmind替换为transformers
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    callbacks=[swanlab_call],   # callback加入进去即可
    **data_module,
)
...

使用后不仅能进行多图表对比,更重要的是把一大堆的huggingface transformers的训练超参数全部记录下来了,简直调参党福音。

swanlab-hyp.png

2.4 微调代码(多卡,支持华为Ascend卡)

下面附上完整的微调代码。在项目目录下创建finetune.py文件,并将如下代码粘贴进文件中。

多卡训练的话可以使用torchrun,这里附上一个启动多卡的bash脚本,在当前目录下创建finetune.sh,并且粘贴如下脚本:

import copy
import os
import io
import json
import logging
from dataclasses import dataclass, field, asdict
from typing import Dict, Optional, Sequence

import torch
from torch.utils.data import Dataset
try:
    import openmind as tf_module
except:
    import transformers as tf_module
import transformers

from peft import LoraConfig, get_peft_model
from swanlab.integration.huggingface import SwanLabCallback

IGNORE_INDEX = -100

PROMPT_DICT = {
    "prompt_no_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n""",
    "prompt_input": """<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n""",
}


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(
        default="./weights/Qwen/Qwen2.5-7B-Instruct"
    )


@dataclass
class DataArguments:
    data_path: str = field(
        default="./data/cot_train_cn.jsonl",
        metadata={"help": "Path to the training data."},
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )


def _tokenize_fn(strings: Sequence[str], tokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def jload(f, mode="r", jsonl=True):
    if not isinstance(f, io.IOBase):
        with open(f, mode=mode, encoding="utf-8") as f:
            if jsonl:
                # Parse JSON Lines
                return [json.loads(line) for line in f if line.strip()]
            else:
                # Parse standard JSON
                return json.load(f)
    else:
        if jsonl:
            return [json.loads(line) for line in f if line.strip()]
        else:
            return json.load(f)


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [
        _tokenize_fn(strings, tokenizer) for strings in (examples, sources)
    ]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, tokenizer):
        super(SupervisedDataset, self).__init__()
        logging.warning("Loading data...")
        list_data_dict = jload(data_path)

        logging.warning("Formatting inputs...")
        prompt_input, prompt_no_input = (
            PROMPT_DICT["prompt_input"],
            PROMPT_DICT["prompt_no_input"],
        )
        sources = [
            (
                prompt_input.format_map(example)
                if example.get("input", "") != ""
                else prompt_no_input.format_map(example)
            )
            for example in list_data_dict
        ]
        targets = [
            f"{example['output']}\n{tokenizer.eos_token}\n"
            for example in list_data_dict
        ]

        logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = preprocess(sources, targets, tokenizer)

        try:
            self.input_ids = data_dict["input_ids"]
        except KeyError as e:
            raise KeyError("input_ids is invalid") from e
        try:
            self.labels = data_dict["labels"]
        except KeyError as e:
            raise KeyError("labels is invalid") from e

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: object

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def make_supervised_data_module(tokenizer, data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = SupervisedDataset(
        tokenizer=tokenizer, data_path=data_args.data_path
    )
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(
        train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
    )


def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = tf_module.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=True,
    )

    # 定义LoRA配置
    lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.1,
        bias="none",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    tokenizer = tf_module.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
    )

    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

    swanlab_call = SwanLabCallback(
        "Ascend_finetune_v2",
        experiment_name=os.path.basename(os.path.normpath(training_args.output_dir)),
        config=asdict(data_args)
        | asdict(model_args)
        | asdict(training_args)
        | asdict(lora_config),
        public=True,
    )

    trainer = tf_module.Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        callbacks=[swanlab_call],
        **data_module,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()

开启多卡训练的方式如下:

bash finetune.sh <使用的GPU/NPU数量>

如果提示登录swanlab,可以在官网完成注册后,使用获取API KEY找到对应的登陆密钥并粘贴,这样将能够使用云上看版随时查看训练过程与结果。

2.5 微调效果(附上Gradio代码)

本来准备了Ceval的测试结果,结果不知道为什么Ascend服务器连不上了,等过段时间更新下教程文档。

swanlab-loss.png

这里放出使用CoT数据微调qwen-7b-instruct、qwen-0.5b-instruct和使用qwen-7b-instruct(8NPU)的loss结果。可以看到使用8个NPU能带来更好的训练loss表现和稳定性,哪怕在使用同样迭代数据量的情况下,8个NPU依然能带来更好的loss结果。可能更大的loss有助于模型稳定下降。

最后展现下使用gradio完成的官方Qwen2.5-7B-Instruct、基于Qwen2.5-7B在中文alpaca数据集上指令微调、以及cot微调后的模型回复对比。可以看到CoT微调后模型确实具备了“step by step”的回复模式。​

demo.png

当然许多读者注意到了官方模型也展现出了“step by step”的回答模式,这主要是因为现在较新的模型在finetune数据集甚至pretrain数据集中就会预先加入CoT数据,所以模型在进行问答、尤其是数学题问答时,会展现出“步骤分解”的现象。笔者后续会尝试在较早期的demo中更新微调的​

附上启用gradio的demo测试代码:

使用pip install gradio安装依赖包

import gradio as gr

from openmind import AutoModelForCausalLM, pipeline
from peft import PeftModel

TOTAL_GPU_NUMS = 8
TOKENIZE_PATH = "~/weightsweights/Qwen/Qwen2.5-7B-Instruct"
MODEL_LIST = {
    "office_qwen7b": "~/weights/Qwen/Qwen2.5-7B-Instruct",  # 官方模型
    "alpaca_qwen7b_lora": "./projects/qwen_finietune_cot/output/qwen25-7B-alpaca",  # 7b+alpaca
    "cot_qwen7b_lora": "./projects/qwen_finietune_cot/output/qwen25-7Bi-cot",  # cot微调
}

model_names = MODEL_LIST.keys()
pipes = dict()
for i, model_name in enumerate(model_names):
    save_path = MODEL_LIST[model_name]
    model = AutoModelForCausalLM.from_pretrained(save_path)
    if model_name[:-5] == "_lora":
        model = PeftModel.from_pretrained(model, save_path)
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=TOKENIZE_PATH,
        framework="pt",
        device=f"npu:{i%TOTAL_GPU_NUMS}",
    )
    pipes[model_name] = pipe


def generate_response(instruct_text, input_text):
    messages = [
        {
            "role": "system",
            "content": instruct_text,
        },
        {
            "role": "user",
            "content": input_text,
        },
    ]
    outputs = [
        pipes[model_name](messages, max_new_tokens=256)[-1]["content"]
        for model_name in model_names
    ]
    return tuple(outputs)


# 创建 Gradio 界面
demo = gr.Interface(
    fn=generate_response,  # 函数名
    inputs=[
        gr.Textbox(label="instruction"),
        gr.Textbox(label="input"),
    ],  # 输入文本框
    outputs=[gr.Textbox(label=model_name) for model_name in model_names],
)


if __name__ == "__main__":
    demo.launch()