AI网关架构设计:统一管理多LLM提供商的工程实践

2 阅读1分钟

引言

当一个企业同时使用OpenAI、Anthropic、Azure OpenAI、本地部署的LLaMA……如何统一管理这些提供商?如何实现智能路由、故障转移、成本控制和访问审计?

AI网关(AI Gateway) 正是为这一需求而生的基础设施组件。它在业务应用和LLM提供商之间架设统一的代理层,解决多提供商管理的复杂性。本文将深度解析AI网关的设计架构和工程实现。


一、AI网关的核心价值

1.1 没有AI网关的痛苦

现状(没有AI网关):
应用A → OpenAI API(各自独立的SDK)
应用B → Anthropic API(各自独立的SDK)
应用C → Azure OpenAI(各自独立的SDK)
应用D → 本地LLM(各自独立的SDK)

问题:
- 每个应用独立管理API密钥和成本
- 某个提供商宕机,所有依赖它的应用都挂
- 无法全局限速和配额管理
- 没有统一的访问日志和审计能力
- 切换提供商需要修改每个应用的代码

1.2 AI网关的核心能力

理想状态(有AI网关):

应用A ─┐
应用B ─┤→ [AI Gateway] ─→ OpenAI
应用C ─┤             ├─→ Anthropic
应用D ─┘             ├─→ Azure OpenAI
                      └─→ 本地LLM

AI网关提供:
✅ 统一API接口(OpenAI兼容格式)
✅ 智能路由(基于成本、延迟、能力)
✅ 自动故障转移
✅ 全局速率限制
✅ 访问审计和成本归因
✅ 响应缓存
✅ 敏感信息脱敏

二、核心架构设计

2.1 网关整体架构

from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from typing import Optional, AsyncIterator
import asyncio
import time
import json
import logging

app = FastAPI(title="AI Gateway")
logger = logging.getLogger(__name__)

class AIGateway:
    """AI网关核心组件"""
    
    def __init__(self):
        self.router = ProviderRouter()
        self.rate_limiter = RateLimiter()
        self.cache = ResponseCache()
        self.auditor = AuditLogger()
        self.cost_tracker = CostTracker()

2.2 提供商路由器

from dataclasses import dataclass, field
from enum import Enum
import httpx
import random

class RoutingStrategy(Enum):
    COST_OPTIMIZED = "cost_optimized"      # 最低成本
    LATENCY_OPTIMIZED = "latency_optimized" # 最低延迟
    CAPABILITY_BASED = "capability_based"   # 按能力路由
    ROUND_ROBIN = "round_robin"             # 轮询
    FAILOVER = "failover"                   # 故障转移

@dataclass
class ProviderConfig:
    name: str
    api_base: str
    api_key: str
    model_mapping: dict  # 统一模型名 → 提供商模型名
    
    # 能力声明
    max_context_tokens: int = 128000
    supports_vision: bool = False
    supports_function_calling: bool = True
    
    # 价格(per 1M tokens)
    input_price_per_1m: float = 3.0
    output_price_per_1m: float = 15.0
    
    # 实时指标(动态更新)
    avg_latency_ms: float = 1000
    error_rate: float = 0.0
    is_available: bool = True
    
    # 权重(用于加权轮询)
    weight: int = 1

class ProviderRouter:
    """智能提供商路由器"""
    
    def __init__(self):
        self.providers: dict[str, ProviderConfig] = {}
        self._round_robin_idx = 0
    
    def register_provider(self, config: ProviderConfig):
        self.providers[config.name] = config
    
    def select_provider(
        self,
        request_context: dict,
        strategy: RoutingStrategy = RoutingStrategy.COST_OPTIMIZED
    ) -> Optional[ProviderConfig]:
        """根据策略选择最优提供商"""
        
        available = [
            p for p in self.providers.values()
            if p.is_available and self._meets_requirements(p, request_context)
        ]
        
        if not available:
            return None
        
        if strategy == RoutingStrategy.COST_OPTIMIZED:
            return min(available, key=lambda p: p.input_price_per_1m)
        
        elif strategy == RoutingStrategy.LATENCY_OPTIMIZED:
            return min(available, key=lambda p: p.avg_latency_ms)
        
        elif strategy == RoutingStrategy.ROUND_ROBIN:
            provider = available[self._round_robin_idx % len(available)]
            self._round_robin_idx += 1
            return provider
        
        elif strategy == RoutingStrategy.CAPABILITY_BASED:
            return self._route_by_capability(available, request_context)
        
        elif strategy == RoutingStrategy.FAILOVER:
            # 按权重排序,选第一个可用的
            return sorted(available, key=lambda p: -p.weight)[0]
        
        return available[0]
    
    def _meets_requirements(
        self,
        provider: ProviderConfig,
        request_context: dict
    ) -> bool:
        """检查提供商是否满足请求需求"""
        
        # 检查context长度
        token_estimate = request_context.get("estimated_tokens", 0)
        if token_estimate > provider.max_context_tokens:
            return False
        
        # 检查是否需要视觉能力
        if request_context.get("has_images") and not provider.supports_vision:
            return False
        
        # 检查是否需要函数调用
        if request_context.get("has_tools") and not provider.supports_function_calling:
            return False
        
        return True
    
    def _route_by_capability(
        self,
        providers: list,
        context: dict
    ) -> ProviderConfig:
        """按任务复杂度路由"""
        
        complexity = context.get("complexity_score", 0.5)
        
        if complexity > 0.8:
            # 复杂任务:选最强模型(通常也最贵)
            return max(providers, key=lambda p: p.max_context_tokens)
        elif complexity < 0.3:
            # 简单任务:选最便宜的
            return min(providers, key=lambda p: p.input_price_per_1m)
        else:
            # 中等任务:按延迟
            return min(providers, key=lambda p: p.avg_latency_ms)
    
    async def update_provider_metrics(
        self,
        provider_name: str,
        latency_ms: float,
        success: bool
    ):
        """更新提供商实时指标(指数移动平均)"""
        
        if provider_name not in self.providers:
            return
        
        p = self.providers[provider_name]
        alpha = 0.1  # 平滑系数
        
        # 更新平均延迟
        p.avg_latency_ms = alpha * latency_ms + (1 - alpha) * p.avg_latency_ms
        
        # 更新错误率
        error_val = 0 if success else 1
        p.error_rate = alpha * error_val + (1 - alpha) * p.error_rate
        
        # 如果错误率过高,标记为不可用
        if p.error_rate > 0.5:
            p.is_available = False
            logger.warning(f"提供商 {provider_name} 错误率过高({p.error_rate:.1%}),暂时下线")
            
            # 5分钟后自动重试
            asyncio.create_task(self._schedule_recovery(provider_name, 300))
    
    async def _schedule_recovery(self, provider_name: str, delay_seconds: int):
        await asyncio.sleep(delay_seconds)
        if provider_name in self.providers:
            self.providers[provider_name].is_available = True
            self.providers[provider_name].error_rate = 0.0
            logger.info(f"提供商 {provider_name} 已恢复上线")

2.3 故障转移机制

class FailoverHandler:
    """自动故障转移处理器"""
    
    def __init__(self, router: ProviderRouter, max_retries: int = 3):
        self.router = router
        self.max_retries = max_retries
    
    async def execute_with_failover(
        self,
        request: dict,
        excluded_providers: set = None
    ) -> dict:
        """执行请求,自动故障转移"""
        
        excluded = excluded_providers or set()
        errors = []
        
        for attempt in range(self.max_retries):
            provider = self.router.select_provider(
                request,
                strategy=RoutingStrategy.FAILOVER
            )
            
            if provider is None or provider.name in excluded:
                break
            
            try:
                start = time.time()
                result = await self._call_provider(provider, request)
                latency = (time.time() - start) * 1000
                
                # 更新成功指标
                await self.router.update_provider_metrics(
                    provider.name, latency, success=True
                )
                
                result["_provider_used"] = provider.name
                result["_latency_ms"] = latency
                return result
            
            except Exception as e:
                latency = (time.time() - start) * 1000
                
                await self.router.update_provider_metrics(
                    provider.name, latency, success=False
                )
                
                errors.append({
                    "provider": provider.name,
                    "error": str(e)
                })
                excluded.add(provider.name)
                
                logger.warning(
                    f"提供商 {provider.name} 失败: {e},"
                    f"尝试故障转移({attempt+1}/{self.max_retries})"
                )
        
        raise Exception(f"所有提供商均失败: {errors}")
    
    async def _call_provider(
        self,
        provider: ProviderConfig,
        request: dict
    ) -> dict:
        """调用指定提供商"""
        
        # 转换模型名称
        model = request.get("model", "gpt-4o")
        provider_model = provider.model_mapping.get(model, model)
        provider_request = {**request, "model": provider_model}
        
        async with httpx.AsyncClient(timeout=30.0) as client:
            response = await client.post(
                f"{provider.api_base}/chat/completions",
                json=provider_request,
                headers={
                    "Authorization": f"Bearer {provider.api_key}",
                    "Content-Type": "application/json"
                }
            )
            response.raise_for_status()
            return response.json()

三、速率限制与配额管理

import redis.asyncio as redis
from typing import Optional

class RateLimiter:
    """基于Redis的分布式速率限制"""
    
    def __init__(self, redis_url: str):
        self.redis = redis.from_url(redis_url)
    
    async def check_and_consume(
        self,
        tenant_id: str,
        tokens_needed: int,
        limits: dict
    ) -> dict:
        """
        检查并消耗配额
        
        Returns:
            {"allowed": bool, "remaining": int, "reset_at": int}
        """
        pipe = self.redis.pipeline()
        
        now = int(time.time())
        
        # 分钟级限制
        minute_key = f"rl:{tenant_id}:min:{now // 60}"
        # 天级限制
        day_key = f"rl:{tenant_id}:day:{now // 86400}"
        # 月级Token配额
        month_key = f"rl:{tenant_id}:tokens:month:{now // (86400 * 30)}"
        
        async with self.redis.pipeline() as pipe:
            pipe.incrby(minute_key, tokens_needed)
            pipe.expire(minute_key, 120)
            pipe.incrby(day_key, 1)
            pipe.expire(day_key, 172800)
            pipe.incrby(month_key, tokens_needed)
            pipe.expire(month_key, 5184000)
            
            results = await pipe.execute()
        
        minute_total = results[0]
        day_requests = results[2]
        month_tokens = results[4]
        
        # 检查是否超限
        if minute_total > limits.get("tokens_per_minute", 100000):
            return {
                "allowed": False,
                "reason": "minute_token_limit",
                "remaining": 0,
                "reset_at": (now // 60 + 1) * 60
            }
        
        if day_requests > limits.get("requests_per_day", 10000):
            return {
                "allowed": False,
                "reason": "daily_request_limit",
                "remaining": 0,
                "reset_at": (now // 86400 + 1) * 86400
            }
        
        if month_tokens > limits.get("tokens_per_month", 10000000):
            return {
                "allowed": False,
                "reason": "monthly_token_limit",
                "remaining": 0,
                "reset_at": (now // (86400 * 30) + 1) * 86400 * 30
            }
        
        return {
            "allowed": True,
            "remaining": limits.get("tokens_per_minute", 100000) - minute_total,
            "reset_at": (now // 60 + 1) * 60
        }

四、FastAPI接口层

from fastapi import FastAPI, Request, Header
from fastapi.responses import JSONResponse, StreamingResponse

app = FastAPI(title="AI Gateway v1.0")

@app.post("/v1/chat/completions")
async def chat_completions(
    request: Request,
    authorization: str = Header(...)
):
    """OpenAI兼容的统一入口"""
    
    # 解析租户
    api_key = authorization.replace("Bearer ", "")
    tenant = await validate_api_key(api_key)
    if not tenant:
        raise HTTPException(status_code=401, detail="Invalid API key")
    
    body = await request.json()
    
    # 速率限制检查
    rate_check = await gateway.rate_limiter.check_and_consume(
        tenant["id"],
        tokens_needed=estimate_tokens(body),
        limits=tenant.get("limits", {})
    )
    
    if not rate_check["allowed"]:
        return JSONResponse(
            status_code=429,
            content={
                "error": {
                    "type": "rate_limit_exceeded",
                    "message": f"配额超限: {rate_check['reason']}",
                    "reset_at": rate_check["reset_at"]
                }
            }
        )
    
    # 路由并执行
    failover = FailoverHandler(gateway.router)
    
    try:
        result = await failover.execute_with_failover(body)
        
        # 审计日志
        await gateway.auditor.log({
            "tenant_id": tenant["id"],
            "model": body.get("model"),
            "provider": result.get("_provider_used"),
            "tokens": result.get("usage", {}),
            "latency_ms": result.get("_latency_ms")
        })
        
        return JSONResponse(content=result)
    
    except Exception as e:
        raise HTTPException(status_code=503, detail=str(e))

五、总结

AI网关是企业AI基础设施的核心中间件。关键设计要点:

  1. 统一接口:采用OpenAI兼容格式,业务应用无需改动
  2. 多策略路由:成本、延迟、能力三维度智能选择
  3. 自动故障转移:指数移动平均监测健康状态,自动切换
  4. 分布式限速:Redis支撑多实例部署,支持分钟/天/月三级配额
  5. 完整审计:每次调用记录租户、提供商、成本,支持费用分摊

从长远看,AI网关会逐渐向AI可观测性平台演进,集成Prompt管理、A/B测试、模型评估等能力,成为企业AI工程化的核心基础设施。