第11章 数据库与数据管理MCP应用

50 阅读8分钟

第11章 数据库与数据管理MCP应用

前言

数据是现代企业的生命血液。本章展示如何通过MCP构建统一的数据访问层,让LLM能够安全、高效地与各类数据库交互,赋能数据分析、报表生成、决策支持等场景。


11.1 案例1:多数据库统一查询接口

11.1.1 应用场景

graph TB
    A["企业数据环境"] --> B["PostgreSQL"]
    A --> C["MySQL"]
    A --> D["MongoDB"]
    A --> E["Redis"]
    A --> F["Elasticsearch"]
    
    G["传统方案"] --> G1["不同驱动程序"]
    G --> G2["重复编码"]
    G --> G3["难以维护"]
    
    H["MCP方案"] --> H1["统一接口"]
    H --> H2["一致性查询"]
    H --> H3["权限管理"]
    H --> H4["审计日志"]
    
    I["Claude访问"] --> H
    H --> B
    H --> C
    H --> D
    H --> E
    H --> F

典型场景

  • 企业有多个数据库系统(关系型、文档型、缓存、搜索)
  • 不同部门维护不同数据库
  • 需要统一的数据查询和分析
  • 不能暴露数据库细节给最终用户

11.1.2 架构设计

from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import json

@dataclass
class QueryResult:
    """查询结果"""
    database: str
    query: str
    rows: List[Dict]
    execution_time_ms: float
    row_count: int
    timestamp: datetime


class DatabaseAdapter(ABC):
    """数据库适配器基类"""
    
    def __init__(self, config: Dict):
        self.config = config
        self.connection = None
    
    @abstractmethod
    async def connect(self) -> bool:
        """建立连接"""
        pass
    
    @abstractmethod
    async def disconnect(self):
        """断开连接"""
        pass
    
    @abstractmethod
    async def execute_query(self, sql: str, parameters: Dict = None) -> QueryResult:
        """执行查询"""
        pass
    
    @abstractmethod
    async def get_schema(self, table_name: str = None) -> Dict:
        """获取数据库结构信息"""
        pass


class PostgreSQLAdapter(DatabaseAdapter):
    """PostgreSQL适配器"""
    
    async def connect(self) -> bool:
        """建立PostgreSQL连接"""
        import asyncpg
        
        try:
            self.connection = await asyncpg.connect(
                user=self.config['user'],
                password=self.config['password'],
                database=self.config['database'],
                host=self.config['host'],
                port=self.config.get('port', 5432)
            )
            logger.info("Connected to PostgreSQL")
            return True
        except Exception as e:
            logger.error(f"PostgreSQL connection failed: {e}")
            return False
    
    async def disconnect(self):
        """断开连接"""
        if self.connection:
            await self.connection.close()
    
    async def execute_query(self, sql: str, parameters: Dict = None) -> QueryResult:
        """执行SQL查询"""
        import time
        
        # SQL验证(防止危险操作)
        if not self._validate_sql(sql):
            raise ValueError("SQL validation failed: query not allowed")
        
        start = time.time()
        
        try:
            # 执行查询
            if parameters:
                rows = await self.connection.fetch(sql, *parameters.values())
            else:
                rows = await self.connection.fetch(sql)
            
            # 转换结果
            result_rows = [dict(row) for row in rows]
            
            execution_time = (time.time() - start) * 1000
            
            return QueryResult(
                database="PostgreSQL",
                query=sql,
                rows=result_rows,
                execution_time_ms=execution_time,
                row_count=len(result_rows),
                timestamp=datetime.now()
            )
        
        except Exception as e:
            logger.error(f"Query execution failed: {e}")
            raise
    
    def _validate_sql(self, sql: str) -> bool:
        """SQL安全性验证"""
        # 禁止DROP、DELETE等危险操作
        dangerous_keywords = ['DROP', 'DELETE', 'TRUNCATE', 'ALTER']
        sql_upper = sql.upper().strip()
        
        for keyword in dangerous_keywords:
            if sql_upper.startswith(keyword):
                return False
        
        return True
    
    async def get_schema(self, table_name: str = None) -> Dict:
        """获取表结构"""
        if table_name:
            query = f"""
            SELECT column_name, data_type, is_nullable
            FROM information_schema.columns
            WHERE table_name = '{table_name}'
            """
        else:
            query = """
            SELECT table_name
            FROM information_schema.tables
            WHERE table_schema = 'public'
            """
        
        rows = await self.connection.fetch(query)
        return [dict(row) for row in rows]


class MongoDBAdapter(DatabaseAdapter):
    """MongoDB适配器"""
    
    async def connect(self) -> bool:
        """建立MongoDB连接"""
        from pymongo import MongoClient
        from pymongo.errors import ConnectionFailure
        
        try:
            connection_string = (
                f"mongodb://{self.config['user']}:{self.config['password']}"
                f"@{self.config['host']}:{self.config.get('port', 27017)}"
                f"/{self.config['database']}"
            )
            
            self.connection = MongoClient(connection_string)
            self.connection.admin.command('ping')  # 测试连接
            
            logger.info("Connected to MongoDB")
            return True
        
        except ConnectionFailure as e:
            logger.error(f"MongoDB connection failed: {e}")
            return False
    
    async def execute_query(self, query: Dict, collection: str = None) -> QueryResult:
        """执行MongoDB查询"""
        import time
        
        if not collection:
            raise ValueError("Collection name required for MongoDB")
        
        start = time.time()
        
        try:
            db = self.connection[self.config['database']]
            collection_obj = db[collection]
            
            # 执行查询
            if isinstance(query, dict) and '$find' in query:
                # MongoDB find查询
                cursor = collection_obj.find(query['$find'])
                rows = list(cursor)
            else:
                rows = []
            
            # 转换ObjectId为字符串
            for row in rows:
                if '_id' in row:
                    row['_id'] = str(row['_id'])
            
            execution_time = (time.time() - start) * 1000
            
            return QueryResult(
                database="MongoDB",
                query=json.dumps(query),
                rows=rows,
                execution_time_ms=execution_time,
                row_count=len(rows),
                timestamp=datetime.now()
            )
        
        except Exception as e:
            logger.error(f"Query execution failed: {e}")
            raise


class UnifiedDatabaseManager:
    """统一数据库管理器"""
    
    def __init__(self):
        self.adapters: Dict[str, DatabaseAdapter] = {}
        self.query_history: List[Dict] = []
    
    def register_database(self, name: str, adapter_type: str, config: Dict):
        """注册数据库连接"""
        if adapter_type == "postgresql":
            adapter = PostgreSQLAdapter(config)
        elif adapter_type == "mongodb":
            adapter = MongoDBAdapter(config)
        # 其他数据库类型...
        else:
            raise ValueError(f"Unknown adapter type: {adapter_type}")
        
        self.adapters[name] = adapter
        logger.info(f"Registered database: {name}")
    
    async def query(self, database_name: str, query: str, 
                   parameters: Dict = None, collection: str = None) -> QueryResult:
        """执行查询"""
        if database_name not in self.adapters:
            raise ValueError(f"Database not found: {database_name}")
        
        adapter = self.adapters[database_name]
        
        # 记录查询历史
        history_entry = {
            "database": database_name,
            "query": query,
            "timestamp": datetime.now().isoformat(),
            "user": "system"  # 应从认证信息获取
        }
        self.query_history.append(history_entry)
        
        # 执行查询
        result = await adapter.execute_query(query, parameters)
        
        logger.info(
            f"Query executed: {database_name} ({result.execution_time_ms:.0f}ms, "
            f"{result.row_count} rows)"
        )
        
        return result
    
    async def get_databases(self) -> List[str]:
        """获取所有注册的数据库"""
        return list(self.adapters.keys())
    
    async def get_schema(self, database_name: str, 
                        table_name: str = None) -> Dict:
        """获取数据库结构"""
        if database_name not in self.adapters:
            raise ValueError(f"Database not found: {database_name}")
        
        adapter = self.adapters[database_name]
        return await adapter.get_schema(table_name)


# 使用示例
async def setup_unified_database():
    """设置统一数据库管理"""
    manager = UnifiedDatabaseManager()
    
    # 注册PostgreSQL
    manager.register_database(
        "sales_db",
        "postgresql",
        {
            "host": "localhost",
            "port": 5432,
            "user": "admin",
            "password": "secret",
            "database": "sales"
        }
    )
    
    # 注册MongoDB
    manager.register_database(
        "logs_db",
        "mongodb",
        {
            "host": "localhost",
            "port": 27017,
            "user": "admin",
            "password": "secret",
            "database": "logs"
        }
    )
    
    return manager

11.1.3 MCP服务器实现

from mcp.server import Server, Tool, Resource
from mcp.types import TextContent

class DatabaseMCPServer(Server):
    """数据库MCP服务器"""
    
    def __init__(self, database_manager: UnifiedDatabaseManager):
        super().__init__("database-server")
        self.db = database_manager
        self._register_tools()
        self._register_resources()
    
    def _register_tools(self):
        """注册工具"""
        
        @self.tool("query_database")
        async def query_database(
            database: str,
            query: str,
            parameters: dict = None
        ) -> str:
            """
            查询数据库
            
            Args:
                database: 数据库名称
                query: SQL查询语句
                parameters: 查询参数
            
            Returns:
                查询结果JSON
            """
            try:
                result = await self.db.query(database, query, parameters)
                
                return json.dumps({
                    "success": True,
                    "database": result.database,
                    "rows": result.rows[:100],  # 限制返回行数
                    "total_rows": result.row_count,
                    "execution_time_ms": result.execution_time_ms,
                    "timestamp": result.timestamp.isoformat()
                }, ensure_ascii=False)
            
            except Exception as e:
                return json.dumps({
                    "success": False,
                    "error": str(e)
                })
        
        @self.tool("list_databases")
        async def list_databases() -> str:
            """
            列出所有可用的数据库
            
            Returns:
                数据库列表JSON
            """
            databases = await self.db.get_databases()
            return json.dumps({
                "databases": databases,
                "count": len(databases)
            })
        
        @self.tool("get_table_schema")
        async def get_table_schema(database: str, table: str) -> str:
            """
            获取表结构
            
            Args:
                database: 数据库名称
                table: 表名称
            
            Returns:
                表结构信息JSON
            """
            try:
                schema = await self.db.get_schema(database, table)
                return json.dumps({
                    "success": True,
                    "database": database,
                    "table": table,
                    "schema": schema
                })
            
            except Exception as e:
                return json.dumps({
                    "success": False,
                    "error": str(e)
                })
    
    def _register_resources(self):
        """注册资源"""
        
        @self.resource("databases://summary")
        async def database_summary() -> str:
            """
            数据库摘要资源
            
            Returns:
                数据库统计信息
            """
            databases = await self.db.get_databases()
            
            summary = {
                "total_databases": len(databases),
                "databases": databases,
                "query_history_count": len(self.db.query_history),
                "last_query": self.db.query_history[-1] if self.db.query_history else None
            }
            
            return TextContent(
                type="text",
                text=json.dumps(summary, indent=2, ensure_ascii=False)
            )

11.2 案例2:数据分析与可视化MCP

11.2.1 需求分析

graph TB
    A["销售数据源"] --> B["数据提取"]
    B --> C["数据清洗"]
    C --> D["数据分析"]
    
    D --> E["趋势分析"]
    D --> F["对标分析"]
    D --> G["异常检测"]
    
    E --> H["可视化"]
    F --> H
    G --> H
    
    H --> I["报表生成"]
    H --> J["Claude分析"]

11.2.2 分析工具实现

import pandas as pd
import numpy as np
from typing import List, Tuple

class SalesAnalysisTool:
    """销售数据分析工具"""
    
    def __init__(self, data: pd.DataFrame):
        self.data = data
    
    def trend_analysis(self, period: str = "month") -> Dict:
        """
        趋势分析
        
        Args:
            period: 分析周期 (day/week/month/quarter)
        
        Returns:
            趋势分析结果
        """
        # 按周期分组并汇总
        if period == "month":
            grouped = self.data.groupby(pd.Grouper(key='date', freq='M'))
        elif period == "quarter":
            grouped = self.data.groupby(pd.Grouper(key='date', freq='Q'))
        else:
            grouped = self.data.groupby(pd.Grouper(key='date', freq='D'))
        
        sales_by_period = grouped.agg({
            'amount': ['sum', 'mean', 'count'],
            'quantity': 'sum'
        }).round(2)
        
        # 计算环比增长率
        sales_trend = sales_by_period['amount']['sum'].pct_change() * 100
        
        return {
            "period": period,
            "sales_by_period": sales_by_period.to_dict(),
            "growth_rate": sales_trend.to_dict(),
            "highest_period": sales_by_period['amount']['sum'].idxmax(),
            "lowest_period": sales_by_period['amount']['sum'].idxmin()
        }
    
    def product_analysis(self) -> Dict:
        """
        产品分析
        
        Returns:
            产品销售情况
        """
        product_stats = self.data.groupby('product').agg({
            'amount': ['sum', 'mean'],
            'quantity': 'sum',
            'order_id': 'count'
        }).round(2)
        
        # 排序
        top_products = product_stats.nlargest(10, ('amount', 'sum'))
        
        return {
            "total_products": len(product_stats),
            "top_10_products": top_products.to_dict(),
            "product_mix": (product_stats[('amount', 'sum')] / 
                          product_stats[('amount', 'sum')].sum() * 100).to_dict()
        }
    
    def customer_analysis(self) -> Dict:
        """
        客户分析
        
        Returns:
            客户分布和特征
        """
        customer_stats = self.data.groupby('customer_id').agg({
            'amount': ['sum', 'mean', 'count'],
            'date': ['min', 'max']
        }).round(2)
        
        # 客户分层
        quantiles = customer_stats[('amount', 'sum')].quantile([0.8, 0.95])
        
        vip_customers = len(customer_stats[customer_stats[('amount', 'sum')] > quantiles[0.95]])
        high_value_customers = len(customer_stats[
            (customer_stats[('amount', 'sum')] > quantiles[0.8]) &
            (customer_stats[('amount', 'sum')] <= quantiles[0.95])
        ])
        
        return {
            "total_customers": len(customer_stats),
            "vip_count": vip_customers,
            "high_value_count": high_value_customers,
            "avg_customer_value": customer_stats[('amount', 'sum')].mean(),
            "customer_lifetime_value": customer_stats[('amount', 'sum')].quantile([0.25, 0.5, 0.75]).to_dict()
        }
    
    def anomaly_detection(self, threshold: float = 2.0) -> List[Dict]:
        """
        异常检测
        
        Args:
            threshold: 标准差阈值
        
        Returns:
            异常订单列表
        """
        # 按产品计算Z-score
        self.data['z_score'] = self.data.groupby('product')['amount'].transform(
            lambda x: np.abs((x - x.mean()) / x.std())
        )
        
        anomalies = self.data[self.data['z_score'] > threshold]
        
        return [
            {
                "order_id": row['order_id'],
                "product": row['product'],
                "amount": row['amount'],
                "z_score": row['z_score'],
                "reason": "unusually high" if row['amount'] > self.data['amount'].mean() else "unusually low"
            }
            for _, row in anomalies.iterrows()
        ]


# 使用示例
async def sales_analysis_example():
    """销售分析示例"""
    
    # 模拟销售数据
    import pandas as pd
    from datetime import datetime, timedelta
    
    dates = pd.date_range('2025-01-01', periods=365, freq='D')
    data = pd.DataFrame({
        'date': dates,
        'order_id': range(1, 366),
        'customer_id': np.random.randint(1, 100, 365),
        'product': np.random.choice(['Product A', 'Product B', 'Product C'], 365),
        'amount': np.random.normal(1000, 300, 365),
        'quantity': np.random.randint(1, 10, 365)
    })
    
    analyzer = SalesAnalysisTool(data)
    
    # 趋势分析
    trend = analyzer.trend_analysis(period="month")
    print("趋势分析:", trend)
    
    # 产品分析
    products = analyzer.product_analysis()
    print("产品分析:", products)
    
    # 客户分析
    customers = analyzer.customer_analysis()
    print("客户分析:", customers)
    
    # 异常检测
    anomalies = analyzer.anomaly_detection()
    print(f"检测到 {len(anomalies)} 个异常")

11.3 生产部署

11.3.1 部署架构

graph TB
    subgraph "客户端层"
        Claude["Claude + MCP"]
    end
    
    subgraph "MCP服务器"
        AuthMiddleware["认证中间件"]
        RateLimiter["限流器"]
        DBServer["数据库MCP服务器"]
        Analyzer["分析引擎"]
    end
    
    subgraph "数据层"
        PG["PostgreSQL"]
        Mongo["MongoDB"]
        Redis["Redis缓存"]
    end
    
    subgraph "监控层"
        Logs["日志系统"]
        Metrics["指标收集"]
    end
    
    Claude --> AuthMiddleware
    AuthMiddleware --> RateLimiter
    RateLimiter --> DBServer
    RateLimiter --> Analyzer
    
    DBServer --> PG
    DBServer --> Mongo
    Analyzer --> PG
    
    PG --> Redis
    
    DBServer --> Logs
    Analyzer --> Metrics

11.3.2 Docker部署

FROM python:3.11-slim

WORKDIR /app

# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt

# 复制应用
COPY . .

# 运行服务器
CMD ["python", "-m", "mcp_server.database_server"]

requirements.txt

asyncpg>=0.29.0
pymongo>=4.6.0
pandas>=2.1.0
pydantic>=2.4.0
mcp>=0.1.0

本章总结

关键点说明
适配器模式统一不同数据库的接口
权限管理SQL验证、查询审计
性能优化缓存、连接池、查询优化
数据分析趋势、异常、对标
实时监控日志、指标、告警
生产部署Docker容器化部署

常见问题

Q1: 如何防止SQL注入? A: 使用参数化查询、SQL验证、白名单检查。不要直接拼接SQL。

Q2: 如何处理大结果集? A: 使用分页、流式传输、聚合计算。不要一次性加载全部数据。

Q3: 跨数据库事务如何处理? A: 推荐使用事件溯源模式或两阶段提交。对于MCP应该单数据库操作。

Q4: 数据如何加密? A: 传输层使用TLS,存储层使用数据库加密,敏感字段字段级加密。

Q5: 如何监控数据库性能? A: 收集慢查询日志、连接数、执行时间等指标,定期分析。


实战要点

✅ 推荐做法

  • 使用适配器模式支持多数据库
  • 实施严格的权限和审计
  • 优化常用查询性能
  • 使用连接池管理连接
  • 定期分析查询日志
  • 建立备份和恢复机制

❌ 避免的做法

  • 不要直接暴露数据库连接
  • 不要跳过权限检查
  • 不要忽视查询性能
  • 不要存储敏感数据日志
  • 不要无限制的数据返回

延伸阅读


下一章预告:第12章将讲述文件系统与文档管理MCP应用——让Claude能够智能地管理企业文档!