设计一个多品类牌类AI决策API:接口规范、架构与性能优化
我们团队开发了一套覆盖麻将(8种规则)、掼蛋、斗地主、德州扑克的 AI 决策 API。本文分享 API 的设计思路、接口规范、性能优化方案,以及在真实业务场景中踩过的坑。
一、需求场景
谁会调用牌类AI决策API?
| 客户类型 | 典型需求 | 延迟要求 |
|---|---|---|
| 棋牌平台 | 掉线玩家兜底、AI对手匹配 | <200ms |
| 陪伴机器人 | 陪老人打麻将/掼蛋 | <500ms |
| 教学应用 | AI教练出牌建议、复盘分析 | <1s |
| 游戏开发者 | 单机游戏AI对手 | <300ms |
核心需求:输入当前牌面状态,实时返回最优出牌决策。
二、API 接口设计
2.1 核心接口:决策请求
POST /api/v1/decision
Content-Type: application/json
Authorization: Bearer <api_key>
{
"game_type": "sichuan_mahjong",
"game_state": {
"hand": ["1wan", "2wan", "3wan", "5tiao", "5tiao", "9tong", ...],
"discards": [
{"player": 0, "tile": "dong", "turn": 3},
{"player": 2, "tile": "7wan", "turn": 4}
],
"melds": [],
"dora_indicators": [],
"seat": 0,
"round_wind": "east",
"prevalent_wind": "east"
},
"options": {
"top_k": 3,
"temperature": 1.0,
"include_analysis": true
}
}
2.2 响应格式
{
"action": {
"type": "discard",
"tile": "9tong",
"confidence": 0.78
},
"alternatives": [
{"type": "discard", "tile": "1wan", "confidence": 0.15},
{"type": "discard", "tile": "dong", "confidence": 0.07}
],
"analysis": {
"shanten": 1,
"waiting_tiles": ["4wan", "7wan"],
"danger_tiles": ["9tong"],
"win_probability": 0.23
},
"meta": {
"model_version": "v2.3.1",
"latency_ms": 156,
"game_type": "sichuan_mahjong"
}
}
2.3 支持的 game_type
SUPPORTED_GAMES = {
# 麻将 8 种
"sichuan_mahjong", # 四川血战到底
"guangdong_mahjong", # 广东推倒胡
"changsha_mahjong", # 长沙麻将
"hongzhong_mahjong", # 红中麻将
"riichi_mahjong", # 日本麻将
"mcr_mahjong", # 国标麻将
"hongkong_mahjong", # 香港麻将
"taiwan_mahjong", # 台湾麻将
# 其他
"guandan", # 掼蛋
"doudizhu", # 斗地主
"texas_holdem", # 无限注德州
}
2.4 难度控制
通过 temperature 参数控制 AI 强度:
# 服务端实现
def apply_difficulty(logits, temperature):
"""
temperature=0.1 → 接近贪心,总选最优(大师级)
temperature=1.0 → 按概率采样(正常强度)
temperature=3.0 → 接近随机(新手陪练)
"""
scaled = logits / max(temperature, 0.01)
return F.softmax(scaled, dim=-1)
三、系统架构
客户端
↓ HTTPS
Nginx (SSL终止 + 限流)
↓
API Gateway (FastAPI)
├── 鉴权中间件 (JWT/API Key)
├── 请求校验 (Pydantic schema)
├── 路由分发 (按 game_type)
↓
推理服务集群
├── 麻将推理节点 (GPU, 共享模型 + 8个规则头)
├── 掼蛋推理节点 (GPU)
├── 斗地主推理节点 (GPU)
└── 德州推理节点 (GPU)
↓
Redis (会话缓存 + 对局状态)
↓
PostgreSQL (调用日志 + 计费)
FastAPI 服务端核心代码
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
import torch
app = FastAPI(title="Card Game AI Decision API")
class DecisionRequest(BaseModel):
game_type: str
game_state: dict
options: dict = {}
class DecisionResponse(BaseModel):
action: dict
alternatives: list = []
analysis: dict = {}
meta: dict = {}
# 模型池:每种 game_type 对应一个推理模型
model_pool = {}
def load_models():
for game_type in SUPPORTED_GAMES:
model_pool[game_type] = load_model(game_type)
model_pool[game_type].eval()
@app.post("/api/v1/decision", response_model=DecisionResponse)
async def get_decision(req: DecisionRequest, api_key: str = Depends(verify_api_key)):
if req.game_type not in SUPPORTED_GAMES:
raise HTTPException(400, f"Unsupported game: {req.game_type}")
model = model_pool[req.game_type]
start = time.time()
# 编码输入
state_tensor = encode_state(req.game_state, req.game_type)
# 推理
with torch.no_grad():
logits = model(state_tensor)
# 应用难度和合法动作过滤
temperature = req.options.get("temperature", 1.0)
legal_mask = get_legal_mask(req.game_state, req.game_type)
logits[~legal_mask] = -float('inf')
probs = apply_difficulty(logits, temperature)
# 取 Top-K
top_k = req.options.get("top_k", 3)
top_probs, top_indices = torch.topk(probs, top_k)
latency = (time.time() - start) * 1000
return DecisionResponse(
action=decode_action(top_indices[0], req.game_type, top_probs[0]),
alternatives=[decode_action(idx, req.game_type, p)
for idx, p in zip(top_indices[1:], top_probs[1:])],
meta={"latency_ms": round(latency), "game_type": req.game_type}
)
四、性能优化
4.1 模型推理优化
# 1. TorchScript 编译加速
model = torch.jit.script(model)
# 2. 半精度推理
model = model.half() # FP16 推理速度提升 40%
# 3. 批量推理:把同一时间窗口内的请求合并
class BatchInferenceEngine:
def __init__(self, model, max_batch=32, max_wait_ms=10):
self.queue = asyncio.Queue()
self.model = model
async def infer(self, state):
future = asyncio.Future()
await self.queue.put((state, future))
return await future
async def batch_loop(self):
while True:
batch = await self.collect_batch()
states = torch.stack([s for s, _ in batch])
with torch.no_grad():
results = self.model(states)
for (_, future), result in zip(batch, results):
future.set_result(result)
4.2 延迟数据
| 优化阶段 | P50 延迟 | P99 延迟 |
|---|---|---|
| 优化前(FP32 单条) | 180ms | 350ms |
| + TorchScript | 120ms | 240ms |
| + FP16 | 85ms | 170ms |
| + 批量推理 | 60ms | 130ms |
五、踩坑记录
坑1:不同规则的输入编码差异
麻将 8 种规则的牌种数量不同(日麻有红宝牌、国标有花牌)。最初用统一编码导致日麻和国标的信息丢失。
解决:每种规则有自己的 StateEncoder,输出统一维度的特征向量。
坑2:对局状态一致性
客户端可能发送不一致的 game_state(比如手牌和出牌历史对不上)。加了严格的状态校验中间件:
def validate_state(state, game_type):
total_tiles = count_all_tiles(state)
expected = GAME_TILE_COUNTS[game_type]
if total_tiles != expected:
raise HTTPException(400, f"Tile count mismatch: {total_tiles} vs {expected}")
坑3:冷启动延迟
模型首次推理因为 CUDA kernel 编译会有 2-3 秒延迟。解决方案:服务启动时用 dummy 数据做一次 warmup。
六、小结
设计一个多品类牌类 AI API 的核心挑战是:统一多种差异巨大的规则系统,同时保证推理延迟满足实时对局需求。关键 takeaway:
- 共享底层 + 规则专属头的架构,让 11 种玩法共用一套推理服务
- TorchScript + FP16 + 批量推理三板斧把 P99 延迟压到 130ms
- 严格的状态校验是线上稳定性的基础
作者团队:长沙赢麻哒文化传播 | malinguo.com