Datawhale AI夏令营:Task2:理解赛事重难点

106 阅读3分钟

基于AI的列车时刻表理解系统实现

项目背景

列车时刻表包含大量信息,如车次、检票口、站台、到发时间等。本项目旨在利用AI技术自动解析这些信息并生成问答对,帮助用户快速获取所需的列车信息。

技术栈

  • Python 3.9
  • Pandas:数据处理
  • Requests:API调用
  • Regular Expressions:文本提取
  • Qwen/Qwen3-8B:大语言模型

实现步骤

1. 数据读取与预处理

首先读取Excel格式的列车数据,并进行缺失值填充:

import pandas as pd

# 读取数据
data = pd.read_excel('data/info_table.xlsx')
data = data.fillna('无数据')

2. 大模型调用函数实现

实现调用Qwen3-8B模型的函数,用于生成问答对:

def call_llm(content: str):
    """
    调用大模型
    
    Args:
        content: 模型对话文本
    
    Returns:
        list: 问答对列表
    """
    # 调用大模型(硅基流动免费模型)
    url = "https://api.siliconflow.cn/v1/chat/completions"
    payload = {
        "model": "Qwen/Qwen3-8B",
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ]
    }
    headers = {
        "Authorization": "Bearer ###",  # 需要替换为实际API密钥
        "Content-Type": "application/json"
    }
    resp = requests.request("POST", url, json=payload, headers=headers).json()
    
    # 使用正则提取大模型返回的json
    content = resp['choices'][0]['message']['content'].split('</think>')[-1]
    pattern = re.compile(r'^```json\s*([\s\S]*?)```$', re.IGNORECASE)
    match = pattern.match(content.strip())
    if match:
        json_str = match.group(1).strip()
        return json.loads(json_str)
    else:
        return content

3. 问题列表生成

根据列车数据生成相关问题,如检票口、站台、目的地等:

def create_question_list(row: dict):
    """
    根据一行数据创建问题列表
    
    Args:
        row: 一行数据的字典形式
    
    Returns:
        list: 问题列表
    """
    question_list = []
    # 检票口
    question_list.append(f'{row["车次"]}号车次应该从哪个检票口检票?')
    # 站台
    question_list.append(f'{row["车次"]}号车次应该从哪个站台上车?')
    # 目的地
    question_list.append(f'{row["车次"]}次列车的终到站是哪里?')
    
    return question_list

4. 生成训练数据

遍历数据,调用大模型生成答案,并将结果保存为JSON格式:

prompt = '''你是列车的乘务员,请你基于给定的列车班次信息回答用户的问题。
# 列车班次信息
{}

# 用户问题列表
{}

'''
output_format = '''# 输出格式
按json格式输出,且只需要输出一个json即可
```json
[{
    "q": "用户问题",
    "a": "问题答案"
},
...
]
```'''

train_data_list = []
error_data_list = []
# 遍历数据
i = 1
for idx, row in tqdm(data.iterrows(), desc='遍历生成答案', total=len(data)):
    try:
        # 组装数据
        row = dict(row)
        row['到点'] = str(row['到点'])
        row['开点'] = str(row['开点'])
        # 创建问题对
        question_list = create_question_list(row)
        # 大模型生成答案
        llm_result = call_llm(prompt.format(row, question_list) + output_format)
        # 总结结果
        train_data_list += llm_result
    except:
        error_data_list.append(row)
        continue

# 转换训练集
data_list = []
for data in tqdm(train_data_list, total=len(train_data_list)):
    if isinstance(data, str):
        continue
    data_list.append({'instruction': data['q'], 'output': data['a']})

json.dump(data_list, open('single_row.json', 'w', encoding='utf-8'), ensure_ascii=False)

项目结构

├── data/
│   ├── info_table.csv
│   └── info_table.xlsx
├── train_data/
│   └── single_row.json
├── baseline.ipynb
└── baseline.py

注意事项

  1. 需替换代码中的API密钥为实际申请的密钥
  2. 目前仅实现了基础功能,可通过增加问题类型和优化Prompt来提升模型回答质量
  3. Datawhale活动链接(AI夏令营活动报名 - Datawhale)

总结

本项目实现了一个基于大语言模型的列车时刻表理解系统,能够自动生成问答对,为后续模型训练提供数据支持。通过优化问题生成和模型调用方式,可以进一步提升系统的实用性和准确性。