用户说 AI 推理太慢了,但我不知道慢在哪,直到…

63 阅读6分钟

image.png

聊聊 AI 后端那些事儿 · 第 4 篇(完结篇) | 预估阅读:8 分钟


用户说慢,但慢在哪?

周五下班前,产品经理发来消息:

"用户反馈生成图片太慢了,能优化一下吗?"

小禾回复:"图片生成本来就要时间啊,AI 推理需要十几秒。"

产品经理:"不是,用户说点完按钮要等很久才开始转圈。"

小禾愣了一下。

点完按钮到开始转圈,那是接口响应时间。

但接口响应慢,慢在哪?是参数处理慢?模型推理慢?还是文件保存慢?

小禾打开代码,一脸茫然。

没有任何性能数据,只能猜。


最简方案:一个响应头

小禾决定先加个计时,看看整体耗时:

import time
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    """记录请求处理时间"""
    start_time = time.time()

    response = await call_next(request)

    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = f"{process_time:.4f}"

    return response

就这么几行代码。

现在每个响应头里都有处理时间:

$ curl -I http://localhost:8000/api/generate/shot-image

HTTP/1.1 200 OK
content-type: application/json
x-process-time: 12.3456

12.3 秒,确实有点久。

但这只是总时间,还是不知道慢在哪。


进阶:详细的请求日志

小禾完善了日志中间件:

import time
import uuid
from fastapi import Request
from app.core.logger import logger

@app.middleware("http")
async def log_requests(request: Request, call_next):
    """详细的请求日志"""

    # 生成请求 ID
    request_id = str(uuid.uuid4())[:8]

    # 记录请求开始
    logger.info(
        f"[{request_id}] --> {request.method} {request.url.path}",
        extra={
            "request_id": request_id,
            "method": request.method,
            "path": request.url.path,
            "client": request.client.host if request.client else "unknown"
        }
    )

    start_time = time.time()

    # 处理请求
    response = await call_next(request)

    # 计算耗时
    process_time = time.time() - start_time

    # 超过 5 秒用警告级别
    log_func = logger.warning if process_time > 5 else logger.info
    log_func(
        f"[{request_id}] <-- {response.status_code} in {process_time:.2f}s",
        extra={
            "request_id": request_id,
            "status_code": response.status_code,
            "process_time": process_time
        }
    )

    # 添加响应头
    response.headers["X-Request-ID"] = request_id
    response.headers["X-Process-Time"] = f"{process_time:.4f}"

    return response

日志输出:

2024-01-15 14:30:00 | INFO    | [a1b2c3d4] --> POST /api/generate/shot-image
2024-01-15 14:30:12 | WARNING | [a1b2c3d4] <-- 200 in 12.34s
2024-01-15 14:30:12 | INFO    | [e5f6g7h8] --> GET /api/stories
2024-01-15 14:30:12 | INFO    | [e5f6g7h8] <-- 200 in 0.05s

慢请求一眼就能看出来。

但还是不知道这 12 秒花在哪了。


分段计时:找到真正的瓶颈

小禾写了个分段计时器:

# app/core/timer.py
import time
from contextlib import contextmanager
from app.core.logger import logger

class Timer:
    """分段计时器"""

    def __init__(self, name: str):
        self.name = name
        self.segments = []
        self.start_time = time.time()

    @contextmanager
    def segment(self, segment_name: str):
        """记录一个分段的耗时"""
        segment_start = time.time()
        yield
        segment_time = time.time() - segment_start
        self.segments.append((segment_name, segment_time))

    def report(self):
        """输出计时报告"""
        total = time.time() - self.start_time

        lines = [f"[{self.name}] 总耗时: {total:.2f}s"]
        for name, duration in self.segments:
            percentage = (duration / total) * 100 if total > 0 else 0
            lines.append(f"  - {name}: {duration:.2f}s ({percentage:.1f}%)")

        logger.info("\n".join(lines))

在接口里使用:

@router.post("/shot-image")
async def generate_shot_image(request: GenerateShotImageRequest):
    timer = Timer("generate_shot_image")

    with timer.segment("参数处理"):
        processed_prompt = preprocess_prompt(request.prompt)
        params = build_generation_params(request)

    with timer.segment("模型推理"):
        result = await model.generate(
            prompt=processed_prompt,
            **params
        )

    with timer.segment("保存文件"):
        image_path = save_image(result.image)
        thumbnail_path = create_thumbnail(image_path)

    with timer.segment("构建响应"):
        response = build_response(image_path, thumbnail_path, result)

    timer.report()
    return response

现在日志输出:

[generate_shot_image] 总耗时: 12.50s
  - 参数处理: 0.02s (0.2%)
  - 模型推理: 11.80s (94.4%)
  - 保存文件: 0.50s (4.0%)
  - 构建响应: 0.18s (1.4%)

真相大白:94% 的时间都花在模型推理上

这不是代码的问题,是 AI 就是这么慢。

小禾把这个数据发给产品经理:"模型推理本身要 12 秒,这个优化不了。要不我们加个进度提示?"


慢请求告警

小禾还加了个自动告警机制:

SLOW_THRESHOLD = 10  # 秒

@app.middleware("http")
async def slow_request_alert(request: Request, call_next):
    start = time.time()
    response = await call_next(request)
    duration = time.time() - start

    if duration > SLOW_THRESHOLD:
        logger.warning(
            f"慢请求告警: {request.method} {request.url.path} "
            f"耗时 {duration:.2f}s (阈值: {SLOW_THRESHOLD}s)",
            extra={
                "alert_type": "slow_request",
                "path": request.url.path,
                "duration": duration
            }
        )

    return response

超过阈值就告警,不用盯着日志看了。


简单的统计分析

小禾还做了个简单的性能统计:

from collections import defaultdict

# 请求统计
request_stats = defaultdict(list)

@app.middleware("http")
async def collect_stats(request: Request, call_next):
    start = time.time()
    response = await call_next(request)
    duration = time.time() - start

    # 收集统计
    path = request.url.path
    request_stats[path].append(duration)

    # 只保留最近 1000 条
    if len(request_stats[path]) > 1000:
        request_stats[path] = request_stats[path][-1000:]

    return response

@app.get("/metrics")
async def get_metrics():
    """返回性能统计"""
    metrics = {}
    for path, times in request_stats.items():
        if times:
            metrics[path] = {
                "count": len(times),
                "avg": round(sum(times) / len(times), 3),
                "max": round(max(times), 3),
                "min": round(min(times), 3),
                "p95": round(sorted(times)[int(len(times) * 0.95)], 3)
            }
    return metrics

请求 /metrics

{
    "/api/generate/shot-image": {
        "count": 156,
        "avg": 11.234,
        "max": 25.678,
        "min": 8.901,
        "p95": 15.432
    },
    "/api/stories": {
        "count": 1000,
        "avg": 0.045,
        "max": 0.234,
        "min": 0.012,
        "p95": 0.089
    }
}

哪个接口慢,一目了然。


性能追踪层级图

flowchart TB
    A[用户感知慢] --> B{哪里慢?}

    B --> C[网络层]
    B --> D[后端处理]
    B --> E[前端渲染]

    D --> F[X-Process-Time<br/>快速判断是不是后端问题]

    F --> G{后端慢?}
    G -->|是| H[分段计时<br/>找到具体瓶颈]
    G -->|否| I[排查网络或前端]

    H --> J[参数处理]
    H --> K[业务逻辑]
    H --> L[数据库查询]
    H --> M[外部调用]
    H --> N[文件IO]

分析问题的思路

小禾总结了一套分析流程:

  1. 用户说慢 -> 先看 X-Process-Time,确定是不是后端问题
  2. 后端慢 -> 看请求日志,找到具体慢的接口
  3. 找到接口 -> 加分段计时,看慢在哪个环节
  4. 找到环节 -> 针对性优化

大部分情况下:

慢在哪可能的原因优化方向
参数处理复杂的数据转换简化逻辑、缓存
数据库查询没有索引、N+1 问题加索引、批量查询
外部调用第三方服务慢异步化、超时控制
模型推理AI 本身就慢换更快的模型、量化
文件 IO大文件读写异步IO、对象存储

一行代码版本

如果你只想加一行代码,用这个:

@app.middleware("http")
async def timing(request: Request, call_next):
    start = time.time()
    response = await call_next(request)
    response.headers["X-Process-Time"] = f"{time.time() - start:.4f}"
    return response

就这一行(好吧,严格来说是五行),能解决 80% 的性能排查问题。


性能追踪清单

层级方案用途
基础X-Process-Time 响应头快速判断是前端还是后端问题
进阶请求日志追踪每个请求的处理情况
深入分段计时找到具体的性能瓶颈
监控慢请求告警及时发现性能问题
分析统计接口长期性能趋势分析

小禾的感悟

用户说慢,
你说优化,
但不知道慢在哪。

这就是盲人摸象。

加一个响应头,
就知道是前端还是后端。

加一段日志,
就知道是哪个接口。

加一个计时器,
就知道是哪个环节。

性能优化不是玄学,
是数据驱动的工程。

没有数据,
一切都是猜。

有了数据,
问题自己会说话。

一行代码,
价值千金。

小禾把性能追踪加到了所有服务里。

从此,再也不用对着"太慢了"的反馈一脸茫然。


合集完结撒花

四篇文章,我们一起经历了后端开发的各种挑战:

  1. 生命周期管理:让 GPU 显存不再泄漏
  2. 异常处理:让每个错误都有迹可循
  3. 数据验证:让接口固若金汤
  4. 性能追踪:让慢请求无处遁形

这些都是小禾在实战中踩过的坑,也是每个后端开发者早晚会遇到的问题。

希望这些经验能帮你少走弯路。

下一个合集,我们聊聊和 AI 模型打交道的那些事儿。

敬请期待。


📌 我在公众号「程序员义拉冠」持续分享 AI 应用开发实战。点此 获取本系列合集