StructBERT FAQ问答-中文-通用领域-base大模型训练笔记

41 阅读3分钟

一、训练目的

FAQ问答是智能对话系统(特别是垂直领域对话系统)的核心业务场景,业务专家基于经验或数据挖掘的结果,将用户会频繁问到的业务知识以Q&A的形式维护起来,称之为知识库,当用户使用对话系统时,提问一个业务方面的问题,机器自动从知识库中找到最合适的回答。

本次训练目的使用【StructBERT FAQ问答】模型来理解用户输入意图,将用户的问题转发到不同领域的智能体,由相应的智能体进行问题响应处理。

二、训练环境

本次训练从魔搭社区开始,找到【StructBERT FAQ问答-中文-通用领域-base】模型,使用魔搭社区和阿里云合作提供的免费训练环境进行相应的训练,使用从网上下载的保险业务测试数据,编写脚本进行训练测试。

三、训练步骤

3.1 启动训练环境

1. 在魔搭社区的模型库中搜索找到【StructBERT FAQ问答-中文-通用领域-base】模型,进入该模型主页,可以了解该模型的简单介绍和应用场景;

2. 在模型主页的右上角点击【Notebook快速开发】,进入魔搭平台免费提供的模型运行实例;

image.png

3. 先选择【阿里云弹性加速计算EAIS】,再选择GPU环境并启动,即可启动训练环境;

image.png

4. 点击【查看Notebook】进入Jupyter开发界面,接下来可以编写开发训练脚本了。

image.png

3.2 基于保险业务数据进行训练调整

从网上下载的保险业务数据文件总共包含8000+条数据,格式如下图所示:

image.png

如果需要数据文件,可以私我,谢谢!

下面开始训练脚本编写,如下:

(1)加载数据

from dataclasses import dataclass 
from typing import List 
import pandas as pd 
import json 
@dataclass 
class FAQ: 
    title: str 
    sim_questions: List[str] 
    answer: str 
    faq_id: int 
# 数据量太大,下一步骤向量化操作需要很长时间,所以这里只取前200条。 
ori_data = pd.read_csv('baoxianzhidao_filter.csv', nrows=200) 
data = [] 
exist_titles = set() 
for index, row in enumerate(ori_data.iterrows()): 
    row_dict = row[1] 
    title = row_dict['title'] 
    if title not in exist_titles: 
        data.append(FAQ(title=title, answer=row_dict['reply'], sim_questions=[title], faq_id=index)) 
    exist_titles.add(title)

(2)向量化

from modelscope.pipelines import pipeline 
from modelscope.utils.constant import Tasks 
pipeline_ins = pipeline(Tasks.faq_question_answering, 'damo/nlp_structbert_faq-question-answering_chinese-base') 
bsz = 32 
all_sentence_vecs = [] batch = [] sentence_list = [faq.title for faq in data] 
for i,sent in enumerate(sentence_list): 
    batch.append(sent) 
    if len(batch) == bsz or (i == len(sentence_list)-1 and len(batch)>0): 
        sentence_vecs = pipeline_ins.get_sentence_embedding(batch)    
        all_sentence_vecs.extend(sentence_vecs) 
        batch.clear()

(3)构建向量索引

# 提前安装好faiss依赖,pip install faiss-gpu 
import faiss 
import numpy as np 
#说明:v1.3版本之后,请使用 hidden_size = pipeline_ins.model.network.bert.config.hidden_size 
hidden_size = pipeline_ins.model.network.bert.config.hidden_size 
index = faiss.IndexFlatIP(hidden_size) 
vecs = np.asarray(all_sentence_vecs, dtype='float32') 
index.add(vecs)

(4)FAQ函数封装

image.png

四、训练结果

调用上述封装好的函数,就可以获得问题回复,评估训练结果。

执行代码:ask_faq(["安邦长青树怎么样",""], [])

image.png

返回结果包含3条最优答案,按照得分从高到低输出,可以根据业务需求获取最优前N条答案。