python Web开发从入门到精通(二十四)FastAPI性能优化实战 - 支撑千万级并发访问(上)

4 阅读1分钟

💡 掘金摘要钩子:你的FastAPI应用在高并发下开始"气喘吁吁"?数据库连接池频频告警?响应时间从50ms暴涨到2秒以上?别慌!本文基于真实生产案例,手把手带你从零搭建千万级并发架构,涵盖缓存策略、异步改造、连接池优化、负载均衡等全套实战技巧。学完这篇,让你的FastAPI应用性能提升10倍,轻松应对流量洪峰!

开场:一次大促后的"血腥"复盘

去年双十一,我负责的电商推荐API遭遇了一场"滑铁卢":当并发用户突破1000时,接口响应时间从50ms暴涨到2秒以上,数据库CPU直接飙到90%,最终导致超时率高达12%,直接影响千万级用户的购物体验。

经过彻夜排查,我们发现问题远比想象中复杂:不是简单的代码bug,而是整个架构在高并发下的系统性崩溃。经过2个月的深度优化,我们最终将QPS从98提升到1000+,响应时间稳定在100ms以内,同时节省了40%的服务器成本。

今天,我就把这个真实的踩坑经历和全套优化方案完整分享给你。无论你是正在为性能发愁的开发者,还是希望提前预防的架构师,这篇文章都将帮你构建一个真正能抗住千万级并发的FastAPI应用。

效果预览:优化前后的惊人对比

在深入技术细节之前,先看看我们最终实现的性能飞跃:

性能指标

优化前

优化后

提升幅度

最大QPS

98.3

1000+

1000%+

95%响应时间

1200ms

95ms

92%

数据库CPU使用率

90%+

30-40%

55-60%

连接池等待时间

经常超时

几乎为零

接近100%

服务器成本

$2000/月

$1200/月

40%节约

这样的提升不是魔法,而是系统化的架构优化带来的实实在在的好处。接下来,我就带你一步步实现这个转变。

第1部分:环境搭建与基础架构

1.1 项目初始化

首先,创建一个新的FastAPI项目,并安装核心依赖:

# 创建项目目录
mkdir fastapi-performance-demo
cd fastapi-performance-demo

# 创建虚拟环境
python -m venv venv
source venv/bin/activate  # Linux/Mac
# venv\Scripts\activate  # Windows

# 安装核心依赖
pip install fastapi uvicorn[standard] httpx redis aioredis databases sqlalchemy asyncpg aiomysql orjson

关键依赖说明

  • uvicorn[standard]:包含uvloop和httptools,性能提升20-40%
  • orjson:最快的JSON序列化库,比内置json快2-3倍
  • aioredis / redis:异步Redis客户端
  • asyncpg / aiomysql:异步数据库驱动
  • databases:异步数据库访问库

1.2 基础应用结构

创建应用主文件 app/main.py

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import ORJSONResponse
from contextlib import asynccontextmanager
import uvicorn
import logging
import os

# 导入自定义模块
from app.database import init_database, close_database
from app.cache import cache
from app.monitoring import monitor_request, update_system_metrics
from app.routers import products

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# 应用生命周期管理
@asynccontextmanager
async def lifespan(app: FastAPI):
    """管理应用启动和关闭时的资源"""
    # 启动时:初始化数据库连接池、Redis连接等
    logger.info("🚀 应用启动中...")
    
    try:
        await init_database()
        logger.info("✅ 数据库连接池初始化完成")
    except Exception as e:
        logger.error(f"❌ 数据库初始化失败: {e}")
        raise
    
    try:
        await cache.connect()
        logger.info("✅ Redis连接初始化完成")
    except Exception as e:
        logger.error(f"❌ Redis连接失败: {e}")
        # 继续启动,应用可能可以降级运行
    
    yield
    
    # 关闭时:清理资源
    logger.info("👋 应用关闭中...")
    
    try:
        await close_database()
        logger.info("✅ 数据库连接池已关闭")
    except Exception as e:
        logger.error(f"❌ 数据库关闭失败: {e}")
    
    try:
        await cache.disconnect()
        logger.info("✅ Redis连接已关闭")
    except Exception as e:
        logger.error(f"❌ Redis关闭失败: {e}")

# 创建FastAPI应用实例
app = FastAPI(
    title="FastAPI性能优化实战",
    description="千万级并发架构完整实现",
    version="1.0.0",
    lifespan=lifespan,
    default_response_class=ORJSONResponse  # 使用orjson加速序列化
)

# 添加请求监控中间件
@app.middleware("http")
async def monitoring_middleware(request, call_next):
    return await monitor_request(request, call_next)

# 添加必要的中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 生产环境应该限制具体的域名
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.add_middleware(GZipMiddleware, minimum_size=1000)

# 注册路由
app.include_router(products.router, prefix="/api/v1")

# 健康检查端点
@app.get("/health")
async def health_check():
    """健康检查端点,用于负载均衡器和监控系统"""
    return {"status": "healthy", "timestamp": __import__("datetime").datetime.now().isoformat()}

# 系统信息端点
@app.get("/system/info")
async def system_info():
    """系统信息端点,用于监控和调试"""
    import psutil
    import platform
    
    return {
        "python_version": platform.python_version(),
        "platform": platform.platform(),
        "cpu_count": psutil.cpu_count(),
        "memory_total": psutil.virtual_memory().total,
        "memory_used": psutil.virtual_memory().used,
        "disk_usage": psutil.disk_usage("/")._asdict(),
    }

# 示例根路由
@app.get("/")
async def root():
    return {
        "message": "欢迎来到FastAPI性能优化实战",
        "status": "running",
        "docs_url": "/docs",
        "openapi_url": "/openapi.json"
    }

# 性能测试端点
@app.get("/benchmark/test")
async def benchmark_test():
    """简单的性能测试端点,用于验证优化效果"""
    import time
    import asyncio
    
    start_time = time.time()
    
    # 模拟一些异步操作
    tasks = []
    for i in range(10):
        task = asyncio.create_task(asyncio.sleep(0.01))
        tasks.append(task)
    
    await asyncio.gather(*tasks)
    
    processing_time = time.time() - start_time
    
    return {
        "status": "success",
        "processing_time": processing_time,
        "requests_per_second": 10 / processing_time if processing_time > 0 else 0
    }

if __name__ == "__main__":
    # 从环境变量获取配置
    host = os.getenv("HOST", "0.0.0.0")
    port = int(os.getenv("PORT", "8000"))
    workers = int(os.getenv("WORKERS", 1))  # 默认单进程,适合开发环境
    reload_flag = os.getenv("RELOAD", "false").lower() == "true"
    
    # 生产环境建议使用Gunicorn启动,这里只是开发环境
    uvicorn.run(
        "app.main:app",
        host=host,
        port=port,
        workers=workers,
        reload=reload_flag,
        loop="uvloop",  # 使用uvloop加速
        http="httptools"  # 使用httptools解析HTTP
    )

第2部分:数据库连接池优化 - 解决连接风暴

2.1 同步SQLAlchemy的问题

很多开发者在使用FastAPI时,依然使用同步的SQLAlchemy,这在高并发下是致命的:

# ❌ 错误示例:同步SQLAlchemy在异步环境中
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

DATABASE_URL = "postgresql://user:pass@localhost/db"

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(bind=engine)

@app.get("/users/{user_id}")
async def get_user(user_id: int):
    db = SessionLocal()  # 同步调用会阻塞事件循环!
    user = db.query(User).filter(User.id == user_id).first()
    db.close()
    return user

问题分析:每个请求都创建新连接,同步操作阻塞事件循环,连接池迅速耗尽。

2.2 异步连接池解决方案

创建 app/database.py,实现高性能异步连接池:

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import AsyncAdaptedQueuePool
import os
import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager

# 配置日志
logger = logging.getLogger(__name__)

# 数据库配置
DATABASE_URL = os.getenv(
    "DATABASE_URL",
    "postgresql+asyncpg://user:password@localhost:5432/mydb"
)

# SQLAlchemy基础模型
Base = declarative_base()

# 创建异步数据库引擎(带高性能连接池)
engine = create_async_engine(
    DATABASE_URL,
    echo=os.getenv("SQL_ECHO", "false").lower() == "true",  # 生产环境关闭SQL日志
    poolclass=AsyncAdaptedQueuePool,  # 异步适配队列池
    pool_size=int(os.getenv("DB_POOL_SIZE", 20)),  # 连接池大小
    max_overflow=int(os.getenv("DB_MAX_OVERFLOW", 30)),  # 最大溢出连接数
    pool_timeout=int(os.getenv("DB_POOL_TIMEOUT", 30)),  # 获取连接超时时间(秒)
    pool_recycle=int(os.getenv("DB_POOL_RECYCLE", 1800)),  # 连接回收时间(秒),防止数据库断开空闲连接
    pool_pre_ping=True,  # 连接前ping检查,自动检测失效连接
    connect_args={
        "server_settings": {
            "jit": "off",  # 关闭JIT编译,避免查询计划缓存问题
        }
    }
)

# 异步会话工厂
AsyncSessionLocal = sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,  # 提交后不使实例过期,支持继续使用
    autocommit=False,
    autoflush=False,
)

# 依赖注入:获取数据库会话
async def get_db() -> AsyncGenerator[AsyncSession, None]:
    """获取数据库会话,自动管理资源"""
    async with AsyncSessionLocal() as session:
        try:
            yield session
            await session.commit()  # 自动提交
        except Exception:
            await session.rollback()  # 异常时回滚
            raise
        finally:
            await session.close()  # 确保连接归还到池

@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
    """会话上下文管理器,可以手动控制会话生命周期"""
    async with AsyncSessionLocal() as session:
        try:
            yield session
        finally:
            await session.close()

# 数据库初始化
async def init_database():
    """初始化数据库连接池"""
    try:
        # 导入所有模型,确保它们注册到Base.metadata中
        from app.models import Product, User, Order
        
        # 创建所有表(生产环境建议使用迁移工具如Alembic)
        async with engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
        
        # 预热连接池(建立初始连接,避免首次请求延迟)
        async with engine.connect() as conn:
            await conn.execute("SELECT 1")
        
        logger.info(f"✅ 数据库连接池初始化完成,连接池大小: {engine.pool.size()}")
        
    except Exception as e:
        logger.error(f"❌ 数据库初始化失败: {e}")
        raise

# 数据库关闭
async def close_database():
    """关闭数据库连接池"""
    try:
        await engine.dispose()
        logger.info("✅ 数据库连接池已关闭")
    except Exception as e:
        logger.error(f"❌ 数据库关闭失败: {e}")
        raise

# 连接池监控
async def get_pool_stats() -> dict:
    """获取连接池统计信息"""
    try:
        pool = engine.pool
        return {
            "connections_in_use": pool.checkedout(),
            "connections_idle": pool.checkedin(),
            "pool_size": pool.size(),
            "max_overflow": pool._max_overflow,
            "timeout": pool._timeout,
            "recycle": pool._recycle,
        }
    except Exception as e:
        logger.error(f"❌ 获取连接池统计信息失败: {e}")
        return {}

# 数据库健康检查
async def check_database_health() -> dict:
    """检查数据库健康状况"""
    try:
        start_time = __import__("time").time()
        async with engine.connect() as conn:
            result = await conn.execute("SELECT 1 as health_check")
            row = result.first()
            
            query_time = __import__("time").time() - start_time
            
            if row and row[0] == 1:
                return {
                    "status": "healthy",
                    "latency_ms": round(query_time * 1000, 2),
                    "message": "数据库连接正常"
                }
            else:
                return {
                    "status": "unhealthy",
                    "latency_ms": round(query_time * 1000, 2),
                    "message": "数据库健康检查查询返回异常结果"
                }
    
    except Exception as e:
        logger.error(f"❌ 数据库健康检查失败: {e}")
        return {
            "status": "unhealthy",
            "latency_ms": None,
            "message": f"数据库连接失败: {str(e)}"
        }

# 数据库连接测试
async def test_database_connection():
    """测试数据库连接(用于调试)"""
    try:
        logger.info("🔍 测试数据库连接...")
        
        # 连接池统计
        stats = await get_pool_stats()
        logger.info(f"连接池统计: {stats}")
        
        # 健康检查
        health = await check_database_health()
        logger.info(f"健康检查: {health}")
        
        return {
            "stats": stats,
            "health": health,
            "database_url": DATABASE_URL.split("@")[-1]  # 隐藏密码
        }
        
    except Exception as e:
        logger.error(f"❌ 数据库连接测试失败: {e}")
        return {
            "error": str(e),
            "status": "failed"
        }

# 数据库优化建议
def get_database_optimization_tips() -> list:
    """获取数据库优化建议"""
    tips = [
        "1. 确保数据库连接池大小设置合理(建议:CPU核心数 × 2 + 有效磁盘数)",
        "2. 启用pool_pre_ping,自动检测失效连接",
        "3. 设置合理的pool_recycle时间,防止数据库连接超时",
        "4. 对于读多写少的场景,考虑使用读写分离",
        "5. 定期监控数据库连接池状态,及时发现瓶颈",
        "6. 对于热点数据,使用缓存减少数据库查询压力",
        "7. 优化查询语句,避免全表扫描,使用合适的索引",
        "8. 定期分析慢查询日志,定位性能瓶颈",
        "9. 考虑使用连接池监控工具,实时掌握连接状态",
        "10. 在生产环境使用数据库连接池的预热机制",
    ]
    return tips

# 数据库配置验证
def validate_database_config() -> dict:
    """验证数据库配置是否合理"""
    import multiprocessing
    
    cpu_count = multiprocessing.cpu_count()
    recommended_pool_size = cpu_count * 2
    max_recommended_overflow = recommended_pool_size * 1.5
    
    config = {
        "cpu_count": cpu_count,
        "recommended_pool_size": recommended_pool_size,
        "recommended_max_overflow": int(max_recommended_overflow),
        "current_pool_size": int(os.getenv("DB_POOL_SIZE", 20)),
        "current_max_overflow": int(os.getenv("DB_MAX_OVERFLOW", 30)),
        "issues": [],
    }
    
    # 检查连接池大小
    current_size = config["current_pool_size"]
    if current_size < recommended_pool_size:
        config["issues"].append(
            f"连接池大小偏小,当前{current_size},推荐{recommended_pool_size}"
        )
    elif current_size > recommended_pool_size * 2:
        config["issues"].append(
            f"连接池大小可能过大,当前{current_size},推荐{recommended_pool_size}"
        )
    
    # 检查最大溢出连接数
    current_overflow = config["current_max_overflow"]
    if current_overflow > max_recommended_overflow:
        config["issues"].append(
            f"最大溢出连接数过高,当前{current_overflow},推荐{int(max_recommended_overflow)}"
        )
    
    config["is_valid"] = len(config["issues"]) == 0
    
    return config

2.3 连接池监控与调优

添加连接池监控功能,实时掌握连接状态:

import time
from prometheus_client import Gauge, Counter
from sqlalchemy import text

# Prometheus指标
DB_CONNECTIONS_IN_USE = Gauge(
    'db_connections_in_use',
    '当前正在使用的数据库连接数'
)

DB_CONNECTIONS_IDLE = Gauge(
    'db_connections_idle',
    '当前空闲的数据库连接数'
)

DB_CONNECTION_WAIT_TIME = Histogram(
    'db_connection_wait_seconds',
    '获取数据库连接的等待时间',
    buckets=(0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0)
)

async def monitor_db_pool():
    """实时监控数据库连接池状态"""
    # 获取连接池统计信息
    pool = engine.pool
    
    # 更新指标
    DB_CONNECTIONS_IN_USE.set(pool.checkedout())
    DB_CONNECTIONS_IDLE.set(pool.checkedin())
    
    # 记录连接等待时间
    start_time = time.time()
    async with engine.connect() as conn:
        wait_time = time.time() - start_time
        DB_CONNECTION_WAIT_TIME.observe(wait_time)
    
    # 如果连接等待时间超过阈值,发出警告
    if wait_time > 0.5:  # 500ms
        logger.warning(f"⚠️ 数据库连接等待时间过长: {wait_time:.3f}秒")

第3部分:Redis缓存架构 - 80%请求不上数据库

3.1 多层缓存策略设计

创建 app/cache.py,实现生产级缓存架构:

import json
import hashlib
import time
import asyncio
from typing import Any, Optional, Union, Callable
from functools import wraps
import logging

import aioredis
from fastapi import Depends
from databases import Database

# 配置日志
logger = logging.getLogger(__name__)

class RedisCache:
    """生产级Redis缓存管理器"""
    
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.redis_url = redis_url
        self._redis: Optional[aioredis.Redis] = None
        self._local_cache = {}  # 进程内缓存(有限使用)
        self._local_cache_ttl = {}
        self._local_cache_hits = 0
        self._local_cache_misses = 0
        self._redis_cache_hits = 0
        self._redis_cache_misses = 0
        
    async def connect(self):
        """连接Redis"""
        if not self._redis or await self._redis.ping() == False:
            if self._redis:
                await self._redis.close()
            
            self._redis = await aioredis.from_url(
                self.redis_url,
                encoding="utf-8",
                decode_responses=True,
                max_connections=int(os.getenv("REDIS_MAX_CONNECTIONS", 50)),
                socket_keepalive=True,
                retry_on_timeout=True,
                socket_timeout=5,
                retry=3,
            )
            
            logger.info(f"✅ Redis连接已建立: {self.redis_url}")
        
        return self._redis
    
    async def disconnect(self):
        """断开Redis连接"""
        if self._redis:
            await self._redis.close()
            self._redis = None
            logger.info("✅ Redis连接已关闭")
    
    def _generate_cache_key(self, func: Callable, *args, **kwargs) -> str:
        """生成唯一的缓存键"""
        # 基于函数名和参数生成MD5哈希
        arg_str = "|".join(str(arg) for arg in args)
        kwarg_str = "|".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
        raw_key = f"{func.__module__}.{func.__name__}|{arg_str}|{kwarg_str}"
        
        return f"cache:{hashlib.md5(raw_key.encode()).hexdigest()}"
    
    async def get(self, key: str, default: Any = None) -> Any:
        """从缓存获取数据(多级缓存策略)"""
        # 1. 先查进程内缓存(超快)
        current_time = time.time()
        if key in self._local_cache:
            ttl = self._local_cache_ttl.get(key, 0)
            if current_time < ttl:
                self._local_cache_hits += 1
                logger.debug(f"💾 进程内缓存命中: {key}")
                return self._local_cache[key]
            else:
                # 进程内缓存过期,清理
                del self._local_cache[key]
                del self._local_cache_ttl[key]
        
        # 2. 查Redis缓存
        try:
            redis = await self.connect()
            cached_data = await redis.get(key)
            
            if cached_data is not None:
                self._redis_cache_hits += 1
                
                try:
                    # 尝试JSON解析
                    data = json.loads(cached_data)
                except json.JSONDecodeError:
                    # 如果不是JSON,直接返回字符串
                    data = cached_data
                
                # 更新进程内缓存(过期时间较短,30秒)
                self._local_cache[key] = data
                self._local_cache_ttl[key] = current_time + 30
                
                logger.debug(f"🔍 Redis缓存命中: {key}")
                return data
        
        except Exception as e:
            logger.warning(f"⚠️ Redis缓存查询失败: {e}")
            self._redis_cache_misses += 1
        else:
            self._redis_cache_misses += 1
        
        return default
    
    async def set(self, key: str, value: Any, ttl: int = 3600) -> bool:
        """设置缓存(自动序列化)"""
        try:
            redis = await self.connect()
            
            # 序列化
            if isinstance(value, (dict, list, tuple, set)):
                value = json.dumps(value, ensure_ascii=False)
            elif isinstance(value, (int, float, bool, str)):
                pass  # Redis原生支持
            else:
                # 其他类型尝试JSON序列化
                try:
                    value = json.dumps(value, ensure_ascii=False)
                except:
                    value = str(value)
            
            # 设置Redis缓存
            result = await redis.setex(key, ttl, value)
            
            # 同时更新进程内缓存(短期)
            if result:
                # 反序列化存入进程内缓存
                if isinstance(value, str):
                    try:
                        cached_value = json.loads(value)
                    except:
                        cached_value = value
                else:
                    cached_value = value
                
                self._local_cache[key] = cached_value
                self._local_cache_ttl[key] = time.time() + min(ttl, 30)  # 不超过30秒
            
            return bool(result)
        
        except Exception as e:
            logger.error(f"❌ Redis缓存设置失败: {e}")
            return False
    
    async def delete(self, key: str) -> int:
        """删除缓存(多级清理)"""
        # 清理进程内缓存
        if key in self._local_cache:
            del self._local_cache[key]
        if key in self._local_cache_ttl:
            del self._local_cache_ttl[key]
        
        # 清理Redis缓存
        try:
            redis = await self.connect()
            result = await redis.delete(key)
            return result
        except Exception as e:
            logger.error(f"❌ Redis缓存删除失败: {e}")
            return 0
    
    async def get_or_set(self, key: str, func: Callable, ttl: int = 3600, *args, **kwargs) -> Any:
        """获取或设置缓存(原子操作)"""
        # 先尝试获取
        cached_value = await self.get(key)
        if cached_value is not None:
            return cached_value
        
        # 缓存未命中,执行函数
        if asyncio.iscoroutinefunction(func):
            value = await func(*args, **kwargs)
        else:
            value = func(*args, **kwargs)
        
        # 设置缓存
        await self.set(key, value, ttl)
        
        return value
    
    async def clear_local_cache(self):
        """清理进程内缓存"""
        self._local_cache.clear()
        self._local_cache_ttl.clear()
        logger.info("✅ 进程内缓存已清理")
    
    def get_stats(self) -> dict:
        """获取缓存统计信息"""
        return {
            "local_cache": {
                "hits": self._local_cache_hits,
                "misses": self._local_cache_misses,
                "hit_rate": self._local_cache_hits / max(1, self._local_cache_hits + self._local_cache_misses),
                "size": len(self._local_cache),
            },
            "redis_cache": {
                "hits": self._redis_cache_hits,
                "misses": self._redis_cache_misses,
                "hit_rate": self._redis_cache_hits / max(1, self._redis_cache_hits + self._redis_cache_misses),
            },
            "total": {
                "hits": self._local_cache_hits + self._redis_cache_hits,
                "misses": self._local_cache_misses + self._redis_cache_misses,
            }
        }

# 全局缓存实例
cache = RedisCache()

# 缓存装饰器
def cached(ttl: int = 300, prefix: str = "", key_func: Optional[Callable] = None):
    """
    通用缓存装饰器
    :param ttl: 缓存过期时间(秒)
    :param prefix: 缓存键前缀
    :param key_func: 自定义缓存键生成函数
    """
    def decorator(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            # 生成缓存键
            if key_func:
                cache_key = key_func(func, *args, **kwargs)
            else:
                cache_key = cache._generate_cache_key(func, *args, **kwargs)
            
            if prefix:
                cache_key = f"{prefix}:{cache_key}"
            
            # 尝试从缓存获取
            cached_result = await cache.get(cache_key)
            if cached_result is not None:
                logger.debug(f"✅ 缓存命中: {func.__name__}")
                return cached_result
            
            # 缓存未命中,执行原函数
            logger.debug(f"❌ 缓存未命中: {func.__name__}")
            result = await func(*args, **kwargs)
            
            # 结果存入缓存
            await cache.set(cache_key, result, ttl)
            
            return result
        
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            # 生成缓存键
            if key_func:
                cache_key = key_func(func, *args, **kwargs)
            else:
                cache_key = cache._generate_cache_key(func, *args, **kwargs)
            
            if prefix:
                cache_key = f"{prefix}:{cache_key}"
            
            # 尝试从缓存获取
            cached_result = asyncio.run(cache.get(cache_key))
            if cached_result is not None:
                logger.debug(f"✅ 缓存命中: {func.__name__}")
                return cached_result
            
            # 缓存未命中,执行原函数
            logger.debug(f"❌ 缓存未命中: {func.__name__}")
            result = func(*args, **kwargs)
            
            # 结果存入缓存
            asyncio.run(cache.set(cache_key, result, ttl))
            
            return result
        
        # 根据函数类型返回相应的包装器
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper
    
    return decorator

# 防缓存穿透装饰器
def cache_with_penetration_protection(ttl: int = 300, empty_ttl: int = 60, prefix: str = ""):
    """
    带防穿透保护的缓存装饰器
    :param ttl: 正常数据缓存时间
    :param empty_ttl: 空数据缓存时间(防穿透)
    :param prefix: 缓存键前缀
    """
    def decorator(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            cache_key = cache._generate_cache_key(func, *args, **kwargs)
            if prefix:
                cache_key = f"{prefix}:{cache_key}"
            
            # 尝试获取缓存
            cached_result = await cache.get(cache_key)
            if cached_result is not None:
                # 如果是特殊标记的空结果,返回None
                if isinstance(cached_result, str) and cached_result == "__EMPTY__":
                    logger.debug(f"🔒 防穿透保护: {func.__name__}")
                    return None
                return cached_result
            
            # 执行原函数
            result = await func(*args, **kwargs)
            
            if result is None:
                # 结果为空,设置防穿透标记(较短时间)
                await cache.set(cache_key, "__EMPTY__", empty_ttl)
                logger.debug(f"🛡️ 设置防穿透标记: {func.__name__}")
            else:
                # 正常结果,正常缓存
                await cache.set(cache_key, result, ttl)
            
            return result
        
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            cache_key = cache._generate_cache_key(func, *args, **kwargs)
            if prefix:
                cache_key = f"{prefix}:{cache_key}"
            
            # 尝试获取缓存
            cached_result = asyncio.run(cache.get(cache_key))
            if cached_result is not None:
                # 如果是特殊标记的空结果,返回None
                if isinstance(cached_result, str) and cached_result == "__EMPTY__":
                    logger.debug(f"🔒 防穿透保护: {func.__name__}")
                    return None
                return cached_result
            
            # 执行原函数
            result = func(*args, **kwargs)
            
            if result is None:
                # 结果为空,设置防穿透标记(较短时间)
                asyncio.run(cache.set(cache_key, "__EMPTY__", empty_ttl))
                logger.debug(f"🛡️ 设置防穿透标记: {func.__name__}")
            else:
                # 正常结果,正常缓存
                asyncio.run(cache.set(cache_key, result, ttl))
            
            return result
        
        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper
    
    return decorator

# 缓存批量操作装饰器
def batch_cached(ttl: int = 300, prefix: str = "", key_field: str = "id):
    """
    批量缓存装饰器(适合批量查询接口)
    :param ttl: 缓存过期时间
    :param prefix: 缓存键前缀
    :param key_field: 缓存键使用的字段名
    """
    def decorator(func):
        @wraps(func)
        async def wrapper(ids: list, *args, **kwargs):
            # 结果字典
            results = {}
            # 未命中缓存的需要查询的ID列表
            missing_ids = []
            
            # 先批量查询缓存
            for id in ids:
                cache_key = f"{prefix}:{id}" if prefix else str(id)
                cached_value = await cache.get(cache_key)
                
                if cached_value is not None:
                    results[id] = cached_value
                else:
                    missing_ids.append(id)
            
            # 如果有未命中缓存的,执行原函数查询
            if missing_ids:
                batch_result = await func(missing_ids, *args, **kwargs)
                
                # 将查询结果存入缓存
                for id, value in batch_result.items():
                    cache_key = f"{prefix}:{id}" if prefix else str(id)
                    await cache.set(cache_key, value, ttl)
                    results[id] = value
            
            return results
        
        return wrapper
    
    return decorator

# 缓存依赖注入
async def get_cache():
    """获取缓存实例(依赖注入)"""
    await cache.connect()
    return cache

# 缓存健康检查
async def check_cache_health() -> dict:
    """检查缓存健康状况"""
    try:
        start_time = time.time()
        
        # 测试Redis连接
        redis = await cache.connect()
        pong = await redis.ping()
        
        latency = time.time() - start_time
        
        if pong:
            return {
                "status": "healthy",
                "latency_ms": round(latency * 1000, 2),
                "stats": cache.get_stats(),
                "message": "Redis缓存连接正常"
            }
        else:
            return {
                "status": "unhealthy",
                "latency_ms": round(latency * 1000, 2),
                "message": "Redis Ping命令返回异常"
            }
    
    except Exception as e:
        logger.error(f"❌ 缓存健康检查失败: {e}")
        return {
            "status": "unhealthy",
            "latency_ms": None,
            "message": f"Redis连接失败: {str(e)}"
        }

# 缓存优化建议
def get_cache_optimization_tips() -> list:
    """获取缓存优化建议"""
    tips = [
        "1. 热点数据使用多级缓存策略(进程内缓存 + Redis)",
        "2. 合理设置缓存过期时间,平衡数据新鲜度和缓存命中率",
        "3. 实现防缓存穿透机制,避免恶意请求压垮数据库",
        "4. 对于批量查询接口,使用批量缓存操作减少网络开销",
        "5. 监控缓存命中率,及时调整缓存策略",
        "6. 考虑使用布隆过滤器减少无效缓存查询",
        "7. 对于写多读少的场景,谨慎使用缓存,避免数据不一致",
        "8. 实施缓存预热机制,应对流量高峰",
        "9. 定期清理过期缓存,避免内存泄漏",
        "10. 使用连接池管理Redis连接,避免连接风暴",
    ]
    return tips

3.2 缓存优化实战:商品详情接口

创建 app/routers/products.py,展示完整的缓存应用:

from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, update
from typing import List, Optional, Dict, Any
import time
import logging

from app.database import get_db
from app.cache import cached, cache_with_penetration_protection, batch_cached
from app.models import Product, User, Order
from app.schemas import (
    ProductResponse, 
    PaginationParams, 
    SearchParams,
    PaginatedResponse,
    APIResponse
)

# 配置日志
logger = logging.getLogger(__name__)

router = APIRouter(prefix="/products", tags=["商品管理"])

# ==================== 商品详情接口 ====================

@router.get("/{product_id}", response_model=ProductResponse)
@cache_with_penetration_protection(ttl=600, empty_ttl=30, prefix="product")
async def get_product(
    product_id: int,
    db: AsyncSession = Depends(get_db),
    background_tasks: Optional[BackgroundTasks] = None
):
    """
    获取商品详情(带多级缓存和防穿透保护)
    
    特性:
    1. 进程内缓存:毫秒级响应
    2. Redis缓存:分布式缓存
    3. 防穿透保护:防止恶意请求压垮数据库
    4. 缓存预热:异步更新缓存
    
    缓存策略:
    - 正常数据:缓存10分钟
    - 空数据:缓存30秒(防穿透)
    """
    start_time = time.time()
    
    logger.info(f"🔍 查询商品详情: product_id={product_id}")
    
    try:
        # 查询数据库
        result = await db.execute(
            select(Product).filter(Product.id == product_id)
        )
        product = result.scalar_one_or_none()
        
        # 计算查询耗时
        db_query_time = time.time() - start_time
        
        # 商品不存在
        if not product:
            logger.warning(f"⚠️ 商品不存在: product_id={product_id}")
            return None
        
        # 异步更新商品热度(不影响当前请求)
        if background_tasks:
            background_tasks.add_task(
                update_product_popularity, 
                product_id, 
                increment=1
            )
        
        # 构建响应
        response = ProductResponse(
            id=product.id,
            name=product.name,
            description=product.description,
            price=float(product.price),
            original_price=float(product.original_price) if product.original_price else None,
            discount_rate=product.discount_rate,
            stock=product.stock,
            sku=product.sku,
            category=product.category,
            subcategory=product.subcategory,
            brand=product.brand,
            sales=product.sales,
            popularity=product.popularity,
            rating=product.rating,
            review_count=product.review_count,
            is_published=product.is_published,
            is_featured=product.is_featured,
            created_at=product.created_at,
            updated_at=product.updated_at,
        )
        
        # 添加性能指标
        response_dict = response.dict()
        response_dict["metadata"] = {
            "db_query_time_ms": round(db_query_time * 1000, 2),
            "cached": False,  # 首次查询,未命中缓存
            "request_id": getattr(getattr(Depends(get_db), "context", None), "request_id", None),
            "server_time": time.time(),
        }
        
        logger.info(f"✅ 商品查询成功: product_id={product_id}, 耗时={db_query_time:.3f}s")
        
        return response_dict
        
    except Exception as e:
        logger.error(f"❌ 商品查询失败: product_id={product_id}, 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"商品查询失败: {str(e)}")

# ==================== 热销商品接口 ====================

@router.get("/hot/{category}", response_model=Dict[str, Any])
@cached(ttl=300, prefix="hot_products")
async def get_hot_products(
    category: str,
    limit: int = Query(20, ge=1, le=100, description="返回数量"),
    db: AsyncSession = Depends(get_db)
):
    """
    获取热销商品列表(带缓存)
    
    缓存策略:
    - 缓存5分钟
    - 按分类缓存,不同分类独立缓存
    """
    try:
        logger.info(f"🔥 查询热销商品: category={category}, limit={limit}")
        
        # 查询热销商品
        result = await db.execute(
            select(Product)
            .filter(Product.category == category, Product.is_published == True)
            .order_by(Product.sales.desc(), Product.popularity.desc())
            .limit(limit)
        )
        
        products = result.scalars().all()
        
        response = {
            "category": category,
            "count": len(products),
            "products": [
                {
                    "id": p.id,
                    "name": p.name,
                    "price": float(p.price),
                    "original_price": float(p.original_price) if p.original_price else None,
                    "discount_rate": p.discount_rate,
                    "sales": p.sales,
                    "popularity": p.popularity,
                    "rating": p.rating,
                    "main_image_url": p.main_image_url,
                }
                for p in products
            ],
            "timestamp": time.time(),
        }
        
        logger.info(f"✅ 热销商品查询成功: category={category}, 数量={len(products)}")
        
        return response
        
    except Exception as e:
        logger.error(f"❌ 热销商品查询失败: category={category}, 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"热销商品查询失败: {str(e)}")

# ==================== 商品搜索接口 ====================

@router.get("/search", response_model=Dict[str, Any])
async def search_products(
    query: str = Query(..., min_length=1, max_length=200, description="搜索关键词"),
    category: Optional[str] = Query(None, description="分类筛选"),
    min_price: Optional[float] = Query(None, ge=0, description="最低价格"),
    max_price: Optional[float] = Query(None, ge=0, description="最高价格"),
    sort_by: str = Query("popularity", description="排序字段"),
    sort_order: str = Query("desc", description="排序顺序"),
    page: int = Query(1, ge=1, description="页码"),
    size: int = Query(20, ge=1, le=100, description="每页大小"),
    db: AsyncSession = Depends(get_db)
):
    """
    商品搜索接口(支持分页和多种筛选条件)
    
    搜索逻辑:
    1. 按关键词在名称和描述中搜索
    2. 支持分类筛选
    3. 支持价格区间筛选
    4. 支持多种排序方式
    5. 支持分页
    """
    try:
        logger.info(f"🔎 商品搜索: query={query}, category={category}, page={page}, size={size}")
        
        # 构建查询
        stmt = select(Product).filter(
            Product.is_published == True,
            (
                Product.name.ilike(f"%{query}%") | 
                Product.description.ilike(f"%{query}%")
            )
        )
        
        # 分类筛选
        if category:
            stmt = stmt.filter(Product.category == category)
        
        # 价格区间筛选
        if min_price is not None:
            stmt = stmt.filter(Product.price >= min_price)
        if max_price is not None:
            stmt = stmt.filter(Product.price <= max_price)
        
        # 排序
        if sort_by == "price":
            order_by = Product.price
        elif sort_by == "sales":
            order_by = Product.sales
        elif sort_by == "rating":
            order_by = Product.rating
        elif sort_by == "created_at":
            order_by = Product.created_at
        else:  # popularity
            order_by = Product.popularity
        
        if sort_order == "asc":
            stmt = stmt.order_by(order_by.asc())
        else:
            stmt = stmt.order_by(order_by.desc())
        
        # 分页
        offset = (page - 1) * size
        stmt = stmt.offset(offset).limit(size)
        
        # 执行查询
        result = await db.execute(stmt)
        products = result.scalars().all()
        
        # 获取总数
        count_stmt = select(func.count()).select_from(Product).filter(
            Product.is_published == True,
            (
                Product.name.ilike(f"%{query}%") | 
                Product.description.ilike(f"%{query}%")
            )
        )
        
        if category:
            count_stmt = count_stmt.filter(Product.category == category)
        if min_price is not None:
            count_stmt = count_stmt.filter(Product.price >= min_price)
        if max_price is not None:
            count_stmt = count_stmt.filter(Product.price <= max_price)
        
        count_result = await db.execute(count_stmt)
        total = count_result.scalar()
        
        # 构建响应
        response = {
            "query": query,
            "category": category,
            "total": total,
            "page": page,
            "size": size,
            "pages": (total + size - 1) // size if size > 0 else 0,
            "products": [
                {
                    "id": p.id,
                    "name": p.name,
                    "price": float(p.price),
                    "original_price": float(p.original_price) if p.original_price else None,
                    "discount_rate": p.discount_rate,
                    "stock": p.stock,
                    "category": p.category,
                    "subcategory": p.subcategory,
                    "brand": p.brand,
                    "sales": p.sales,
                    "rating": p.rating,
                    "main_image_url": p.main_image_url,
                }
                for p in products
            ],
            "filters": {
                "min_price": min_price,
                "max_price": max_price,
                "sort_by": sort_by,
                "sort_order": sort_order,
            },
            "timestamp": time.time(),
        }
        
        logger.info(f"✅ 商品搜索成功: query={query}, 总数={total}, 返回数量={len(products)}")
        
        return response
        
    except Exception as e:
        logger.error(f"❌ 商品搜索失败: query={query}, 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"商品搜索失败: {str(e)}")

# ==================== 批量商品查询接口 ====================

@router.post("/batch", response_model=Dict[str, Any])
@batch_cached(ttl=300, prefix="product", key_field="id")
async def get_products_batch(
    ids: List[int],
    db: AsyncSession = Depends(get_db)
):
    """
    批量查询商品详情(带批量缓存优化)
    
    优化特性:
    1. 批量缓存查询:减少网络开销
    2. 批量数据库查询:减少连接次数
    3. 结果聚合:一次性返回所有结果
    """
    try:
        logger.info(f"📦 批量查询商品: ids={ids[:10]}{'...' if len(ids) > 10 else ''}")
        
        if not ids:
            return {"products": {}}
        
        # 批量查询数据库
        result = await db.execute(
            select(Product).filter(Product.id.in_(ids))
        )
        
        products = result.scalars().all()
        
        # 构建响应字典
        response = {}
        for product in products:
            response[product.id] = {
                "id": product.id,
                "name": product.name,
                "price": float(product.price),
                "original_price": float(product.original_price) if product.original_price else None,
                "discount_rate": product.discount_rate,
                "stock": product.stock,
                "category": product.category,
                "brand": product.brand,
                "sales": product.sales,
                "rating": product.rating,
            }
        
        # 处理未找到的商品(防穿透机制已处理)
        found_ids = set(response.keys())
        for id in ids:
            if id not in found_ids:
                # 防穿透机制确保不会频繁查询数据库
                pass
        
        logger.info(f"✅ 批量查询成功: 请求数量={len(ids)}, 找到数量={len(response)}")
        
        return {"products": response}
        
    except Exception as e:
        logger.error(f"❌ 批量查询失败: ids={ids}, 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"批量查询失败: {str(e)}")

# ==================== 商品列表接口 ====================

@router.get("/", response_model=Dict[str, Any])
async def list_products(
    category: Optional[str] = Query(None, description="分类筛选"),
    featured: Optional[bool] = Query(None, description="是否推荐"),
    page: int = Query(1, ge=1, description="页码"),
    size: int = Query(20, ge=1, le=100, description="每页大小"),
    db: AsyncSession = Depends(get_db)
):
    """
    商品列表接口(支持分页和筛选)
    """
    try:
        logger.info(f"📋 查询商品列表: category={category}, featured={featured}, page={page}, size={size}")
        
        # 构建查询
        stmt = select(Product).filter(Product.is_published == True)
        
        if category:
            stmt = stmt.filter(Product.category == category)
        
        if featured is not None:
            stmt = stmt.filter(Product.is_featured == featured)
        
        # 排序(默认按创建时间倒序)
        stmt = stmt.order_by(Product.created_at.desc())
        
        # 分页
        offset = (page - 1) * size
        stmt = stmt.offset(offset).limit(size)
        
        # 执行查询
        result = await db.execute(stmt)
        products = result.scalars().all()
        
        # 获取总数
        count_stmt = select(func.count()).select_from(Product).filter(Product.is_published == True)
        
        if category:
            count_stmt = count_stmt.filter(Product.category == category)
        
        if featured is not None:
            count_stmt = count_stmt.filter(Product.is_featured == featured)
        
        count_result = await db.execute(count_stmt)
        total = count_result.scalar()
        
        # 构建响应
        response = {
            "total": total,
            "page": page,
            "size": size,
            "pages": (total + size - 1) // size if size > 0 else 0,
            "filters": {
                "category": category,
                "featured": featured,
            },
            "products": [
                {
                    "id": p.id,
                    "name": p.name,
                    "price": float(p.price),
                    "original_price": float(p.original_price) if p.original_price else None,
                    "discount_rate": p.discount_rate,
                    "stock": p.stock,
                    "category": p.category,
                    "brand": p.brand,
                    "sales": p.sales,
                    "rating": p.rating,
                    "main_image_url": p.main_image_url,
                }
                for p in products
            ],
            "timestamp": time.time(),
        }
        
        logger.info(f"✅ 商品列表查询成功: 总数={total}, 返回数量={len(products)}")
        
        return response
        
    except Exception as e:
        logger.error(f"❌ 商品列表查询失败: category={category}, 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"商品列表查询失败: {str(e)}")

# ==================== 辅助函数 ====================

async def update_product_popularity(product_id: int, increment: int = 1):
    """
    异步更新商品热度(后台任务)
    """
    try:
        from app.cache import cache
        
        # 这里可以连接到数据库,但为了简单起见,我们直接使用缓存
        cache_key = f"product_popularity:{product_id}"
        
        # 从缓存获取当前热度
        current_popularity = await cache.get(cache_key, default=0)
        
        # 更新热度
        new_popularity = current_popularity + increment
        
        # 存入缓存(1小时过期)
        await cache.set(cache_key, new_popularity, ttl=3600)
        
        # 可以定期同步到数据库
        logger.debug(f"📈 更新商品热度: product_id={product_id}, 新热度={new_popularity}")
        
    except Exception as e:
        logger.error(f"❌ 更新商品热度失败: product_id={product_id}, 错误={str(e)}")

def generate_cache_key_for_product(product_id: int) -> str:
    """
    为商品生成缓存键
    """
    return f"product:{product_id}"

def invalidate_product_cache(product_id: int):
    """
    使商品缓存失效(当商品信息更新时调用)
    """
    try:
        from app.cache import cache
        
        cache_key = generate_cache_key_for_product(product_id)
        
        # 删除缓存
        asyncio.create_task(cache.delete(cache_key))
        
        # 删除相关的批量缓存
        asyncio.create_task(cache.delete("hot_products:*"))
        
        logger.info(f"🗑️ 商品缓存已失效: product_id={product_id}")
        
    except Exception as e:
        logger.error(f"❌ 失效商品缓存失败: product_id={product_id}, 错误={str(e)}")

# ==================== 商品统计接口 ====================

@router.get("/stats/summary")
async def get_products_stats_summary(db: AsyncSession = Depends(get_db)):
    """
    获取商品统计摘要
    """
    try:
        logger.info("📊 查询商品统计摘要")
        
        # 查询商品总数
        total_result = await db.execute(
            select(func.count()).select_from(Product)
        )
        total = total_result.scalar()
        
        # 查询已发布商品数
        published_result = await db.execute(
            select(func.count()).select_from(Product).filter(Product.is_published == True)
        )
        published = published_result.scalar()
        
        # 查询推荐商品数
        featured_result = await db.execute(
            select(func.count()).select_from(Product).filter(Product.is_featured == True)
        )
        featured = featured_result.scalar()
        
        # 查询库存总量
        stock_result = await db.execute(
            select(func.sum(Product.stock)).select_from(Product)
        )
        total_stock = stock_result.scalar() or 0
        
        # 查询总销量
        sales_result = await db.execute(
            select(func.sum(Product.sales)).select_from(Product)
        )
        total_sales = sales_result.scalar() or 0
        
        response = {
            "total": total,
            "published": published,
            "featured": featured,
            "total_stock": total_stock,
            "total_sales": total_sales,
            "publish_rate": round(published / total * 100, 2) if total > 0 else 0,
            "featured_rate": round(featured / total * 100, 2) if total > 0 else 0,
            "timestamp": time.time(),
        }
        
        logger.info(f"✅ 商品统计摘要查询成功: 总数={total}, 已发布={published}")
        
        return response
        
    except Exception as e:
        logger.error(f"❌ 商品统计摘要查询失败: 错误={str(e)}")
        raise HTTPException(status_code=500, detail=f"商品统计摘要查询失败: {str(e)}")