DataWhale AI夏令营:Baseline调优进阶

65 阅读6分钟

1.整体思路:

整体思路首先把原始数据洗干净,把用户输入的的时间格式统一,把空的地方填上"无数据",这样AI就不会被脏数据污染,给出奇怪的解释。然后解决API调用的问题。原来的代码就像一个人排队买咖啡,现在改成8个人同时去买,谁先买到谁先走。还加了缓存机制,就像记住咖啡店的位置,下次直接去,不用重新找路。如果网络卡了,会自动重试几次,不会因为一次失败就放弃。接着优化AI的"工作指南"(其实也就是promptprompt),明确告诉它数据里有什么、没什么,让它生成的问题更合理,不会问一些数据里根本没有的信息。最后加了好几道质检关卡,确保AI生成的内容靠谱,格式统一,不会给出不存在的信息

整个方案的核心就是:多线程提速 + 缓存省时间 + 重试保稳定 + 验证保质量,让整个流程又快又稳又准。

2.关键节点的优化思路:

以下我们会对一些关键的功能实现代码段进行解析,我们观察以下我们给的表格字段主要问题在于部分值缺失,这一部分缺失的值有可能导致逻辑错误,故此我们编写代码来规范化数据,同时移除无用的序号列

image.png

2.1. 数据预处理优化

策略:清洗脏数据,统一格式,确保数据一致性

def standardize_time(time_str):

    """时间标准化:统一处理不同格式的时间"""

    time_str = str(time_str)

    if pd.isna(time_str) or '无数据' in time_str:

        return '无数据'

    match = re.search(r'(\d{1,2}:\d{2})', time_str)

    if match:

        return match.group(1)

    return time_str

def preprocess_dataframe(df):

    """数据清洗和一致性校验"""

    df = df.fillna('无数据')  # 统一缺失值

    for col in ['到点''开点']:

        if col in df.columns:

            df[col] = df[col].apply(standardize_time)  # 时间标准化

    

    # 逻辑一致性校验

    for idx, row in df.iterrows():

        try:

            if row['到点'] != '无数据' and row['开点'] != '无数据':

                start = pd.to_datetime(row['到点'], format='%H:%M')

                end = pd.to_datetime(row['开点'], format='%H:%M')

                if (end - start).total_seconds() < 0:  # 检查时间逻辑

                    print(f"警告: 第{idx}行时间不合理")

        except (ValueError, TypeError):

            print(f"警告: 第{idx}行时间格式无法解析")

    

    if '序号' in df.columns:

        df = df.drop(columns=['序号'])  # 移除无用列

    return df

2.2 API缓存机制

在之前的baseline中,训练时间较长,故此我们在这里采用了API缓存机制来让训练速度更快一些,这里用了MD5来保证键的唯一性


class APICache:
    def __init__(self, cache_dir='api_cache'):
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)
    
    def get_cache_key(self, text):
        """使用MD5生成唯一缓存键"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()
    
    def get(self, text):
        """从缓存获取结果"""
        cache_file = os.path.join(self.cache_dir, self.get_cache_key(text))
        if os.path.exists(cache_file):
            with open(cache_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        return None
    
    def set(self, text, result):
        """保存结果到缓存"""
        cache_file = os.path.join(self.cache_dir, self.get_cache_key(text))
        with open(cache_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False)

2.3 优化重试机制

在之前的baseline中访问API的时候往往会出现很多错误,其中主要可以分为两类

  1. 永久性错误(400、401)
  2. 临时性错误(429、503)

前者应该立即抛出Bug,后者可以尝试再次调用API

同时此处也采用指数退避策略,避免对服务器造成压力,外加上随机抖动防止多线程同时重试

def _call_llm_with_retry(self, data_text, prompt):
    """智能重试:区分错误类型,指数退避"""
    combined_input = prompt + data_text
    cached_result = self.cache.get(combined_input)
    if cached_result:
        return cached_result  # 缓存命中,直接返回
    
    for attempt in range(self.max_retries):
        try:
            result = self._call_llm(data_text, prompt)
            if result:
                self.cache.set(combined_input, result)
            return result
        except requests.exceptions.HTTPError as e:
            # 区分永久性错误和临时性错误
            if 400 <= e.response.status_code < 500 and e.response.status_code != 429:
                print(f"客户端错误 (状态码 {e.response.status_code}),停止重试。")
                raise e  
            if attempt == self.max_retries - 1:
                raise e
        except Exception as e:
            if attempt == self.max_retries - 1:
                raise e
        
        # 指数退避 + 随机抖动
        wait_time = (2 ** attempt) + random.uniform(0, 1)
        print(f"API调用失败,将在 {wait_time:.1f} 秒后重试...")
        time.sleep(wait_time)
    return None

2.4 多线程:

这个属于比较经典的操作了,多线程调用API,多线程收集LLM的回答数据

def generate(self, df, prompt):
    """多线程并行处理,最大化CPU利用率"""
    all_results, error_logs = [], []
    tasks = list(df.iterrows())  # 准备所有任务
    
    with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
        # 一次性提交所有任务
        future_to_idx = {executor.submit(self._process_row, task, prompt): task[0] for task in tasks}
        
        # 异步收集结果:谁先完成先处理谁
        for future in tqdm(as_completed(future_to_idx), total=len(tasks), desc="生成QA数据"):
            try:
                result = future.result()
                if result:
                    all_results.extend(result)
            except Exception as e:
                error_logs.append(str(e))
    
    return all_results, error_logs
    

2.4 数据质量验证

AI生成问答对的过程中,我们经常遇到格式不一致(AI可能返回不同的key名称如q/question/问题)、内容质量差(生成的问题或答案可能过短、格式错误)、输出不规范(包含客套话、时间格式不统一、标点符号混乱)以及事实错误(AI可能生成与原始数据不符的内容)等问题,因此需要通过数据质量验证来确保生成的数据的质量

本次的数据质量验证的主要流程如下图所示

image.png

第一层:格式验证

兼容性处理
@staticmethod
def _validate_and_enrich_qa(qa_list, original_row_data):
    """多层验证:确保生成数据质量"""
    valid_qa = []
    for qa in qa_list:
        # 兼容多种key格式
        question = qa.get('instruction') or qa.get('q') or qa.get('question') or qa.get('问题') or qa.get('问', '')
        answer = qa.get('output') or qa.get('a') or qa.get('answer') or qa.get('答案') or qa.get('答', '')
        question = question.strip()
        answer = answer.strip()

设计思路

  • 使用 or 链式获取,确保兼容各种API返回格式
  • 处理不同AI模型可能使用的不同key名称
  • 去除首尾空白字符
基础格式验证
# 基础验证
if not question or not answer:
    continue
if len(question) < 5 or not question.endswith('?'):
    continue
if len(answer) < 2:
    continue

# 通过验证的QA对
valid_qa.append({'instruction': question, 'output': answer})

验证规则

  • 确保问题和答案都不为空
  • 问题长度至少5个字符且以问号结尾
  • 答案长度至少2个字符

第二层:后处理标准化

去除客套话
@staticmethod
def _post_process_qa(qa_pair):
    """后处理:标准化格式"""
    question = qa_pair.get('instruction', '').strip()
    answer = qa_pair.get('output', '').strip()
    
    # 去除客套话
    answer = re.sub(r'^(好的,|根据我们提供的信息,|查询结果如下:|当然,)', '', answer).strip()

解决的问题

  • AI模型经常在答案开头添加客套话
  • 这些客套话对实际内容没有价值
  • 通过正则表达式精确去除
时间格式标准化
# 统一时间格式
answer = re.sub(r'(\d{1,2})\s*点\s*(\d{2})分?', r'\1:\2', answer)

转换示例

  • 8点30分8:30
  • 09点15分09:15
  • 10点10:
车次格式标准化
# 统一车次格式
answer = re.sub(r'([A-Z]?\d+/?\d*)\s*(号?|次)列车', r'\1次', answer)

转换示例

  • K420号列车K420次
  • Z152次列车Z152次
  • T123号T123次
标点符号统一
# 统一标点
answer = answer.replace(':', ':').replace(',', ', ')

转换示例

  • 始发站:北京西始发站:北京西
  • 到点,开点到点, 开点

3 结果分数比较

可以看到返回的分数明显高了

image.png