Datawhale | AI+X AI 夏令营2024 Task3 笔记(3)

139 阅读18分钟

#AI夏令营 #Datawhale #夏令营

今天开始做 baseline 2,早排队早拿到分。baseline 2 主要是使用官方零代码微调方式上分,之后这种方式可能不会经常用,prompt 的潜力还没发挥到极致,用这个有点浪费算力。

大致流程如下:

image.png

1. 数据集制作

对应流程图里的数据处理这一步

链接:aistudio.baidu.com/projectdeta…

还是先 fork 一个自己的版本,然后开始操作

1.1 环境配置

先对原始群聊数据做初步抽取,准备一下讯飞 3.5 的 API 环境配置。和 baseline 1 的配置一样。

!pip uninstall websocket-client
WARNING: Skipping websocket-client as it is not installed.
!pip install --upgrade spark_ai_python websocket-client
Looking in indexes: https://mirror.baidu.com/pypi/simple/, https://mirrors.aliyun.com/pypi/simple/
WARNING: Skipping page https://mirror.baidu.com/pypi/simple/spark-ai-python/ because the GET request got Content-Type: application/octet-stream. The only supported Content-Types are application/vnd.pypi.simple.v1+json, application/vnd.pypi.simple.v1+html, and text/html
Collecting spark_ai_python
  Downloading https://mirrors.aliyun.com/pypi/packages/c6/cc/0f3a96f46d763a305e2f94ecc4f5500d47aa80966b84347f7b49b52d6c83/spark_ai_python-0.3.31-py3-none-any.whl (344 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 344.8/344.8 kB 648.2 kB/s eta 0:00:00a 0:00:01
Collecting websocket-client
  Downloading https://mirrors.aliyun.com/pypi/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl (58 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 kB 632.3 kB/s eta 0:00:00a 0:00:01
Requirement already satisfied: aiohttp>3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (3.9.5)
Requirement already satisfied: httpx in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (0.27.0)
Collecting jsonpatch (from spark_ai_python)
  Downloading https://mirrors.aliyun.com/pypi/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Requirement already satisfied: nest-asyncio<2.0.0,>=1.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (1.6.0)
Requirement already satisfied: packaging in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (24.0)
Requirement already satisfied: pydantic in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (2.7.0)
Collecting python-dotenv (from spark_ai_python)
  Downloading https://mirrors.aliyun.com/pypi/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (6.0.1)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (2.31.0)
WARNING: Skipping page https://mirror.baidu.com/pypi/simple/tenacity/ because the GET request got Content-Type: application/octet-stream. The only supported Content-Types are application/vnd.pypi.simple.v1+json, application/vnd.pypi.simple.v1+html, and text/html
Collecting tenacity (from spark_ai_python)
  Downloading https://mirrors.aliyun.com/pypi/packages/e3/ee/b179c3ab5cb842d75c65339c4b86b572eaf8f43407890bd1d2c7b72eb829/tenacity-8.4.2-py3-none-any.whl (28 kB)
Requirement already satisfied: websockets in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from spark_ai_python) (11.0.3)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (1.9.4)
Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp>3.3->spark_ai_python) (4.0.3)
Requirement already satisfied: anyio in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx->spark_ai_python) (4.3.0)
Requirement already satisfied: certifi in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx->spark_ai_python) (2024.2.2)
Requirement already satisfied: httpcore==1.* in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx->spark_ai_python) (1.0.5)
Requirement already satisfied: idna in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx->spark_ai_python) (3.7)
Requirement already satisfied: sniffio in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx->spark_ai_python) (1.3.1)
Requirement already satisfied: h11<0.15,>=0.13 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpcore==1.*->httpx->spark_ai_python) (0.14.0)
Collecting jsonpointer>=1.9 (from jsonpatch->spark_ai_python)
  Downloading https://mirrors.aliyun.com/pypi/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl (7.6 kB)
Requirement already satisfied: annotated-types>=0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pydantic->spark_ai_python) (0.6.0)
Requirement already satisfied: pydantic-core==2.18.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pydantic->spark_ai_python) (2.18.1)
Requirement already satisfied: typing-extensions>=4.6.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pydantic->spark_ai_python) (4.11.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from requests->spark_ai_python) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from requests->spark_ai_python) (2.2.1)
Requirement already satisfied: exceptiongroup>=1.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from anyio->httpx->spark_ai_python) (1.2.1)
Installing collected packages: websocket-client, tenacity, python-dotenv, jsonpointer, jsonpatch, spark_ai_python
Successfully installed jsonpatch-1.33 jsonpointer-3.0.0 python-dotenv-1.0.1 spark_ai_python-0.3.31 tenacity-8.4.2 websocket-client-1.8.0
WARNING: Skipping page https://mirror.baidu.com/pypi/simple/pip/ because the GET request got Content-Type: application/octet-stream. The only supported Content-Types are application/vnd.pypi.simple.v1+json, application/vnd.pypi.simple.v1+html, and text/html

环境准备好,把 API 调用抽出来写成一个函数

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import numpy as np
from tqdm import tqdm


def chatbot(prompt):
    #星火认知大模型Spark3.5 Max的URL值,其他版本大模型URL值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
    #星火认知大模型调用秘钥信息,请前往讯飞开放平台控制台(https://console.xfyun.cn/services/bm35)查看
    SPARKAI_APP_ID = ''
    SPARKAI_API_SECRET = ''
    SPARKAI_API_KEY = ''
    #星火认知大模型Spark3.5 Max的domain值,其他版本大模型domain值请前往文档(https://www.xfyun.cn/doc/spark/Web.html)查看
    SPARKAI_DOMAIN = 'generalv3.5'
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    messages = [ChatMessage(
        role="user",
        content=prompt
    )]
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].message.content

1.2 数据处理 Prompt

这里我们对原群聊对话设计了一个总结 Prompt,目的是将原始对话内容进行精简,方便做微调数据。

一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制,还有上下文太长会导致抽取效果变差。

过长的上下文也会导致训练时长和费用倍增。(比如我做了一个数据集要花 3000 多块钱跑完。就算能跑可能也要 1 - 2 天...)

好了我们来说说 prompt。这个 prompt 相较于 baseline 1 区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:

字段分组.png

通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。

content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

****群聊对话****
{content}

****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''

1.3 训练数据集制作

jsonl_data 是用来训练的规范单行数据,需要由训练数据组成一个 jsonl 文件(每行是一个 json 数据的文件),格式如下:

jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
print(jsonl_data)
{'instruction': '假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。', 'input': '请调小空气净化器的湿度到1', 'output': '{"intent":"CONTROL","slots":[{"name":"device","normValue":"airCleaner","value":"空气净化器"},{"name":"insType","normValue":"set","value":"调小"},{"name":"attr","normValue":"humidity","value":"湿度"},{"name":"attrValue","normValue":"1","value":"1"}],"sample":"请调小空气净化器的湿度到1"}'}
print(jsonl_data["instruction"])
print(jsonl_data["input"])
print(jsonl_data["output"])
假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。
请调小空气净化器的湿度到1
{"intent":"CONTROL","slots":[{"name":"device","normValue":"airCleaner","value":"空气净化器"},{"name":"insType","normValue":"set","value":"调小"},{"name":"attr","normValue":"humidity","value":"湿度"},{"name":"attrValue","normValue":"1","value":"1"}],"sample":"请调小空气净化器的湿度到1"}

需要训练的数据文件在官网下载后是 train.json 这里直接导入到根目录,无需重复下载

import json

# 打开并读取JSON文件
with open('train.json', 'r', encoding='utf-8') as file:
    data = json.load(file)

这里我们通过星火 3.5 API 清洗原来的数据,总结后按照刚才看到得单行 jsonl 存储格式将数据存入 traindata.jsonl 中。大家可以经过处理后自行查阅 traindata.jsonl 文件,看看都有啥。

这里的训练时长大概 40 min 左右,请耐心等待。这段等待的时间可以看看后面的内容。

# 训练集制作

# 打开一个文件用于写入,如果文件已存在则会被覆盖
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
    # 训练集行数(130)不符合要求,范围:1500~90000000
    # 遍历数据列表,并将每一行写入文件
    # 这里为了满足微调需求我们重复12次数据集 130*12=1560
    
    for line_data in tqdm(data):
        line_input = line_data["chat_text"] 
        line_output = line_data["infos"]
        content = line_input
        
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        # print(res)
        line_write = {
            "instruction":jsonl_data["instruction"],
            "input":json.dumps(res, ensure_ascii=False),
            "output":json.dumps(line_output, ensure_ascii=False)
        }
        # 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
        for time in range(12):
            file.write(json.dumps(line_write, ensure_ascii=False) + '\n')  # '\n' 用于在每行末尾添加换行符
 68%|██████▊   | 88/129 [23:50<12:02, 17.61s/it]2024-07-05 08:07:09 CST - SparkPythonSDK - ERROR - [/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/sparkai/llm/llm.py:687] - SparkLLMClient wait LLM api response timeout 30 seconds
100%|██████████| 129/129 [36:01<00:00, 16.76s/it]

1.4 测试集数据制作

测试数据和训练数据相似,都是通过 API 清洗后存储。

# 验证集制作(提交版本)
# input,target

import json

# 打开并读取JSON文件
with open('test_data.json', 'r', encoding='utf-8') as file:
    data_test = json.load(file)

这里的验证数据我们以 csv 文件存储,有 input 和 target 两列,由于我们没有这些数据的真实标签,我这里将 target 列设置为 '-'。

测试集 text.csv 文件大概需要 20 min 能得到,也请大家耐心等待~

import csv

# 打开一个文件用于写入CSV数据
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
    # 创建一个csv writer对象
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["input","target"])
    # 遍历数据列表,并将每一行写入CSV文件
    for line_data in tqdm(data_test):
        content = line_data["chat_text"]
        prompt = f'''
                你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

                ****群聊对话****
                {content}

                ****分析数据****
                客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
                客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
                客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
                跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

                ****注意****
                1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
                2.不要输出分析内容
                3.输出内容格式为md格式
                '''
        res = chatbot(prompt=prompt)
        
        # print(line_data["chat_text"])
        ## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
        line_list = [res, "-"]   
        csvwriter.writerow(line_list)
        # break
100%|██████████| 55/55 [16:09<00:00, 17.63s/it]

看一下里面是什么样的

print(data_test)

太长了自己去看吧

2. 模型零代码微调

2.1 训练数据上传

链接:training.xfyun.cn/overview

image.png

image.png

image.png

  • 点击确定,等待上传成功

对于 jsonl 数据,运行成功表示训练数据集上传结束

image.png

2.2 测试数据上传

回到:training.xfyun.cn/dataset/dat…

这次我们选择测试集即可。

image.png

上传我们的test.csv文件即可。

image.png

CSV 文件上传成功之后就没有上面说的运行成功,其实运行成功的意思是将 jsonl 数据格式转化成统一格式,转换任务成功而已。CSV 没有转换过程,上传完就完事了。

2.3 平台微调

进入创建微调页面:training.xfyun.cn/model/add

基本配置与版本配置如下,我们选择性价比较好的 Spark Lite 模型~

image.png

版本信息这里不用管,可选项都是唯一选项。然后选择数据集

image.png

训练集和测试集都选择刚才上传好的,如果前面选择了 Spark Pro 的话这里没有测试集的选项,没关系,没有就不上传了。

image.png

参数配置里默认的训练次数是 10,省点钱,改成 1。

image.png

计费规则这里选择刚进页面送的代金券,记得勾选上自动付费。我这里预计费用 0 元是因为我填的信息不完整,我记得如果是 Spark Lite 的话训练一次是 19 元左右,Spark Pro 的话训练一次是 77 元左右。

训练时长的话 Lite 和 Pro 都是预计半小时以上。

这次的数据集生成的质量不算高,而且用重复内容扩充,用 Pro 真的浪费了。

image.png

最后也是这样的状态,排队要排好久,慢慢等吧。排完队之后训练的时间才开始算。

3. 推理

训练完成后,进入模型页面。

training.xfyun.cn/model

image.png

左侧点击模型管理,右侧找到训练好的模型,点击详情。

进去之后,点击右侧发布为服务

image.png

如果已经发布过,这里的按钮会变成这样

image.png

点击发布为服务之后,会有一个窗口让你绑定应用,就绑定到这次比赛对应的应用里就可以。不会和之前没有微调过的模型混淆。等一会讲为什么不会混淆。

发布之后要看详细信息在页面左侧选我的模型服务

image.png

接着拿到 resourceId(模型服务列表这边对应卡片里)、APPID、APIKey、APISecret

回到 aistudio,aistudio.baidu.com/projectdeta…

从第三部分开始

# 定义写入函数

def write_json(json_file_path, data):
    #"""写入json文件"""
    with open(json_file_path, 'w') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)

在 main.ipynb 的微调推理部分填入 APPID、APIKey、APISecret(注意顺序)

import SparkApi
import json
#以下密钥信息从控制台获取
appid = ""     #填写控制台中获取的 APPID 信息
api_secret = ""   #填写控制台中获取的 APISecret 信息
api_key = ""    #填写控制台中获取的 APIKey 信息

#调用微调大模型时,设置为“patch”
domain = "patchv3"

#云端环境的服务地址
Spark_url = "wss://spark-api-n.xf-yun.com/v1.1/chat"  # 微调v1.5环境的地址
# Spark_url = "wss://spark-api-n.xf-yun.com/v3.1/chat"  # 微调v3.0环境的地址


text =[]

# length = 0

def getText(role,content):
    jsoncon = {}
    jsoncon["role"] = role
    jsoncon["content"] = content
    text.append(jsoncon)
    return text

def getlength(text):
    length = 0
    for content in text:
        temp = content["content"]
        leng = len(temp)
        length += leng
    return length

def checklen(text):
    while (getlength(text) > 8000):
        del text[0]
    return text
    


def core_run(text,prompt):
    # print('prompt',prompt)
    text.clear
    Input = prompt
    question = checklen(getText("user",Input))
    SparkApi.answer =""
    # print("星火:",end = "")
    SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
    getText("assistant",SparkApi.answer)
    # print(text)
    return text[-1]['content']

text = []
res = core_run(text,'你好吗?')

这里面出现了一个新的模块 SparkApi,仔细看环境里的文件,有一个 SparkApi.py,就是这个文件

image.png

在SparkApi.py文件的108行,引号中填入你的resourceId

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import websocket  # 使用websocket_client
answer = ""

class Ws_Param(object):
    # 初始化
    def __init__(self, APPID, APIKey, APISecret, Spark_url):
        self.APPID = APPID
        self.APIKey = APIKey
        self.APISecret = APISecret
        self.host = urlparse(Spark_url).netloc
        self.path = urlparse(Spark_url).path
        self.Spark_url = Spark_url

    # 生成url
    def create_url(self):
        # 生成RFC1123格式的时间戳
        now = datetime.now()
        date = format_date_time(mktime(now.timetuple()))

        # 拼接字符串
        signature_origin = "host: " + self.host + "\n"
        signature_origin += "date: " + date + "\n"
        signature_origin += "GET " + self.path + " HTTP/1.1"

        # 进行hmac-sha256进行加密
        signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
                                 digestmod=hashlib.sha256).digest()

        signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')

        authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

        authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')

        # 将请求的鉴权参数组合为字典
        v = {
            "authorization": authorization,
            "date": date,
            "host": self.host
        }
        # 拼接鉴权参数,生成url
        url = self.Spark_url + '?' + urlencode(v)
        # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
        return url


# 收到websocket错误的处理
def on_error(ws, error):
    print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws,one,two):
    print(" ")


# 收到websocket连接建立的处理
def on_open(ws):
    thread.start_new_thread(run, (ws,))


def run(ws, *args):
    data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
    ws.send(data)


# 收到websocket消息的处理
def on_message(ws, message):
    # print(message)
    data = json.loads(message)
    code = data['header']['code']
    if code != 0:
        print(f'请求错误: {code}, {data}')
        ws.close()
    else:
        choices = data["payload"]["choices"]
        status = choices["status"]
        content = choices["text"][0]["content"]
        # print(content,end ="")
        global answer
        answer += content
        # print(1)
        if status == 2:
            ws.close()


def gen_params(appid, domain,question):
    """
    通过appid和用户的提问来生成请参数
    """
    data = {
        "header": {
            "app_id": appid,
            "uid": "1234",
            "patch_id": [""] #调用微调大模型时必传, 否则不传。对应resourceId
        },
        "parameter": {
            "chat": {
                "domain": domain,
                "temperature": 0.1,
                "max_tokens": 4096
            }
        },
        "payload": {
            "message": {
                "text": question
            }
        }
    }
    return data


def main(appid, api_key, api_secret, Spark_url,domain, question):
    # print("星火:")
    wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
    websocket.enableTrace(False)
    wsUrl = wsParam.create_url()
    ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
    ws.appid = appid
    ws.question = question
    ws.domain = domain
    ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})

上面的信息都填好之后继续往下运行 main.ipynb

import pandas as pd
import re

# 读取Excel文件
df_test = pd.read_csv('test.csv',)

data_dict_empty = {
                "基本信息-姓名": "",
                "基本信息-手机号码": "",
                "基本信息-邮箱": "",
                "基本信息-地区": "",
                "基本信息-详细地址": "",
                "基本信息-性别": "",
                "基本信息-年龄": "",
                "基本信息-生日": "",
                "咨询类型": [],
                "意向产品": [],
                "购买异议点": [],
                "客户预算-预算是否充足": "",
                "客户预算-总体预算金额": "",
                "客户预算-预算明细": "",
                "竞品信息": "",
                "客户是否有意向": "",
                "客户是否有卡点": "",
                "客户购买阶段": "",
                "下一步跟进计划-参与人": [],
                "下一步跟进计划-时间点": "",
                "下一步跟进计划-具体事项": ""
            }
submit_data = []
for id,line_data in tqdm(enumerate(df_test['input'])):
    # print(line_data)
    content = line_data
    text = []
    prompt = json.dumps(content,ensure_ascii=False)
    
    # print(json.dumps(content,ensure_ascii=False))
    res = core_run(text,prompt)
    try:
        data_dict = json.loads(res)
    except json.JSONDecodeError as e:
        data_dict = data_dict_empty
    submit_data.append({"infos":data_dict,"index":id+1})
# 预计执行8min

这个地方 fork 过来的代码有一点小瑕疵,最后一句里面应该改成这样:submit_data.append({"infos":[data_dict],"index":id+1}),相信聪明的你一定看出来这是为什么了,比赛要求的格式就是这样的

[
    {
        "infos": [],
        "index": 1
    },
    {
        "infos": [],
        "index": 2
    },
    {
        "infos": [],
        "index": 3
    },
    
    ...
    
    {
        "infos": [],
        "index": 55
    },
]

如果不改的话,最后提交的 json 格式不对,评分就卡在那里没有结果了。

和群内助教反馈了一下,可能只有我这边有这个问题?他不需要修改这里也能正常评分得到结果。

image.png

反正如果也有人遇到评分半天没结果的情况,可以参考我这么做。

4. 提交文件

快去提交结果吧

challenge.xfyun.cn/h5/detail?t…

大概是这样的分数,这只是一个很简单的版本。

image.png