第6章 工具(Tools)开发

3 阅读13分钟

第6章 工具(Tools)开发

前言

工具是MCP最常用的功能。本章将深入讲解如何设计和实现高质量的MCP工具,从参数定义到错误处理,再到性能优化。


6.1 工具的定义与规范

6.1.1 工具的完整结构

graph TB
    A["MCP工具完整结构"] --> B["元数据层"]
    A --> C["接口层"]
    A --> D["实现层"]
    
    B --> B1["名称 name"]
    B --> B2["描述 description"]
    B --> B3["分类 category"]
    B --> B4["版本 version"]
    
    C --> C1["输入参数 inputSchema"]
    C --> C2["输出格式 outputSchema"]
    C --> C3["错误类型 errors"]
    
    D --> D1["执行函数 handler"]
    D --> D2["参数验证 validation"]
    D --> D3["错误处理 error_handling"]
    D --> D4["日志记录 logging"]

6.1.2 工具的完整定义

标准工具定义格式

{
  "name": "query_sales_data",
  "description": "查询销售数据。支持按时间范围、产品类别、地区等条件过滤,返回销售总额、订单数、平均值等多维度数据。",
  "category": "data_query",
  "version": "1.0.0",
  "inputSchema": {
    "type": "object",
    "properties": {
      "start_date": {
        "type": "string",
        "format": "date",
        "description": "查询开始日期,格式YYYY-MM-DD。例如2025-01-01"
      },
      "end_date": {
        "type": "string",
        "format": "date",
        "description": "查询结束日期,格式YYYY-MM-DD。例如2025-12-31"
      },
      "category": {
        "type": "string",
        "enum": ["electronics", "clothing", "food"],
        "description": "产品类别,不指定则返回所有类别"
      },
      "min_amount": {
        "type": "number",
        "minimum": 0,
        "description": "订单最小金额过滤"
      }
    },
    "required": ["start_date", "end_date"],
    "additionalProperties": false
  },
  "outputSchema": {
    "type": "object",
    "properties": {
      "total_sales": { "type": "number" },
      "order_count": { "type": "integer" },
      "average_order_value": { "type": "number" },
      "breakdown_by_category": {
        "type": "object",
        "additionalProperties": { "type": "number" }
      }
    }
  }
}

关键字段解释

字段类型说明示例
namestring工具唯一标识符,用蛇形命名query_sales_data
descriptionstring详细描述,包括用途和能力查询销售数据...
categorystring工具分类便于管理data_query
inputSchemaJSON Schema参数定义和约束见上面示例
outputSchemaJSON Schema返回值结构见上面示例

6.1.3 JSON Schema最佳实践

参数验证规则

# Python中的参数定义示例
from pydantic import BaseModel, Field, validator
from typing import Optional, List
from datetime import date

class QuerySalesRequest(BaseModel):
    """销售数据查询请求"""
    start_date: date = Field(..., description="开始日期")
    end_date: date = Field(..., description="结束日期")
    category: Optional[str] = Field(None, description="产品类别")
    min_amount: Optional[float] = Field(None, ge=0, description="最小金额")
    
    @validator('end_date')
    def validate_date_range(cls, v, values):
        """验证日期范围"""
        if 'start_date' in values and v < values['start_date']:
            raise ValueError('end_date must be after start_date')
        return v
    
    @validator('start_date', 'end_date')
    def validate_date_not_future(cls, v):
        """不能查询未来数据"""
        if v > date.today():
            raise ValueError('Cannot query future dates')
        return v

class QuerySalesResponse(BaseModel):
    """销售数据查询响应"""
    total_sales: float
    order_count: int
    average_order_value: float
    breakdown_by_category: dict

6.2 工具开发最佳实践

6.2.1 清晰的工具命名与描述

命名规范

{动词}_{对象}_{修饰词}

示例:
✅ query_sales_data        - 查询销售数据
✅ create_order            - 创建订单
✅ update_customer_profile - 更新客户档案
❌ getData                  - 不清晰
❌ tool_1                   - 不描述性

好的描述示例

# ❌ 不好的描述
"description": "Query data"

# ✅ 好的描述
"description": """查询销售数据库中的订单和销售额信息。

支持的功能:
- 按时间范围过滤(支持天/周/月粒度)
- 按产品类别过滤(电子产品、服装、食品)
- 按金额范围过滤

返回的数据包括:
- 总销售额
- 订单数
- 平均订单价值
- 各类别的明细数据

性能:查询最多支持过去2年的数据,大规模查询可能需要10-30秒。
"""

6.2.2 参数校验与错误处理

完整的参数校验框架 (Python):

from typing import Any, Dict
from datetime import datetime, date
import logging

logger = logging.getLogger(__name__)

class ToolParameterValidator:
    """工具参数校验器"""
    
    @staticmethod
    def validate_date_range(start_date: str, end_date: str, max_days: int = 730) -> tuple:
        """
        校验日期范围
        
        Args:
            start_date: 开始日期字符串 (YYYY-MM-DD)
            end_date: 结束日期字符串 (YYYY-MM-DD)
            max_days: 最大查询天数
            
        Returns:
            (start_date, end_date) 元组
            
        Raises:
            ValueError: 日期格式或范围不合法
        """
        try:
            start = datetime.strptime(start_date, "%Y-%m-%d").date()
            end = datetime.strptime(end_date, "%Y-%m-%d").date()
        except ValueError as e:
            raise ValueError(f"日期格式错误,需要YYYY-MM-DD格式: {e}")
        
        # 检查日期顺序
        if start > end:
            raise ValueError(f"开始日期({start})不能晚于结束日期({end})")
        
        # 检查是否为未来日期
        today = date.today()
        if start > today:
            raise ValueError(f"开始日期不能为未来日期")
        
        # 检查查询范围
        days_diff = (end - start).days
        if days_diff > max_days:
            raise ValueError(f"查询范围不能超过{max_days}天,您查询了{days_diff}天")
        
        return start, end
    
    @staticmethod
    def validate_enum(value: str, allowed_values: list, field_name: str) -> str:
        """校验枚举值"""
        if value not in allowed_values:
            raise ValueError(
                f"{field_name}必须是以下值之一: {', '.join(allowed_values)}, "
                f"得到: {value}"
            )
        return value
    
    @staticmethod
    def validate_numeric_range(value: float, min_val: float = None, 
                              max_val: float = None, field_name: str = "value") -> float:
        """校验数值范围"""
        if min_val is not None and value < min_val:
            raise ValueError(f"{field_name}不能小于{min_val},得到{value}")
        if max_val is not None and value > max_val:
            raise ValueError(f"{field_name}不能大于{max_val},得到{value}")
        return value

错误处理框架

class ToolError(Exception):
    """工具执行错误基类"""
    pass

class ToolParameterError(ToolError):
    """参数错误"""
    pass

class ToolExecutionError(ToolError):
    """执行错误"""
    pass

class ToolTimeoutError(ToolError):
    """超时错误"""
    pass


async def execute_tool_with_error_handling(tool_name: str, args: dict, handler):
    """
    带完善错误处理的工具执行器
    
    Args:
        tool_name: 工具名称
        args: 工具参数
        handler: 工具处理函数
        
    Returns:
        工具执行结果或错误信息
    """
    try:
        logger.info(f"Executing tool: {tool_name} with args: {args}")
        result = await handler(**args)
        logger.info(f"Tool {tool_name} executed successfully")
        return {
            "content": [{"type": "text", "text": str(result)}],
            "is_error": False
        }
    
    except ToolParameterError as e:
        logger.warning(f"Parameter error in {tool_name}: {e}")
        return {
            "content": [{"type": "text", "text": f"参数错误: {str(e)}"}],
            "is_error": True
        }
    
    except ToolTimeoutError as e:
        logger.error(f"Timeout in {tool_name}: {e}")
        return {
            "content": [{"type": "text", "text": f"执行超时: {str(e)}"}],
            "is_error": True
        }
    
    except ToolExecutionError as e:
        logger.error(f"Execution error in {tool_name}: {e}")
        return {
            "content": [{"type": "text", "text": f"执行错误: {str(e)}"}],
            "is_error": True
        }
    
    except Exception as e:
        logger.exception(f"Unexpected error in {tool_name}")
        return {
            "content": [{"type": "text", "text": f"未知错误: {str(e)}"}],
            "is_error": True
        }

6.2.3 工具文档生成

自动生成工具文档

from typing import Callable
import inspect
import json

def generate_tool_doc(handler: Callable, name: str = None) -> dict:
    """
    从函数自动生成工具定义
    
    Args:
        handler: 工具处理函数
        name: 工具名称(默认使用函数名)
        
    Returns:
        工具定义字典
    """
    name = name or handler.__name__
    doc = inspect.getdoc(handler) or ""
    sig = inspect.signature(handler)
    
    # 构建inputSchema
    properties = {}
    required = []
    
    for param_name, param in sig.parameters.items():
        if param_name == 'self':
            continue
        
        param_type = param.annotation if param.annotation != inspect.Parameter.empty else str
        
        # 将Python类型转换为JSON Schema类型
        json_type_map = {
            str: "string",
            int: "integer",
            float: "number",
            bool: "boolean",
            list: "array",
            dict: "object"
        }
        
        properties[param_name] = {
            "type": json_type_map.get(param_type, "string"),
            "description": f"参数: {param_name}"
        }
        
        # 必需参数(没有默认值)
        if param.default == inspect.Parameter.empty:
            required.append(param_name)
    
    return {
        "name": name,
        "description": doc.split('\n')[0] if doc else "",
        "inputSchema": {
            "type": "object",
            "properties": properties,
            "required": required
        }
    }


# 使用示例
def query_balance(account_id: str, currency: str = "USD") -> dict:
    """
    查询账户余额
    
    Args:
        account_id: 账户ID
        currency: 货币类型
        
    Returns:
        包含余额的字典
    """
    pass

# 自动生成文档
tool_doc = generate_tool_doc(query_balance)
print(json.dumps(tool_doc, indent=2))
# 输出:
# {
#   "name": "query_balance",
#   "description": "查询账户余额",
#   "inputSchema": {
#     "type": "object",
#     "properties": {
#       "account_id": {"type": "string", "description": "参数: account_id"},
#       "currency": {"type": "string", "description": "参数: currency"}
#     },
#     "required": ["account_id"]
#   }
# }

6.3 常见工具类型实现

6.3.1 查询工具

import asyncio
from typing import List, Dict, Any
from datetime import date, timedelta

class QueryTool:
    """查询型工具基类"""
    
    async def query_sales_by_region(self, region: str, start_date: str, 
                                   end_date: str) -> Dict[str, Any]:
        """
        按地区查询销售数据
        
        Args:
            region: 地区代码 (CN, US, EU, etc)
            start_date: 开始日期
            end_date: 结束日期
            
        Returns:
            销售数据
        """
        # 参数验证
        validator = ToolParameterValidator()
        start, end = validator.validate_date_range(start_date, end_date)
        region = validator.validate_enum(region, ["CN", "US", "EU", "JP"], "region")
        
        # 构建查询(这里使用模拟数据)
        query = f"""
            SELECT region, SUM(amount) as total, COUNT(*) as count
            FROM sales
            WHERE region = %s AND date BETWEEN %s AND %s
            GROUP BY region
        """
        
        # 执行数据库查询(模拟)
        result = await self._execute_query(query, (region, start, end))
        
        return {
            "region": region,
            "period": f"{start} to {end}",
            "total_sales": result.get("total", 0),
            "order_count": result.get("count", 0)
        }
    
    async def _execute_query(self, query: str, params: tuple) -> Dict:
        """执行数据库查询(模拟)"""
        # 实际应用中使用真实数据库连接
        await asyncio.sleep(0.1)  # 模拟查询延迟
        return {
            "total": 50000,
            "count": 100
        }

6.3.2 操作工具

class OperationTool:
    """操作型工具(有副作用)"""
    
    async def create_order(self, customer_id: str, items: List[Dict], 
                          shipping_address: str) -> Dict[str, Any]:
        """
        创建订单
        
        Args:
            customer_id: 客户ID
            items: 商品列表 [{"product_id": "...", "quantity": 2}]
            shipping_address: 收货地址
            
        Returns:
            订单信息
        """
        # 参数验证
        if not customer_id:
            raise ToolParameterError("customer_id不能为空")
        
        if not items or len(items) == 0:
            raise ToolParameterError("必须至少包含一个商品")
        
        if not shipping_address:
            raise ToolParameterError("shipping_address不能为空")
        
        try:
            # 验证库存
            for item in items:
                stock = await self._check_stock(item["product_id"], item["quantity"])
                if not stock:
                    raise ToolExecutionError(
                        f"商品{item['product_id']}库存不足,无法创建订单"
                    )
            
            # 创建订单
            order_id = await self._persist_order({
                "customer_id": customer_id,
                "items": items,
                "shipping_address": shipping_address,
                "status": "created"
            })
            
            # 更新库存
            for item in items:
                await self._update_stock(item["product_id"], -item["quantity"])
            
            return {
                "order_id": order_id,
                "status": "created",
                "customer_id": customer_id,
                "total_items": sum(item["quantity"] for item in items)
            }
        
        except Exception as e:
            logger.error(f"Failed to create order: {e}")
            raise ToolExecutionError(f"订单创建失败: {str(e)}")
    
    async def _check_stock(self, product_id: str, quantity: int) -> bool:
        """检查库存"""
        # 模拟库存检查
        return True
    
    async def _persist_order(self, order_data: dict) -> str:
        """持久化订单"""
        # 模拟持久化
        return f"ORD-{id(order_data)}"
    
    async def _update_stock(self, product_id: str, change: int):
        """更新库存"""
        # 模拟库存更新
        pass

6.3.3 计算工具

class ComputeTool:
    """计算型工具"""
    
    async def analyze_sales_trend(self, sales_data: List[Dict]) -> Dict[str, Any]:
        """
        分析销售趋势
        
        Args:
            sales_data: 销售数据列表
            
        Returns:
            分析结果
        """
        if not sales_data:
            raise ToolParameterError("sales_data不能为空")
        
        # 计算统计指标
        amounts = [d["amount"] for d in sales_data]
        
        total = sum(amounts)
        count = len(amounts)
        average = total / count if count > 0 else 0
        
        # 计算趋势
        trend = self._calculate_trend(sales_data)
        
        # 计算增长率
        growth_rate = self._calculate_growth_rate(sales_data)
        
        return {
            "total": total,
            "count": count,
            "average": average,
            "trend": trend,
            "growth_rate": growth_rate,
            "recommendation": self._generate_recommendation(trend, growth_rate)
        }
    
    def _calculate_trend(self, data: List[Dict]) -> str:
        """计算趋势"""
        if len(data) < 2:
            return "insufficient_data"
        
        first_half_avg = sum(d["amount"] for d in data[:len(data)//2]) / (len(data)//2)
        second_half_avg = sum(d["amount"] for d in data[len(data)//2:]) / (len(data) - len(data)//2)
        
        if second_half_avg > first_half_avg * 1.1:
            return "upward"
        elif second_half_avg < first_half_avg * 0.9:
            return "downward"
        else:
            return "stable"
    
    def _calculate_growth_rate(self, data: List[Dict]) -> float:
        """计算增长率"""
        if len(data) < 2:
            return 0.0
        
        first = data[0]["amount"]
        last = data[-1]["amount"]
        
        return ((last - first) / first * 100) if first != 0 else 0
    
    def _generate_recommendation(self, trend: str, growth_rate: float) -> str:
        """生成建议"""
        if trend == "upward" and growth_rate > 20:
            return "销售表现良好,继续执行当前策略"
        elif trend == "downward":
            return "销售下滑,需要调整营销策略"
        else:
            return "销售平稳,建议进行创新尝试"

6.3.4 转换工具

class TransformTool:
    """转换型工具"""
    
    async def export_to_excel(self, data: List[Dict], filename: str) -> Dict[str, str]:
        """
        导出数据到Excel
        
        Args:
            data: 数据列表
            filename: 文件名
            
        Returns:
            导出结果
        """
        import os
        from datetime import datetime
        
        try:
            # 验证数据
            if not data:
                raise ToolParameterError("data不能为空")
            
            # 创建文件路径
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            file_path = f"/tmp/exports/{filename}_{timestamp}.xlsx"
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            
            # 模拟Excel导出
            await self._write_excel(file_path, data)
            
            return {
                "status": "success",
                "filename": os.path.basename(file_path),
                "path": file_path,
                "records": len(data)
            }
        
        except Exception as e:
            raise ToolExecutionError(f"Excel导出失败: {str(e)}")
    
    async def _write_excel(self, path: str, data: List[Dict]):
        """写入Excel文件"""
        # 实际应用中使用openpyxl或pandas
        pass

6.4 工具性能优化

6.4.1 缓存策略

import asyncio
from functools import wraps
from datetime import datetime, timedelta
from typing import Any, Callable

class CachedResult:
    """缓存结果"""
    def __init__(self, data: Any, ttl: int = 300):
        self.data = data
        self.created_at = datetime.now()
        self.ttl = ttl  # Time to live in seconds
    
    def is_expired(self) -> bool:
        """检查是否过期"""
        return (datetime.now() - self.created_at).total_seconds() > self.ttl


class ToolCacheManager:
    """工具缓存管理器"""
    
    def __init__(self, max_size: int = 100):
        self.cache: Dict[str, CachedResult] = {}
        self.max_size = max_size
    
    def _get_cache_key(self, tool_name: str, args: dict) -> str:
        """生成缓存键"""
        import hashlib
        import json
        
        key_str = f"{tool_name}:{json.dumps(args, sort_keys=True)}"
        return hashlib.md5(key_str.encode()).hexdigest()
    
    def get(self, tool_name: str, args: dict) -> Any:
        """获取缓存"""
        key = self._get_cache_key(tool_name, args)
        
        if key in self.cache:
            result = self.cache[key]
            if not result.is_expired():
                logger.info(f"Cache hit for {tool_name}")
                return result.data
            else:
                del self.cache[key]
        
        return None
    
    def set(self, tool_name: str, args: dict, data: Any, ttl: int = 300):
        """设置缓存"""
        # 简单的LRU策略
        if len(self.cache) >= self.max_size:
            # 删除最旧的缓存
            oldest_key = min(self.cache.keys(), 
                           key=lambda k: self.cache[k].created_at)
            del self.cache[oldest_key]
        
        key = self._get_cache_key(tool_name, args)
        self.cache[key] = CachedResult(data, ttl)
        logger.info(f"Cached result for {tool_name} with ttl {ttl}s")


def cached_tool(ttl: int = 300):
    """缓存装饰器"""
    cache_manager = ToolCacheManager()
    
    def decorator(func: Callable):
        @wraps(func)
        async def wrapper(self, *args, **kwargs):
            # 尝试获取缓存
            cached_result = cache_manager.get(func.__name__, kwargs)
            if cached_result is not None:
                return cached_result
            
            # 执行函数
            result = await func(self, *args, **kwargs)
            
            # 保存到缓存
            cache_manager.set(func.__name__, kwargs, result, ttl)
            
            return result
        
        return wrapper
    
    return decorator


# 使用示例
class OptimizedQueryTool:
    
    @cached_tool(ttl=3600)  # 缓存1小时
    async def get_annual_report(self, year: int) -> Dict:
        """获取年度报告(缓存)"""
        logger.info(f"Generating annual report for {year}")
        # 模拟耗时操作
        await asyncio.sleep(2)
        return {"year": year, "data": "..."}

6.4.2 异步处理

class AsyncToolExecutor:
    """异步工具执行器"""
    
    def __init__(self, max_concurrent: int = 10):
        self.max_concurrent = max_concurrent
        self.semaphore = asyncio.Semaphore(max_concurrent)
    
    async def execute_batch_tools(self, tools: List[Dict]) -> List[Dict]:
        """
        批量执行工具
        
        Args:
            tools: 工具列表,每个包含 {"name": "...", "args": {...}}
            
        Returns:
            执行结果列表
        """
        tasks = []
        
        for tool in tools:
            task = self._execute_single_tool(tool)
            tasks.append(task)
        
        # 并发执行
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        return [
            {"tool": tool["name"], "result": result}
            for tool, result in zip(tools, results)
        ]
    
    async def _execute_single_tool(self, tool: Dict):
        """执行单个工具"""
        async with self.semaphore:
            try:
                logger.info(f"Executing tool: {tool['name']}")
                # 这里调用实际的工具处理函数
                result = await self._call_tool_handler(tool)
                return result
            except Exception as e:
                logger.error(f"Tool {tool['name']} failed: {e}")
                return {"error": str(e)}
    
    async def _call_tool_handler(self, tool: Dict):
        """调用工具处理器"""
        # 模拟工具执行
        await asyncio.sleep(0.5)
        return {"status": "success"}

6.4.3 超时管理

class TimeoutManager:
    """超时管理器"""
    
    @staticmethod
    async def execute_with_timeout(coro, timeout: int = 30):
        """
        执行协程并设置超时
        
        Args:
            coro: 要执行的协程
            timeout: 超时时间(秒)
            
        Returns:
            执行结果
            
        Raises:
            ToolTimeoutError: 执行超时
        """
        try:
            result = await asyncio.wait_for(coro, timeout=timeout)
            return result
        except asyncio.TimeoutError:
            logger.error(f"Tool execution timeout after {timeout}s")
            raise ToolTimeoutError(
                f"工具执行超时({timeout}秒)。请尝试缩小查询范围或稍后重试。"
            )


# 使用示例
async def execute_tool_with_timeout(tool_name: str, handler, args: dict, 
                                    timeout: int = 30) -> Dict:
    """执行带超时的工具"""
    try:
        result = await TimeoutManager.execute_with_timeout(
            handler(**args),
            timeout=timeout
        )
        return {
            "tool": tool_name,
            "status": "success",
            "result": result
        }
    except ToolTimeoutError as e:
        return {
            "tool": tool_name,
            "status": "timeout",
            "error": str(e)
        }

本章总结

核心概念关键点
工具定义完整的元数据、接口、实现三层结构
命名规范动词_对象_修饰词,清晰有意义
参数验证JSON Schema + Pydantic的多层验证
错误处理自定义异常类、统一的错误处理框架
工具分类查询、操作、计算、转换四种类型
文档生成自动从函数生成工具定义
缓存策略LRU缓存 + TTL机制
异步处理信号量控制并发、批量执行
超时管理asyncio.wait_for保证可靠性

常见问题

Q1: 工具是否可以调用其他工具? A: 可以。工具可以依赖其他工具的结果,但要避免形成循环依赖。

Q2: 如何处理长时间运行的工具? A: 使用异步处理、设置合理超时、可选的分页结果、后台任务等方法。

Q3: 工具能否修改全局状态? A: 应避免。工具应该是幂等的,多次调用相同参数应返回相同结果。

Q4: 如何测试工具的正确性? A: 编写单元测试,测试正常情况、边界情况、错误情况,使用mock数据。

Q5: 生产环境中如何监控工具性能? A: 添加详细日志、记录执行时间、使用APM工具、设置性能告警。


实战要点

✅ 推荐做法

  • 为每个工具编写完整的单元测试
  • 提供详细清晰的工具描述和参数说明
  • 实现完善的参数验证和错误处理
  • 添加日志记录便于调试
  • 考虑缓存以提升性能
  • 设置合理的超时时间

❌ 避免的做法

  • 不要忽视参数验证
  • 不要让工具修改全局状态
  • 不要在工具中进行复杂的业务逻辑
  • 不要忘记错误处理
  • 不要设置过短的超时时间
  • 不要跳过工具的文档

延伸阅读


下一章预告:第7章将讲述如何开发MCP资源(Resources)——包括资源设计、生命周期管理、版本控制等内容。