1.整体思路:
整体思路首先把原始数据洗干净,把用户输入的的时间格式统一,把空的地方填上"无数据",这样AI就不会被脏数据污染,给出奇怪的解释。然后解决API调用的问题。原来的代码就像一个人排队买咖啡,现在改成8个人同时去买,谁先买到谁先走。还加了缓存机制,就像记住咖啡店的位置,下次直接去,不用重新找路。如果网络卡了,会自动重试几次,不会因为一次失败就放弃。接着优化AI的"工作指南"(其实也就是promptprompt),明确告诉它数据里有什么、没什么,让它生成的问题更合理,不会问一些数据里根本没有的信息。最后加了好几道质检关卡,确保AI生成的内容靠谱,格式统一,不会给出不存在的信息
整个方案的核心就是:多线程提速 + 缓存省时间 + 重试保稳定 + 验证保质量,让整个流程又快又稳又准。
2.关键节点的优化思路:
以下我们会对一些关键的功能实现代码段进行解析,我们观察以下我们给的表格字段主要问题在于部分值缺失,这一部分缺失的值有可能导致逻辑错误,故此我们编写代码来规范化数据,同时移除无用的序号列
2.1. 数据预处理优化
策略:清洗脏数据,统一格式,确保数据一致性
def standardize_time(time_str):
"""时间标准化:统一处理不同格式的时间"""
time_str = str(time_str)
if pd.isna(time_str) or '无数据' in time_str:
return '无数据'
match = re.search(r'(\d{1,2}:\d{2})', time_str)
if match:
return match.group(1)
return time_str
def preprocess_dataframe(df):
"""数据清洗和一致性校验"""
df = df.fillna('无数据') # 统一缺失值
for col in ['到点', '开点']:
if col in df.columns:
df[col] = df[col].apply(standardize_time) # 时间标准化
# 逻辑一致性校验
for idx, row in df.iterrows():
try:
if row['到点'] != '无数据' and row['开点'] != '无数据':
start = pd.to_datetime(row['到点'], format='%H:%M')
end = pd.to_datetime(row['开点'], format='%H:%M')
if (end - start).total_seconds() < 0: # 检查时间逻辑
print(f"警告: 第{idx}行时间不合理")
except (ValueError, TypeError):
print(f"警告: 第{idx}行时间格式无法解析")
if '序号' in df.columns:
df = df.drop(columns=['序号']) # 移除无用列
return df
2.2 API缓存机制
在之前的baseline中,训练时间较长,故此我们在这里采用了API缓存机制来让训练速度更快一些,这里用了MD5来保证键的唯一性
class APICache:
def __init__(self, cache_dir='api_cache'):
self.cache_dir = cache_dir
os.makedirs(self.cache_dir, exist_ok=True)
def get_cache_key(self, text):
"""使用MD5生成唯一缓存键"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def get(self, text):
"""从缓存获取结果"""
cache_file = os.path.join(self.cache_dir, self.get_cache_key(text))
if os.path.exists(cache_file):
with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f)
return None
def set(self, text, result):
"""保存结果到缓存"""
cache_file = os.path.join(self.cache_dir, self.get_cache_key(text))
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False)
2.3 优化重试机制
在之前的baseline中访问API的时候往往会出现很多错误,其中主要可以分为两类
- 永久性错误(400、401)
- 临时性错误(429、503)
前者应该立即抛出Bug,后者可以尝试再次调用API
同时此处也采用指数退避策略,避免对服务器造成压力,外加上随机抖动防止多线程同时重试
def _call_llm_with_retry(self, data_text, prompt):
"""智能重试:区分错误类型,指数退避"""
combined_input = prompt + data_text
cached_result = self.cache.get(combined_input)
if cached_result:
return cached_result # 缓存命中,直接返回
for attempt in range(self.max_retries):
try:
result = self._call_llm(data_text, prompt)
if result:
self.cache.set(combined_input, result)
return result
except requests.exceptions.HTTPError as e:
# 区分永久性错误和临时性错误
if 400 <= e.response.status_code < 500 and e.response.status_code != 429:
print(f"客户端错误 (状态码 {e.response.status_code}),停止重试。")
raise e
if attempt == self.max_retries - 1:
raise e
except Exception as e:
if attempt == self.max_retries - 1:
raise e
# 指数退避 + 随机抖动
wait_time = (2 ** attempt) + random.uniform(0, 1)
print(f"API调用失败,将在 {wait_time:.1f} 秒后重试...")
time.sleep(wait_time)
return None
2.4 多线程:
这个属于比较经典的操作了,多线程调用API,多线程收集LLM的回答数据
def generate(self, df, prompt):
"""多线程并行处理,最大化CPU利用率"""
all_results, error_logs = [], []
tasks = list(df.iterrows()) # 准备所有任务
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
# 一次性提交所有任务
future_to_idx = {executor.submit(self._process_row, task, prompt): task[0] for task in tasks}
# 异步收集结果:谁先完成先处理谁
for future in tqdm(as_completed(future_to_idx), total=len(tasks), desc="生成QA数据"):
try:
result = future.result()
if result:
all_results.extend(result)
except Exception as e:
error_logs.append(str(e))
return all_results, error_logs
2.4 数据质量验证
AI生成问答对的过程中,我们经常遇到格式不一致(AI可能返回不同的key名称如q/question/问题)、内容质量差(生成的问题或答案可能过短、格式错误)、输出不规范(包含客套话、时间格式不统一、标点符号混乱)以及事实错误(AI可能生成与原始数据不符的内容)等问题,因此需要通过数据质量验证来确保生成的数据的质量
本次的数据质量验证的主要流程如下图所示
第一层:格式验证
兼容性处理
@staticmethod
def _validate_and_enrich_qa(qa_list, original_row_data):
"""多层验证:确保生成数据质量"""
valid_qa = []
for qa in qa_list:
# 兼容多种key格式
question = qa.get('instruction') or qa.get('q') or qa.get('question') or qa.get('问题') or qa.get('问', '')
answer = qa.get('output') or qa.get('a') or qa.get('answer') or qa.get('答案') or qa.get('答', '')
question = question.strip()
answer = answer.strip()
设计思路:
- 使用
or链式获取,确保兼容各种API返回格式 - 处理不同AI模型可能使用的不同key名称
- 去除首尾空白字符
基础格式验证
# 基础验证
if not question or not answer:
continue
if len(question) < 5 or not question.endswith('?'):
continue
if len(answer) < 2:
continue
# 通过验证的QA对
valid_qa.append({'instruction': question, 'output': answer})
验证规则:
- 确保问题和答案都不为空
- 问题长度至少5个字符且以问号结尾
- 答案长度至少2个字符
第二层:后处理标准化
去除客套话
@staticmethod
def _post_process_qa(qa_pair):
"""后处理:标准化格式"""
question = qa_pair.get('instruction', '').strip()
answer = qa_pair.get('output', '').strip()
# 去除客套话
answer = re.sub(r'^(好的,|根据我们提供的信息,|查询结果如下:|当然,)', '', answer).strip()
解决的问题:
- AI模型经常在答案开头添加客套话
- 这些客套话对实际内容没有价值
- 通过正则表达式精确去除
时间格式标准化
# 统一时间格式
answer = re.sub(r'(\d{1,2})\s*点\s*(\d{2})分?', r'\1:\2', answer)
转换示例:
8点30分→8:3009点15分→09:1510点→10:
车次格式标准化
# 统一车次格式
answer = re.sub(r'([A-Z]?\d+/?\d*)\s*(号?|次)列车', r'\1次', answer)
转换示例:
K420号列车→K420次Z152次列车→Z152次T123号→T123次
标点符号统一
# 统一标点
answer = answer.replace(':', ':').replace(',', ', ')
转换示例:
始发站:北京西→始发站:北京西到点,开点→到点, 开点
3 结果分数比较
可以看到返回的分数明显高了