部署一个能够处理高并发请求的 Whisper API 服务,需要仔细考虑架构设计、资源管理(尤其是 GPU)和扩展性。简单地在每个 API 请求中加载和运行模型是不可行的,会导致性能低下和资源耗尽。
核心架构思路:
- API 服务与推理工作者分离: 使用一个 Web 框架(如 FastAPI)作为 API 网关,负责接收请求、验证输入并将任务分发出去,但不直接执行耗时的 Whisper 推理。
- 任务队列: 引入消息队列(如 Celery + Redis 或 RabbitMQ)作为 API 服务和推理工作者之间的缓冲。API 服务将转录任务(包含音频数据)放入队列。
- 后台工作者 (Workers): 启动一个或多个独立的后台工作者进程。这些工作者从任务队列中获取任务,执行 Whisper 模型推理。
- 模型预加载: 每个工作者进程在启动时预先加载 Whisper 模型到内存(和 GPU 显存),避免在处理每个任务时重复加载。
- 并发处理: 运行多个工作者进程,利用多核 CPU 或多个 GPU 并行处理任务,从而提高整体吞吐量。
- 异步 API: API 服务使用异步框架(如 FastAPI)来高效处理网络连接,即使后台任务正在处理,API 服务也能继续接收新请求。
- 容器化与编排: 使用 Docker 将 API 服务和工作者打包成镜像,并使用 Docker Compose(本地/单机)或 Kubernetes(生产环境)进行部署、管理和扩展。
- 优化: 考虑使用优化的 Whisper 实现(如
faster-whisper)和可能的批处理(Batching)来进一步提升性能。
详细步骤:
第一步:环境准备与依赖安装
-
Python 环境: 确保您拥有 Python 3.8 或更高版本。强烈建议使用虚拟环境:
python -m venv whisper_env source whisper_env/bin/activate # Linux/macOS # whisper_env\Scripts\activate # Windows -
安装核心库:
openai-whisper: 官方 Whisper 库。faster-whisper: (强烈推荐) 一个基于 CTranslate2 的优化版 Whisper 实现,速度更快,内存占用更低,支持量化。fastapi: 高性能的异步 Python Web 框架。uvicorn: ASGI 服务器,用于运行 FastAPI 应用。celery: 分布式任务队列系统。redis或librabbitmq: Celery 的消息中间件(Broker)客户端库。您需要安装并运行 Redis 或 RabbitMQ 服务器。Redis 通常更简单。python-multipart: FastAPI 用于处理文件上传。soundfile或pydub: 用于处理音频文件格式(pydub可能需要ffmpeg)。
# 优先推荐 faster-whisper pip install faster-whisper fastapi "uvicorn[standard]" celery[redis] python-multipart soundfile # 如果选择官方 whisper # pip install openai-whisper fastapi "uvicorn[standard]" celery[redis] python-multipart soundfile # 如果使用 pydub 处理音频 (可能需要安装 ffmpeg) # pip install pydub # Linux (Debian/Ubuntu): sudo apt update && sudo apt install ffmpeg # macOS: brew install ffmpeg -
安装 PyTorch: Whisper 依赖 PyTorch。请根据您的系统环境(CPU 或 GPU CUDA 版本)访问 PyTorch 官网 获取正确的安装命令。例如,如果您有支持 CUDA 11.8 的 NVIDIA GPU:
# 访问官网获取最新、最适合您环境的命令 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118注意: 如果您只使用 CPU,请选择 CPU 版本的 PyTorch。如果您使用
faster-whisper,它不直接依赖 PyTorch 进行推理,但安装 PyTorch 可能有助于环境设置或音频处理。 -
安装并运行消息中间件 (Message Broker):
-
Redis (推荐,更简单):
# Linux (Debian/Ubuntu) sudo apt update && sudo apt install redis-server sudo systemctl enable redis-server --now # macOS brew install redis brew services start redis # 或者使用 Docker docker run -d -p 6379:6379 --name whisper-redis redis:alpine -
RabbitMQ (功能更强,配置稍复杂):
# Linux (Debian/Ubuntu) sudo apt update && sudo apt install rabbitmq-server sudo systemctl enable rabbitmq-server --now # macOS brew install rabbitmq brew services start rabbitmq # 或者使用 Docker (包含管理界面) docker run -d -p 5672:5672 -p 15672:15672 --name whisper-rabbitmq rabbitmq:3-management-alpine
-
第二步:选择并准备 Whisper 模型
-
官方
openai-whisper: 模型会在首次调用whisper.load_model()时自动下载。您可以选择不同大小的模型 (tiny,base,small,medium,large,large-v2,large-v3)。模型越大,效果越好,但资源消耗(内存、显存)和推理时间也越长。 -
faster-whisper(推荐):-
性能更好,内存占用更低。
-
支持
float16和int8量化,进一步优化性能和资源。 -
模型需要先转换为 CTranslate2 格式。
faster-whisper首次加载模型时通常会自动处理下载和转换,或者您可以手动转换。 -
示例 (在 Worker 代码中加载):
from faster_whisper import WhisperModel model_size = "medium" # 或 base, small, large-v3 等 # device="cuda" 使用 GPU, device="cpu" 使用 CPU # compute_type="float16" (半精度浮点,推荐用于 GPU) # compute_type="int8_float16" (INT8 量化,速度更快,显存更低,精度略有损失) # compute_type="int8" (纯 INT8 量化) # compute_type="float32" (单精度浮点,CPU 默认) model = WhisperModel(model_size, device="cuda", compute_type="float16") print(f"Loaded faster-whisper model: {model_size} on {model.device} with compute type {model.compute_type}")
-
第三步:创建 Celery 任务 (Worker 逻辑)
创建一个 tasks.py 文件,定义用于执行音频转录的 Celery 任务。
# tasks.py
import os
import tempfile
import time
from celery import Celery, Task
from celery.utils.log import get_task_logger
# --- 选择模型实现 ---
USE_FASTER_WHISPER = True # 推荐设置为 True
if USE_FASTER_WHISPER:
from faster_whisper import WhisperModel
else:
import whisper
# --------------------
# --- 配置 Celery ---
# 使用 Redis 作为 Broker 和 Backend (结果存储)
# 请根据您的 Redis 服务器地址和数据库编号进行修改
BROKER_URL = 'redis://localhost:6379/0'
BACKEND_URL = 'redis://localhost:6379/1'
# 如果使用 RabbitMQ 作为 Broker
# BROKER_URL = 'amqp://guest:guest@localhost:5672//' # 默认用户/密码
celery_app = Celery(
'whisper_tasks',
broker=BROKER_URL,
backend=BACKEND_URL
)
# 设置 Celery 任务序列化方式为 json
celery_app.conf.update(
task_serializer='json',
result_serializer='json',
accept_content=['json']
)
logger = get_task_logger(__name__)
# --- 模型加载 ---
# 在 Celery worker 进程启动时加载模型,避免每次任务都加载
# 这是关键优化点!
# 注意: 对于多 GPU 环境,需要更复杂的逻辑来为每个 worker 分配特定 GPU。
# 下面的代码假定每个 worker 进程使用默认的第一个可用 GPU (或 CPU)。
MODEL_INSTANCE = None
MODEL_LOAD_TIME = 0
def get_model():
"""获取或加载模型实例 (单例模式)"""
global MODEL_INSTANCE, MODEL_LOAD_TIME
if MODEL_INSTANCE is None:
start_time = time.time()
logger.info("Loading Whisper model...")
try:
if USE_FASTER_WHISPER:
model_size = os.environ.get("WHISPER_MODEL_SIZE", "medium") # 从环境变量获取模型大小,默认为 medium
device = os.environ.get("WHISPER_DEVICE", "cuda") # cuda or cpu
compute_type = os.environ.get("WHISPER_COMPUTE_TYPE", "float16") # float16, int8_float16, int8, float32
model_path = os.environ.get("WHISPER_MODEL_PATH", model_size) # 可以指定本地模型路径
# 可以在这里指定下载路径 cache_dir
# model_cache_dir = "/path/to/your/model/cache"
# MODEL_INSTANCE = WhisperModel(model_path, device=device, compute_type=compute_type, download_root=model_cache_dir)
MODEL_INSTANCE = WhisperModel(model_path, device=device, compute_type=compute_type)
logger.info(f"Loaded faster-whisper model: {model_path} on {device} with compute type {compute_type}")
else:
model_size = os.environ.get("WHISPER_MODEL_SIZE", "medium")
device = os.environ.get("WHISPER_DEVICE", "cuda")
download_root = os.environ.get("WHISPER_DOWNLOAD_ROOT", None) # 指定模型下载目录
MODEL_INSTANCE = whisper.load_model(model_size, device=device, download_root=download_root)
logger.info(f"Loaded official whisper model: {model_size} on {MODEL_INSTANCE.device}")
MODEL_LOAD_TIME = time.time() - start_time
logger.info(f"Model loaded in {MODEL_LOAD_TIME:.2f} seconds.")
except Exception as e:
logger.error(f"Error loading Whisper model: {e}", exc_info=True)
raise # 抛出异常,阻止 worker 启动或任务执行
return MODEL_INSTANCE
# -----------------
# 预加载模型(可选,但推荐在 worker 启动时触发一次)
# get_model()
# --- Celery 任务定义 ---
class TranscriptionTask(Task):
"""自定义 Celery Task 类,确保模型已加载"""
_model = None
@property
def model(self):
if self._model is None:
logger.info("Model not loaded in task instance, attempting to load...")
self._model = get_model() # 确保模型被加载
return self._model
@celery_app.task(base=TranscriptionTask, bind=True)
def transcribe_audio_task(self: Task, audio_data: bytes, filename: str, language: str = None, task_config: dict = None):
"""
Celery task to transcribe audio using the pre-loaded Whisper model.
:param self: Celery task instance.
:param audio_data: Raw bytes of the audio file.
:param filename: Original filename (for format detection or logging).
:param language: Optional language code (e.g., 'en', 'zh'). If None, Whisper detects automatically.
:param task_config: Optional dictionary for Whisper parameters (e.g., beam_size, temperature).
:return: Dictionary with transcription result or error.
"""
task_id = self.request.id
logger.info(f"Task {task_id}: Received transcription job for '{filename}', language='{language}', config={task_config}")
# 确保模型已加载
try:
model = self.model
if model is None:
raise RuntimeError("Whisper model could not be loaded.")
except Exception as e:
logger.error(f"Task {task_id}: Failed to get model instance: {e}", exc_info=True)
# 更新任务状态为失败
self.update_state(state='FAILURE', meta={'exc_type': type(e).__name__, 'exc_message': str(e)})
# 返回错误信息,避免任务重试(除非配置了自动重试)
return {'status': 'FAILURE', 'error': f'Model loading error: {str(e)}'}
tmp_audio_path = None # 初始化变量
try:
# 将字节数据写入临时文件,因为很多库需要文件路径
# 考虑使用 io.BytesIO 如果库支持直接从内存读取 (例如 soundfile 可能支持)
suffix = os.path.splitext(filename)[1] if '.' in filename else '.tmp'
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_audio_file:
tmp_audio_file.write(audio_data)
tmp_audio_path = tmp_audio_file.name
logger.debug(f"Task {task_id}: Audio data saved to temporary file: {tmp_audio_path}")
transcribe_options = task_config if task_config else {}
if language:
transcribe_options['language'] = language
# 设置 beam_size, temperature 等参数 (示例)
# transcribe_options.setdefault('beam_size', 5)
# transcribe_options.setdefault('temperature', 0.0) # 可以是一个元组 (0.0, 0.2, ...) 用于温度回退
# --- 执行转录 ---
start_time = time.time()
logger.info(f"Task {task_id}: Starting transcription with options: {transcribe_options}")
if USE_FASTER_WHISPER:
# faster-whisper specific options: beam_size, temperature, vad_filter, etc.
segments, info = model.transcribe(tmp_audio_path, **transcribe_options)
# segments 是一个生成器,需要迭代获取结果
result_text = "".join([segment.text for segment in segments])
detected_language = info.language
language_probability = info.language_probability
duration = info.duration
logger.info(f"Task {task_id}: faster-whisper detected language '{detected_language}' with probability {language_probability:.2f}")
else:
# official whisper options: language, temperature, etc.
result = model.transcribe(tmp_audio_path, **transcribe_options)
result_text = result["text"]
detected_language = result["language"]
language_probability = None # 官方库不直接提供概率
duration = result.get("duration") # 可能需要计算
logger.info(f"Task {task_id}: official whisper detected language '{detected_language}'")
transcription_time = time.time() - start_time
logger.info(f"Task {task_id}: Transcription complete for '{filename}' in {transcription_time:.2f} seconds.")
# -----------------
return {
"status": "SUCCESS",
"text": result_text,
"language": detected_language,
"language_probability": language_probability,
"duration": duration,
"model_info": { # 添加一些模型信息
"implementation": "faster-whisper" if USE_FASTER_WHISPER else "official-whisper",
"size": model.model_size if USE_FASTER_WHISPER else model.ftype, # 获取模型大小信息可能需要调整
"device": str(model.device) if USE_FASTER_WHISPER else str(model.device),
"compute_type": model.compute_type if USE_FASTER_WHISPER else "float32" # 官方库主要是 float32
},
"timing": {
"transcription_seconds": transcription_time,
"model_load_seconds": MODEL_LOAD_TIME if 'MODEL_LOAD_TIME' in globals() else None
}
}
except Exception as e:
logger.error(f"Task {task_id}: Error during transcription of '{filename}': {e}", exc_info=True)
# 更新任务状态为失败
self.update_state(state='FAILURE', meta={'exc_type': type(e).__name__, 'exc_message': str(e)})
return {"status": "FAILURE", "error": str(e)}
finally:
# 清理临时文件
if tmp_audio_path and os.path.exists(tmp_audio_path):
try:
os.remove(tmp_audio_path)
logger.debug(f"Task {task_id}: Cleaned up temp file {tmp_audio_path}")
except OSError as e:
logger.warning(f"Task {task_id}: Could not delete temp file {tmp_audio_path}: {e}")
第四步:创建 FastAPI 应用 (API 网关)
创建一个 main.py 文件,定义 API 端点。
# main.py
import json
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Query, BackgroundTasks
from fastapi.responses import JSONResponse
from celery.result import AsyncResult
from tasks import transcribe_audio_task, celery_app # 导入 Celery app 和 task
import logging
import os
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Whisper Transcription API",
description="API for submitting audio files for transcription using Whisper via Celery.",
version="1.0.0"
)
# --- Helper Function ---
def parse_task_config(config_json: str = None) -> dict:
"""解析 JSON 字符串格式的任务配置"""
if not config_json:
return {}
try:
config = json.loads(config_json)
if not isinstance(config, dict):
raise ValueError("Task config must be a JSON object.")
# 在这里可以添加对配置项的验证逻辑
return config
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format for task_config.")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- API Endpoints ---
@app.post("/transcribe", status_code=202)
async def submit_transcription(
background_tasks: BackgroundTasks, # 用于后台关闭文件句柄
audio_file: UploadFile = File(..., description="The audio file to transcribe."),
language: str = Form(None, description="Optional: Language code (e.g., 'en', 'zh'). If None, auto-detect."),
task_config: str = Form(None, description="Optional: JSON string with Whisper task parameters (e.g., '{"beam_size": 5}').")
):
"""
Submit an audio file for transcription.
This endpoint accepts an audio file and optional parameters,
then queues the transcription task using Celery.
It immediately returns a task ID. Use the /results/{task_id}
endpoint to check the status and retrieve the result.
"""
logger.info(f"Received transcription request for file: {audio_file.filename}, content_type: {audio_file.content_type}, language: {language}")
# 基础验证
if not audio_file:
raise HTTPException(status_code=400, detail="No audio file provided.")
# 检查文件大小 (示例:限制 100MB)
max_size = 100 * 1024 * 1024 # 100 MB
size = await audio_file.read() # 读取整个文件来检查大小,对于大文件可能需要流式处理
if len(size) == 0:
raise HTTPException(status_code=400, detail="Audio file is empty.")
if len(size) > max_size:
raise HTTPException(status_code=413, detail=f"File size exceeds limit of {max_size / (1024*1024)} MB.")
await audio_file.seek(0) # 重置文件指针
# 解析 task_config
try:
parsed_config = parse_task_config(task_config)
except HTTPException as e:
# 确保文件被关闭
background_tasks.add_task(audio_file.close)
raise e # 重新抛出 HTTP 异常
try:
audio_data = await audio_file.read() # 再次读取(如果之前检查大小消耗了)
# 发送任务到 Celery 队列
# .delay() 是 .apply_async() 的快捷方式
task = transcribe_audio_task.delay(audio_data, audio_file.filename, language, parsed_config)
logger.info(f"Submitted task {task.id} for file '{audio_file.filename}' to Celery.")
return {"task_id": task.id, "message": "Transcription task submitted successfully."}
except Exception as e:
logger.error(f"Failed to submit task for file '{audio_file.filename}': {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: Failed to submit task.")
finally:
# 确保文件句柄被关闭,即使发生错误
# 使用 BackgroundTasks 确保在响应发送后关闭
background_tasks.add_task(audio_file.close)
logger.debug(f"Closed file handle for: {audio_file.filename}")
@app.get("/results/{task_id}")
async def get_transcription_result(task_id: str):
"""
Retrieve the status and result of a transcription task.
Poll this endpoint with the task_id received from /transcribe.
Possible statuses: PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED.
If status is SUCCESS, the 'result' field contains the transcription.
If status is FAILURE, the 'error' field contains details.
"""
logger.debug(f"Checking result for task_id: {task_id}")
task_result = AsyncResult(task_id, app=celery_app)
response = {"task_id": task_id, "status": task_result.state}
if task_result.ready():
if task_result.successful():
result = task_result.get()
response["result"] = result
logger.info(f"Task {task_id} completed successfully.")
else:
# 获取失败信息
try:
# 尝试从 backend 获取存储的异常信息
result = task_result.get(propagate=False) # propagate=False 避免在此处抛出异常
# Celery 任务失败时,异常信息通常存储在 task_result.info 中
if isinstance(task_result.info, Exception):
error_info = {'error': str(task_result.info)}
elif isinstance(task_result.info, dict): # 我们的任务在失败时返回 dict
error_info = task_result.info
else: # 其他未知情况
error_info = {'error': 'Task failed with unknown error.', 'details': repr(task_result.info)}
except Exception as e:
# 如果获取结果本身也失败
logger.error(f"Failed to retrieve failure info for task {task_id}: {e}", exc_info=True)
error_info = {'error': f'Could not retrieve failure details: {str(e)}'}
response.update(error_info) # 将错误信息合并到响应中
logger.warning(f"Task {task_id} failed. Info: {error_info}")
# 对于失败的任务,返回 500 状态码可能更合适
return JSONResponse(status_code=500, content=response)
else:
logger.debug(f"Task {task_id} is not ready yet. Status: {task_result.state}")
# 对于 PENDING 或 STARTED 状态,返回 200 OK
pass
return response
# 健康检查端点
@app.get("/health", summary="Health Check")
async def health_check():
"""Basic health check endpoint."""
# 可以添加更复杂的检查,例如 ping Celery broker
# try:
# celery_app.control.ping(timeout=1.0)
# except Exception:
# raise HTTPException(status_code=503, detail="Celery broker connection failed")
return {"status": "ok"}
# 可以添加一个获取 Worker 信息的端点 (需要配置 Celery Events 或 remote control)
# @app.get("/workers")
# async def get_workers():
# """Get information about active workers (requires Celery monitoring)."""
# try:
# inspector = celery_app.control.inspect()
# active = inspector.active()
# stats = inspector.stats()
# return {"active_tasks": active, "worker_stats": stats}
# except Exception as e:
# raise HTTPException(status_code=500, detail=f"Could not inspect workers: {e}")
第五步:运行服务和工作者
-
启动消息中间件: 确保您的 Redis 或 RabbitMQ 服务器正在运行。
-
启动 Celery Worker:
在包含 tasks.py 的项目根目录下打开一个终端,运行:
# -A 指定 Celery app 实例 (tasks 文件中的 celery_app) # worker 子命令表示启动 worker # --loglevel=info 设置日志级别 (debug, info, warning, error, critical) # --concurrency=N 设置并发处理任务的进程/线程数 (关键参数!) # -n worker1@%h 设置 worker 的名称 (可选) # --- 重要:并发设置 (concurrency) --- # 对于 CPU 任务: 可以设置为 CPU 核心数。 # 对于 GPU 任务 (常见情况): # - 如果一个 GPU 显存足够运行多个模型实例 (例如使用 int8 量化的小模型): # 可以设置 concurrency > 1,但要非常小心 GPU 显存分配和竞争。 # - 更常见和稳妥的做法是: 每个 worker 进程独占一个 GPU。 # 设置 concurrency=1,然后启动多个 worker 实例,每个实例绑定到不同的 GPU。 # 示例 1: 启动一个 worker 进程,使用默认 GPU (通常是 GPU 0) 或 CPU # 通过环境变量传递模型配置 (推荐) export WHISPER_MODEL_SIZE="medium" export WHISPER_DEVICE="cuda" # 或 "cpu" export WHISPER_COMPUTE_TYPE="float16" # 或 "int8_float16" 等 celery -A tasks.celery_app worker --loglevel=info --concurrency=1 -n worker1@%h # 示例 2: 在有多个 GPU 的机器上,为每个 GPU 启动一个 worker 进程 # 终端 1 (使用 GPU 0) export CUDA_VISIBLE_DEVICES=0 export WHISPER_MODEL_SIZE="medium" export WHISPER_DEVICE="cuda" export WHISPER_COMPUTE_TYPE="float16" celery -A tasks.celery_app worker --loglevel=info --concurrency=1 -n worker_gpu0@%h # 终端 2 (使用 GPU 1) export CUDA_VISIBLE_DEVICES=1 export WHISPER_MODEL_SIZE="medium" export WHISPER_DEVICE="cuda" export WHISPER_COMPUTE_TYPE="float16" celery -A tasks.celery_app worker --loglevel=info --concurrency=1 -n worker_gpu1@%hCUDA_VISIBLE_DEVICES环境变量用于限制该进程可见的 GPU 设备。
-
启动 FastAPI 应用:
在另一个终端中,运行 uvicorn:
# main:app 指向 main.py 文件中的 FastAPI 实例 app # --host 0.0.0.0 监听所有网络接口,允许外部访问 # --port 8000 指定服务端口 # --workers 4 (可选) 启动多个 uvicorn 工作进程来处理 API 请求 (不是 Celery worker!) # 这有助于处理大量并发的网络连接,因为 API 请求本身很快(只发送任务到队列)。 # 这个数量通常基于 CPU 核心数设置。 uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4
现在,您可以通过 http://<your-server-ip>:8000/docs 访问 FastAPI 的 Swagger UI,测试文件上传和结果获取。
第六步:容器化 (Docker)
使用 Docker 可以极大地简化部署、依赖管理和扩展过程。
-
requirements.txt: 创建一个包含所有 Python 依赖的文件。# requirements.txt fastapi uvicorn[standard] celery[redis] # 或 celery[librabbitmq] # --- 选择 Whisper 实现 --- faster-whisper # 推荐 # openai-whisper # 或者使用官方库 # ------------------------ # --- PyTorch (根据 GPU/CPU 和 CUDA 版本选择) --- # 例如 CUDA 11.8: torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 # --index-url https://download.pytorch.org/whl/cu118 # 安装时需要指定 index-url # -------------------------------------------- python-multipart soundfile # 或 pydub redis # redis 客户端库 # pydantic<2 # 如果遇到 pydantic v2 与 FastAPI/Celery 的兼容问题 -
Dockerfile.api(用于 FastAPI 服务): 这个镜像不需要 GPU 支持。# Dockerfile.api FROM python:3.10-slim WORKDIR /app # 安装系统依赖 (例如 ffmpeg 如果用 pydub) # RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg && rm -rf /var/lib/apt/lists/* COPY requirements.txt . # 安装 Python 依赖 (不安装 PyTorch GPU 版本,节省空间) # 注意:如果 tasks.py 导入了 torch,这里可能需要安装 CPU 版本的 torch RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 # 即使是 API 也可能需要 torch 库本身 COPY ./main.py ./tasks.py . # tasks.py 也需要复制,因为 main.py 导入了它 # 设置 Broker/Backend URL (将由 docker-compose 或 K8s 覆盖) ENV BROKER_URL=redis://redis:6379/0 ENV BACKEND_URL=redis://redis:6379/1 # ENV BROKER_URL=amqp://guest:guest@rabbitmq:5672// EXPOSE 8000 # 使用多个 uvicorn worker 提高 API 吞吐量 CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"] -
Dockerfile.worker(用于 Celery Worker): 这个镜像需要包含 CUDA 和 GPU 版 PyTorch。# Dockerfile.worker # 选择与您 PyTorch 和驱动兼容的 NVIDIA CUDA 基础镜像 FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 WORKDIR /app # 设置非交互式前端,避免 apt 安装时卡住 ENV DEBIAN_FRONTEND=noninteractive # 安装 Python 和其他系统依赖 RUN apt-get update && apt-get install -y --no-install-recommends \ python3.10 python3-pip python3.10-venv git curl \ # 如果需要 ffmpeg: # ffmpeg \ && rm -rf /var/lib/apt/lists/* # (可选) 创建虚拟环境 # RUN python3.10 -m venv /app/venv # ENV PATH="/app/venv/bin:$PATH" COPY requirements.txt . # --- 安装 PyTorch (GPU 版本) --- # 确保这里的 CUDA 版本 (cu118) 与基础镜像 (11.8.0) 匹配 RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 # ----------------------------- COPY ./tasks.py . # --- 预下载/转换模型 (可选但推荐) --- # 将模型下载/转换步骤放入 Docker 构建过程,避免每个容器启动时都执行 # 设置环境变量用于模型下载/加载 ENV WHISPER_MODEL_SIZE="medium" ENV WHISPER_DEVICE="cuda" ENV WHISPER_COMPUTE_TYPE="float16" ENV WHISPER_MODEL_PATH=$WHISPER_MODEL_SIZE # 可以改为 /models/$WHISPER_MODEL_SIZE # ENV WHISPER_DOWNLOAD_ROOT=/models # 指定下载目录 # 创建模型目录 (如果需要) # RUN mkdir /models # 运行 Python 脚本来触发模型下载/转换 # 使用 CPU 下载/转换可能更通用,运行时再指定 GPU # RUN python -c "from faster_whisper import WhisperModel; WhisperModel('$WHISPER_MODEL_SIZE', device='cpu', compute_type='$WHISPER_COMPUTE_TYPE', download_root=os.environ.get('WHISPER_DOWNLOAD_ROOT'))" # 或者官方 whisper # RUN python -c "import whisper; whisper.load_model('$WHISPER_MODEL_SIZE', device='cpu', download_root=os.environ.get('WHISPER_DOWNLOAD_ROOT'))" # ---------------------------------- # 设置 Broker/Backend URL (将由 docker-compose 或 K8s 覆盖) ENV BROKER_URL=redis://redis:6379/0 ENV BACKEND_URL=redis://redis:6379/1 # ENV BROKER_URL=amqp://guest:guest@rabbitmq:5672// # 设置 worker 启动时的默认并发数 (可以被 docker-compose 或 K8s 覆盖) ENV CELERY_CONCURRENCY=1 # 运行 Celery Worker # 使用 exec 确保信号能正确传递给 Celery 进程 CMD exec celery -A tasks.celery_app worker --loglevel=info --concurrency=$CELERY_CONCURRENCY- 关键 (GPU): Worker 的 Dockerfile 需要基于
nvidia/cuda镜像,并安装与 CUDA 兼容的 PyTorch。运行时需要使用 NVIDIA Docker Runtime (或 Kubernetes 的 NVIDIA Device Plugin)。
- 关键 (GPU): Worker 的 Dockerfile 需要基于
-
docker-compose.yml(用于本地或单机部署):version: '3.8' services: redis: image: redis:alpine container_name: whisper_redis ports: - "6379:6379" volumes: - redis_data:/data restart: always # rabbitmq: # 如果使用 RabbitMQ # image: rabbitmq:3-management-alpine # container_name: whisper_rabbitmq # ports: # - "5672:5672" # - "15672:15672" # 管理界面 # volumes: # - rabbitmq_data:/var/lib/rabbitmq # restart: always api: build: context: . dockerfile: Dockerfile.api container_name: whisper_api ports: - "8000:8000" depends_on: - redis # or rabbitmq environment: # 覆盖 Dockerfile 中的默认值,连接到 compose 网络中的 redis 服务 - BROKER_URL=redis://redis:6379/0 - BACKEND_URL=redis://redis:6379/1 # - BROKER_URL=amqp://guest:guest@rabbitmq:5672// volumes: # 可以挂载代码目录方便开发时热更新 (需要 uvicorn --reload) # - .:/app restart: always worker: build: context: . dockerfile: Dockerfile.worker container_name: whisper_worker # 如果 scale > 1,名称会自动加后缀 depends_on: - redis # or rabbitmq environment: - BROKER_URL=redis://redis:6379/0 - BACKEND_URL=redis://redis:6379/1 # - BROKER_URL=amqp://guest:guest@rabbitmq:5672// # --- 传递模型和设备配置 --- - WHISPER_MODEL_SIZE=medium - WHISPER_DEVICE=cuda - WHISPER_COMPUTE_TYPE=float16 # - WHISPER_MODEL_PATH=/models/medium # 如果模型已预置在镜像中 # - NVIDIA_VISIBLE_DEVICES=all # 允许容器访问所有 GPU (需要 deploy 配置) # -------------------------- - CELERY_CONCURRENCY=1 # 每个 worker 容器运行一个处理进程 # --- GPU 访问配置 --- deploy: resources: reservations: devices: - driver: nvidia # count: 1 # 每个容器实例分配一个 GPU # 或者使用 'all' 并配合 CUDA_VISIBLE_DEVICES capabilities: [gpu] # 必须 # -------------------- volumes: # 如果模型在宿主机上,可以挂载进来 # - /path/on/host/models:/models # 挂载代码目录方便开发 # - .:/app restart: always # 使用 scale 命令可以启动多个 worker 容器实例 # docker-compose up --build -d --scale worker=2 # 启动 2 个 worker 实例 volumes: redis_data: # rabbitmq_data: -
运行 Docker Compose:
-
确保您已安装 Docker、Docker Compose v2+ 以及 NVIDIA Container Toolkit (docs.nvidia.com/datacenter/…) 以便 Docker 能使用 GPU。
-
构建并启动服务(在后台
-d):docker compose up --build -d -
扩展 Worker 数量(例如,扩展到 2 个 worker 实例,需要机器有至少 2 个 GPU 或足够显存):
docker compose up -d --scale worker=2注意: Docker Compose 的
deploy.resources.reservations.devices.count可能需要调整,或者依赖NVIDIA_VISIBLE_DEVICES和 worker 内部逻辑来分配 GPU。对于多 GPU 精细控制,Kubernetes 通常更强大。 -
查看日志:
docker compose logs -f api docker compose logs -f worker -
停止并移除容器:
docker compose down
-
第七步:生产环境部署 (高并发与扩展)
对于生产环境,推荐使用 Kubernetes (K8s) 或云服务商提供的容器编排服务(如 AWS ECS/EKS, Google GKE, Azure AKS)。
-
基础设施:
-
API 网关/负载均衡器: 使用云服务商的负载均衡器(如 AWS ELB, Google Cloud Load Balancer)或 K8s Ingress Controller (如 Nginx Ingress) 将外部流量分发到 API 服务 Pod。
-
API 服务 (FastAPI): 部署为 K8s Deployment,运行在普通的 CPU 节点上。配置 Horizontal Pod Autoscaler (HPA) 基于 CPU/内存使用率或自定义指标(如 QPS)自动伸缩 Pod 数量。
-
Celery Workers: 部署为 K8s Deployment,运行在 GPU 节点 上。
- 确保 K8s 集群安装了 NVIDIA Device Plugin (github.com/NVIDIA/k8s-…),这样 K8s 才能调度 GPU 资源。
- 在 Worker 的 Deployment 配置中请求 GPU 资源 (
nvidia.com/gpu: 1)。 - 配置 HPA 基于 Celery 队列长度(需要监控系统支持,如 KEDA + Prometheus/Redis Exporter)或自定义业务指标来自动伸缩 Worker Pod 的数量。这是实现按需扩展处理能力的关键。
-
消息中间件: 使用云服务商托管的 Redis (如 AWS ElastiCache, Google Cloud Memorystore, Azure Cache for Redis) 或 RabbitMQ 服务,它们提供高可用性、可扩展性和易管理性。
-
模型存储:
- 将模型打包进 Worker Docker 镜像(如果模型不大,如
faster-whisper的量化模型)。 - 将模型存储在共享文件系统(如 NFS, AWS EFS, GCS Filestore)并挂载到 Worker Pods。
- 将模型存储在对象存储(如 S3, GCS)并在 Worker 启动时下载。
- 将模型打包进 Worker Docker 镜像(如果模型不大,如
-
-
部署工具:
- Kubernetes: 使用 Helm Charts 或 Kustomize 来定义和管理所有 K8s 资源(Deployments, Services, Ingress, HPA, ConfigMaps, Secrets 等),实现声明式部署和版本控制。
- CI/CD: 建立 CI/CD 流水线 (如 Jenkins, GitLab CI, GitHub Actions) 自动化构建 Docker 镜像、推送镜像仓库、部署到 K8s 集群。
-
监控与日志:
- 指标监控: 使用 Prometheus 收集指标(API 延迟/错误率、Celery 任务数/队列长度、GPU 利用率/显存、Pod CPU/内存),使用 Grafana 进行可视化和告警。需要部署相应的 Exporter(如 Redis Exporter, Celery Exporter (或通过 Flower), Node Exporter, DCGM Exporter for GPU)。
- 日志聚合: 使用 EFK (Elasticsearch, Fluentd/Fluent-bit, Kibana) 或 PLG (Prometheus, Loki, Grafana) 栈收集、存储和查询所有容器的日志。
- 分布式追踪: (可选) 使用 OpenTelemetry 等工具追踪请求在 API 服务和 Worker 之间的流转,帮助诊断性能瓶颈。
第八步:优化与注意事项
-
faster-whisper: 再次强调,使用faster-whisper并开启float16或int8量化是提升性能、降低资源消耗最有效的方法之一。 -
批处理 (Batching): 如果应用场景允许一定的延迟(例如,处理后台上传的长音频),可以在 Worker 中实现批处理逻辑:收集一定数量或等待一小段时间的任务,然后将多个音频一次性传递给
model.transcribe()(faster-whisper支持文件路径列表作为输入)。这能显著提高 GPU 利用率和吞吐量。需要修改tasks.py中的 Worker 逻辑。 -
模型选择: 根据实际业务对准确率的要求和可用的计算资源,选择合适的 Whisper 模型大小。
medium或small模型在速度和效果上通常有较好的平衡。 -
GPU 显存管理: 密切监控 Worker Pod 的 GPU 显存使用情况。如果显存不足:
- 换用更小的模型。
- 使用
int8量化。 - 确保每个 GPU 节点只运行显存容量允许的 Worker Pod 数量。
- 使用显存更大的 GPU 实例类型。
-
音频预处理/后处理: 音频的格式转换、重采样、降噪等操作可能也需要时间。考虑这些操作是在 API 端完成还是在 Worker 端完成。
-
错误处理与重试: 在 Celery 任务中实现健壮的错误处理。配置合理的任务重试策略(例如,只对网络问题或临时性错误重试)。对于无法处理的音频(格式错误、文件损坏),应标记为失败并提供清晰的错误信息。
-
安全性:
- API 接口添加认证和授权(如 API Key, JWT, OAuth2)。
- 限制上传文件的大小和可接受的
Content-Type。 - 对所有输入进行验证和清理。
- 保护消息队列和后台服务的网络访问(例如,使用 K8s Network Policies)。
-
成本优化: GPU 实例通常很昂贵。
- 精确配置 HPA,避免空闲时资源浪费。
- 利用云服务商的 Spot 实例(抢占式实例)运行 Worker Pods 可以大幅降低成本,但需要应用程序能够容忍实例被中断和替换(Celery 任务队列天然适合这种场景,任务可以被其他 Worker 接管)。
- 选择性价比合适的 GPU 类型。
这份指南提供了一个相对完整的路线图。根据您的具体需求(预期并发量、延迟敏感度、预算、技术栈熟悉度),您可能需要调整其中的某些部分。构建这样一个系统是一个迭代的过程,监控和持续优化至关重要。