LangChain自定义记忆存储开发深度解析(39)

185 阅读33分钟

LangChain自定义记忆存储开发深度解析

一、LangChain记忆存储概述

1.1 记忆存储的核心作用

在LangChain框架中,记忆存储是实现对话连续性、上下文感知和历史信息复用的关键模块。无论是聊天机器人、智能客服还是任务型对话系统,记忆存储负责保存用户与AI交互的历史记录,包括对话内容、任务状态、用户偏好等信息。这些历史数据能够辅助语言模型理解当前语境,生成更符合上下文的回复,提升用户体验。从源码层面看,记忆存储模块通过标准化接口与LangChain其他组件(如链、代理、提示模板)交互,为整个系统提供上下文感知能力。

1.2 原生记忆存储的局限性

LangChain提供了多种原生记忆存储实现,如ConversationBufferMemory(存储完整对话历史)、ConversationSummaryMemory(存储对话摘要)等。然而,这些原生实现存在一定局限性:

  1. 存储结构固定:仅支持预设的数据结构,难以满足复杂业务场景需求
  2. 扩展性不足:无法方便地对接自定义数据库或存储系统
  3. 功能单一:缺乏数据清洗、版本管理、安全加密等高级功能
  4. 性能瓶颈:在处理大规模对话数据时可能出现效率问题

这些局限性促使开发者进行自定义记忆存储开发,以满足特定应用场景的需求。

二、自定义记忆存储的设计原则

2.1 接口兼容性

自定义记忆存储必须遵循LangChain的记忆接口规范,确保与现有组件无缝集成。核心接口定义如下:

from langchain.memory import BaseMemory

class CustomMemory(BaseMemory, ABC):
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """从存储中加载记忆变量"""
        raise NotImplementedError()

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """保存对话上下文"""
        raise NotImplementedError()

    @property
    def memory_variables(self) -> List[str]:
        """返回记忆变量名称列表"""
        raise NotImplementedError()

通过实现这三个抽象方法,自定义记忆存储可以与ConversationChainAgentExecutor等组件协同工作。

2.2 数据结构设计

合理的数据结构设计是自定义记忆存储的关键。常见设计模式包括:

  1. 文档型存储:使用JSON格式存储对话记录,便于扩展和查询
  2. 关系型存储:通过表结构管理对话、用户、任务等实体关系
  3. 时序型存储:按时间顺序存储对话片段,适合分析历史趋势
  4. 图结构存储:使用图数据库表示对话中的实体关系

2.3 性能优化考虑

在设计阶段需要考虑性能优化:

  1. 索引设计:为常用查询字段建立索引
  2. 缓存机制:添加本地缓存减少数据库访问
  3. 批量操作:支持批量读写提升效率
  4. 异步处理:采用异步IO处理高并发请求

2.4 安全与合规性

自定义存储需满足安全和合规要求:

  1. 数据加密:对敏感信息进行加密存储
  2. 访问控制:设置严格的权限管理机制
  3. 合规审计:记录数据操作日志
  4. 隐私保护:遵循GDPR、CCPA等法规要求

三、基础存储类实现

3.1 初始化与配置

自定义记忆存储的初始化方法负责配置存储参数:

class CustomMemory(BaseMemory):
    def __init__(self, 
                 db_uri: str, 
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        self.db_uri = db_uri  # 数据库连接字符串
        self.table_name = table_name  # 存储表名
        self.encryption_key = encryption_key  # 加密密钥
        
        # 初始化数据库连接
        self.engine = create_engine(self.db_uri)
        self.Session = sessionmaker(bind=self.engine)
        
        # 创建表结构(如果不存在)
        self._create_table()

通过参数化配置,支持不同类型的数据库(如SQLite、PostgreSQL、MySQL)。

3.2 记忆变量加载

load_memory_variables方法从存储中读取历史数据:

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        session = self.Session()
        try:
            # 查询最近的N条对话记录
            query = session.query(Conversation).order_by(Conversation.timestamp.desc()).limit(10)
            records = query.all()
            
            # 解析记录为字典格式
            memory = {
                "history": [{"input": r.input_text, "output": r.output_text} for r in records]
            }
            return memory
        finally:
            session.close()

这里使用SQLAlchemy进行数据库操作,通过order_bylimit实现分页查询。

3.3 上下文保存

save_context方法将新的对话记录存入存储:

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        session = self.Session()
        try:
            input_text = inputs.get("input")
            output_text = outputs.get("output")
            
            # 创建对话记录对象
            record = Conversation(
                input_text=input_text,
                output_text=output_text,
                timestamp=datetime.utcnow()
            )
            
            # 加密敏感字段(如果配置了密钥)
            if self.encryption_key:
                record.input_text = encrypt(input_text, self.encryption_key)
                record.output_text = encrypt(output_text, self.encryption_key)
            
            session.add(record)
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

该方法支持数据加密,并通过事务机制保证数据一致性。

3.4 变量名称定义

memory_variables属性定义可访问的记忆变量:

    @property
    def memory_variables(self) -> List[str]:
        return ["history"]

这个属性告知LangChain哪些变量可以在提示模板中使用。

四、高级功能扩展

4.1 数据清洗与压缩

为避免存储无限增长,可添加数据清洗功能:

    def clean_memory(self, max_records: int = 100) -> None:
        session = self.Session()
        try:
            # 查询超出限制的记录数量
            count = session.query(func.count(Conversation.id)).scalar()
            if count > max_records:
                # 删除最早的记录
                delete_query = (
                    session.query(Conversation)
                    .order_by(Conversation.timestamp.asc())
                    .limit(count - max_records)
                    .delete(synchronize_session=False)
                )
                session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

此方法定期清理旧记录,保持存储大小可控。

4.2 版本管理

实现版本管理功能可以跟踪记忆数据的变更:

class ConversationVersion(Base):
    __tablename__ = "conversation_versions"
    id = Column(Integer, primary_key=True)
    conversation_id = Column(Integer, ForeignKey("conversations.id"))
    version_number = Column(Integer)
    input_text = Column(Text)
    output_text = Column(Text)
    timestamp = Column(DateTime)

class CustomMemory(BaseMemory):
    def save_version(self, conversation_id: int) -> None:
        session = self.Session()
        try:
            # 获取当前对话记录
            conversation = session.query(Conversation).filter(Conversation.id == conversation_id).first()
            
            # 创建版本记录
            version = ConversationVersion(
                conversation_id=conversation_id,
                version_number=self._get_next_version_number(conversation_id),
                input_text=conversation.input_text,
                output_text=conversation.output_text,
                timestamp=datetime.utcnow()
            )
            
            session.add(version)
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()
    
    def _get_next_version_number(self, conversation_id: int) -> int:
        session = self.Session()
        try:
            # 查询最大版本号
            result = session.query(func.max(ConversationVersion.version_number)).filter(
                ConversationVersion.conversation_id == conversation_id
            ).scalar()
            return (result or 0) + 1
        finally:
            session.close()

通过版本表记录每次变更,支持数据回滚。

4.3 智能检索

增强记忆检索功能可以提高查询效率:

    def search_memory(self, keyword: str) -> List[Dict[str, Any]]:
        session = self.Session()
        try:
            # 使用全文搜索查找包含关键词的记录
            query = session.query(Conversation).filter(
                Conversation.input_text.contains(keyword) | Conversation.output_text.contains(keyword)
            )
            records = query.all()
            
            return [{"input": r.input_text, "output": r.output_text} for r in records]
        finally:
            session.close()

这里利用数据库的全文检索功能实现高效搜索。

五、与其他组件的集成

5.1 与对话链集成

将自定义记忆存储集成到ConversationChain

from langchain.chains import ConversationChain
from langchain.llms import OpenAI

memory = CustomMemory(db_uri="sqlite:///memory.db")
llm = OpenAI(temperature=0)
conversation = ConversationChain(
    llm=llm, 
    memory=memory,
    verbose=True
)

response = conversation.predict(input="你好")

通过这种方式,对话链可以自动使用自定义存储的历史记录。

5.2 与代理集成

在代理中使用自定义记忆:

from langchain.agents import AgentType, initialize_agent
from langchain.tools import BaseTool

class CustomTool(BaseTool):
    name = "custom_tool"
    description = "自定义工具描述"
    def _run(self, query: str) -> str:
        # 使用记忆存储进行查询
        memory = CustomMemory(db_uri="sqlite:///memory.db")
        results = memory.search_memory(query)
        return str(results)

    async def _arun(self, query: str) -> str:
        raise NotImplementedError("异步方法未实现")

tools = [CustomTool()]
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
agent.run("查询历史记录")

代理可以利用记忆存储实现更智能的决策。

5.3 与提示模板集成

在提示模板中引用记忆变量:

from langchain.prompts import PromptTemplate

prompt = PromptTemplate(
    input_variables=["history", "input"],
    template="""
    历史对话: {history}
    用户输入: {input}
    AI回复:
    """
)

memory = CustomMemory(db_uri="sqlite:///memory.db")
formatted_prompt = prompt.format(history=memory.load_memory_variables({})["history"], input="新问题")

通过这种方式,提示模板可以动态获取历史对话信息。

六、性能优化策略

6.1 缓存机制

添加本地缓存减少数据库访问:

import functools
from typing import Dict

class CachedMemory(CustomMemory):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cache = {}
    
    @functools.lru_cache(maxsize=128)
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        return super().load_memory_variables(inputs)
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        super().save_context(inputs, outputs)
        # 清空相关缓存
        self.cache.clear()

使用functools.lru_cache实现LRU缓存策略。

6.2 批量操作

支持批量写入提高效率:

    def batch_save_context(self, contexts: List[Dict[str, Any]]) -> None:
        session = self.Session()
        try:
            records = []
            for ctx in contexts:
                input_text = ctx.get("input")
                output_text = ctx.get("output")
                record = Conversation(
                    input_text=input_text,
                    output_text=output_text,
                    timestamp=datetime.utcnow()
                )
                records.append(record)
            
            session.add_all(records)
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

该方法一次提交多条记录,减少数据库交互次数。

6.3 异步IO

采用异步方式处理高并发请求:

import asyncio
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession

class AsyncCustomMemory(BaseMemory):
    def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
        self.db_uri = db_uri
        self.table_name = table_name
        self.encryption_key = encryption_key
        
        self.engine = create_async_engine(self.db_uri)
        self.Session = async_sessionmaker(bind=self.engine)
        
        self._create_table()
    
    async def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        session = self.Session()
        try:
            query = session.query(Conversation).order_by(Conversation.timestamp.desc()).limit(10)
            records = await query.all()
            memory = {
                "history": [{"input": r.input_text, "output": r.output_text} for r in records]
            }
            return memory
        finally:
            await session.close()
    
    async def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        session = self.Session()
        try:
            input_text = inputs.get("input")
            output_text = outputs.get("output")
            record = Conversation(
                input_text=input_text,
                output_text=output_text,
                timestamp=datetime.utcnow()
            )
            if self.encryption_key:
                record.input_text = encrypt(input_text, self.encryption_key)
                record.output_text = encrypt(output_text, self.encryption_key)
            session.add(record)
            await session.commit()
        except Exception as e:
            await session.rollback()
            raise e
        finally:
            await session.close()

通过asynciosqlalchemy.ext.asyncio实现异步数据库操作。

七、安全与合规实现

7.1 数据加密

对敏感数据进行加密存储:

from cryptography.fernet import Fernet

def encrypt(text: str, key: str) -> str:
    cipher_suite = Fernet(key)
    encrypted_text = cipher_suite.encrypt(text.encode())
    return encrypted_text.decode()

def decrypt(text: str, key: str) -> str:
    cipher_suite = Fernet(key)
    decrypted_text = cipher_suite.decrypt(text.encode())
    return decrypted_text.decode()

class EncryptedMemory(CustomMemory):
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        data = super().load_memory_variables(inputs)
        if self.encryption_key:
            for item in data["history"]:
                item["input"] = decrypt(item["input"], self.encryption_key)
                item["output"] = decrypt(item["output"], self.encryption_key)
        return data
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        if self.encryption_key:
            inputs["input"] = encrypt(inputs["input"], self.encryption_key)
            outputs["output"] = encrypt(outputs["output"], self.encryption_key)
        super().save_context(inputs, outputs)

使用cryptography库实现对称加密。

7.2 访问控制

实现基于角色的访问控制:

class RoleBasedAccessMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None, role: str = "user"):
        super().__init__(db_uri, table_name, encryption_key)
        self.role = role
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        if self.role == "admin":
            return super().load_memory_variables(inputs)
        else:
            # 仅返回部分数据
            data = super().load_memory_variables(inputs)
            data["history"] = data["history"][-5:]  # 仅返回最近5条
            return data
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        if self.role in ["admin", "editor"]:
            super().save_context(inputs, outputs)
        else:
            raise PermissionError("无写入权限")

根据用户角色限制数据访问和操作。

7.3 审计日志

记录数据操作日志:

class AuditableMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.audit_table_name = "audit_logs"
        self._create_
    def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.audit_table_name = "audit_logs"
        self._create_audit_table()
    
    def _create_audit_table(self):
        # 创建审计日志表
        metadata = MetaData()
        Table(
            self.audit_table_name, metadata,
            Column('id', Integer, primary_key=True),
            Column('user_id', String(50)),
            Column('operation', String(20)),
            Column('timestamp', DateTime),
            Column('details', Text)
        )
        metadata.create_all(self.engine)
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        user_id = inputs.get("user_id", "anonymous")
        self._log_audit(user_id, "load", str(inputs))
        return super().load_memory_variables(inputs)
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        user_id = inputs.get("user_id", "anonymous")
        self._log_audit(user_id, "save", f"inputs: {inputs}, outputs: {outputs}")
        super().save_context(inputs, outputs)
    
    def _log_audit(self, user_id: str, operation: str, details: str) -> None:
        session = self.Session()
        try:
            audit_log = AuditLog(
                user_id=user_id,
                operation=operation,
                timestamp=datetime.utcnow(),
                details=details
            )
            session.add(audit_log)
            session.commit()
        except Exception as e:
            session.rollback()
            logger.error(f"审计日志记录失败: {e}")
        finally:
            session.close()

审计日志表记录每次操作的用户、操作类型、时间戳和详细信息,便于后续追踪和合规检查。

7.4 隐私保护

实现数据匿名化和过期清理:

from datetime import datetime, timedelta

class PrivacyAwareMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None, retention_days: int = 30):
        super().__init__(db_uri, table_name, encryption_key)
        self.retention_days = retention_days
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        # 匿名化处理
        if "user_id" in inputs:
            inputs["user_id"] = self._anonymize_id(inputs["user_id"])
        
        super().save_context(inputs, outputs)
    
    def _anonymize_id(self, user_id: str) -> str:
        # 简单的ID匿名化处理
        hash_object = hashlib.sha256(user_id.encode())
        return hash_object.hexdigest()[:10]
    
    def purge_old_data(self) -> None:
        """清理超过保留期限的数据"""
        cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
        session = self.Session()
        try:
            # 删除过期记录
            delete_query = (
                session.query(Conversation)
                .filter(Conversation.timestamp < cutoff_date)
                .delete(synchronize_session=False)
            )
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

通过哈希处理实现用户ID匿名化,并定期清理过期数据,保护用户隐私。

八、分布式与云存储实现

8.1 基于Redis的分布式缓存

import redis
import json

class RedisMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 redis_host: str = "localhost", 
                 redis_port: int = 6379, 
                 redis_db: int = 0,
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.redis_client = redis.Redis(host=redis_host, port=redis_port, db=redis_db)
        self.cache_prefix = "langchain:memory:"
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        conversation_id = inputs.get("conversation_id", "default")
        cache_key = f"{self.cache_prefix}{conversation_id}"
        
        # 尝试从缓存读取
        cached_data = self.redis_client.get(cache_key)
        if cached_data:
            return json.loads(cached_data)
        
        # 从数据库读取并缓存
        data = super().load_memory_variables(inputs)
        self.redis_client.setex(cache_key, 3600, json.dumps(data))  # 缓存1小时
        return data
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        super().save_context(inputs, outputs)
        
        # 清除缓存
        conversation_id = inputs.get("conversation_id", "default")
        cache_key = f"{self.cache_prefix}{conversation_id}"
        self.redis_client.delete(cache_key)

使用Redis作为分布式缓存层,减少数据库访问压力,提高响应速度。

8.2 基于AWS S3的存储

import boto3
from botocore.exceptions import NoCredentialsError

class S3Memory(CustomMemory):
    def __init__(self, 
                 bucket_name: str, 
                 aws_access_key_id: str, 
                 aws_secret_access_key: str,
                 region_name: str = "us-east-1",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        self.bucket_name = bucket_name
        self.s3_client = boto3.client(
            's3',
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            region_name=region_name
        )
        super().__init__(db_uri="sqlite:///:memory:", table_name=table_name, encryption_key=encryption_key)
    
    def _get_s3_key(self, conversation_id: str) -> str:
        return f"conversations/{conversation_id}.json"
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        conversation_id = inputs.get("conversation_id", "default")
        s3_key = self._get_s3_key(conversation_id)
        
        try:
            response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
            content = response['Body'].read().decode('utf-8')
            return json.loads(content)
        except NoCredentialsError:
            logger.error("AWS凭证错误")
            return {"history": []}
        except self.s3_client.exceptions.NoSuchKey:
            return {"history": []}
        except Exception as e:
            logger.error(f"从S3加载失败: {e}")
            return {"history": []}
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        conversation_id = inputs.get("conversation_id", "default")
        s3_key = self._get_s3_key(conversation_id)
        
        # 先加载现有数据
        current_data = self.load_memory_variables(inputs)
        history = current_data.get("history", [])
        
        # 添加新记录
        new_record = {
            "input": inputs.get("input"),
            "output": outputs.get("output"),
            "timestamp": datetime.utcnow().isoformat()
        }
        history.append(new_record)
        
        # 保存回S3
        data_to_save = {"history": history}
        try:
            self.s3_client.put_object(
                Bucket=self.bucket_name,
                Key=s3_key,
                Body=json.dumps(data_to_save).encode('utf-8')
            )
        except Exception as e:
            logger.error(f"保存到S3失败: {e}")

利用AWS S3的高扩展性存储对话数据,适合大规模分布式应用。

8.3 基于MongoDB的文档存储

from pymongo import MongoClient
from pymongo.errors import ConnectionFailure

class MongoMemory(CustomMemory):
    def __init__(self, 
                 connection_string: str, 
                 database_name: str = "langchain", 
                 collection_name: str = "conversations",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        self.client = MongoClient(connection_string)
        try:
            # 尝试连接测试
            self.client.admin.command('ping')
        except ConnectionFailure:
            logger.error("无法连接到MongoDB服务器")
        
        self.db = self.client[database_name]
        self.collection = self.db[collection_name]
        super().__init__(db_uri="sqlite:///:memory:", table_name=table_name, encryption_key=encryption_key)
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        conversation_id = inputs.get("conversation_id", "default")
        
        # 查询最新的10条记录
        cursor = self.collection.find(
            {"conversation_id": conversation_id},
            {"_id": 0, "input": 1, "output": 1, "timestamp": 1}
        ).sort("timestamp", -1).limit(10)
        
        records = list(cursor)
        records.reverse()  # 按时间顺序排列
        
        return {"history": records}
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        conversation_id = inputs.get("conversation_id", "default")
        
        record = {
            "conversation_id": conversation_id,
            "input": inputs.get("input"),
            "output": outputs.get("output"),
            "timestamp": datetime.utcnow()
        }
        
        self.collection.insert_one(record)

使用MongoDB的文档存储特性,灵活存储非结构化对话数据。

九、高级检索与索引

9.1 向量索引实现

from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

class VectorizedMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 embedding_model: str = "text-embedding-ada-002",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.embeddings = OpenAIEmbeddings(model=embedding_model)
        self.vectorstore = None
        self._initialize_vectorstore()
    
    def _initialize_vectorstore(self):
        # 从数据库加载所有文本并构建向量索引
        session = self.Session()
        try:
            records = session.query(Conversation).all()
            texts = []
            metadatas = []
            
            for record in records:
                texts.append(f"用户: {record.input_text}\nAI: {record.output_text}")
                metadatas.append({
                    "input": record.input_text,
                    "output": record.output_text,
                    "timestamp": record.timestamp
                })
            
            if texts:
                self.vectorstore = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
        finally:
            session.close()
    
    def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
        if not self.vectorstore:
            return []
        
        # 执行相似度搜索
        docs = self.vectorstore.similarity_search(query, k=k)
        return [doc.metadata for doc in docs]
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        super().save_context(inputs, outputs)
        
        # 更新向量索引
        text = f"用户: {inputs.get('input')}\nAI: {outputs.get('output')}"
        metadata = {
            "input": inputs.get("input"),
            "output": outputs.get("output"),
            "timestamp": datetime.utcnow()
        }
        
        if self.vectorstore:
            self.vectorstore.add_texts([text], metadatas=[metadata])
        else:
            self.vectorstore = FAISS.from_texts([text], self.embeddings, metadatas=[metadata])

利用向量索引实现语义相似度搜索,提高相关历史记录的检索精度。

9.2 混合检索策略

class HybridMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 embedding_model: str = "text-embedding-ada-002",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.embeddings = OpenAIEmbeddings(model=embedding_model)
        self.vectorstore = None
        self._initialize_vectorstore()
    
    def _initialize_vectorstore(self):
        # 初始化向量存储
        session = self.Session()
        try:
            records = session.query(Conversation).all()
            texts = []
            metadatas = []
            
            for record in records:
                texts.append(f"用户: {record.input_text}\nAI: {record.output_text}")
                metadatas.append({
                    "id": record.id,
                    "timestamp": record.timestamp
                })
            
            if texts:
                self.vectorstore = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
        finally:
            session.close()
    
    def hybrid_search(self, query: str, k: int = 4, keyword_weight: float = 0.5) -> List[Dict[str, Any]]:
        # 执行混合检索(关键词+向量)
        if not self.vectorstore:
            return []
        
        # 1. 向量检索
        vector_results = self.vectorstore.similarity_search_with_score(query, k=k)
        
        # 2. 关键词检索
        keyword_results = self._keyword_search(query, k=k)
        
        # 3. 结果融合
        merged_results = self._merge_results(vector_results, keyword_results, keyword_weight)
        
        # 4. 获取完整记录
        session = self.Session()
        try:
            full_records = []
            for record_id in merged_results:
                record = session.query(Conversation).filter(Conversation.id == record_id).first()
                if record:
                    full_records.append({
                        "input": record.input_text,
                        "output": record.output_text,
                        "timestamp": record.timestamp
                    })
            return full_records
        finally:
            session.close()
    
    def _keyword_search(self, query: str, k: int = 4) -> List[int]:
        session = self.Session()
        try:
            # 使用SQL LIKE进行关键词搜索
            query = session.query(Conversation).filter(
                Conversation.input_text.ilike(f"%{query}%") | 
                Conversation.output_text.ilike(f"%{query}%")
            ).order_by(Conversation.timestamp.desc()).limit(k)
            
            return [record.id for record in query.all()]
        finally:
            session.close()
    
    def _merge_results(self, vector_results: List[Tuple[Document, float]], keyword_results: List[int], keyword_weight: float) -> List[int]:
        # 融合两种检索结果
        vector_scores = {doc.metadata['id']: score for doc, score in vector_results}
        keyword_scores = {id: (len(keyword_results) - i) for i, id in enumerate(keyword_results)}
        
        # 计算综合得分
        combined_scores = {}
        for id in set(vector_scores.keys()).union(set(keyword_scores.keys())):
            vec_score = vector_scores.get(id, 0)
            key_score = keyword_scores.get(id, 0)
            
            # 归一化处理
            if vec_score > 0:
                vec_score = 1 / (1 + vec_score)  # 将相似度得分转换为0-1之间
            
            combined_score = keyword_weight * key_score + (1 - keyword_weight) * vec_score
            combined_scores[id] = combined_score
        
        # 按得分排序
        return sorted(combined_scores.keys(), key=lambda k: combined_scores[k], reverse=True)

结合关键词检索和向量检索的优势,提供更全面的历史记录检索能力。

9.3 时间感知索引

class TimeAwareMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 time_decay_factor: float = 0.9,
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.time_decay_factor = time_decay_factor
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        current_time = datetime.utcnow()
        session = self.Session()
        try:
            # 查询所有记录
            records = session.query(Conversation).all()
            
            # 按时间衰减计算权重
            weighted_records = []
            for record in records:
                time_diff = (current_time - record.timestamp).total_seconds()
                # 时间衰减计算
                weight = self.time_decay_factor ** (time_diff / 3600)  # 每小时衰减一次
                
                if weight > 0.01:  # 忽略权重过小的记录
                    weighted_records.append((record, weight))
            
            # 按权重排序
            weighted_records.sort(key=lambda x: x[1], reverse=True)
            
            # 构建结果
            memory = {
                "history": [{"input": r.input_text, "output": r.output_text} for r, _ in weighted_records]
            }
            return memory
        finally:
            session.close()

通过时间衰减函数,使近期对话记录具有更高权重,提高上下文相关性。

十、多模态记忆存储

10.1 图像记忆处理

from PIL import Image
import io
import base64

class ImageMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 image_storage_path: str = "./images",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.image_storage_path = image_storage_path
        os.makedirs(image_storage_path, exist_ok=True)
    
    def save_image(self, image_data: bytes, image_name: str = None) -> str:
        """保存图像并返回存储路径"""
        if not image_name:
            image_name = f"image_{uuid.uuid4().hex}.jpg"
        
        image_path = os.path.join(self.image_storage_path, image_name)
        
        try:
            # 保存图像
            with open(image_path, 'wb') as f:
                f.write(image_data)
            
            # 记录图像元数据
            image_metadata = {
                "image_name": image_name,
                "path": image_path,
                "timestamp": datetime.utcnow(),
                "size": len(image_data)
            }
            
            # 保存元数据到数据库
            self._save_image_metadata(image_metadata)
            
            return image_name
        except Exception as e:
            logger.error(f"保存图像失败: {e}")
            return None
    
    def _save_image_metadata(self, metadata: Dict[str, Any]) -> None:
        """保存图像元数据到数据库"""
        session = self.Session()
        try:
            image_record = ImageRecord(
                image_name=metadata["image_name"],
                path=metadata["path"],
                timestamp=metadata["timestamp"],
                size=metadata["size"]
            )
            session.add(image_record)
            session.commit()
        except Exception as e:
            session.rollback()
            logger.error(f"保存图像元数据失败: {e}")
        finally:
            session.close()
    
    def get_image(self, image_name: str) -> bytes:
        """获取图像数据"""
        image_path = os.path.join(self.image_storage_path, image_name)
        
        try:
            with open(image_path, 'rb') as f:
                return f.read()
        except Exception as e:
            logger.error(f"获取图像失败: {e}")
            return None
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """扩展保存上下文,处理图像"""
        # 处理输入中的图像
        if "images" in inputs:
            saved_images = []
            for image_data in inputs["images"]:
                image_name = self.save_image(image_data)
                if image_name:
                    saved_images.append(image_name)
            inputs["saved_images"] = saved_images
        
        # 调用基类方法保存文本上下文
        super().save_context(inputs, outputs)

扩展记忆存储功能,支持图像的存储和检索,保存图像元数据并关联到对话记录。

10.2 音频记忆处理

import librosa
import numpy as np

class AudioMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 audio_storage_path: str = "./audio",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.audio_storage_path = audio_storage_path
        os.makedirs(audio_storage_path, exist_ok=True)
    
    def save_audio(self, audio_data: bytes, audio_name: str = None, format: str = "wav") -> str:
        """保存音频并返回存储路径"""
        if not audio_name:
            audio_name = f"audio_{uuid.uuid4().hex}.{format}"
        
        audio_path = os.path.join(self.audio_storage_path, audio_name)
        
        try:
            # 保存音频
            with open(audio_path, 'wb') as f:
                f.write(audio_data)
            
            # 提取音频特征
            features = self._extract_audio_features(audio_path)
            
            # 记录音频元数据
            audio_metadata = {
                "audio_name": audio_name,
                "path": audio_path,
                "timestamp": datetime.utcnow(),
                "duration": features["duration"],
                "sample_rate": features["sample_rate"],
                "features": features["features"]
            }
            
            # 保存元数据到数据库
            self._save_audio_metadata(audio_metadata)
            
            return audio_name
        except Exception as e:
            logger.error(f"保存音频失败: {e}")
            return None
    
    def _extract_audio_features(self, audio_path: str) -> Dict[str, Any]:
        """提取音频特征"""
        try:
            # 加载音频文件
            y, sr = librosa.load(audio_path)
            
            # 提取特征
            mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
            chroma = librosa.feature.chroma_stft(y=y, sr=sr)
            mel = librosa.feature.melspectrogram(y=y, sr=sr)
            
            # 计算音频时长
            duration = librosa.get_duration(y=y, sr=sr)
            
            return {
                "duration": duration,
                "sample_rate": sr,
                "features": {
                    "mfccs": mfccs.tolist(),
                    "chroma": chroma.tolist(),
                    "mel": mel.tolist()
                }
            }
        except Exception as e:
            logger.error(f"提取音频特征失败: {e}")
            return {
                "duration": 0,
                "sample_rate": 0,
                "features": {}
            }
    
    def _save_audio_metadata(self, metadata: Dict[str, Any]) -> None:
        """保存音频元数据到数据库"""
        session = self.Session()
        try:
            audio_record = AudioRecord(
                audio_name=metadata["audio_name"],
                path=metadata["path"],
                timestamp=metadata["timestamp"],
                duration=metadata["duration"],
                sample_rate=metadata["sample_rate"],
                features=json.dumps(metadata["features"])
            )
            session.add(audio_record)
            session.commit()
        except Exception as e:
            session.rollback()
            logger.error(f"保存音频元数据失败: {e}")
        finally:
            session.close()
    
    def get_audio(self, audio_name: str) -> bytes:
        """获取音频数据"""
        audio_path = os.path.join(self.audio_storage_path, audio_name)
        
        try:
            with open(audio_path, 'rb') as f:
                return f.read()
        except Exception as e:
            logger.error(f"获取音频失败: {e}")
            return None
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """扩展保存上下文,处理音频"""
        # 处理输入中的音频
        if "audio" in inputs:
            audio_data = inputs["audio"]
            audio_name = self.save_audio(audio_data)
            if audio_name:
                inputs["saved_audio"] = audio_name
        
        # 调用基类方法保存文本上下文
        super().save_context(inputs, outputs)

支持音频数据的存储和特征提取,将音频与对话记录关联,实现多模态记忆。

10.3 多模态检索

class MultiModalMemory(CustomMemory):
    def __init__(self, 
                 db_uri: str, 
                 image_storage_path: str = "./images",
                 audio_storage_path: str = "./audio",
                 table_name: str = "conversations", 
                 encryption_key: str = None):
        super().__init__(db_uri, table_name, encryption_key)
        self.image_memory = ImageMemory(db_uri, image_storage_path, table_name, encryption_key)
        self.audio_memory = AudioMemory(db_uri, audio_storage_path, table_name, encryption_key)
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """保存多模态上下文"""
        # 处理图像
        if "images" in inputs:
            saved_images = []
            for image_data in inputs["images"]:
                image_name = self.image_memory.save_image(image_data)
                if image_name:
                    saved_images.append(image_name)
            inputs["saved_images"] = saved_images
        
        # 处理音频
        if "audio" in inputs:
            audio_data = inputs["audio"]
            audio_name = self.audio_memory.save_audio(audio_data)
            if audio_name:
                inputs["saved_audio"] = audio_name
        
        # 调用基类方法保存文本上下文
        super().save_context(inputs, outputs)
    
    def multimodal_search(self, query: Dict[str, Any], k: int = 4) -> List[Dict[str, Any]]:
        """多模态检索"""
        # 1. 文本检索
        text_results = []
        if "text" in query:
            text_results = self._text_search(query["text"], k)
        
        # 2. 图像检索
        image_results = []
        if "image" in query:
            image_results = self._image_search(query["image"], k)
        
        # 3. 音频检索
        audio_results = []
        if "audio" in query:
            audio_results = self._audio_search(query["audio"], k)
        
        # 4. 结果融合
        merged_results = self._merge_multimodal_results(text_results, image_results, audio_results)
        
        return merged_results
    
    def _text_search(self, text_query: str, k: int = 4) -> List[Dict[str, Any]]:
        """文本检索实现"""
        # 调用基类的检索方法
        memory_vars = super().load_memory_variables({"input": text_query})
        return memory_vars.get("history", [])[:k]
    
    def _image_search(self, image_query: bytes, k: int = 4) -> List[Dict[str, Any]]:
        """图像检索实现"""
        # 简化实现,实际应使用计算机视觉模型进行图像相似度比较
        return []
    
    def _audio_search(self, audio_query: bytes, k: int = 4) -> List[Dict[str, Any]]:
        """音频检索实现"""
        # 简化实现,实际应使用音频处理模型进行音频相似度比较
        return []
    
    def _merge_multimodal_results(self, text_results: List[Dict[str, Any]], 
                                 image_results: List[Dict[str, Any]], 
                                 audio_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """融合多模态检索结果"""
        # 简化实现,实际应根据相关性评分进行排序
        merged = []
        merged.extend(text_results)
        merged.extend(image_results)
        merged.extend(audio_results)
        return merged

整合文本、图像和音频的多模态检索功能,实现跨媒体类型的上下文关联。

十一、测试与验证

11.1 单元测试框架

import unittest
from unittest.mock import patch, MagicMock
from datetime import datetime

class TestCustomMemory(unittest.TestCase):
    def setUp(self):
        # 创建测试数据库连接
        self.db_uri = "sqlite:///:memory:"
        self.memory = CustomMemory(db_uri=self.db_uri)
    
    def test_save_context(self):
        # 测试保存上下文
        inputs = {"input": "你好"}
        outputs = {"output": "你好!有什么可以帮助你的吗?"}
        
        self.memory.save_context(inputs, outputs)
        
        # 验证是否保存成功
        memory_vars = self.memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        
        self.assertEqual(len(history), 1)
        self.assertEqual(history[0]["input"], "你好")
        self.assertEqual(history[0]["output"], "你好!有什么可以帮助你的吗?")
    
    def test_load_memory_variables(self):
        # 测试加载记忆变量
        inputs = {"input": "今天天气如何?"}
        outputs = {"output": "今天天气晴朗,温度适宜。"}
        
        self.memory.save_context(inputs, outputs)
        
        memory_vars = self.memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        
        self.assertEqual(len(history), 1)
        self.assertEqual(history[0]["input"], "今天天气如何?")
        self.assertEqual(history[0]["output"], "今天天气晴朗,温度适宜。")
    
    def test_multiple_saves(self):
        # 测试多次保存
        inputs1 = {"input": "你好"}
        outputs1 = {"output": "你好!"}
        
        inputs2 = {"input": "再见"}
        outputs2 = {"output": "再见!"}
        
        self.memory.save_context(inputs1, outputs1)
        self.memory.save_context(inputs2, outputs2)
        
        memory_vars = self.memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        
        self.assertEqual(len(history), 2)
        self.assertEqual(history[0]["input"], "你好")
        self.assertEqual(history[0]["output"], "你好!")
        self.assertEqual(history[1]["input"], "再见")
        self.assertEqual(history[1]["output"], "再见!")
    
    @patch('your_module.datetime')
    def test_time_ordering(self, mock_datetime):
        # 测试时间顺序
        # 设置固定时间
        mock_datetime.utcnow.side_effect = [
            datetime(2023, 1, 1, 12, 0, 0),
            datetime(2023, 1, 1, 12, 1, 0),
            datetime(2023, 1, 1, 12, 2, 0)
        ]
        
        inputs1 = {"input": "第一个消息"}
        outputs1 = {"output": "第一个回复"}
        
        inputs2 = {"input": "第二个消息"}
        outputs2 = {"output": "第二个回复"}
        
        inputs3 = {"input": "第三个消息"}
        outputs3 = {"output": "第三个回复"}
        
        self.memory.save_context(inputs1, outputs1)
        self.memory.save_context(inputs2, outputs2)
        self.memory.save_context(inputs3, outputs3)
        
        memory_vars = self.memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        
        self.assertEqual(len(history), 3)
        self.assertEqual(history[0]["input"], "第一个消息")
        self.assertEqual(history[1]["input"], "第二个消息")
        self.assertEqual(history[2]["input"], "第三个消息")

使用Python的unittest框架对自定义记忆存储的基本功能进行单元测试,确保核心功能正常工作。

11.2 性能测试

import timeit
import random
import string
from concurrent.futures import ThreadPoolExecutor

class TestMemoryPerformance(unittest.TestCase):
    def setUp(self):
        self.db_uri = "sqlite:///:memory:"
        self.memory = CustomMemory(db_uri=self.db_uri)
    
    def test_save_performance(self):
        # 测试保存性能
        def save_random_context():
            input_text = ''.join(random.choices(string.ascii_letters, k=100))
            output_text = ''.join(random.choices(string.ascii_letters, k=200))
            self.memory.save_context({"input": input_text}, {"output": output_text})
        
        execution_time = timeit.timeit(save_random_context, number=100)
        print(f"保存100条记录耗时: {execution_time}秒")
        self.assertLess(execution_time, 10.0, "保存性能低于预期")
    
    def test_load_performance(self):
        # 测试加载性能
        # 先插入一些数据
        for i in range(1000):
            input_text = f"测试输入{i}"
            output_text = f"测试输出{i}"
            self.memory.save_context({"input": input_text}, {"output": output_text})
        
        def load_context():
            self.memory.load_memory_variables({})
        
        execution_time = timeit.timeit(load_context, number=10)
        print(f"加载10次耗时: {execution_time}秒")
        self.assertLess(execution_time, 2.0, "加载性能低于预期")
    
    def test_concurrent_access(self):
        # 测试并发访问
        def concurrent_task():
            input_text = ''.join(random.choices(string.ascii_letters, k=50))
            output_text = ''.join(random.choices(string.ascii_letters, k=100))
            
            # 交替进行读写操作
            for _ in range(10):
                self.memory.save_context({"input": input_text}, {"output": output_text})
                self.memory.load_memory_variables({})
        
        with ThreadPoolExecutor(max_workers=10) as executor:
            futures = [executor.submit(concurrent_task) for _ in range(5)]
            for future in futures:
                future.result()
        
        # 验证数据完整性
        memory_vars = self.memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        self.assertGreater(len(history), 0, "并发访问后数据丢失")

通过性能测试评估记忆存储在大量数据和高并发情况下的表现,识别潜在瓶颈。

11.3 集成测试

from langchain.chains import ConversationChain
from langchain.llms import OpenAI

class TestMemoryIntegration(unittest.TestCase):
    @patch('langchain.llms.OpenAI')
    def test_with_conversation_chain(self, mock_openai):
        # 测试与对话链的集成
        # 设置模拟LLM
        mock_llm = MagicMock(spec=OpenAI)
        mock_llm.predict.return_value = "模拟回复"
        mock_openai.return_value = mock_llm
        
        # 创建自定义记忆
        memory = CustomMemory(db_uri="sqlite:///:memory:")
        
        # 创建对话链
        chain = ConversationChain(
            llm=mock_llm,
            memory=memory,
            verbose=False
        )
        
        # 进行对话
        response1 = chain.predict(input="你好")
        response2 = chain.predict(input="今天天气如何?")
        
        # 验证记忆
        memory_vars = memory.load_memory_variables({})
        history = memory_vars.get("history", [])
        
        self.assertEqual(len(history), 2)
        self.assertEqual(history[0]["input"], "你好")
        self.assertEqual(history[0]["output"], "模拟回复")
        self.assertEqual(history[1]["input"], "今天天气如何?")
        self.assertEqual(history[1]["output"], "模拟回复")
    
    def test_with_vector_search(self):
        # 测试与向量检索的集成
        memory = VectorizedMemory(db_uri="sqlite:///:memory:")
        
        # 保存一些上下文
        memory.save_context({"input": "北京天气如何?"}, {"output": "北京今天晴天,25度。"})
        memory.save_context({"input": "上海天气如何?"}, {"output": "上海今天多云,28度。"})
        memory.save_context({"input": "广州天气如何?"}, {"output": "广州今天小雨,27度。"})
        
        # 执行向量搜索
        results = memory.similarity_search("北京的天气", k=1)
        
        self.assertEqual(len(results), 1)
        self.assertIn("北京", results[0]["input"])
        self.assertIn("25度", results[0]["output"])

验证自定义记忆存储与LangChain其他组件的集成效果,确保系统整体功能正常。

十二、监控与调优

12.1 性能监控

import time
from prometheus_client import Counter, Histogram, start_http_server

class MonitoredMemory(CustomMemory):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # 初始化监控指标
        self.save_requests = Counter(
            'memory_save_requests_total', 
            'Total number of save requests',
            ['status']
        )
        
        self.load_requests = Counter(
            'memory_load_requests_total', 
            'Total number of load requests',
            ['status']
        )
        
        self.save_duration = Histogram(
            'memory_save_duration_seconds', 
            'Duration of save operations',
            buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0]
        )
        
        self.load_duration = Histogram(
            'memory_load_duration_seconds', 
            'Duration of load operations',
            buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0]
        )
        
        # 启动监控服务器
        start_http_server(8000)
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        start_time = time.time()
        try:
            super().save_context(inputs, outputs)
            self.save_requests.labels(status='success').inc()
        except Exception as e:
            self.save_requests.labels(status='error').inc()
            raise e
        finally:
            duration = time.time() - start_time
            self.save_duration.observe(duration)
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        start_time = time.time()
        try:
            result = super().load_memory_variables(inputs)
            self.load_requests.labels(status='success').inc()
            return result
        except Exception as e:
            self.load_requests.labels(status='error').inc()
            raise e
        finally:
            duration = time.time() - start_time
            self.load_duration.observe(duration)

使用Prometheus监控记忆存储的关键指标,包括请求次数、响应时间和错误率,为性能优化提供数据支持。

12.2 自动调优

import numpy as np
from sklearn.linear_model import LinearRegression

class AutoTuningMemory(CustomMemory):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.performance_data = []
        self.model = LinearRegression()
        self.optimal_batch_size = 100
        self.auto_tune_enabled = True
    
    def enable_auto_tune(self, enable: bool = True) -> None:
        self.auto_tune_enabled = enable
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        start_time = time.time()
        
        # 使用当前批处理大小
        batch_size = self.optimal_batch_size
        
        # 如果有多个条目需要保存,分批处理
        if isinstance(inputs, list) and isinstance(outputs, list):
            for i in range(0, len(inputs), batch_size):
                batch_inputs = inputs[i:i+batch_size]
                batch_outputs = outputs[i:i+batch_size]
                super().save_context(batch_inputs, batch_outputs)
        else:
            super().save_context(inputs, outputs)
        
        duration = time.time() - start_time
        
        # 收集性能数据
        self._collect_performance_data(len(inputs) if isinstance(inputs, list) else 1, duration)
        
        # 定期重新训练模型
        if len(self.performance_data) % 100 == 0 and self.auto_tune_enabled:
            self._retrain_model()
    
    def _collect_performance_data(self, batch_size: int, duration: float) -> None:
        # 收集批处理大小和执行时间的关系
        self.performance_data.append((batch_size, duration))
        
        # 限制数据点数量,避免内存溢出
        if len(self.performance_data) > 1000:
            self.performance_data.pop(0)
    
    def _retrain_model(self) -> None:
        # 从收集的数据中训练线性回归模型
        if len(self.performance_data) < 10:  # 至少需要10个数据点
            return
        
        X = np.array([[data[0]] for data in self.performance_data])
        y = np.array([data[1] for data in self.performance_data])
        
        self.model.fit(X, y)
        
        # 预测不同批处理大小的性能
        test_sizes = np.array([[i] for i in range(10, 500, 10)])
        predictions = self.model.predict(test_sizes)
        
        # 找到最优批处理大小(执行时间最短)
        optimal_idx = np.argmin(predictions)
        self.optimal_batch_size = test_sizes[optimal_idx][0]
        
        print(f"自动调优: 最优批处理大小已更新为 {self.optimal_batch_size}")

通过收集性能数据并使用线性回归模型,自动调整批处理大小等参数,提高系统性能。

12.3 异常处理

class RobustMemory(CustomMemory):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.retry_attempts = 3
        self.retry_delay = 1
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        for attempt in range(self.retry_attempts):
            try:
                super().save_context(inputs, outputs)
                break
            except Exception as e:
                if attempt == self.retry_attempts - 1:
                    logger.error(f"保存上下文失败,已达到最大重试次数: {e}")
                    raise
                else:
                    logger.warning(f"保存上下文失败,尝试重试 ({attempt+1}/{self.retry_attempts}): {e}")
                    time.sleep(self.retry_delay * (2 ** attempt))  # 指数退避
    
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        for attempt in range(self.retry_attempts):
            try:
                return super().load_memory_variables(inputs)
            except Exception as e:
                if attempt == self.retry_attempts - 1:
                    logger.error(f"加载记忆变量失败,已达到最大重试次数: {e}")
                    raise
                else:
                    logger.warning(f"加载记忆变量失败,尝试重试 ({attempt+1}/{self.retry_attempts}): {e}")
                    time.sleep(self.retry_delay * (2 ** attempt))  # 指数退避
    
    def backup_memory(self, backup_path: str) -> None:
        """创建记忆存储的备份"""
        try:
            # 对于SQLite数据库,直接复制文件
            if self.db_uri.startswith("sqlite:///"):
                db_path = self.db_uri.replace("sqlite:///", "")
                if os.path.exists(db_path):
                    shutil.copy2(db_path, backup_path)
                    logger.info(f"记忆存储备份成功,保存到 {backup_path}")
                else:
                    logger.warning("数据库文件不存在,无法备份")
            else:
                # 对于其他数据库类型,实现相应的备份逻辑
                logger.warning(f"不支持的数据库类型,无法备份: {self.db_uri}")
        except Exception as e:
            logger.error(f"备份失败: {e}")

增强记忆存储的健壮性,通过重试机制和备份功能,确保系统在面对异常情况时能够稳定运行。

十三、实战案例

13.1 智能客服系统中的记忆存储

在智能客服系统中,需要保存用户历史咨询记录、问题分类和解决方案,以便提供连贯的服务体验。

class CustomerServiceMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "customer_service"):
        super().__init__(db_uri, table_name)
        # 初始化客户服务特定的索引
        self._create_customer_indexes()
    
    def _create_customer_indexes(self):
        # 创建客户ID和问题类型的索引
        session = self.Session()
        try:
            # 检查索引是否存在
            inspector = inspect(self.engine)
            indexes = inspector.get_indexes(self.table_name)
            
            if not any(index['name'] == 'idx_customer_id' for index in indexes):
                # 创建客户ID索引
                text_type = Text() if 'sqlite' in self.db_uri else VARCHAR(255)
                idx = Index('idx_customer_id', CustomerServiceRecord.customer_id)
                idx.create(bind=self.engine)
            
            if not any(index['name'] == 'idx_issue_type' for index in indexes):
                # 创建问题类型索引
                idx = Index('idx_issue_type', CustomerServiceRecord.issue_type)
                idx.create(bind=self.engine)
        finally:
            session.close()
    
    def get_customer_history(self, customer_id: str) -> List[Dict[str, Any]]:
        """获取特定客户的历史记录"""
        session = self.Session()
        try:
            records = session.query(CustomerServiceRecord).filter(
                CustomerServiceRecord.customer_id == customer_id
            ).order_by(CustomerServiceRecord.timestamp.desc()).all()
            
            return [
                {
                    "input": record.input_text,
                    "output": record.output_text,
                    "timestamp": record.timestamp,
                    "issue_type": record.issue_type
                }
                for record in records
            ]
        finally:
            session.close()
    
    def get_issue_statistics(self) -> Dict[str, int]:
        """获取问题类型统计"""
        session = self.Session()
        try:
            # 统计每种问题类型的数量
            statistics = session.query(
                CustomerServiceRecord.issue_type,
                func.count(CustomerServiceRecord.id)
            ).group_by(CustomerServiceRecord.issue_type).all()
            
            return {issue_type: count for issue_type, count in statistics}
        finally:
            session.close()
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """扩展保存上下文,添加客户服务特定字段"""
        # 提取客户ID和问题类型
        customer_id = inputs.get("customer_id", "unknown")
        issue_type = inputs.get("issue_type", "general")
        
        # 创建客户服务记录
        record = CustomerServiceRecord(
            customer_id=customer_id,
            input_text=inputs.get("input"),
            output_text=outputs.get("output"),
            issue_type=issue_type,
            timestamp=datetime.utcnow()
        )
        
        session = self.Session()
        try:
            session.add(record)
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()

这个自定义记忆存储专为智能客服系统设计,添加了客户ID和问题类型等特定字段,并创建了相应索引,提高查询效率。

13.2 教育辅助系统中的记忆存储

在教育辅助系统中,需要跟踪学生的学习进度、知识点掌握情况和历史问答记录。

class EducationMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "education"):
        super().__init__(db_uri, table_name)
        self.subject_embeddings = {}
    
    def save_learning_progress(self, student_id: str, subject: str, progress: float, knowledge_points: List[str]) -> None:
        """保存学习进度"""
        session = self.Session()
        try:
            # 检查是否已有该学生的学习记录
            record = session.query(EducationRecord).filter(
                EducationRecord.student_id == student_id,
                EducationRecord.subject == subject
            ).first()
            
            if record:
                # 更新现有记录
                record.progress = progress
                record.knowledge_points = json.dumps(knowledge_points)
                record.updated_at = datetime.utcnow()
            else:
                # 创建新记录
                record = EducationRecord(
                    student_id=student_id,
                    subject=subject,
                    progress=progress,
                    knowledge_points=json.dumps(knowledge_points),
                    created_at=datetime.utcnow(),
                    updated_at=datetime.utcnow()
                )
                session.add(record)
            
            session.commit()
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()
    
    def get_learning_progress(self, student_id: str, subject: str = None) -> Dict[str, Any]:
        """获取学习进度"""
        session = self.Session()
        try:
            if subject:
                # 获取特定科目的学习进度
                record = session.query(EducationRecord).filter(
                    EducationRecord.student_id == student_id,
                    EducationRecord.subject == subject
                ).first()
                
                if record:
                    return {
                        "student_id": record.student_id,
                        "subject": record.subject,
                        "progress": record.progress,
                        "knowledge_points": json.loads(record.knowledge_points),
                        "updated_at": record.updated_at
                    }
                else:
                    return None
            else:
                # 获取所有科目的学习进度
                records = session.query(EducationRecord).filter(
                    EducationRecord.student_id == student_id
                ).all()
                
                return [
                    {
                        "student_id": record.student_id,
                        "subject": record.subject,
                        "progress": record.progress,
                        "knowledge_points": json.loads(record.knowledge_points),
                        "updated_at": record.updated_at
                    }
                    for record in records
                ]
        finally:
            session.close()
    
    def recommend_resources(self, student_id: str, subject: str, max_results: int = 5) -> List[Dict[str, Any]]:
        """推荐学习资源"""
        # 1. 获取学生当前的知识状态
        progress = self.get_learning_progress(student_id, subject)
        if not progress:
            return []
        
        # 2. 基于知识状态生成资源推荐
        # 这里简化为随机推荐,实际应基于知识点掌握情况和资源库匹配
        resources = [
            {"title": f"{subject}基础教程", "type": "book", "difficulty": "beginner"},
            {"title": f"{subject}进阶练习", "type": "exercise", "difficulty": "intermediate"},
            {"title": f"{subject}高级理论", "type": "article", "difficulty": "advanced"},
            {"title": f"{subject}实例分析", "type": "case", "difficulty": "intermediate"},
            {"title": f"{subject}测验", "type": "quiz", "difficulty": "beginner"}
        ]
        
        # 根据学习进度调整推荐难度
        if progress["progress"] < 0.3:
            # 初学者推荐基础资源
            return [r for r in resources if r["difficulty"] == "beginner"][:max_results]
        elif progress["progress"] < 0.7:
            # 中级学习者推荐中级资源
            return [r for r in resources if r["difficulty"] == "intermediate"][:max_results]
        else:
            # 高级学习者推荐高级资源
            return [r for r in resources if r["difficulty"] == "advanced"][:max_results]
    
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
        """扩展保存上下文,添加教育特定信息"""
        # 提取学生ID和科目
        student_id = inputs.get("student_id", "unknown")
        subject = inputs.get("subject", "general")
        
        # 保存常规对话上下文
        super().save_context(inputs, outputs)
        
        # 更新主题嵌入(用于资源推荐)
        self._update_subject_embedding(subject, inputs.get("input", ""))
    
    def _update_subject_embedding(self, subject: str, text: str) -> None:
        """更新主题嵌入向量"""
        if not text:
            return
        
        # 简化实现,实际应使用NLP模型生成嵌入向量
        if subject not in self.subject_embeddings:
            self.subject_embeddings[subject] = []
        
        # 这里只是示例,实际应使用文本嵌入模型
        self.subject_embeddings[subject].append(text)
        
        # 限制嵌入向量数量,避免内存溢出
        if len(self.subject_embeddings[subject]) > 100:
            self.subject_embeddings[subject].pop(0)

这个教育记忆存储不仅保存对话历史,还跟踪学生的学习进度和知识点掌握情况,支持个性化学习资源推荐。

13.3 金融投资助手的记忆存储

在金融投资助手中,需要保存用户的投资组合、交易历史和市场分析记录。

class FinancialMemory(CustomMemory):
    def __init__(self, db_uri: str, table_name: str = "financial"):
        super().__init__(db_uri, table_name)
        self.portfolio_cache = {}
    
    def save_portfolio(self, user_id: str, portfolio: Dict[str, Any]) -> None:
        """保存投资组合"""
        session = self.Session()
        try:
            # 检查是否已有该用户的投资组合
            record = session.query(FinancialRecord).filter(
                FinancialRecord.user_id == user_id,
                FinancialRecord.record_type == "portfolio"
            ).first()
            
            if record:
                # 更新现有记录
                record.content = json.dumps(portfolio)
                record.updated_at = datetime.utcnow()
            else:
                # 创建新记录
                record = FinancialRecord(
                    user_id=user_id,
                    record_type="portfolio",
                    content=json.dumps(portfolio),
                    created_at=datetime.utcnow(),
                    updated_at=datetime.utcnow()
                )
                session.add(record)
            
            session.commit()
            
            # 更新缓存
            self.portfolio_cache[user_id] = portfolio
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()
    
    def get_portfolio(self, user_id: str) -> Dict[str, Any]:
        """获取投资组合"""
        # 先检查缓存
        if user_id in self.portfolio_cache:
            return self.portfolio_cache[user_id]
        
        session = self.Session()
        try:
            record = session.query(FinancialRecord).filter(
                FinancialRecord.user_id == user_id,
                FinancialRecord.record_type == "portfolio"
            ).first()
            
            if record:
                portfolio = json.loads(record.content)
                # 更新缓存
                self.portfolio_cache[user_id] = portfolio
                return portfolio
            else:
                return {"assets": [], "cash": 0.0}
        finally:
            session.close()
    
    def save_transaction(self, user_id: str, transaction: Dict[str, Any]) -> None:
        """保存交易记录"""
        session = self.Session()
        try:
            # 创建交易记录
            record = FinancialRecord(
                user_id=user_id,
                record_type="transaction",
                content=json.dumps(transaction),
                created_at=datetime.utcnow(),
                updated_at=datetime.utcnow()
            )
            
            session.add(record)
            session.commit()
            
            # 更新投资组合
            self._update_portfolio_after_transaction(user_id, transaction)
        except Exception as e:
            session.rollback()
            raise e
        finally:
            session.close()
    
    def get_transaction_history(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
        """获取交易历史"""
        session = self.Session()
        try:
            records = session.query(FinancialRecord).filter(
                FinancialRecord.user_id == user_id,
                FinancialRecord.record_type == "transaction"
            ).order_by(FinancialRecord.created_at.desc()).limit(limit).all()
            
            return [json.loads(record.content) for record in records]
        finally:
            session.close()
    
    def _update_portfolio_after_transaction(self, user_id: str, transaction: Dict[str, Any]) -> None:
        """交易后更新投资组合"""
        portfolio = self.get_portfolio(user_id)
        assets = portfolio.get("assets", [])
        cash = portfolio.get("cash", 0.0)
        
        # 处理交易
        transaction_type = transaction.get("type")
        asset = transaction.get("asset")
        amount = transaction.get("amount", 0)
        price = transaction.get("price", 0)
        cost = amount * price
        
        if transaction_type == "buy":
            # 购买资产
            cash -= cost
            
            # 检查是否已有该资产
            existing_asset = next((a for a in assets if a["symbol"] == asset), None)
            if existing_asset:
                existing_asset["amount"] += amount
            else:
                assets.append({"symbol": asset, "amount": amount})
        elif transaction_type == "sell":
            # 卖出资产
            cash += cost
            
            # 更新资产数量
            for a in assets:
                if a["symbol"] == asset:
                    a["amount"] -= amount
                    if a["amount"] <= 0:
                        assets.remove(a)
                    break
        
        # 更新投资组合
        portfolio["assets"] = assets
        portfolio["cash"] = cash
        self.save_portfolio(user_id, portfolio)
    
    def analyze_portfolio_risk(self, user_id: str) -> Dict[str, Any]:
        """分析投资组合风险"""
        portfolio = self.get_portfolio(user_id)
        assets = portfolio.get("assets", [])
        
        # 简化的风险分析
        risk_score = 0.0
        diversification = len(assets)
        
        for asset in assets:
            # 假设每种资产有一个风险系数(实际应从市场数据获取)
            risk_factor = 0.5  # 默认中等风险
            if asset["symbol"].startswith("BTC"):
                risk_factor = 0.8  # 加密货币高风险
            elif asset["symbol"].startswith("AAPL"):
                risk_factor = 0.3  # 蓝筹股低风险
            
            # 计算加权风险
            asset_value = asset.get("amount", 0) * 100  # 简化,假设每股100元
            total_value = sum(a.get("amount", 0) * 100 for a in assets)
            
            if total_value > 0:
                risk_score += (asset_value / total_value) * risk_factor
        
        return {
            "risk_score": risk_score,
            "diversification": diversification,
            "recommendation": "建议分散投资" if diversification < 3 else "投资组合多元化良好"
        }

这个金融记忆存储专门为投资助手设计,支持保存和管理用户的投资组合和交易历史,并提供基本的风险分析功能。

十四、未来发展方向

14.1 大模型集成

随着大型语言模型能力的不断增强,未来的记忆存储将更加紧密地与大模型集成,实现:

  1. 智能摘要:自动生成对话摘要,减少存储开销
  2. 知识蒸馏:从历史对话中提取结构化知识
  3. 预测性记忆:基于用户历史行为预测未来需求
  4. 语义压缩:使用向量表示法更高效地存储对话内容

14.2 联邦学习与隐私保护

在敏感领域应用中,隐私保护将变得更加重要:

  1. 联邦学习记忆:在不共享原始数据的情况下训练记忆模型
  2. 差分隐私:在存储和检索过程中添加噪声保护隐私
  3. 零知识证明:验证记忆操作的正确性而不泄露内容
  4. 动态权限控制:根据上下文动态调整数据访问权限

14.3 边缘计算与分布式记忆

随着边缘设备能力的提升,记忆存储将向分布式架构发展:

  1. 边缘记忆缓存:在本地设备存储常用对话历史
  2. 分布式记忆网络:跨设备共享和同步记忆数据
  3. 雾计算记忆:在边缘和云端之间平衡存储和计算
  4. 去中心化记忆:使用区块链技术实现可信记忆存储

14.4 多模态与情境感知

未来的记忆存储将支持更丰富的模态和情境感知:

  1. 多感官记忆:整合文本、图像、音频、视频等多种模态
  2. 情感记忆:记录和分析对话中的情感状态
  3. 情境感知:结合时间、地点、环境等情境信息
  4. 物理世界映射:将记忆与现实世界的空间和物体关联

14.5 自我进化与元学习

记忆存储系统将具备自我进化能力:

  1. 自适应记忆结构:根据使用模式自动调整存储结构
  2. 元学习记忆:学习如何更好地组织和检索记忆
  3. 持续学习:在不遗忘旧知识的情况下学习新知识
  4. 记忆重组:定期重构记忆以提高检索效率和准确性

14.6 量子计算与记忆存储

量子计算技术可能为记忆存储带来革命性变化:

  1. 量子存储:利用量子态存储更多信息
  2. 量子搜索算法:加速大规模记忆检索
  3. 量子加密:提供更高级别的记忆安全保护
  4. 量子机器学习:增强记忆分析和处理能力

通过持续的技术创新,LangChain自定义记忆存储将在未来的AI系统中扮演更加核心和智能的角色,为用户提供更加个性化、高效和安全的交互体验。