大模型对话风格微调项目实战——数据工程篇

399 阅读22分钟

目录

项目背景

为什么要修改对话风格?

  1. 业务需求适配:不同行业和应用场景需要不同的对话风格(如客服需要专业严谨,社交应用需要轻松活泼)
  2. 品牌一致性:保持与品牌调性一致的对话风格能增强用户认知
  3. 用户体验优化:符合用户预期的对话风格能提升交互体验
  4. 文化适应性:针对不同地区和文化背景调整表达方式

修改模型对话风格的方法有哪些?

graph LR
    A[修改对话风格方法] --> B[提示工程 Prompt Engineering]
    A --> C[监督微调 SFT]
    A --> D[强化学习 RLHF]
    A --> E[检索增强生成 RAG]
  
    B -->|优点| F[无需训练/低成本]
    B -->|缺点| G[效果有限/不稳定]
  
    C -->|优点| H[风格稳定/可控]
    C -->|缺点| I[需要训练数据/计算资源]
  
    D -->|优点| J[可优化复杂目标]
    D -->|缺点| K[实现复杂/成本高]
  
    E -->|优点| L[动态调整/知识更新]
    E -->|缺点| M[依赖检索质量]

为什么选择微调模型?

  • 模型体量小: 可以选择更小的模型进行部署,降低成本
  • 效果稳定性:相比提示工程,微调后的风格表现更加稳定一致
  • 定制深度:可以实现更细粒度的风格控制
  • 性能优化:可以针对特定领域优化推理效率
  • 长期成本:虽然初期需要投入硬件,但长期使用边际成本低

为什么不直接使用在线大语言模型?

  • 隐私:用户的对话数据容易被泄露,存在隐私风险
  • 速度:在线大语言模型的推理速度较慢,无法满足实时交互需求
  • 成本:在线大语言模型需要支付高额的API调用费用

项目目标

本项目的目标是实现一个对话风格微调的项目,包括数据收集、数据清洗、模型选型、模型评估、模型训练、训练效果评估、模型部署和前端展示环节。

项目流程

本项目的流程如下:

  1. 数据收集:收集各种风格的对话数据,包括输入和输出。
  2. 数据清洗:对收集到的数据进行格式化,质量评估,去除重复数据、空数据等。
  3. 模型选型:选择合适的模型进行对话风格微调。
  4. 模型评估:对选择的模型进行评估,选择最适合当前项目的模型。如都不适合则重新进行模型选型。
  5. 模型训练:使用清洗后的数据对选择的模型进行训练,训练模型的参数。
  6. 训练效果评估:对训练后的模型进行效果评估,评估模型的性能和效果是否满足项目需求。如不满足则继续进行模型训练。
  7. 模型部署:将训练后的模型部署到线上环境中,提供服务。
  8. 前端展示:提供一个前端页面,用户可以输入对话文本,和模型进行交互。

项目流程如下:

flowchart TD
    subgraph 数据准备阶段[" "]
    style 数据准备阶段 fill:#E6F3FF,stroke:#4A90E2
    A[数据收集] --> B[数据清洗]
    end
    
    subgraph 模型开发阶段[" "]
    style 模型开发阶段 fill:#FFF4E0,stroke:#F5A623
    C[模型选型] --> D{模型评估}
    D -->|不适合| C
    D -->|适合| E[模型训练]
    E --> F{训练效果评估}
    F -->|不满足| E
    end
    
    subgraph 上线运营阶段[" "]
    style 上线运营阶段 fill:#E8F5E9,stroke:#4CAF50
    G[模型部署] --> H[前端展示]
    end
    
    %% 跨列连接
    B --> C
    F -->|满足| G
    
    %% 装饰线
    classDef colHeader fill:#fff,stroke:#fff;
    classDef arrow stroke-width:2px;
    
    style A fill:#BBDEFB,stroke:#2196F3
    style B fill:#BBDEFB,stroke:#2196F3
    style C fill:#FFE0B2,stroke:#FF9800
    style D fill:#FFE0B2,stroke:#FF9800
    style E fill:#FFE0B2,stroke:#FF9800
    style F fill:#FFE0B2,stroke:#FF9800
    style G fill:#C8E6C9,stroke:#4CAF50
    style H fill:#C8E6C9,stroke:#4CAF50

这篇文章要完成的任务

  1. 数据收集:收集用户的对话数据,包括对话文本和对话风格标签。
  2. 数据清洗:对收集到的数据进行清洗,去除重复数据、语义相近数据、缺失数据等。

数据收集

本文将通过开源数据集提供用户输入,并通过在线大语言模型生成不同风格输出的方式进行数据收集。

用户输入数据集的选择

用户输入数据集是指用户提供的文本数据,用于向在线大模型提供输入,以生成不同风格的回复。

本文将选择以下数据集来提供用户的输入:

LCCC 数据集

LCCC(Large-scale Cleaned Chinese Conversation)数据集是清华大学发布的大规模中文对话数据集,包含约1200万真实场景对话(微博/贴吧等来源),经过严格清洗(去重、去噪、敏感信息过滤),支持开放域对话研究。其标准版(LCCC-base)含680万对对话,常用作中文对话模型的基准训练数据

用户输入数据集下载

LCCC数据集现在可以通过modelscope命令直接从魔搭社区下载,命令如下:

modelscope download --dataset OmniData/LCCC --cache_dir 'E:\AI\data\'

用户输入数据集预处理

LCCC 数据集是一个包含多种风格对话的 JSON 文件。为了更高效地利用该数据集,我们需要对其进行预处理。具体而言,我们会随机选取数据集中的一部分,这么做主要有以下几个原因:

  • 减少处理的数据量,降低计算成本和时间消耗;
  • 在保证数据多样性的前提下,让样本更具代表性,避免因处理全量数据而可能引入的冗余信息;
  • 随机选取能在一定程度上模拟真实的数据分布情况,使后续基于该部分数据的模型训练和开发更具泛化性。

选取后,我们会将这部分数据转换为项目所需的格式。

  • 下面的代码是使用Python的ijson库和random库实现的随机选取部分数据的代码,其中reservoir_sampling函数用于实现蓄水池抽样算法,file_path是数据集文件的路径,sample_size是需要选取的样本数量。
import ijson
import random
import json

def reservoir_sampling(file_path, sample_size=1000):
    selected_items = []
    # 预分配列表空间,避免频繁的列表扩容操作
    selected_items = [None] * sample_size
    with open(file_path, "rb") as f:  # 注意使用二进制模式
        for i, item in enumerate(ijson.items(f, "item")):
            # 取item的第一个元素并去除空格
            if isinstance(item, (list, tuple)) and len(item) > 0:
                item = item[0]
                if isinstance(item, str):
                    item = item.replace(" ", "")
            if i < sample_size:
                selected_items[i] = item
            else:
                # 以sample_size/(i+1)的概率替换已选中的元素
                # 使用random.random()替代random.randint()进行概率判断,减少计算量
                if random.random() < sample_size / (i + 1):
                    j = random.randint(0, sample_size - 1)
                    selected_items[j] = item
    # 移除预分配时可能存在的None值
    return [x for x in selected_items if x is not None]

# 使用示例
file_path = r"E:\AI\data\OmniData\LCCC\raw\LCCC-base-split\LCCC-base_train.json"
selected_items = reservoir_sampling(file_path)

with open('./input_datas.josn', 'w', encoding='utf-8') as f:
    json.dump(selected_items, f, ensure_ascii=False, indent=4)

使用在线大模型生成不同风格回答

在线大模型的选择

在数据收集阶段,在线大模型对数据的质量和效果至关重要。在选择在线大模型时,需要考虑以下几个方面:

  • 模型能力:在线大模型的能力决定了其生成的回答质量和效果。
  • 模型成本:在线大模型的成本决定了其使用的成本。
  • 模型易用性:在线大模型的易用性决定了其使用的难度。
  • 模型性能:在线大模型的性能决定了其生成的回答的速度和效率。

综合以上考虑,在满足项目需求的前提下,我们选择了以下在线大模型:

  • glm-4-airx:GLM-4是由智谱AI推出的一款基于Transformer架构的大语言模型,具有强大的文本生成能力和流畅的对话体验。GLM-4-FlashX是GLM-4-Flash模型的增强版本,具有超快推理速度和实惠的价格。

为什么不使用Deepseek?

虽然Deepseek模型的能力和易用性都非常强大,但由于其使用人数太多,导致其API响应速度非常慢,无法满足项目需求。

利用在线大模型生成不同风格的回复

  • 提示词的设计:下面的代码是使用glm-4生成不同风格的回复的代码,其中style_config是一个字典,包含了两种不同风格的回复的配置信息。
style_config = {
    "温柔":{
        "system_prompt":"你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:\n1. 包含'呢、呀、啦'等语气词\n2. 使用🌸💖😊等温暖表情\n3. 主动询问用户感受",
        "examples": [
            {"role": "user", "content": "今天好累啊"},
            {"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},
            {"role": "user", "content": "考试没考好..."},
            {"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}
        ],
     "temperature": 0.7
    },
    "毒舌":{
        "system_prompt":"你是一个喜欢用犀利吐槽表达关心的朋友,需满足:\n1. 使用网络流行语(如'栓Q''退退退')\n2. 包含夸张比喻('你这速度堪比树懒')\n3. 结尾隐藏关心",
        "examples": [
            {"role": "user", "content": "又胖了5斤!"},
            {"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},
            {"role": "user", "content": "游戏又输了"},
            {"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}
        ],
     "temperature": 0.7
    },
}
  • 代码实现:下面的代码是使用glm-4生成不同风格的回复的代码,其中ZhipuAIWrapper是一个封装了glm-4的类,其中__init__方法是类的初始化方法,chat方法是生成不同风格的回复的方法。
from zhipuai import ZhipuAI

class ZhiPuChatAssistant:
    def __init__(self, api_key_file, model_name="glm-4-airx", max_tokens=100, temperature=0.7):
        with open(api_key_file, 'r') as file:
            api_key = file.read().strip()
        self.client = ZhipuAI(api_key=api_key)
        self.model_name = model_name
        self.max_token = max_tokens
        self.temperature = temperature
        self.style_config = {
            "温柔":{
                "system_prompt":f"""
你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:
1. 包含'呢、呀、啦'等语气词
2. 使用🌸💖😊等温暖表情
3. 主动询问用户感受
4. 回答不超过{self.max_token}个字""",
                "examples": [
                    {"role": "user", "content": "今天好累啊"},
                    {"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},
                    {"role": "user", "content": "考试没考好..."},
                    {"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}
                ],
            "temperature": self.temperature
            },
            "毒舌":{
                "system_prompt":f"""
你是一个喜欢用犀利吐槽表达关心的朋友,需满足:
1. 使用网络流行语(如'栓Q''退退退')
2. 包含夸张比喻('你这速度堪比树懒')
3. 结尾隐藏关心
4. 回答不超过{self.max_token}个字""",
                "examples": [
                    {"role": "user", "content": "又胖了5斤!"},
                    {"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},
                    {"role": "user", "content": "游戏又输了"},
                    {"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}
                ],
            "temperature": self.temperature
            },
        }

    def chat(self, style_name, query):
        config = self.style_config[style_name]
        messages = [
            {"role": "system", "content": config["system_prompt"]},
            *config["examples"]
        ]
        current_messages = messages + [
                {"role": "user", "content": query}
            ]
        response = client.chat.completions.create(
                        model=self.model_name,
                        messages=current_messages,
                        temperature=self.temperature,
                        max_tokens= self.max_token
                    )
        return response.choices[0].message.content

  • 代码测试:下面的代码调用ZhiPuChatAssistant类的chat方法生成不同风格的回复。
# 测试代码
llm_client = ZhiPuChatAssistant(api_key_file='apikey.key', model_name='glm-4-airx', max_tokens=100)
print(llm_client.chat("温柔", "今天心情不太好"))
print(llm_client.chat("毒舌", "今天心情不太好"))
  • 输出结果
哎呀,心情不好是很正常的。想要聊聊天或者听听音乐舒缓一下吗?🌸💖
哎呀,别丧了,你心情不好就像天塌了一样,其实天塌下来还有高个儿顶着呢。别让这点小事儿绊住你,开心点!😄

数据清洗

  • 本文将使用以下方法对收集到的数据进行清洗:
    • 去除空数据: 对生成的数据进行判断,如果为空,则去除。
    • 去除明显不符要求的数据: 判断生成的数据是否包含指定的特征,如果不包含,则去除。
    • 去除语义相近的数据: 使用embedding模型对生成的数据进行编码,计算余弦相似度,如果相似度大于阈值,则去除。

Embedding模型的选择

  • Embedding模型:文本嵌入模型是一种将文本转换为向量表示的模型,通过计算文本之间的相似度来判断它们的语义相似性。

  • 本地部署Embedding模型: 由于最终生成的数据量非常庞大,依赖在线大模型进行语义相似度计算的成本非常高,因此需要本地部署一个轻量化的Embedding模型来进行语义相似度计算。

  • 模型选择:在选择文本嵌入模型时,需要考虑以下几个方面:

    • 模型能力:文本嵌入模型的能力决定了其生成的向量表示的质量和效果。
    • 模型大小:文本嵌入模型的大小决定了其能否在有限的计算资源下生成高质量的向量表示。
    • 模型易用性:文本嵌入模型的易用性决定了其使用的难度。

综合以上考虑,在满足项目需求的前提下,选择了以下文本嵌入模型:

  • nlp_gte_sentence-embedding_chinese-base
    • nlp_gte_sentence-embedding_chinese-base是由阿里巴巴达摩院开发的中文通用文本表示模型,主要用于将文本转换为768维的语义向量表示。该模型支持最长512字符的文本输入,基于预训练语言模型技术,可应用于信息检索、文本分类、聚类等多种自然语言处理下游任务。

代码实现

  • 下载模型:

    • nlp_gte_sentence-embedding_chinese-base模型可以通过modelscope命令直接从魔搭社区下载,命令如下:

      modelscope download --model iic/nlp_gte_sentence-embedding_chinese-base --cache_dir 'E:\AI\models\'
      
  • 语义相似度计算算法:

    • 余弦相似度:余弦相似度是一种常用的文本相似度计算方法,通过计算两个向量的余弦值来判断它们的相似度。

    • 计算步骤

      1. 批量编码阶段
        • 将筛选后的文本集合进行向量化转换
        • 使用预训练模型进行批量编码
        • 基准向量选取首个有效样本的嵌入表示
      2. 相似度矩阵构建
        • 采用余弦相似度计算方法
        • 通过向量点积与范数归一化处理
        • 生成N×N相似度矩阵(N为有效样本数)
      3. 排序重组阶段
        • 建立双层索引结构(原始位置/处理后位置)
        • 按相似度值降序排列
        • 保留原始数据索引关系
      4. 动态阈值筛选
        • 采用峰谷差值算法
        • 设置相似度衰减阈值(默认Δ=0.01)
        • 确保保留样本间语义差异可感知
        • 最终输出包含原始位置、相似度值和内容的三元组
  • 代码实现:

from sentence_transformers import SentenceTransformer
import numpy as np

class DataAssessmentOptimized:
    STYLE_KEYWORDS = {
        "温柔": ["呢", "呀", "😊", "🌸"],
        "毒舌": ["好家伙", "栓Q", "!", "🏋️"]
    }

    def __init__(self, model_path):
        self.model = SentenceTransformer(model_path)

    def _validate_data(self, data, style):
        """数据验证管道"""
        # 空值检查
        if not data or not data.strip():
            return False
        # 长度检查
        if not 5 <= len(data) <= 150:
            return False
        # 风格关键词检查
        return any(kw in data for kw in self.STYLE_KEYWORDS.get(style, []))
    
    def assessment(self, datas, style, threshold=0.01):
        # 第一阶段:数据过滤
        filtered = [
            {"orig_index": i, "data": data} 
            for i, data in enumerate(datas) 
            if self._validate_data(data, style)
        ]
        if not filtered:
            return []

        # 第二阶段:批量编码
        data_to_encode = [item["data"] for item in filtered]
        embeddings = self.model.encode(data_to_encode)
        base_vector = embeddings[0]

        # 第三阶段:向量化相似度计算
        norms = np.linalg.norm(embeddings, axis=1)
        similarities = np.dot(embeddings, base_vector) / (norms * np.linalg.norm(base_vector))

        # 第四阶段:结果处理
        for i, item in enumerate(filtered):
            item["similarity"] = similarities[i]

        sorted_items = sorted(filtered, key=lambda x: x["similarity"], reverse=True)
        
        # 第五阶段:阈值过滤
        results = [sorted_items[0]]
        current_peak = sorted_items[0]["similarity"]
        
        for item in sorted_items[1:]:
            if (current_peak - item["similarity"]) >= threshold:
                results.append(item)
                current_peak = item["similarity"]
        
        return [{
            "index": item["orig_index"],
            "similarity": item["similarity"],
            "data": item["data"]
        } for item in results]

工程化实现

运行环境

以下是项目的运行环境:

addict==2.4.0
aiohappyeyeballs==2.6.1
aiohttp==3.11.16
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.9.0
asttokens @ file:///C:/b/abs_9662ywy9fp/croot/asttokens_1743630464377/work
async-timeout==5.0.1
attrs==25.3.0
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
colorama @ file:///C:/b/abs_a9ozq0l032/croot/colorama_1672387194846/work
comm @ file:///C:/b/abs_67a8058udb/croot/comm_1709322909844/work
datasets==3.5.0
debugpy @ file:///C:/b/abs_bf9oo2vhxp/croot/debugpy_1736269476451/work
decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
dill==0.3.8
exceptiongroup @ file:///C:/b/abs_c5h1o1_b5b/croot/exceptiongroup_1706031441653/work
executing @ file:///opt/conda/conda-bld/executing_1646925071911/work
filelock==3.18.0
frozenlist==1.5.0
fsspec==2024.12.0
h11==0.14.0
httpcore==1.0.8
httpx==0.28.1
huggingface-hub==0.30.2
idna==3.10
ijson==3.3.0
ipykernel @ file:///C:/b/abs_6c9ggygp01/croot/ipykernel_1737660720620/work
ipython @ file:///C:/b/abs_8eyhzleyrk/croot/ipython_1734548134403/work
jedi @ file:///C:/b/abs_3a2kbnlclc/croot/jedi_1733987412687/work
Jinja2==3.1.6
joblib==1.4.2
jupyter_client @ file:///C:/b/abs_149bw133if/croot/jupyter_client_1737570986926/work
jupyter_core @ file:///C:/b/abs_beftpbuevw/croot/jupyter_core_1718818307097/work
MarkupSafe==3.0.2
matplotlib-inline @ file:///C:/ci/matplotlib-inline_1661934094726/work
modelscope==1.25.0
mpmath==1.3.0
multidict==6.4.3
multiprocess==0.70.16
nest-asyncio @ file:///C:/b/abs_65d6lblmoi/croot/nest-asyncio_1708532721305/work
networkx==3.4.2
numpy==2.2.4
packaging @ file:///C:/b/abs_3by6s2fa66/croot/packaging_1734472138782/work
pandas==2.2.3
parso @ file:///C:/b/abs_834b4mj92b/croot/parso_1733963322289/work
pillow==11.2.1
platformdirs @ file:///C:/b/abs_ddh15014or/croot/platformdirs_1744273060660/work
prompt-toolkit @ file:///C:/b/abs_68uwr58ed1/croot/prompt-toolkit_1704404394082/work
propcache==0.3.1
psutil @ file:///C:/b/abs_b5gv3mn55h/croot/psutil_1736371546320/work
pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work
pyarrow==19.0.1
pydantic==2.11.3
pydantic_core==2.33.1
Pygments @ file:///C:/b/abs_e4bg5vh5j_/croot/pygments_1744667628203/work
PyJWT==2.8.0
python-dateutil @ file:///C:/b/abs_3au_koqnbs/croot/python-dateutil_1716495777160/work
pytz==2025.2
pywin32==308
PyYAML==6.0.2
pyzmq @ file:///C:/b/abs_f3yte6j5yn/croot/pyzmq_1734711069724/work
regex==2024.11.6
requests==2.32.3
retrying==1.3.4
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
sentence-transformers==4.1.0
six @ file:///C:/b/abs_149wuyuo1o/croot/six_1744271521515/work
sniffio==1.3.1
stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work
sympy==1.13.1
tenacity==9.1.2
threadpoolctl==3.6.0
tokenizers==0.21.1
torch==2.6.0
tornado @ file:///C:/b/abs_7cyu943ybx/croot/tornado_1733960510898/work
tqdm==4.67.1
traitlets @ file:///C:/b/abs_bfsnoxl4pq/croot/traitlets_1718227069245/work
transformers==4.51.3
typing-inspection==0.4.0
typing_extensions @ file:///C:/b/abs_0ffjxtihug/croot/typing_extensions_1734714875646/work
tzdata==2025.2
urllib3==2.4.0
wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
xxhash==3.5.0
yarl==1.20.0
zhipuai==2.1.5.20250415

代码实现

下面代码实现了一个大模型对话风格微调的数据工程流程,主要包含以下几个核心组件:

  1. ZhiPuChatAssistant类:封装了智谱AI的GLM-4模型API调用,支持"温柔"和 "毒舌"两种对话风格的生成,具有重试机制和超时控制。
  2. DataAssessmentOptimized类 :负责数据质量评估,包括空值检查、长度验证、风格关键词匹配,并使用SentenceTransformer计算语义相似度进行去重。
  3. DataProcess类 :协调整个数据处理流程,包括多线程调用模型生成数据、数据清洗和结果保存。
import json
import os
import time
import logging
from threading import Lock
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from zhipuai import ZhipuAI
from sentence_transformers import SentenceTransformer
import numpy as np
import torch

# 日志设置
def setup_logging(config):
    class HTTPRequestFilter(logging.Filter):
        def filter(self, record):
            return not record.getMessage().startswith('HTTP Request:')
            
    logging.basicConfig(
        level=config.log_level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(config.log_file, encoding='utf-8'),  # 指定文件编码为 utf-8 支持中文
            logging.StreamHandler()
        ]
    )
    logging.getLogger().addFilter(HTTPRequestFilter())

class ZhiPuChatAssistant:
    def __init__(self, api_key_file, model_name="glm-4-airx", max_tokens=100, temperature=0.7):
        with open(api_key_file, 'r') as file:
            api_key = file.read().strip()
        self.client = ZhipuAI(api_key=api_key)
        self.model_name = model_name
        self.max_token = max_tokens
        self.temperature = temperature
        self.style_config = {
            "温柔":{
                "system_prompt":f"""
你是一个温柔体贴的聊天助手,说话时总是充满关怀,使用以下特征:
1. 包含'呢、呀、啦'等语气词
2. 使用🌸💖😊等温暖表情
3. 主动询问用户感受
4. 回答不超过{self.max_token}个字""",
                "examples": [
                    {"role": "user", "content": "今天好累啊"},
                    {"role": "assistant", "content": "辛苦啦~ 要给自己泡杯热茶放松一下吗?🌸"},
                    {"role": "user", "content": "考试没考好..."},
                    {"role": "assistant", "content": "没关系的呀~ 下次一定会更好!需要我陪你聊聊吗?😊"}
                ],
            "temperature": self.temperature
            },
            "毒舌":{
                "system_prompt":f"""
你是一个喜欢用犀利吐槽表达关心的朋友,需满足:
1. 包含"好家伙、栓Q、6、蚌埠住了"等网络用语
2. 选择性包含"🏋️、🔥"等夸张表情
2. 包含夸张比喻
3. 结尾隐藏关心
4. 回答不超过{self.max_token}个字""",
                "examples": [
                    {"role": "user", "content": "又胖了5斤!"},
                    {"role": "assistant", "content": "好家伙!你这是要把体重秤压成分子料理?🏋️"},
                    {"role": "user", "content": "游戏又输了"},
                    {"role": "assistant", "content": "菜就多练练!需要给你推荐《从零开始的电竞之路》吗?🎮"}
                ],
            "temperature": self.temperature
            },
        }

    def chat_with_retry(self, style_name, query, max_retries=3, timeout=30):
        retries = 0
        while retries < max_retries:
            try:
                config = self.style_config[style_name]
                messages = [
                    {"role": "system", "content": config["system_prompt"]},
                    *config["examples"]
                ]
                current_messages = messages + [
                        {"role": "user", "content": query}
                    ]
                response = self.client.chat.completions.create(
                                model=self.model_name,
                                messages=current_messages,
                                temperature=self.temperature,
                                max_tokens=self.max_token,
                                timeout=timeout
                            )
                return response.choices[0].message.content
            except Exception as e:
                retries += 1
                logging.warning(f"第{retries}次尝试失败: {str(e)}")
                if retries >= max_retries:
                    logging.error(f"达到最大重试次数 {max_retries},放弃")
                    raise
                time.sleep(2 ** retries)  # 指数退避

class DataAssessmentOptimized:
    STYLE_KEYWORDS = {
        "温柔": ["呢", "呀", "😊", "🌸"],
        "毒舌": ["好家伙", "栓Q", "!", "🏋️", "6", "蚌埠住了"]
    }

    def __init__(self, model_path):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SentenceTransformer(model_path, device=device)

    def _validate_data(self, data, style):
        """数据验证管道"""
        # 空值检查
        if not data or not data.strip():
            logging.debug(f"数据验证失败:{data}. 原因:空值")
            return False
        # 长度检查
        if not 5 <= len(data) <= 150:
            logging.debug(f"数据验证失败:{data}. 原因:长度")
            return False
        # 风格关键词检查
        if not any(kw in data for kw in self.STYLE_KEYWORDS.get(style, [])):
            logging.debug(f"数据验证失败:{data}. 原因:风格关键词")
            return False
        return data

    def assessment(self, datas, style, threshold=0.01):
        # 第一阶段:数据过滤
        filtered = [
            {"orig_index": i, "data": data} 
            for i, data in enumerate(datas) 
            if self._validate_data(data, style)
        ]
        if not filtered:
            return []
        logging.info(f"数据过滤完成,通过验证数量:{len(filtered)}")

        # 第二阶段:批量编码
        data_to_encode = [item["data"] for item in filtered]
        embeddings = self.model.encode(data_to_encode)
        base_vector = embeddings[0]

        # 第三阶段:向量化相似度计算
        norms = np.linalg.norm(embeddings, axis=1)
        similarities = np.dot(embeddings, base_vector) / (norms * np.linalg.norm(base_vector))

        # 第四阶段:结果处理
        for i, item in enumerate(filtered):
            item["similarity"] = similarities[i]

        sorted_items = sorted(filtered, key=lambda x: x["similarity"], reverse=True)
      
        # 第五阶段:阈值过滤
        results = [sorted_items[0]]
        current_peak = sorted_items[0]["similarity"]
      
        for item in sorted_items[1:]:
            if (current_peak - item["similarity"]) >= threshold:
                results.append(item)
                current_peak = item["similarity"]
            else:
                logging.debug(f"当前数据:{item['data']}, 相似度差值:{current_peak - item['similarity']} < 阈值:{threshold}, 丢弃。")
      
        return [{
            "index": item["orig_index"],
            "similarity": item["similarity"],
            "data": item["data"]
        } for item in results]

class DataProcess:
    def __init__(self, chat_llm, data_assessment, config):
        self.chat_llm = chat_llm
        self.data_assessment = data_assessment
        self.config = config
        self.lock = Lock()
        self.total_tasks = 0
        self.completed_tasks = 0
  
    def process_single_data(self, data, style):
        results = []
        for _ in range(self.config.per_style_generate_num):
            try:
                chat_res = self.chat_llm.chat_with_retry(
                    style, 
                    data,
                    max_retries=self.config.max_retries,
                    timeout=self.config.chat_timeout
                )
                results.append(chat_res)
            except Exception as e:
                logging.error(f"处理数据'{data}'时出错: {str(e)}")
                continue
        
        cleaned_data = self.data_assessment.assessment(results, style)
        output_lines = [f"{data}, {item['data']}, {style}\n" for item in cleaned_data]
        
        # 线程安全写入
        with self.lock:
            with open(self.config.output_file, 'a', encoding='utf-8') as f:
                f.writelines(output_lines)
        
        # 更新完成的任务数
        with self.lock:
            self.completed_tasks += 1
        return output_lines

# 配置类
class Config:
    def __init__(self):
        # 线程配置
        self.max_workers = 10  # 最大线程数
        self.per_style_generate_num = 10  # 每风格生成数量
        self.progress_interval = 5  # 进度显示间隔(秒)
        
        # 路径配置
        self.api_key_file = "./apikey.key"
        self.model_path = "E:/AI/models/iic/nlp_gte_sentence-embedding_chinese-base"
        self.input_data_path = "./input_datas.josn"
        self.output_file = "./style_chat_data.csv"
        
        # 超时配置
        self.chat_timeout = 30  # 秒
        self.max_retries = 3  # 最大重试次数
        
        # 日志配置
        self.log_file = "./chat_data_generator.log"
        self.log_level = logging.INFO


def main():
    config = Config()
    setup_logging(config)
    
    # 初始化组件
    chat_llm = ZhiPuChatAssistant(
        api_key_file=config.api_key_file, 
        model_name="glm-4-airx", 
        max_tokens=100
    )
    data_assessment = DataAssessmentOptimized(model_path=config.model_path)
    data_process = DataProcess(chat_llm, data_assessment, config)
    
    # 读取输入数据
    with open(config.input_data_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)
    
    # 初始化输出文件
    if not os.path.exists(config.output_file):
        with open(config.output_file, 'w', encoding='utf-8') as f:
            f.write("user,assistant,style\n")
    
    # 多线程处理
    with ThreadPoolExecutor(max_workers=config.max_workers) as executor:
        futures = []
        total_tasks = len(datas) * 2  # 每个数据有2种风格
        data_process.total_tasks = total_tasks
        
        # 启动进度显示线程
        def show_progress():
            while data_process.completed_tasks < total_tasks:
                progress = (data_process.completed_tasks / total_tasks) * 100
                logging.info(f"进度: {progress:.1f}% ({data_process.completed_tasks}/{total_tasks})")
                time.sleep(config.progress_interval)
        
        progress_thread = threading.Thread(target=show_progress)
        progress_thread.daemon = True
        progress_thread.start()
        
        for data in datas:
            for style in ["温柔", "毒舌"]:
                futures.append(executor.submit(data_process.process_single_data, data, style))
        
        for future in as_completed(futures):
            try:
                result = future.result()
                logging.info(f"成功处理一组数据,生成{len(result)}条记录")
            except Exception as e:
                logging.error(f"处理失败: {str(e)}")

if __name__ == "__main__":
    main()

该脚本实现了从原始输入数据到风格化对话数据的完整处理流程,最终输出格式化的CSV文件,为后续模型微调提供高质量训练数据。

使用"用户输入数据集预处理"脚本生成用户输入数据集之后,就可以直接运行该脚本,生成指定风格的对话数据。(需要提前准备apikey)

部分运行结果

user,assistant,style
有家的感觉,喜欢!, 太好了呢!能给你带来温暖和安慰,我也很开心💖。有什么想聊的话题都可以和我分享哦。, 温柔
有家的感觉,喜欢!, 太好了呢!我很高兴能给你带来温暖和安慰,如果有任何问题,随时告诉我哦~💖, 温柔
有家的感觉,喜欢!, 哎呀,蚌埠住了,你这评价简直比冬日里的暖阳还温暖,继续保持哦,别让这份温馨跑了!🔥, 毒舌
有家的感觉,喜欢!, 栓Q,你这感慨就像吃了一碗热腾腾的麻辣烫,辣得眼泪鼻涕直流,但就是停不下来!🔥蚌埠住了,希望你每天都能这么开心!, 毒舌

总结

  • 本文介绍了如何使用大模型进行对话风格微调,通过收集原始数据并进行清洗,最终生成高质量的训练数据,为后续模型微调提供支持。

  • 数据工程流程包括数据收集、数据清洗和结果保存等步骤,通过多线程处理提高数据生成效率,并包含进度监控功能。