引言
当一个企业同时使用OpenAI、Anthropic、Azure OpenAI、本地部署的LLaMA……如何统一管理这些提供商?如何实现智能路由、故障转移、成本控制和访问审计?
AI网关(AI Gateway) 正是为这一需求而生的基础设施组件。它在业务应用和LLM提供商之间架设统一的代理层,解决多提供商管理的复杂性。本文将深度解析AI网关的设计架构和工程实现。
一、AI网关的核心价值
1.1 没有AI网关的痛苦
现状(没有AI网关):
应用A → OpenAI API(各自独立的SDK)
应用B → Anthropic API(各自独立的SDK)
应用C → Azure OpenAI(各自独立的SDK)
应用D → 本地LLM(各自独立的SDK)
问题:
- 每个应用独立管理API密钥和成本
- 某个提供商宕机,所有依赖它的应用都挂
- 无法全局限速和配额管理
- 没有统一的访问日志和审计能力
- 切换提供商需要修改每个应用的代码
1.2 AI网关的核心能力
理想状态(有AI网关):
应用A ─┐
应用B ─┤→ [AI Gateway] ─→ OpenAI
应用C ─┤ ├─→ Anthropic
应用D ─┘ ├─→ Azure OpenAI
└─→ 本地LLM
AI网关提供:
✅ 统一API接口(OpenAI兼容格式)
✅ 智能路由(基于成本、延迟、能力)
✅ 自动故障转移
✅ 全局速率限制
✅ 访问审计和成本归因
✅ 响应缓存
✅ 敏感信息脱敏
二、核心架构设计
2.1 网关整体架构
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from typing import Optional, AsyncIterator
import asyncio
import time
import json
import logging
app = FastAPI(title="AI Gateway")
logger = logging.getLogger(__name__)
class AIGateway:
"""AI网关核心组件"""
def __init__(self):
self.router = ProviderRouter()
self.rate_limiter = RateLimiter()
self.cache = ResponseCache()
self.auditor = AuditLogger()
self.cost_tracker = CostTracker()
2.2 提供商路由器
from dataclasses import dataclass, field
from enum import Enum
import httpx
import random
class RoutingStrategy(Enum):
COST_OPTIMIZED = "cost_optimized" # 最低成本
LATENCY_OPTIMIZED = "latency_optimized" # 最低延迟
CAPABILITY_BASED = "capability_based" # 按能力路由
ROUND_ROBIN = "round_robin" # 轮询
FAILOVER = "failover" # 故障转移
@dataclass
class ProviderConfig:
name: str
api_base: str
api_key: str
model_mapping: dict # 统一模型名 → 提供商模型名
# 能力声明
max_context_tokens: int = 128000
supports_vision: bool = False
supports_function_calling: bool = True
# 价格(per 1M tokens)
input_price_per_1m: float = 3.0
output_price_per_1m: float = 15.0
# 实时指标(动态更新)
avg_latency_ms: float = 1000
error_rate: float = 0.0
is_available: bool = True
# 权重(用于加权轮询)
weight: int = 1
class ProviderRouter:
"""智能提供商路由器"""
def __init__(self):
self.providers: dict[str, ProviderConfig] = {}
self._round_robin_idx = 0
def register_provider(self, config: ProviderConfig):
self.providers[config.name] = config
def select_provider(
self,
request_context: dict,
strategy: RoutingStrategy = RoutingStrategy.COST_OPTIMIZED
) -> Optional[ProviderConfig]:
"""根据策略选择最优提供商"""
available = [
p for p in self.providers.values()
if p.is_available and self._meets_requirements(p, request_context)
]
if not available:
return None
if strategy == RoutingStrategy.COST_OPTIMIZED:
return min(available, key=lambda p: p.input_price_per_1m)
elif strategy == RoutingStrategy.LATENCY_OPTIMIZED:
return min(available, key=lambda p: p.avg_latency_ms)
elif strategy == RoutingStrategy.ROUND_ROBIN:
provider = available[self._round_robin_idx % len(available)]
self._round_robin_idx += 1
return provider
elif strategy == RoutingStrategy.CAPABILITY_BASED:
return self._route_by_capability(available, request_context)
elif strategy == RoutingStrategy.FAILOVER:
# 按权重排序,选第一个可用的
return sorted(available, key=lambda p: -p.weight)[0]
return available[0]
def _meets_requirements(
self,
provider: ProviderConfig,
request_context: dict
) -> bool:
"""检查提供商是否满足请求需求"""
# 检查context长度
token_estimate = request_context.get("estimated_tokens", 0)
if token_estimate > provider.max_context_tokens:
return False
# 检查是否需要视觉能力
if request_context.get("has_images") and not provider.supports_vision:
return False
# 检查是否需要函数调用
if request_context.get("has_tools") and not provider.supports_function_calling:
return False
return True
def _route_by_capability(
self,
providers: list,
context: dict
) -> ProviderConfig:
"""按任务复杂度路由"""
complexity = context.get("complexity_score", 0.5)
if complexity > 0.8:
# 复杂任务:选最强模型(通常也最贵)
return max(providers, key=lambda p: p.max_context_tokens)
elif complexity < 0.3:
# 简单任务:选最便宜的
return min(providers, key=lambda p: p.input_price_per_1m)
else:
# 中等任务:按延迟
return min(providers, key=lambda p: p.avg_latency_ms)
async def update_provider_metrics(
self,
provider_name: str,
latency_ms: float,
success: bool
):
"""更新提供商实时指标(指数移动平均)"""
if provider_name not in self.providers:
return
p = self.providers[provider_name]
alpha = 0.1 # 平滑系数
# 更新平均延迟
p.avg_latency_ms = alpha * latency_ms + (1 - alpha) * p.avg_latency_ms
# 更新错误率
error_val = 0 if success else 1
p.error_rate = alpha * error_val + (1 - alpha) * p.error_rate
# 如果错误率过高,标记为不可用
if p.error_rate > 0.5:
p.is_available = False
logger.warning(f"提供商 {provider_name} 错误率过高({p.error_rate:.1%}),暂时下线")
# 5分钟后自动重试
asyncio.create_task(self._schedule_recovery(provider_name, 300))
async def _schedule_recovery(self, provider_name: str, delay_seconds: int):
await asyncio.sleep(delay_seconds)
if provider_name in self.providers:
self.providers[provider_name].is_available = True
self.providers[provider_name].error_rate = 0.0
logger.info(f"提供商 {provider_name} 已恢复上线")
2.3 故障转移机制
class FailoverHandler:
"""自动故障转移处理器"""
def __init__(self, router: ProviderRouter, max_retries: int = 3):
self.router = router
self.max_retries = max_retries
async def execute_with_failover(
self,
request: dict,
excluded_providers: set = None
) -> dict:
"""执行请求,自动故障转移"""
excluded = excluded_providers or set()
errors = []
for attempt in range(self.max_retries):
provider = self.router.select_provider(
request,
strategy=RoutingStrategy.FAILOVER
)
if provider is None or provider.name in excluded:
break
try:
start = time.time()
result = await self._call_provider(provider, request)
latency = (time.time() - start) * 1000
# 更新成功指标
await self.router.update_provider_metrics(
provider.name, latency, success=True
)
result["_provider_used"] = provider.name
result["_latency_ms"] = latency
return result
except Exception as e:
latency = (time.time() - start) * 1000
await self.router.update_provider_metrics(
provider.name, latency, success=False
)
errors.append({
"provider": provider.name,
"error": str(e)
})
excluded.add(provider.name)
logger.warning(
f"提供商 {provider.name} 失败: {e},"
f"尝试故障转移({attempt+1}/{self.max_retries})"
)
raise Exception(f"所有提供商均失败: {errors}")
async def _call_provider(
self,
provider: ProviderConfig,
request: dict
) -> dict:
"""调用指定提供商"""
# 转换模型名称
model = request.get("model", "gpt-4o")
provider_model = provider.model_mapping.get(model, model)
provider_request = {**request, "model": provider_model}
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{provider.api_base}/chat/completions",
json=provider_request,
headers={
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json"
}
)
response.raise_for_status()
return response.json()
三、速率限制与配额管理
import redis.asyncio as redis
from typing import Optional
class RateLimiter:
"""基于Redis的分布式速率限制"""
def __init__(self, redis_url: str):
self.redis = redis.from_url(redis_url)
async def check_and_consume(
self,
tenant_id: str,
tokens_needed: int,
limits: dict
) -> dict:
"""
检查并消耗配额
Returns:
{"allowed": bool, "remaining": int, "reset_at": int}
"""
pipe = self.redis.pipeline()
now = int(time.time())
# 分钟级限制
minute_key = f"rl:{tenant_id}:min:{now // 60}"
# 天级限制
day_key = f"rl:{tenant_id}:day:{now // 86400}"
# 月级Token配额
month_key = f"rl:{tenant_id}:tokens:month:{now // (86400 * 30)}"
async with self.redis.pipeline() as pipe:
pipe.incrby(minute_key, tokens_needed)
pipe.expire(minute_key, 120)
pipe.incrby(day_key, 1)
pipe.expire(day_key, 172800)
pipe.incrby(month_key, tokens_needed)
pipe.expire(month_key, 5184000)
results = await pipe.execute()
minute_total = results[0]
day_requests = results[2]
month_tokens = results[4]
# 检查是否超限
if minute_total > limits.get("tokens_per_minute", 100000):
return {
"allowed": False,
"reason": "minute_token_limit",
"remaining": 0,
"reset_at": (now // 60 + 1) * 60
}
if day_requests > limits.get("requests_per_day", 10000):
return {
"allowed": False,
"reason": "daily_request_limit",
"remaining": 0,
"reset_at": (now // 86400 + 1) * 86400
}
if month_tokens > limits.get("tokens_per_month", 10000000):
return {
"allowed": False,
"reason": "monthly_token_limit",
"remaining": 0,
"reset_at": (now // (86400 * 30) + 1) * 86400 * 30
}
return {
"allowed": True,
"remaining": limits.get("tokens_per_minute", 100000) - minute_total,
"reset_at": (now // 60 + 1) * 60
}
四、FastAPI接口层
from fastapi import FastAPI, Request, Header
from fastapi.responses import JSONResponse, StreamingResponse
app = FastAPI(title="AI Gateway v1.0")
@app.post("/v1/chat/completions")
async def chat_completions(
request: Request,
authorization: str = Header(...)
):
"""OpenAI兼容的统一入口"""
# 解析租户
api_key = authorization.replace("Bearer ", "")
tenant = await validate_api_key(api_key)
if not tenant:
raise HTTPException(status_code=401, detail="Invalid API key")
body = await request.json()
# 速率限制检查
rate_check = await gateway.rate_limiter.check_and_consume(
tenant["id"],
tokens_needed=estimate_tokens(body),
limits=tenant.get("limits", {})
)
if not rate_check["allowed"]:
return JSONResponse(
status_code=429,
content={
"error": {
"type": "rate_limit_exceeded",
"message": f"配额超限: {rate_check['reason']}",
"reset_at": rate_check["reset_at"]
}
}
)
# 路由并执行
failover = FailoverHandler(gateway.router)
try:
result = await failover.execute_with_failover(body)
# 审计日志
await gateway.auditor.log({
"tenant_id": tenant["id"],
"model": body.get("model"),
"provider": result.get("_provider_used"),
"tokens": result.get("usage", {}),
"latency_ms": result.get("_latency_ms")
})
return JSONResponse(content=result)
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
五、总结
AI网关是企业AI基础设施的核心中间件。关键设计要点:
- 统一接口:采用OpenAI兼容格式,业务应用无需改动
- 多策略路由:成本、延迟、能力三维度智能选择
- 自动故障转移:指数移动平均监测健康状态,自动切换
- 分布式限速:Redis支撑多实例部署,支持分钟/天/月三级配额
- 完整审计:每次调用记录租户、提供商、成本,支持费用分摊
从长远看,AI网关会逐渐向AI可观测性平台演进,集成Prompt管理、A/B测试、模型评估等能力,成为企业AI工程化的核心基础设施。