别凭感觉写 Prompt!如何构建 LLM 文本分类的“本地评测跑分”系统?

2 阅读4分钟

📌 痛点:改了一版 Prompt,怎么量化效果?

很多 LLM 开发者改完 Prompt 后,只凭感觉说“好像准了一点”。但没有量化指标,优化就是盲人摸象

专业的做法是用一批人工标注的数据(Ground Truth)来评估模型输出,计算准确率、精确率、召回率、F1 分数。

但问题来了:sklearn 安装包体积大,有时候线上环境不方便装。其实,这几个指标的计算逻辑非常简单,完全可以纯手写

这篇文章就带你手撕分类评估指标,并绘制漂亮的混淆矩阵热力图。

📊 第一步:准备 Ground Truth

Ground Truth 是一批人工标注的“标准答案”。我的做法是随机抽 500 条数据,自己标注好正确的需求层次,保存为 ground_truth.csv

微博原文需求层次
又延误了,上班要迟到了基础层
工作人员帮我找回了钱包尊重层
车厢空调太冷了舒适层

🧮 第二步:手写混淆矩阵

混淆矩阵是评估的基础。我们不需要 sklearn.metrics.confusion_matrix,几行代码就能搞定。

def evaluate_classification(true_labels, predicted_labels):
    # 获取所有类别
    unique_labels = sorted(set(true_labels + predicted_labels))
    num_classes = len(unique_labels)
    label_to_idx = {label: i for i, label in enumerate(unique_labels)}
    
    # 初始化混淆矩阵
    confusion_matrix = [[0] * num_classes for _ in range(num_classes)]
    
    # 统计
    for true, pred in zip(true_labels, predicted_labels):
        true_idx = label_to_idx[true]
        pred_idx = label_to_idx[pred]
        confusion_matrix[true_idx][pred_idx] += 1
    
    return confusion_matrix, unique_labels

核心逻辑

  • label_to_idx 把文字标签映射成数字索引。
  • confusion_matrix[true_idx][pred_idx] 就是真实类别为 true、预测为 pred 的样本数。

📐 第三步:手写 Precision、Recall、F1

有了混淆矩阵,我们可以进一步计算每个类别的 TP、FP、FN。

# 初始化统计
class_stats = {label: {'tp': 0, 'fp': 0, 'fn': 0} for label in unique_labels}

for true, pred in zip(true_labels, predicted_labels):
    if true == pred:
        class_stats[true]['tp'] += 1
    else:
        class_stats[true]['fn'] += 1
        class_stats[pred]['fp'] += 1

# 计算指标
for label in unique_labels:
    tp = class_stats[label]['tp']
    fp = class_stats[label]['fp']
    fn = class_stats[label]['fn']
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"{label}: P={precision:.3f}, R={recall:.3f}, F1={f1:.3f}")

关键点

  • 分母可能为 0,需要加 if 判断。
  • TP(真正例):预测对的正样本。
  • FP(假正例):把别的类预测成这个类。
  • FN(假负例):把这个类预测成别的类。

📈 第四步:绘制混淆矩阵热力图

数字表格不够直观,用 seaborn 画热力图,一眼看出哪些类别容易混淆。

import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['font.family'] = 'SimHei'  # 支持中文

plt.figure(figsize=(10, 8))
sns.heatmap(
    confusion_matrix,
    annot=True,           # 显示数字
    fmt='d',              # 整数格式
    cmap='Blues',         # 蓝色渐变
    xticklabels=unique_labels,
    yticklabels=unique_labels
)
plt.title('混淆矩阵 - 真实 vs 预测')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300)

效果

  • 对角线越亮,分类越准。
  • 非对角线亮块表示混淆严重。比如“保障层”容易被误判为“舒适层”,就需要回头优化这两个层次的 Prompt 定义。

📊 完整代码整合

def evaluate_classification(true_labels, predicted_labels):
    unique_labels = sorted(set(true_labels + predicted_labels))
    
    # 混淆矩阵
    label_to_idx = {label: i for i, label in enumerate(unique_labels)}
    cm = [[0] * len(unique_labels) for _ in range(len(unique_labels))]
    for t, p in zip(true_labels, predicted_labels):
        cm[label_to_idx[t]][label_to_idx[p]] += 1
    
    # 各类指标
    class_stats = {label: {'tp': 0, 'fp': 0, 'fn': 0} for label in unique_labels}
    for t, p in zip(true_labels, predicted_labels):
        if t == p:
            class_stats[t]['tp'] += 1
        else:
            class_stats[t]['fn'] += 1
            class_stats[p]['fp'] += 1
    
    print(f"{'类别':<10}{'精确率':<10}{'召回率':<10}{'F1分数':<10}")
    for label in unique_labels:
        tp, fp, fn = class_stats[label]['tp'], class_stats[label]['fp'], class_stats[label]['fn']
        p = tp / (tp + fp) if tp + fp else 0
        r = tp / (tp + fn) if tp + fn else 0
        f1 = 2 * p * r / (p + r) if p + r else 0
        print(f"{label:<10}{p:<10.3f}{r:<10.3f}{f1:<10.3f}")
    
    accuracy = sum(class_stats[l]['tp'] for l in unique_labels) / len(true_labels)
    print(f"\n总准确率: {accuracy:.4f}")
    
    return cm, unique_labels

💡 总结:手写评估指标的好处

  1. 轻量:不依赖 sklearn,部署环境更简单。
  2. 可定制:可以随时加入自己需要的指标(比如 Macro/Micro F1)。
  3. 理解原理:手写一遍,你对 Precision/Recall 的理解会更深刻。

🔗 完整代码

👉 nanjing-metro-analysis/scripts/01_demand_classification/classify_demand_deepseek.py