文心大模型4.5-国产算力华为昇腾910B部署教程

0 阅读14分钟

项目概述

本教程将详细介绍如何在华为昇腾910B NPU硬件上部署百度文心大模型4.5(ERNIE-4.5-21B),并提供OpenAI兼容的API服务。该方案支持并发推理、流式输出,适用于国产算力环境部署。
⚠️注意:由于飞桨官方套件Fastdeploy尚未支持昇腾部署,因此本项目以transformers框架部署为基础,经过测试,部署可用但速度较慢,仅供参考学习,后续等vllm和fastdeploy成熟再持续更新高性能部署内容。

环境要求

  • 硬件平台 : 910B2x鲲鹏920(64GB) * 1
  • 系统:Linux a5.15.0-101-generic #111-Ubuntu aarch64
  • Ascend显卡驱动安装和 CANN 环境
  • Python 3.10+
  • 内存: 建议128GB+
  • 存储: 模型存放盘建议大于50GB

技术步骤

一、环境配置与依赖安装

1.1 NPU驱动安装验证

# 检查NPU设备状态
npu-smi info

正确输出如下图所示:

如果命令无输出或报错,请参考 华为昇腾官方文档 安装驱动。

1.2 自动化环境配置

项目提供了自动化环境配置脚本 env.sh ,执行以下命令:

#!/bin/bash

# 环境配置脚本 - 华为昇腾NPU环境设置
echo "开始配置华为昇腾NPU环境..."

# 检查NPU驱动是否正确安装
echo "检查NPU驱动状态..."
if [ ! -f "/usr/local/bin/npu-smi" ]; then
    echo "错误: 未找到npu-smi命令,请先安装NPU驱动"
    echo "请前往以下链接安装驱动: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html"
    exit 1
fi

# 检查npu-smi info是否有输出
echo "执行npu-smi info检查NPU设备..."
/usr/local/bin/npu-smi info > /tmp/npu_output.txt 2>&1
npu_exit_code=$?
npu_output=$(cat /tmp/npu_output.txt)

if [ $npu_exit_code -ne 0 ] || [ -z "$npu_output" ]; then
    echo "错误: npu-smi info 无输出或执行失败,NPU驱动未正确安装或NPU设备不可用"
    echo "请前往以下链接检查驱动安装: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html"
    echo "NPU驱动安装后,npu-smi info 应该有设备信息输出"
    echo "实际输出内容:"
    echo "$npu_output"
    exit 1
fi

# 显示NPU信息
echo "NPU设备信息:"
echo "$npu_output"

# 清理临时文件
rm -f /tmp/npu_output.txt

echo "NPU驱动检查通过,开始检测系统架构..."

# 检测系统架构
ARCH=$(uname -m)
echo "检测到系统架构: $ARCH"

# 根据架构选择下载链接和文件名
if [ "$ARCH" = "aarch64" ]; then
    DOWNLOAD_URL="https://gitee.com/ascend/pytorch/releases/download/v7.0.0-pytorch2.5.1/torch_npu-2.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
    WHEEL_FILE="torch_npu-2.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl"
    echo "使用ARM64架构的torch_npu包"
elif [ "$ARCH" = "x86_64" ]; then
    DOWNLOAD_URL="https://gitee.com/ascend/pytorch/releases/download/v7.0.0-pytorch2.5.1/torch_npu-2.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
    WHEEL_FILE="torch_npu-2.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
    echo "使用x86_64架构的torch_npu包"
else
    echo "错误: 不支持的系统架构 $ARCH"
    echo "仅支持 aarch64 和 x86_64 架构"
    exit 1
fi

echo "开始下载torch_npu包..."

# 检查文件是否已存在
if [ -f "$WHEEL_FILE" ]; then
    echo "检测到torch_npu包文件已存在,将覆盖现有文件"
fi

# 下载插件包(使用-O选项强制覆盖)
echo "正在下载torch_npu包..."
wget -O "$WHEEL_FILE" "$DOWNLOAD_URL"

if [ $? -ne 0 ]; then
    echo "错误: torch_npu包下载失败"
    exit 1
fi

echo "torch_npu包下载完成,开始安装..."

# 安装命令
pip install $WHEEL_FILE

if [ $? -ne 0 ]; then
    echo "错误: torch_npu包安装失败"
    exit 1
fi

echo "torch_npu包安装完成,正在测试..."

# 测试torch_npu是否支持(如果为True则为支持)
echo "测试torch_npu是否可用..."
result=$(python -c "import torch;import torch_npu;print(torch_npu.npu.is_available())" 2>/dev/null)

if [ "$result" = "True" ]; then
    echo "✅ torch_npu测试通过,NPU支持已启用"
else
    echo "❌ torch_npu测试失败,NPU支持未启用"
    echo "请检查NPU驱动和torch_npu安装是否正确"
    exit 1
fi

# 检查requirements.txt是否存在
if [ -f "requirements.txt" ]; then
    echo "正在安装项目依赖..."
    # 安装包依赖
    pip install -r requirements.txt
    
    if [ $? -eq 0 ]; then
        echo "✅ 项目依赖安装完成"
    else
        echo "❌ 项目依赖安装失败"
        exit 1
    fi
else
    echo "⚠️  未找到requirements.txt文件,跳过依赖安装"
fi

echo "🎉 华为昇腾NPU环境配置完成!"
echo "您现在可以使用torch_npu进行NPU加速计算了。"

该脚本将自动完成:

  • NPU驱动状态检测
  • 系统架构识别(x86_64/aarch64)
  • torch_npu包下载安装
  • 项目依赖安装
  • 环境可用性验证

1.3 验证环境

# 验证torch_npu可用性
python -c "import torch; import torch_npu; print('NPU可用:', torch_npu.npu.is_available())"

二、模型下载与准备

2.1 自动化模型下载

# 这里有点问题,用高版本的hf下载模型会很慢,但是低版本的就正常了
pip install huggingface_hub==0.19.0
export HF_ENDPOINT=https://hf-mirror.com
huggingface-cli download --resume-download baidu/ERNIE-4.5-21B-A3B-PT --local-dir ./baidu/ERNIE-4.5-21B-A3B-PT
pip install huggingface_hub --upgrade

2.2 验证模型文件

# 检查模型目录结构
ls -la ./baidu/ERNIE-4.5-21B-A3B-PT/

# 验证关键文件存在
ls -la ./baidu/ERNIE-4.5-21B-A3B-PT/config.json
ls -la ./baidu/ERNIE-4.5-21B-A3B-PT/model.safetensors.index.json

三、API服务部署

3.1 服务架构说明

ernie_api.py 实现了完整的OpenAI兼容API服务,主要特性:

  • OpenAI兼容 : 支持标准的 /v1/chat/completions 接口
  • 并发处理 : 多线程队列机制,支持高并发推理
  • 流式输出 : 支持Server-Sent Events流式响应
  • NPU优化 : 针对华为昇腾NPU优化的推理流程
  • 健康检查 : 提供服务状态监控接口
import asyncio
import time
import uuid
import json
from typing import List, Optional, Dict, Any, AsyncGenerator
from datetime import datetime
from queue import Queue
from threading import Thread, Lock

import torch_npu
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import uvicorn

# OpenAI API兼容的数据模型
class Message(BaseModel):
    role: str
    content: str

class ChatCompletionRequest(BaseModel):
    model: str = "ernie-4.5"
    messages: List[Message]
    max_tokens: Optional[int] = 1024
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9
    stream: Optional[bool] = False

class ChatCompletionResponse(BaseModel):
    id: str
    object: str = "chat.completion"
    created: int
    model: str
    choices: List[Dict[str, Any]]
    usage: Dict[str, int]

class ModelInfo(BaseModel):
    id: str
    object: str = "model"
    created: int
    owned_by: str = "baidu"

# 全局变量
app = FastAPI(title="ERNIE API Server", version="1.0.0")
model = None
tokenizer = None
request_queue = Queue()
response_dict = {}
response_lock = Lock()
stream_dict = {}
stream_lock = Lock()

# CORS中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def load_model():
    """加载模型和分词器"""
    global model, tokenizer
    model_name = "./baidu/ERNIE-4.5-21B-A3B-PT"
    device = "npu:0"
    
    print("正在加载模型...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True,
        torch_dtype=torch.float16
    ).to(device)
    print("模型加载完成")

def process_request(request_id: str, request_data: ChatCompletionRequest):
    """处理单个请求"""
    try:
        # 准备输入
        messages = [{"role": msg.role, "content": msg.content} for msg in request_data.messages]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to("npu")
        
        # 生成文本
        with torch.no_grad():
            generated_ids = model.generate(
                model_inputs.input_ids,
                max_new_tokens=request_data.max_tokens,
                temperature=request_data.temperature,
                top_p=request_data.top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
        
        # 构造响应
        response = ChatCompletionResponse(
            id=request_id,
            created=int(time.time()),
            model=request_data.model,
            choices=[
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": generated_text
                    },
                    "finish_reason": "stop"
                }
            ],
            usage={
                "prompt_tokens": len(model_inputs.input_ids[0]),
                "completion_tokens": len(output_ids),
                "total_tokens": len(model_inputs.input_ids[0]) + len(output_ids)
            }
        )
        
        # 存储响应
        with response_lock:
            response_dict[request_id] = response
            
    except Exception as e:
        with response_lock:
            response_dict[request_id] = {"error": str(e)}

def process_stream_request(request_id: str, request_data: ChatCompletionRequest):
    """处理流式请求"""
    try:
        # 准备输入
        messages = [{"role": msg.role, "content": msg.content} for msg in request_data.messages]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to("npu")
        
        # 创建流式生成器
        streamer = TextIteratorStreamer(
            tokenizer, 
            timeout=60.0, 
            skip_prompt=True, 
            skip_special_tokens=True
        )
        
        # 在新线程中生成文本
        generation_kwargs = {
            "input_ids": model_inputs.input_ids,
            "max_new_tokens": request_data.max_tokens,
            "temperature": request_data.temperature,
            "top_p": request_data.top_p,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
            "streamer": streamer
        }
        
        def generate():
            with torch.no_grad():
                model.generate(**generation_kwargs)
        
        thread = Thread(target=generate)
        thread.start()
        
        # 存储流式生成器
        with stream_lock:
            stream_dict[request_id] = {
                "streamer": streamer,
                "thread": thread,
                "created": int(time.time()),
                "model": request_data.model
            }
            
    except Exception as e:
        with stream_lock:
            stream_dict[request_id] = {"error": str(e)}

def worker():
    """工作线程处理队列中的请求"""
    while True:
        try:
            request_id, request_data = request_queue.get(timeout=1)
            if request_data.stream:
                process_stream_request(request_id, request_data)
            else:
                process_request(request_id, request_data)
            request_queue.task_done()
        except:
            continue

async def generate_stream_response(request_id: str) -> AsyncGenerator[str, None]:
    """生成流式响应"""
    with stream_lock:
        if request_id not in stream_dict:
            yield f"data: {{\"error\": \"Request not found\"}}\n\n"
            return
        
        stream_info = stream_dict[request_id]
        if "error" in stream_info:
            yield f"data: {{\"error\": \"{stream_info['error']}\"}}\n\n"
            return
    
    streamer = stream_info["streamer"]
    created = stream_info["created"]
    model_name = stream_info["model"]
    
    try:
        for new_text in streamer:
            if new_text:
                chunk = {
                    "id": request_id,
                    "object": "chat.completion.chunk",
                    "created": created,
                    "model": model_name,
                    "choices": [
                        {
                            "index": 0,
                            "delta": {
                                "content": new_text
                            },
                            "finish_reason": None
                        }
                    ]
                }
                yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
                await asyncio.sleep(0.01)  # 小延迟避免过快发送
        
        # 发送结束标记
        final_chunk = {
            "id": request_id,
            "object": "chat.completion.chunk",
            "created": created,
            "model": model_name,
            "choices": [
                {
                    "index": 0,
                    "delta": {},
                    "finish_reason": "stop"
                }
            ]
        }
        yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
        yield "data: [DONE]\n\n"
        
    except Exception as e:
        error_chunk = {
            "error": {
                "message": str(e),
                "type": "server_error"
            }
        }
        yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
    
    finally:
        # 清理资源
        with stream_lock:
            if request_id in stream_dict:
                del stream_dict[request_id]

@app.on_event("startup")
async def startup_event():
    """启动时加载模型并启动工作线程"""
    load_model()
    
    # 启动多个工作线程支持并发
    for i in range(2):  # 可以根据需要调整线程数
        thread = Thread(target=worker, daemon=True)
        thread.start()

@app.get("/v1/models")
async def list_models():
    """列出可用模型"""
    return {
        "object": "list",
        "data": [
            ModelInfo(
                id="ernie-4.5",
                created=int(time.time()),
                owned_by="baidu"
            )
        ]
    }

@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """聊天完成接口"""
    if model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="模型尚未加载完成")
    
    request_id = str(uuid.uuid4())
    
    # 将请求加入队列
    request_queue.put((request_id, request))
    
    if request.stream:
        # 流式响应
        # 等待流式处理开始
        max_wait_time = 30
        start_time = time.time()
        
        while time.time() - start_time < max_wait_time:
            with stream_lock:
                if request_id in stream_dict:
                    break
            await asyncio.sleep(0.1)
        else:
            raise HTTPException(status_code=408, detail="流式请求初始化超时")
        
        return StreamingResponse(
            generate_stream_response(request_id),
            media_type="text/plain",
            headers={
                "Cache-Control": "no-cache",
                "Connection": "keep-alive",
                "Content-Type": "text/event-stream"
            }
        )
    else:
        # 非流式响应
        max_wait_time = 300  # 最大等待时间5分钟
        start_time = time.time()
        
        while time.time() - start_time < max_wait_time:
            with response_lock:
                if request_id in response_dict:
                    response = response_dict.pop(request_id)
                    if "error" in response:
                        raise HTTPException(status_code=500, detail=response["error"])
                    return response
            
            await asyncio.sleep(0.1)
        
        raise HTTPException(status_code=408, detail="请求超时")

@app.get("/health")
async def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "queue_size": request_queue.qsize(),
        "active_streams": len(stream_dict)
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=6006)

3.2 启动API服务

# 启动服务(默认端口6006)
python ernie_api.py

服务启动后将显示:

正在加载模型...
模型加载完成
INFO:     Started server process [PID]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:6006

3.3 服务配置说明

关键配置参数(在 load_model 函数中):

def load_model():
    global model, tokenizer
    model_name = "./baidu/ERNIE-4.5-21B-A3B-PT"  # 模型路径
    device = "npu:0"  # NPU设备
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True,
        torch_dtype=torch.float16  # 使用半精度节省显存
    ).to(device)

3.4 API接口说明

健康检查

curl http://localhost:6006/health
{"status":"healthy","model_loaded":true,"queue_size":0,"active_streams":0}

模型列表

curl http://localhost:6006/v1/models
{"object":"list","data":[{"id":"ernie-4.5","object":"model","created":1751462626,"owned_by":"baidu"}]}

Chat请求(非流式)

curl -X POST http://localhost:6006/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "ernie-4.5",
    "messages": [
      {"role": "user", "content": "你好,请介绍一下你自己"}
    ],
    "max_tokens": 512,
    "temperature": 0.7
  }'

Chat请求(流式)

curl -X POST http://localhost:6006/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "ernie-4.5",
    "messages": [
      {"role": "user", "content": "请详细介绍Python编程语言"}
    ],
    "max_tokens": 512,
    "temperature": 0.7,
    "stream": true
  }'

四、功能测试

4.1 基础功能测试

使用 test.py 进行全面测试:

import requests
import json
import time
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor

# 服务器配置
BASE_URL = "http://localhost:6006"

def test_health_check():
    """测试健康检查"""
    print("=== 健康检查测试 ===")
    response = requests.get(f"{BASE_URL}/health")
    print(f"状态码: {response.status_code}")
    print(f"响应: {response.json()}")
    print()

def test_list_models():
    """测试模型列表"""
    print("=== 模型列表测试 ===")
    response = requests.get(f"{BASE_URL}/v1/models")
    print(f"状态码: {response.status_code}")
    print(f"响应: {json.dumps(response.json(), indent=2, ensure_ascii=False)}")
    print()

def test_single_chat():
    """测试单次聊天"""
    print("=== 单次聊天测试 ===")
    
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": "你好,请介绍一下你自己"}
        ],
        "max_tokens": 512,
        "temperature": 0.7
    }
    
    start_time = time.time()
    response = requests.post(f"{BASE_URL}/v1/chat/completions", json=payload)
    end_time = time.time()
    
    print(f"状态码: {response.status_code}")
    print(f"响应时间: {end_time - start_time:.2f}秒")
    
    if response.status_code == 200:
        result = response.json()
        print(f"响应: {json.dumps(result, indent=2, ensure_ascii=False)}")
    else:
        print(f"错误: {response.text}")
    print()

def single_request(session, request_id):
    """单个请求函数"""
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": f"请用一句话介绍Python编程语言(请求{request_id})"}
        ],
        "max_tokens": 256,
        "temperature": 0.7
    }
    
    start_time = time.time()
    response = session.post(f"{BASE_URL}/v1/chat/completions", json=payload)
    end_time = time.time()
    
    return {
        "request_id": request_id,
        "status_code": response.status_code,
        "response_time": end_time - start_time,
        "success": response.status_code == 200
    }

def test_concurrent_requests():
    """测试并发请求"""
    print("=== 并发请求测试 ===")
    
    num_requests = 5
    
    with requests.Session() as session:
        with ThreadPoolExecutor(max_workers=num_requests) as executor:
            start_time = time.time()
            futures = [executor.submit(single_request, session, i+1) for i in range(num_requests)]
            results = [future.result() for future in futures]
            end_time = time.time()
    
    print(f"总耗时: {end_time - start_time:.2f}秒")
    print(f"成功请求: {sum(1 for r in results if r['success'])}/{num_requests}")
    
    for result in results:
        print(f"请求{result['request_id']}: 状态码={result['status_code']}, 耗时={result['response_time']:.2f}秒")
    print()

async def async_single_request(session, request_id):
    """异步单个请求"""
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": f"请简单介绍一下人工智能(异步请求{request_id})"}
        ],
        "max_tokens": 256,
        "temperature": 0.7
    }
    
    start_time = time.time()
    async with session.post(f"{BASE_URL}/v1/chat/completions", json=payload) as response:
        end_time = time.time()
        return {
            "request_id": request_id,
            "status_code": response.status,
            "response_time": end_time - start_time,
            "success": response.status == 200
        }

async def test_async_concurrent_requests():
    """测试异步并发请求"""
    print("=== 异步并发请求测试 ===")
    
    num_requests = 3
    
    async with aiohttp.ClientSession() as session:
        start_time = time.time()
        tasks = [async_single_request(session, i+1) for i in range(num_requests)]
        results = await asyncio.gather(*tasks)
        end_time = time.time()
    
    print(f"总耗时: {end_time - start_time:.2f}秒")
    print(f"成功请求: {sum(1 for r in results if r['success'])}/{num_requests}")
    
    for result in results:
        print(f"异步请求{result['request_id']}: 状态码={result['status_code']}, 耗时={result['response_time']:.2f}秒")
    print()

def test_openai_compatibility():
    """测试OpenAI兼容性"""
    print("=== OpenAI兼容性测试 ===")
    
    # 模拟OpenAI客户端调用
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer dummy-key"  # 可选,当前实现不验证
    }
    
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "system", "content": "你是一个有用的AI助手。"},
            {"role": "user", "content": "解释一下什么是机器学习?"}
        ],
        "max_tokens": 512,
        "temperature": 0.8,
        "top_p": 0.9
    }
    
    response = requests.post(f"{BASE_URL}/v1/chat/completions", json=payload, headers=headers)
    
    print(f"状态码: {response.status_code}")
    if response.status_code == 200:
        result = response.json()
        print("OpenAI格式响应结构验证:")
        print(f"- id: {result.get('id')}")
        print(f"- object: {result.get('object')}")
        print(f"- model: {result.get('model')}")
        print(f"- choices数量: {len(result.get('choices', []))}")
        print(f"- usage: {result.get('usage')}")
        print(f"- 生成内容: {result['choices'][0]['message']['content'][:100]}...")
    else:
        print(f"错误: {response.text}")
    print()

def main():
    """主测试函数"""
    print("开始测试ERNIE API服务...")
    print("请确保服务已在 http://localhost:6006 启动\n")
    
    # 等待服务启动
    print("等待服务启动...")
    for i in range(30):
        try:
            response = requests.get(f"{BASE_URL}/health", timeout=2)
            if response.status_code == 200:
                print("服务已启动\n")
                break
        except:
            time.sleep(2)
    else:
        print("服务启动超时,请检查服务是否正常运行")
        return
    
    # 运行测试
    test_health_check()
    test_list_models()
    test_single_chat()
    test_concurrent_requests()
    
    # 运行异步测试
    asyncio.run(test_async_concurrent_requests())
    
    test_openai_compatibility()
    
    print("所有测试完成!")

if __name__ == "__main__":
    main()

测试内容包括:

  • 健康检查测试
  • 模型列表测试
  • 单次聊天测试
  • 并发请求测试
  • 异步并发测试
  • OpenAI兼容性测试

4.2 流式功能测试

使用 test_stream.py 测试流式输出:

import requests
import json
import time
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor

# 服务器配置
BASE_URL = "http://localhost:6006"

def test_stream_chat():
    """测试流式聊天"""
    print("=== 流式聊天测试 ===")
    
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": "请详细介绍一下Python编程语言的特点和应用领域"}
        ],
        "max_tokens": 512,
        "temperature": 0.7,
        "stream": True
    }
    
    start_time = time.time()
    
    with requests.post(f"{BASE_URL}/v1/chat/completions", json=payload, stream=True) as response:
        print(f"状态码: {response.status_code}")
        
        if response.status_code == 200:
            print("流式响应内容:")
            full_content = ""
            
            for line in response.iter_lines():
                if line:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        data = line[6:]  # 移除 'data: ' 前缀
                        
                        if data == '[DONE]':
                            print("\n[流式响应完成]")
                            break
                        
                        try:
                            chunk = json.loads(data)
                            if 'choices' in chunk and len(chunk['choices']) > 0:
                                delta = chunk['choices'][0].get('delta', {})
                                content = delta.get('content', '')
                                if content:
                                    print(content, end='', flush=True)
                                    full_content += content
                        except json.JSONDecodeError:
                            continue
            
            end_time = time.time()
            print(f"\n\n总耗时: {end_time - start_time:.2f}秒")
            print(f"生成内容长度: {len(full_content)}字符")
        else:
            print(f"错误: {response.text}")
    print()

def test_non_stream_chat():
    """测试非流式聊天对比"""
    print("=== 非流式聊天测试(对比) ===")
    
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": "请简单介绍一下人工智能"}
        ],
        "max_tokens": 256,
        "temperature": 0.7,
        "stream": False
    }
    
    start_time = time.time()
    response = requests.post(f"{BASE_URL}/v1/chat/completions", json=payload)
    end_time = time.time()
    
    print(f"状态码: {response.status_code}")
    print(f"响应时间: {end_time - start_time:.2f}秒")
    
    if response.status_code == 200:
        result = response.json()
        content = result['choices'][0]['message']['content']
        print(f"生成内容: {content}")
        print(f"Token使用: {result['usage']}")
    else:
        print(f"错误: {response.text}")
    print()

async def test_async_stream():
    """测试异步流式请求"""
    print("=== 异步流式测试 ===")
    
    payload = {
        "model": "ernie-4.5",
        "messages": [
            {"role": "user", "content": "请用一段话描述机器学习的基本概念"}
        ],
        "max_tokens": 300,
        "temperature": 0.8,
        "stream": True
    }
    
    async with aiohttp.ClientSession() as session:
        start_time = time.time()
        async with session.post(f"{BASE_URL}/v1/chat/completions", json=payload) as response:
            print(f"状态码: {response.status}")
            
            if response.status == 200:
                print("异步流式响应:")
                full_content = ""
                
                async for line in response.content:
                    line = line.decode('utf-8').strip()
                    if line.startswith('data: '):
                        data = line[6:]
                        
                        if data == '[DONE]':
                            print("\n[异步流式响应完成]")
                            break
                        
                        try:
                            chunk = json.loads(data)
                            if 'choices' in chunk and len(chunk['choices']) > 0:
                                delta = chunk['choices'][0].get('delta', {})
                                content = delta.get('content', '')
                                if content:
                                    print(content, end='', flush=True)
                                    full_content += content
                        except json.JSONDecodeError:
                            continue
                
                end_time = time.time()
                print(f"\n\n异步总耗时: {end_time - start_time:.2f}秒")
                print(f"生成内容长度: {len(full_content)}字符")
            else:
                print(f"错误: {await response.text()}")
    print()

def test_concurrent_streams():
    """测试并发流式请求"""
    print("=== 并发流式测试 ===")
    
    def single_stream_request(request_id):
        payload = {
            "model": "ernie-4.5",
            "messages": [
                {"role": "user", "content": f"请简单介绍编程语言的发展历史(请求{request_id})"}
            ],
            "max_tokens": 200,
            "temperature": 0.7,
            "stream": True
        }
        
        start_time = time.time()
        try:
            with requests.post(f"{BASE_URL}/v1/chat/completions", json=payload, stream=True, timeout=60) as response:
                if response.status_code == 200:
                    content_length = 0
                    for line in response.iter_lines():
                        if line:
                            line = line.decode('utf-8')
                            if line.startswith('data: '):
                                data = line[6:]
                                if data == '[DONE]':
                                    break
                                try:
                                    chunk = json.loads(data)
                                    if 'choices' in chunk and len(chunk['choices']) > 0:
                                        delta = chunk['choices'][0].get('delta', {})
                                        content = delta.get('content', '')
                                        content_length += len(content)
                                except json.JSONDecodeError:
                                    continue
                    
                    end_time = time.time()
                    return {
                        "request_id": request_id,
                        "success": True,
                        "response_time": end_time - start_time,
                        "content_length": content_length
                    }
                else:
                    return {
                        "request_id": request_id,
                        "success": False,
                        "error": response.text
                    }
        except Exception as e:
            return {
                "request_id": request_id,
                "success": False,
                "error": str(e)
            }
    
    num_requests = 3
    
    with ThreadPoolExecutor(max_workers=num_requests) as executor:
        start_time = time.time()
        futures = [executor.submit(single_stream_request, i+1) for i in range(num_requests)]
        results = [future.result() for future in futures]
        end_time = time.time()
    
    print(f"并发流式请求总耗时: {end_time - start_time:.2f}秒")
    print(f"成功请求: {sum(1 for r in results if r['success'])}/{num_requests}")
    
    for result in results:
        if result['success']:
            print(f"流式请求{result['request_id']}: 耗时={result['response_time']:.2f}秒, 内容长度={result['content_length']}字符")
        else:
            print(f"流式请求{result['request_id']}: 失败 - {result['error']}")
    print()

def main():
    """主测试函数"""
    print("开始测试ERNIE API流式服务...")
    print("请确保服务已在 http://localhost:6006 启动\n")
    
    # 等待服务启动
    print("等待服务启动...")
    for i in range(30):
        try:
            response = requests.get(f"{BASE_URL}/health", timeout=2)
            if response.status_code == 200:
                print("服务已启动\n")
                break
        except:
            time.sleep(2)
    else:
        print("服务启动超时,请检查服务是否正常运行")
        return
    
    # 运行测试
    test_stream_chat()
    test_non_stream_chat()
    
    # 运行异步测试
    asyncio.run(test_async_stream())
    
    test_concurrent_streams()
    
    print("所有流式测试完成!")

if __name__ == "__main__":
    main()

测试内容包括:

  • 流式聊天测试
  • 非流式对比测试
  • 异步流式测试
  • 并发流式测试

4.3 性能基准测试

# 测试单次推理延迟
time curl -X POST http://localhost:6006/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "ernie-4.5", "messages": [{"role": "user", "content": "你好"}], "max_tokens": 100}'

五、性能调优建议

  1. 模型量化 : 考虑使用INT4/8量化减少内存占用
  2. 批处理优化 : 调整batch size平衡延迟和吞吐量
  3. 缓存策略 : 实现KV缓存优化重复推理
  4. 多卡部署 : 使用模型并行提升处理能力