LangChain自定义记忆存储开发深度解析
一、LangChain记忆存储概述
1.1 记忆存储的核心作用
在LangChain框架中,记忆存储是实现对话连续性、上下文感知和历史信息复用的关键模块。无论是聊天机器人、智能客服还是任务型对话系统,记忆存储负责保存用户与AI交互的历史记录,包括对话内容、任务状态、用户偏好等信息。这些历史数据能够辅助语言模型理解当前语境,生成更符合上下文的回复,提升用户体验。从源码层面看,记忆存储模块通过标准化接口与LangChain其他组件(如链、代理、提示模板)交互,为整个系统提供上下文感知能力。
1.2 原生记忆存储的局限性
LangChain提供了多种原生记忆存储实现,如ConversationBufferMemory(存储完整对话历史)、ConversationSummaryMemory(存储对话摘要)等。然而,这些原生实现存在一定局限性:
- 存储结构固定:仅支持预设的数据结构,难以满足复杂业务场景需求
- 扩展性不足:无法方便地对接自定义数据库或存储系统
- 功能单一:缺乏数据清洗、版本管理、安全加密等高级功能
- 性能瓶颈:在处理大规模对话数据时可能出现效率问题
这些局限性促使开发者进行自定义记忆存储开发,以满足特定应用场景的需求。
二、自定义记忆存储的设计原则
2.1 接口兼容性
自定义记忆存储必须遵循LangChain的记忆接口规范,确保与现有组件无缝集成。核心接口定义如下:
from langchain.memory import BaseMemory
class CustomMemory(BaseMemory, ABC):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""从存储中加载记忆变量"""
raise NotImplementedError()
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""保存对话上下文"""
raise NotImplementedError()
@property
def memory_variables(self) -> List[str]:
"""返回记忆变量名称列表"""
raise NotImplementedError()
通过实现这三个抽象方法,自定义记忆存储可以与ConversationChain、AgentExecutor等组件协同工作。
2.2 数据结构设计
合理的数据结构设计是自定义记忆存储的关键。常见设计模式包括:
- 文档型存储:使用JSON格式存储对话记录,便于扩展和查询
- 关系型存储:通过表结构管理对话、用户、任务等实体关系
- 时序型存储:按时间顺序存储对话片段,适合分析历史趋势
- 图结构存储:使用图数据库表示对话中的实体关系
2.3 性能优化考虑
在设计阶段需要考虑性能优化:
- 索引设计:为常用查询字段建立索引
- 缓存机制:添加本地缓存减少数据库访问
- 批量操作:支持批量读写提升效率
- 异步处理:采用异步IO处理高并发请求
2.4 安全与合规性
自定义存储需满足安全和合规要求:
- 数据加密:对敏感信息进行加密存储
- 访问控制:设置严格的权限管理机制
- 合规审计:记录数据操作日志
- 隐私保护:遵循GDPR、CCPA等法规要求
三、基础存储类实现
3.1 初始化与配置
自定义记忆存储的初始化方法负责配置存储参数:
class CustomMemory(BaseMemory):
def __init__(self,
db_uri: str,
table_name: str = "conversations",
encryption_key: str = None):
self.db_uri = db_uri # 数据库连接字符串
self.table_name = table_name # 存储表名
self.encryption_key = encryption_key # 加密密钥
# 初始化数据库连接
self.engine = create_engine(self.db_uri)
self.Session = sessionmaker(bind=self.engine)
# 创建表结构(如果不存在)
self._create_table()
通过参数化配置,支持不同类型的数据库(如SQLite、PostgreSQL、MySQL)。
3.2 记忆变量加载
load_memory_variables方法从存储中读取历史数据:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
session = self.Session()
try:
# 查询最近的N条对话记录
query = session.query(Conversation).order_by(Conversation.timestamp.desc()).limit(10)
records = query.all()
# 解析记录为字典格式
memory = {
"history": [{"input": r.input_text, "output": r.output_text} for r in records]
}
return memory
finally:
session.close()
这里使用SQLAlchemy进行数据库操作,通过order_by和limit实现分页查询。
3.3 上下文保存
save_context方法将新的对话记录存入存储:
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
session = self.Session()
try:
input_text = inputs.get("input")
output_text = outputs.get("output")
# 创建对话记录对象
record = Conversation(
input_text=input_text,
output_text=output_text,
timestamp=datetime.utcnow()
)
# 加密敏感字段(如果配置了密钥)
if self.encryption_key:
record.input_text = encrypt(input_text, self.encryption_key)
record.output_text = encrypt(output_text, self.encryption_key)
session.add(record)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
该方法支持数据加密,并通过事务机制保证数据一致性。
3.4 变量名称定义
memory_variables属性定义可访问的记忆变量:
@property
def memory_variables(self) -> List[str]:
return ["history"]
这个属性告知LangChain哪些变量可以在提示模板中使用。
四、高级功能扩展
4.1 数据清洗与压缩
为避免存储无限增长,可添加数据清洗功能:
def clean_memory(self, max_records: int = 100) -> None:
session = self.Session()
try:
# 查询超出限制的记录数量
count = session.query(func.count(Conversation.id)).scalar()
if count > max_records:
# 删除最早的记录
delete_query = (
session.query(Conversation)
.order_by(Conversation.timestamp.asc())
.limit(count - max_records)
.delete(synchronize_session=False)
)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
此方法定期清理旧记录,保持存储大小可控。
4.2 版本管理
实现版本管理功能可以跟踪记忆数据的变更:
class ConversationVersion(Base):
__tablename__ = "conversation_versions"
id = Column(Integer, primary_key=True)
conversation_id = Column(Integer, ForeignKey("conversations.id"))
version_number = Column(Integer)
input_text = Column(Text)
output_text = Column(Text)
timestamp = Column(DateTime)
class CustomMemory(BaseMemory):
def save_version(self, conversation_id: int) -> None:
session = self.Session()
try:
# 获取当前对话记录
conversation = session.query(Conversation).filter(Conversation.id == conversation_id).first()
# 创建版本记录
version = ConversationVersion(
conversation_id=conversation_id,
version_number=self._get_next_version_number(conversation_id),
input_text=conversation.input_text,
output_text=conversation.output_text,
timestamp=datetime.utcnow()
)
session.add(version)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def _get_next_version_number(self, conversation_id: int) -> int:
session = self.Session()
try:
# 查询最大版本号
result = session.query(func.max(ConversationVersion.version_number)).filter(
ConversationVersion.conversation_id == conversation_id
).scalar()
return (result or 0) + 1
finally:
session.close()
通过版本表记录每次变更,支持数据回滚。
4.3 智能检索
增强记忆检索功能可以提高查询效率:
def search_memory(self, keyword: str) -> List[Dict[str, Any]]:
session = self.Session()
try:
# 使用全文搜索查找包含关键词的记录
query = session.query(Conversation).filter(
Conversation.input_text.contains(keyword) | Conversation.output_text.contains(keyword)
)
records = query.all()
return [{"input": r.input_text, "output": r.output_text} for r in records]
finally:
session.close()
这里利用数据库的全文检索功能实现高效搜索。
五、与其他组件的集成
5.1 与对话链集成
将自定义记忆存储集成到ConversationChain:
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
memory = CustomMemory(db_uri="sqlite:///memory.db")
llm = OpenAI(temperature=0)
conversation = ConversationChain(
llm=llm,
memory=memory,
verbose=True
)
response = conversation.predict(input="你好")
通过这种方式,对话链可以自动使用自定义存储的历史记录。
5.2 与代理集成
在代理中使用自定义记忆:
from langchain.agents import AgentType, initialize_agent
from langchain.tools import BaseTool
class CustomTool(BaseTool):
name = "custom_tool"
description = "自定义工具描述"
def _run(self, query: str) -> str:
# 使用记忆存储进行查询
memory = CustomMemory(db_uri="sqlite:///memory.db")
results = memory.search_memory(query)
return str(results)
async def _arun(self, query: str) -> str:
raise NotImplementedError("异步方法未实现")
tools = [CustomTool()]
agent = initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
agent.run("查询历史记录")
代理可以利用记忆存储实现更智能的决策。
5.3 与提示模板集成
在提示模板中引用记忆变量:
from langchain.prompts import PromptTemplate
prompt = PromptTemplate(
input_variables=["history", "input"],
template="""
历史对话: {history}
用户输入: {input}
AI回复:
"""
)
memory = CustomMemory(db_uri="sqlite:///memory.db")
formatted_prompt = prompt.format(history=memory.load_memory_variables({})["history"], input="新问题")
通过这种方式,提示模板可以动态获取历史对话信息。
六、性能优化策略
6.1 缓存机制
添加本地缓存减少数据库访问:
import functools
from typing import Dict
class CachedMemory(CustomMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cache = {}
@functools.lru_cache(maxsize=128)
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return super().load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
super().save_context(inputs, outputs)
# 清空相关缓存
self.cache.clear()
使用functools.lru_cache实现LRU缓存策略。
6.2 批量操作
支持批量写入提高效率:
def batch_save_context(self, contexts: List[Dict[str, Any]]) -> None:
session = self.Session()
try:
records = []
for ctx in contexts:
input_text = ctx.get("input")
output_text = ctx.get("output")
record = Conversation(
input_text=input_text,
output_text=output_text,
timestamp=datetime.utcnow()
)
records.append(record)
session.add_all(records)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
该方法一次提交多条记录,减少数据库交互次数。
6.3 异步IO
采用异步方式处理高并发请求:
import asyncio
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
class AsyncCustomMemory(BaseMemory):
def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
self.db_uri = db_uri
self.table_name = table_name
self.encryption_key = encryption_key
self.engine = create_async_engine(self.db_uri)
self.Session = async_sessionmaker(bind=self.engine)
self._create_table()
async def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
session = self.Session()
try:
query = session.query(Conversation).order_by(Conversation.timestamp.desc()).limit(10)
records = await query.all()
memory = {
"history": [{"input": r.input_text, "output": r.output_text} for r in records]
}
return memory
finally:
await session.close()
async def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
session = self.Session()
try:
input_text = inputs.get("input")
output_text = outputs.get("output")
record = Conversation(
input_text=input_text,
output_text=output_text,
timestamp=datetime.utcnow()
)
if self.encryption_key:
record.input_text = encrypt(input_text, self.encryption_key)
record.output_text = encrypt(output_text, self.encryption_key)
session.add(record)
await session.commit()
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()
通过asyncio和sqlalchemy.ext.asyncio实现异步数据库操作。
七、安全与合规实现
7.1 数据加密
对敏感数据进行加密存储:
from cryptography.fernet import Fernet
def encrypt(text: str, key: str) -> str:
cipher_suite = Fernet(key)
encrypted_text = cipher_suite.encrypt(text.encode())
return encrypted_text.decode()
def decrypt(text: str, key: str) -> str:
cipher_suite = Fernet(key)
decrypted_text = cipher_suite.decrypt(text.encode())
return decrypted_text.decode()
class EncryptedMemory(CustomMemory):
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
data = super().load_memory_variables(inputs)
if self.encryption_key:
for item in data["history"]:
item["input"] = decrypt(item["input"], self.encryption_key)
item["output"] = decrypt(item["output"], self.encryption_key)
return data
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
if self.encryption_key:
inputs["input"] = encrypt(inputs["input"], self.encryption_key)
outputs["output"] = encrypt(outputs["output"], self.encryption_key)
super().save_context(inputs, outputs)
使用cryptography库实现对称加密。
7.2 访问控制
实现基于角色的访问控制:
class RoleBasedAccessMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None, role: str = "user"):
super().__init__(db_uri, table_name, encryption_key)
self.role = role
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.role == "admin":
return super().load_memory_variables(inputs)
else:
# 仅返回部分数据
data = super().load_memory_variables(inputs)
data["history"] = data["history"][-5:] # 仅返回最近5条
return data
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
if self.role in ["admin", "editor"]:
super().save_context(inputs, outputs)
else:
raise PermissionError("无写入权限")
根据用户角色限制数据访问和操作。
7.3 审计日志
记录数据操作日志:
class AuditableMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.audit_table_name = "audit_logs"
self._create_
def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.audit_table_name = "audit_logs"
self._create_audit_table()
def _create_audit_table(self):
# 创建审计日志表
metadata = MetaData()
Table(
self.audit_table_name, metadata,
Column('id', Integer, primary_key=True),
Column('user_id', String(50)),
Column('operation', String(20)),
Column('timestamp', DateTime),
Column('details', Text)
)
metadata.create_all(self.engine)
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
user_id = inputs.get("user_id", "anonymous")
self._log_audit(user_id, "load", str(inputs))
return super().load_memory_variables(inputs)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
user_id = inputs.get("user_id", "anonymous")
self._log_audit(user_id, "save", f"inputs: {inputs}, outputs: {outputs}")
super().save_context(inputs, outputs)
def _log_audit(self, user_id: str, operation: str, details: str) -> None:
session = self.Session()
try:
audit_log = AuditLog(
user_id=user_id,
operation=operation,
timestamp=datetime.utcnow(),
details=details
)
session.add(audit_log)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"审计日志记录失败: {e}")
finally:
session.close()
审计日志表记录每次操作的用户、操作类型、时间戳和详细信息,便于后续追踪和合规检查。
7.4 隐私保护
实现数据匿名化和过期清理:
from datetime import datetime, timedelta
class PrivacyAwareMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "conversations", encryption_key: str = None, retention_days: int = 30):
super().__init__(db_uri, table_name, encryption_key)
self.retention_days = retention_days
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
# 匿名化处理
if "user_id" in inputs:
inputs["user_id"] = self._anonymize_id(inputs["user_id"])
super().save_context(inputs, outputs)
def _anonymize_id(self, user_id: str) -> str:
# 简单的ID匿名化处理
hash_object = hashlib.sha256(user_id.encode())
return hash_object.hexdigest()[:10]
def purge_old_data(self) -> None:
"""清理超过保留期限的数据"""
cutoff_date = datetime.utcnow() - timedelta(days=self.retention_days)
session = self.Session()
try:
# 删除过期记录
delete_query = (
session.query(Conversation)
.filter(Conversation.timestamp < cutoff_date)
.delete(synchronize_session=False)
)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
通过哈希处理实现用户ID匿名化,并定期清理过期数据,保护用户隐私。
八、分布式与云存储实现
8.1 基于Redis的分布式缓存
import redis
import json
class RedisMemory(CustomMemory):
def __init__(self,
db_uri: str,
redis_host: str = "localhost",
redis_port: int = 6379,
redis_db: int = 0,
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.redis_client = redis.Redis(host=redis_host, port=redis_port, db=redis_db)
self.cache_prefix = "langchain:memory:"
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
conversation_id = inputs.get("conversation_id", "default")
cache_key = f"{self.cache_prefix}{conversation_id}"
# 尝试从缓存读取
cached_data = self.redis_client.get(cache_key)
if cached_data:
return json.loads(cached_data)
# 从数据库读取并缓存
data = super().load_memory_variables(inputs)
self.redis_client.setex(cache_key, 3600, json.dumps(data)) # 缓存1小时
return data
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
super().save_context(inputs, outputs)
# 清除缓存
conversation_id = inputs.get("conversation_id", "default")
cache_key = f"{self.cache_prefix}{conversation_id}"
self.redis_client.delete(cache_key)
使用Redis作为分布式缓存层,减少数据库访问压力,提高响应速度。
8.2 基于AWS S3的存储
import boto3
from botocore.exceptions import NoCredentialsError
class S3Memory(CustomMemory):
def __init__(self,
bucket_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
region_name: str = "us-east-1",
table_name: str = "conversations",
encryption_key: str = None):
self.bucket_name = bucket_name
self.s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name
)
super().__init__(db_uri="sqlite:///:memory:", table_name=table_name, encryption_key=encryption_key)
def _get_s3_key(self, conversation_id: str) -> str:
return f"conversations/{conversation_id}.json"
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
conversation_id = inputs.get("conversation_id", "default")
s3_key = self._get_s3_key(conversation_id)
try:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
content = response['Body'].read().decode('utf-8')
return json.loads(content)
except NoCredentialsError:
logger.error("AWS凭证错误")
return {"history": []}
except self.s3_client.exceptions.NoSuchKey:
return {"history": []}
except Exception as e:
logger.error(f"从S3加载失败: {e}")
return {"history": []}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
conversation_id = inputs.get("conversation_id", "default")
s3_key = self._get_s3_key(conversation_id)
# 先加载现有数据
current_data = self.load_memory_variables(inputs)
history = current_data.get("history", [])
# 添加新记录
new_record = {
"input": inputs.get("input"),
"output": outputs.get("output"),
"timestamp": datetime.utcnow().isoformat()
}
history.append(new_record)
# 保存回S3
data_to_save = {"history": history}
try:
self.s3_client.put_object(
Bucket=self.bucket_name,
Key=s3_key,
Body=json.dumps(data_to_save).encode('utf-8')
)
except Exception as e:
logger.error(f"保存到S3失败: {e}")
利用AWS S3的高扩展性存储对话数据,适合大规模分布式应用。
8.3 基于MongoDB的文档存储
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
class MongoMemory(CustomMemory):
def __init__(self,
connection_string: str,
database_name: str = "langchain",
collection_name: str = "conversations",
table_name: str = "conversations",
encryption_key: str = None):
self.client = MongoClient(connection_string)
try:
# 尝试连接测试
self.client.admin.command('ping')
except ConnectionFailure:
logger.error("无法连接到MongoDB服务器")
self.db = self.client[database_name]
self.collection = self.db[collection_name]
super().__init__(db_uri="sqlite:///:memory:", table_name=table_name, encryption_key=encryption_key)
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
conversation_id = inputs.get("conversation_id", "default")
# 查询最新的10条记录
cursor = self.collection.find(
{"conversation_id": conversation_id},
{"_id": 0, "input": 1, "output": 1, "timestamp": 1}
).sort("timestamp", -1).limit(10)
records = list(cursor)
records.reverse() # 按时间顺序排列
return {"history": records}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
conversation_id = inputs.get("conversation_id", "default")
record = {
"conversation_id": conversation_id,
"input": inputs.get("input"),
"output": outputs.get("output"),
"timestamp": datetime.utcnow()
}
self.collection.insert_one(record)
使用MongoDB的文档存储特性,灵活存储非结构化对话数据。
九、高级检索与索引
9.1 向量索引实现
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
class VectorizedMemory(CustomMemory):
def __init__(self,
db_uri: str,
embedding_model: str = "text-embedding-ada-002",
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.embeddings = OpenAIEmbeddings(model=embedding_model)
self.vectorstore = None
self._initialize_vectorstore()
def _initialize_vectorstore(self):
# 从数据库加载所有文本并构建向量索引
session = self.Session()
try:
records = session.query(Conversation).all()
texts = []
metadatas = []
for record in records:
texts.append(f"用户: {record.input_text}\nAI: {record.output_text}")
metadatas.append({
"input": record.input_text,
"output": record.output_text,
"timestamp": record.timestamp
})
if texts:
self.vectorstore = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
finally:
session.close()
def similarity_search(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
if not self.vectorstore:
return []
# 执行相似度搜索
docs = self.vectorstore.similarity_search(query, k=k)
return [doc.metadata for doc in docs]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
super().save_context(inputs, outputs)
# 更新向量索引
text = f"用户: {inputs.get('input')}\nAI: {outputs.get('output')}"
metadata = {
"input": inputs.get("input"),
"output": outputs.get("output"),
"timestamp": datetime.utcnow()
}
if self.vectorstore:
self.vectorstore.add_texts([text], metadatas=[metadata])
else:
self.vectorstore = FAISS.from_texts([text], self.embeddings, metadatas=[metadata])
利用向量索引实现语义相似度搜索,提高相关历史记录的检索精度。
9.2 混合检索策略
class HybridMemory(CustomMemory):
def __init__(self,
db_uri: str,
embedding_model: str = "text-embedding-ada-002",
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.embeddings = OpenAIEmbeddings(model=embedding_model)
self.vectorstore = None
self._initialize_vectorstore()
def _initialize_vectorstore(self):
# 初始化向量存储
session = self.Session()
try:
records = session.query(Conversation).all()
texts = []
metadatas = []
for record in records:
texts.append(f"用户: {record.input_text}\nAI: {record.output_text}")
metadatas.append({
"id": record.id,
"timestamp": record.timestamp
})
if texts:
self.vectorstore = FAISS.from_texts(texts, self.embeddings, metadatas=metadatas)
finally:
session.close()
def hybrid_search(self, query: str, k: int = 4, keyword_weight: float = 0.5) -> List[Dict[str, Any]]:
# 执行混合检索(关键词+向量)
if not self.vectorstore:
return []
# 1. 向量检索
vector_results = self.vectorstore.similarity_search_with_score(query, k=k)
# 2. 关键词检索
keyword_results = self._keyword_search(query, k=k)
# 3. 结果融合
merged_results = self._merge_results(vector_results, keyword_results, keyword_weight)
# 4. 获取完整记录
session = self.Session()
try:
full_records = []
for record_id in merged_results:
record = session.query(Conversation).filter(Conversation.id == record_id).first()
if record:
full_records.append({
"input": record.input_text,
"output": record.output_text,
"timestamp": record.timestamp
})
return full_records
finally:
session.close()
def _keyword_search(self, query: str, k: int = 4) -> List[int]:
session = self.Session()
try:
# 使用SQL LIKE进行关键词搜索
query = session.query(Conversation).filter(
Conversation.input_text.ilike(f"%{query}%") |
Conversation.output_text.ilike(f"%{query}%")
).order_by(Conversation.timestamp.desc()).limit(k)
return [record.id for record in query.all()]
finally:
session.close()
def _merge_results(self, vector_results: List[Tuple[Document, float]], keyword_results: List[int], keyword_weight: float) -> List[int]:
# 融合两种检索结果
vector_scores = {doc.metadata['id']: score for doc, score in vector_results}
keyword_scores = {id: (len(keyword_results) - i) for i, id in enumerate(keyword_results)}
# 计算综合得分
combined_scores = {}
for id in set(vector_scores.keys()).union(set(keyword_scores.keys())):
vec_score = vector_scores.get(id, 0)
key_score = keyword_scores.get(id, 0)
# 归一化处理
if vec_score > 0:
vec_score = 1 / (1 + vec_score) # 将相似度得分转换为0-1之间
combined_score = keyword_weight * key_score + (1 - keyword_weight) * vec_score
combined_scores[id] = combined_score
# 按得分排序
return sorted(combined_scores.keys(), key=lambda k: combined_scores[k], reverse=True)
结合关键词检索和向量检索的优势,提供更全面的历史记录检索能力。
9.3 时间感知索引
class TimeAwareMemory(CustomMemory):
def __init__(self,
db_uri: str,
time_decay_factor: float = 0.9,
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.time_decay_factor = time_decay_factor
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
current_time = datetime.utcnow()
session = self.Session()
try:
# 查询所有记录
records = session.query(Conversation).all()
# 按时间衰减计算权重
weighted_records = []
for record in records:
time_diff = (current_time - record.timestamp).total_seconds()
# 时间衰减计算
weight = self.time_decay_factor ** (time_diff / 3600) # 每小时衰减一次
if weight > 0.01: # 忽略权重过小的记录
weighted_records.append((record, weight))
# 按权重排序
weighted_records.sort(key=lambda x: x[1], reverse=True)
# 构建结果
memory = {
"history": [{"input": r.input_text, "output": r.output_text} for r, _ in weighted_records]
}
return memory
finally:
session.close()
通过时间衰减函数,使近期对话记录具有更高权重,提高上下文相关性。
十、多模态记忆存储
10.1 图像记忆处理
from PIL import Image
import io
import base64
class ImageMemory(CustomMemory):
def __init__(self,
db_uri: str,
image_storage_path: str = "./images",
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.image_storage_path = image_storage_path
os.makedirs(image_storage_path, exist_ok=True)
def save_image(self, image_data: bytes, image_name: str = None) -> str:
"""保存图像并返回存储路径"""
if not image_name:
image_name = f"image_{uuid.uuid4().hex}.jpg"
image_path = os.path.join(self.image_storage_path, image_name)
try:
# 保存图像
with open(image_path, 'wb') as f:
f.write(image_data)
# 记录图像元数据
image_metadata = {
"image_name": image_name,
"path": image_path,
"timestamp": datetime.utcnow(),
"size": len(image_data)
}
# 保存元数据到数据库
self._save_image_metadata(image_metadata)
return image_name
except Exception as e:
logger.error(f"保存图像失败: {e}")
return None
def _save_image_metadata(self, metadata: Dict[str, Any]) -> None:
"""保存图像元数据到数据库"""
session = self.Session()
try:
image_record = ImageRecord(
image_name=metadata["image_name"],
path=metadata["path"],
timestamp=metadata["timestamp"],
size=metadata["size"]
)
session.add(image_record)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"保存图像元数据失败: {e}")
finally:
session.close()
def get_image(self, image_name: str) -> bytes:
"""获取图像数据"""
image_path = os.path.join(self.image_storage_path, image_name)
try:
with open(image_path, 'rb') as f:
return f.read()
except Exception as e:
logger.error(f"获取图像失败: {e}")
return None
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""扩展保存上下文,处理图像"""
# 处理输入中的图像
if "images" in inputs:
saved_images = []
for image_data in inputs["images"]:
image_name = self.save_image(image_data)
if image_name:
saved_images.append(image_name)
inputs["saved_images"] = saved_images
# 调用基类方法保存文本上下文
super().save_context(inputs, outputs)
扩展记忆存储功能,支持图像的存储和检索,保存图像元数据并关联到对话记录。
10.2 音频记忆处理
import librosa
import numpy as np
class AudioMemory(CustomMemory):
def __init__(self,
db_uri: str,
audio_storage_path: str = "./audio",
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.audio_storage_path = audio_storage_path
os.makedirs(audio_storage_path, exist_ok=True)
def save_audio(self, audio_data: bytes, audio_name: str = None, format: str = "wav") -> str:
"""保存音频并返回存储路径"""
if not audio_name:
audio_name = f"audio_{uuid.uuid4().hex}.{format}"
audio_path = os.path.join(self.audio_storage_path, audio_name)
try:
# 保存音频
with open(audio_path, 'wb') as f:
f.write(audio_data)
# 提取音频特征
features = self._extract_audio_features(audio_path)
# 记录音频元数据
audio_metadata = {
"audio_name": audio_name,
"path": audio_path,
"timestamp": datetime.utcnow(),
"duration": features["duration"],
"sample_rate": features["sample_rate"],
"features": features["features"]
}
# 保存元数据到数据库
self._save_audio_metadata(audio_metadata)
return audio_name
except Exception as e:
logger.error(f"保存音频失败: {e}")
return None
def _extract_audio_features(self, audio_path: str) -> Dict[str, Any]:
"""提取音频特征"""
try:
# 加载音频文件
y, sr = librosa.load(audio_path)
# 提取特征
mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
chroma = librosa.feature.chroma_stft(y=y, sr=sr)
mel = librosa.feature.melspectrogram(y=y, sr=sr)
# 计算音频时长
duration = librosa.get_duration(y=y, sr=sr)
return {
"duration": duration,
"sample_rate": sr,
"features": {
"mfccs": mfccs.tolist(),
"chroma": chroma.tolist(),
"mel": mel.tolist()
}
}
except Exception as e:
logger.error(f"提取音频特征失败: {e}")
return {
"duration": 0,
"sample_rate": 0,
"features": {}
}
def _save_audio_metadata(self, metadata: Dict[str, Any]) -> None:
"""保存音频元数据到数据库"""
session = self.Session()
try:
audio_record = AudioRecord(
audio_name=metadata["audio_name"],
path=metadata["path"],
timestamp=metadata["timestamp"],
duration=metadata["duration"],
sample_rate=metadata["sample_rate"],
features=json.dumps(metadata["features"])
)
session.add(audio_record)
session.commit()
except Exception as e:
session.rollback()
logger.error(f"保存音频元数据失败: {e}")
finally:
session.close()
def get_audio(self, audio_name: str) -> bytes:
"""获取音频数据"""
audio_path = os.path.join(self.audio_storage_path, audio_name)
try:
with open(audio_path, 'rb') as f:
return f.read()
except Exception as e:
logger.error(f"获取音频失败: {e}")
return None
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""扩展保存上下文,处理音频"""
# 处理输入中的音频
if "audio" in inputs:
audio_data = inputs["audio"]
audio_name = self.save_audio(audio_data)
if audio_name:
inputs["saved_audio"] = audio_name
# 调用基类方法保存文本上下文
super().save_context(inputs, outputs)
支持音频数据的存储和特征提取,将音频与对话记录关联,实现多模态记忆。
10.3 多模态检索
class MultiModalMemory(CustomMemory):
def __init__(self,
db_uri: str,
image_storage_path: str = "./images",
audio_storage_path: str = "./audio",
table_name: str = "conversations",
encryption_key: str = None):
super().__init__(db_uri, table_name, encryption_key)
self.image_memory = ImageMemory(db_uri, image_storage_path, table_name, encryption_key)
self.audio_memory = AudioMemory(db_uri, audio_storage_path, table_name, encryption_key)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""保存多模态上下文"""
# 处理图像
if "images" in inputs:
saved_images = []
for image_data in inputs["images"]:
image_name = self.image_memory.save_image(image_data)
if image_name:
saved_images.append(image_name)
inputs["saved_images"] = saved_images
# 处理音频
if "audio" in inputs:
audio_data = inputs["audio"]
audio_name = self.audio_memory.save_audio(audio_data)
if audio_name:
inputs["saved_audio"] = audio_name
# 调用基类方法保存文本上下文
super().save_context(inputs, outputs)
def multimodal_search(self, query: Dict[str, Any], k: int = 4) -> List[Dict[str, Any]]:
"""多模态检索"""
# 1. 文本检索
text_results = []
if "text" in query:
text_results = self._text_search(query["text"], k)
# 2. 图像检索
image_results = []
if "image" in query:
image_results = self._image_search(query["image"], k)
# 3. 音频检索
audio_results = []
if "audio" in query:
audio_results = self._audio_search(query["audio"], k)
# 4. 结果融合
merged_results = self._merge_multimodal_results(text_results, image_results, audio_results)
return merged_results
def _text_search(self, text_query: str, k: int = 4) -> List[Dict[str, Any]]:
"""文本检索实现"""
# 调用基类的检索方法
memory_vars = super().load_memory_variables({"input": text_query})
return memory_vars.get("history", [])[:k]
def _image_search(self, image_query: bytes, k: int = 4) -> List[Dict[str, Any]]:
"""图像检索实现"""
# 简化实现,实际应使用计算机视觉模型进行图像相似度比较
return []
def _audio_search(self, audio_query: bytes, k: int = 4) -> List[Dict[str, Any]]:
"""音频检索实现"""
# 简化实现,实际应使用音频处理模型进行音频相似度比较
return []
def _merge_multimodal_results(self, text_results: List[Dict[str, Any]],
image_results: List[Dict[str, Any]],
audio_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""融合多模态检索结果"""
# 简化实现,实际应根据相关性评分进行排序
merged = []
merged.extend(text_results)
merged.extend(image_results)
merged.extend(audio_results)
return merged
整合文本、图像和音频的多模态检索功能,实现跨媒体类型的上下文关联。
十一、测试与验证
11.1 单元测试框架
import unittest
from unittest.mock import patch, MagicMock
from datetime import datetime
class TestCustomMemory(unittest.TestCase):
def setUp(self):
# 创建测试数据库连接
self.db_uri = "sqlite:///:memory:"
self.memory = CustomMemory(db_uri=self.db_uri)
def test_save_context(self):
# 测试保存上下文
inputs = {"input": "你好"}
outputs = {"output": "你好!有什么可以帮助你的吗?"}
self.memory.save_context(inputs, outputs)
# 验证是否保存成功
memory_vars = self.memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertEqual(len(history), 1)
self.assertEqual(history[0]["input"], "你好")
self.assertEqual(history[0]["output"], "你好!有什么可以帮助你的吗?")
def test_load_memory_variables(self):
# 测试加载记忆变量
inputs = {"input": "今天天气如何?"}
outputs = {"output": "今天天气晴朗,温度适宜。"}
self.memory.save_context(inputs, outputs)
memory_vars = self.memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertEqual(len(history), 1)
self.assertEqual(history[0]["input"], "今天天气如何?")
self.assertEqual(history[0]["output"], "今天天气晴朗,温度适宜。")
def test_multiple_saves(self):
# 测试多次保存
inputs1 = {"input": "你好"}
outputs1 = {"output": "你好!"}
inputs2 = {"input": "再见"}
outputs2 = {"output": "再见!"}
self.memory.save_context(inputs1, outputs1)
self.memory.save_context(inputs2, outputs2)
memory_vars = self.memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertEqual(len(history), 2)
self.assertEqual(history[0]["input"], "你好")
self.assertEqual(history[0]["output"], "你好!")
self.assertEqual(history[1]["input"], "再见")
self.assertEqual(history[1]["output"], "再见!")
@patch('your_module.datetime')
def test_time_ordering(self, mock_datetime):
# 测试时间顺序
# 设置固定时间
mock_datetime.utcnow.side_effect = [
datetime(2023, 1, 1, 12, 0, 0),
datetime(2023, 1, 1, 12, 1, 0),
datetime(2023, 1, 1, 12, 2, 0)
]
inputs1 = {"input": "第一个消息"}
outputs1 = {"output": "第一个回复"}
inputs2 = {"input": "第二个消息"}
outputs2 = {"output": "第二个回复"}
inputs3 = {"input": "第三个消息"}
outputs3 = {"output": "第三个回复"}
self.memory.save_context(inputs1, outputs1)
self.memory.save_context(inputs2, outputs2)
self.memory.save_context(inputs3, outputs3)
memory_vars = self.memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertEqual(len(history), 3)
self.assertEqual(history[0]["input"], "第一个消息")
self.assertEqual(history[1]["input"], "第二个消息")
self.assertEqual(history[2]["input"], "第三个消息")
使用Python的unittest框架对自定义记忆存储的基本功能进行单元测试,确保核心功能正常工作。
11.2 性能测试
import timeit
import random
import string
from concurrent.futures import ThreadPoolExecutor
class TestMemoryPerformance(unittest.TestCase):
def setUp(self):
self.db_uri = "sqlite:///:memory:"
self.memory = CustomMemory(db_uri=self.db_uri)
def test_save_performance(self):
# 测试保存性能
def save_random_context():
input_text = ''.join(random.choices(string.ascii_letters, k=100))
output_text = ''.join(random.choices(string.ascii_letters, k=200))
self.memory.save_context({"input": input_text}, {"output": output_text})
execution_time = timeit.timeit(save_random_context, number=100)
print(f"保存100条记录耗时: {execution_time}秒")
self.assertLess(execution_time, 10.0, "保存性能低于预期")
def test_load_performance(self):
# 测试加载性能
# 先插入一些数据
for i in range(1000):
input_text = f"测试输入{i}"
output_text = f"测试输出{i}"
self.memory.save_context({"input": input_text}, {"output": output_text})
def load_context():
self.memory.load_memory_variables({})
execution_time = timeit.timeit(load_context, number=10)
print(f"加载10次耗时: {execution_time}秒")
self.assertLess(execution_time, 2.0, "加载性能低于预期")
def test_concurrent_access(self):
# 测试并发访问
def concurrent_task():
input_text = ''.join(random.choices(string.ascii_letters, k=50))
output_text = ''.join(random.choices(string.ascii_letters, k=100))
# 交替进行读写操作
for _ in range(10):
self.memory.save_context({"input": input_text}, {"output": output_text})
self.memory.load_memory_variables({})
with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(concurrent_task) for _ in range(5)]
for future in futures:
future.result()
# 验证数据完整性
memory_vars = self.memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertGreater(len(history), 0, "并发访问后数据丢失")
通过性能测试评估记忆存储在大量数据和高并发情况下的表现,识别潜在瓶颈。
11.3 集成测试
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
class TestMemoryIntegration(unittest.TestCase):
@patch('langchain.llms.OpenAI')
def test_with_conversation_chain(self, mock_openai):
# 测试与对话链的集成
# 设置模拟LLM
mock_llm = MagicMock(spec=OpenAI)
mock_llm.predict.return_value = "模拟回复"
mock_openai.return_value = mock_llm
# 创建自定义记忆
memory = CustomMemory(db_uri="sqlite:///:memory:")
# 创建对话链
chain = ConversationChain(
llm=mock_llm,
memory=memory,
verbose=False
)
# 进行对话
response1 = chain.predict(input="你好")
response2 = chain.predict(input="今天天气如何?")
# 验证记忆
memory_vars = memory.load_memory_variables({})
history = memory_vars.get("history", [])
self.assertEqual(len(history), 2)
self.assertEqual(history[0]["input"], "你好")
self.assertEqual(history[0]["output"], "模拟回复")
self.assertEqual(history[1]["input"], "今天天气如何?")
self.assertEqual(history[1]["output"], "模拟回复")
def test_with_vector_search(self):
# 测试与向量检索的集成
memory = VectorizedMemory(db_uri="sqlite:///:memory:")
# 保存一些上下文
memory.save_context({"input": "北京天气如何?"}, {"output": "北京今天晴天,25度。"})
memory.save_context({"input": "上海天气如何?"}, {"output": "上海今天多云,28度。"})
memory.save_context({"input": "广州天气如何?"}, {"output": "广州今天小雨,27度。"})
# 执行向量搜索
results = memory.similarity_search("北京的天气", k=1)
self.assertEqual(len(results), 1)
self.assertIn("北京", results[0]["input"])
self.assertIn("25度", results[0]["output"])
验证自定义记忆存储与LangChain其他组件的集成效果,确保系统整体功能正常。
十二、监控与调优
12.1 性能监控
import time
from prometheus_client import Counter, Histogram, start_http_server
class MonitoredMemory(CustomMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 初始化监控指标
self.save_requests = Counter(
'memory_save_requests_total',
'Total number of save requests',
['status']
)
self.load_requests = Counter(
'memory_load_requests_total',
'Total number of load requests',
['status']
)
self.save_duration = Histogram(
'memory_save_duration_seconds',
'Duration of save operations',
buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0]
)
self.load_duration = Histogram(
'memory_load_duration_seconds',
'Duration of load operations',
buckets=[0.01, 0.05, 0.1, 0.5, 1.0, 5.0]
)
# 启动监控服务器
start_http_server(8000)
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
start_time = time.time()
try:
super().save_context(inputs, outputs)
self.save_requests.labels(status='success').inc()
except Exception as e:
self.save_requests.labels(status='error').inc()
raise e
finally:
duration = time.time() - start_time
self.save_duration.observe(duration)
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
start_time = time.time()
try:
result = super().load_memory_variables(inputs)
self.load_requests.labels(status='success').inc()
return result
except Exception as e:
self.load_requests.labels(status='error').inc()
raise e
finally:
duration = time.time() - start_time
self.load_duration.observe(duration)
使用Prometheus监控记忆存储的关键指标,包括请求次数、响应时间和错误率,为性能优化提供数据支持。
12.2 自动调优
import numpy as np
from sklearn.linear_model import LinearRegression
class AutoTuningMemory(CustomMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.performance_data = []
self.model = LinearRegression()
self.optimal_batch_size = 100
self.auto_tune_enabled = True
def enable_auto_tune(self, enable: bool = True) -> None:
self.auto_tune_enabled = enable
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
start_time = time.time()
# 使用当前批处理大小
batch_size = self.optimal_batch_size
# 如果有多个条目需要保存,分批处理
if isinstance(inputs, list) and isinstance(outputs, list):
for i in range(0, len(inputs), batch_size):
batch_inputs = inputs[i:i+batch_size]
batch_outputs = outputs[i:i+batch_size]
super().save_context(batch_inputs, batch_outputs)
else:
super().save_context(inputs, outputs)
duration = time.time() - start_time
# 收集性能数据
self._collect_performance_data(len(inputs) if isinstance(inputs, list) else 1, duration)
# 定期重新训练模型
if len(self.performance_data) % 100 == 0 and self.auto_tune_enabled:
self._retrain_model()
def _collect_performance_data(self, batch_size: int, duration: float) -> None:
# 收集批处理大小和执行时间的关系
self.performance_data.append((batch_size, duration))
# 限制数据点数量,避免内存溢出
if len(self.performance_data) > 1000:
self.performance_data.pop(0)
def _retrain_model(self) -> None:
# 从收集的数据中训练线性回归模型
if len(self.performance_data) < 10: # 至少需要10个数据点
return
X = np.array([[data[0]] for data in self.performance_data])
y = np.array([data[1] for data in self.performance_data])
self.model.fit(X, y)
# 预测不同批处理大小的性能
test_sizes = np.array([[i] for i in range(10, 500, 10)])
predictions = self.model.predict(test_sizes)
# 找到最优批处理大小(执行时间最短)
optimal_idx = np.argmin(predictions)
self.optimal_batch_size = test_sizes[optimal_idx][0]
print(f"自动调优: 最优批处理大小已更新为 {self.optimal_batch_size}")
通过收集性能数据并使用线性回归模型,自动调整批处理大小等参数,提高系统性能。
12.3 异常处理
class RobustMemory(CustomMemory):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.retry_attempts = 3
self.retry_delay = 1
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
for attempt in range(self.retry_attempts):
try:
super().save_context(inputs, outputs)
break
except Exception as e:
if attempt == self.retry_attempts - 1:
logger.error(f"保存上下文失败,已达到最大重试次数: {e}")
raise
else:
logger.warning(f"保存上下文失败,尝试重试 ({attempt+1}/{self.retry_attempts}): {e}")
time.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
for attempt in range(self.retry_attempts):
try:
return super().load_memory_variables(inputs)
except Exception as e:
if attempt == self.retry_attempts - 1:
logger.error(f"加载记忆变量失败,已达到最大重试次数: {e}")
raise
else:
logger.warning(f"加载记忆变量失败,尝试重试 ({attempt+1}/{self.retry_attempts}): {e}")
time.sleep(self.retry_delay * (2 ** attempt)) # 指数退避
def backup_memory(self, backup_path: str) -> None:
"""创建记忆存储的备份"""
try:
# 对于SQLite数据库,直接复制文件
if self.db_uri.startswith("sqlite:///"):
db_path = self.db_uri.replace("sqlite:///", "")
if os.path.exists(db_path):
shutil.copy2(db_path, backup_path)
logger.info(f"记忆存储备份成功,保存到 {backup_path}")
else:
logger.warning("数据库文件不存在,无法备份")
else:
# 对于其他数据库类型,实现相应的备份逻辑
logger.warning(f"不支持的数据库类型,无法备份: {self.db_uri}")
except Exception as e:
logger.error(f"备份失败: {e}")
增强记忆存储的健壮性,通过重试机制和备份功能,确保系统在面对异常情况时能够稳定运行。
十三、实战案例
13.1 智能客服系统中的记忆存储
在智能客服系统中,需要保存用户历史咨询记录、问题分类和解决方案,以便提供连贯的服务体验。
class CustomerServiceMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "customer_service"):
super().__init__(db_uri, table_name)
# 初始化客户服务特定的索引
self._create_customer_indexes()
def _create_customer_indexes(self):
# 创建客户ID和问题类型的索引
session = self.Session()
try:
# 检查索引是否存在
inspector = inspect(self.engine)
indexes = inspector.get_indexes(self.table_name)
if not any(index['name'] == 'idx_customer_id' for index in indexes):
# 创建客户ID索引
text_type = Text() if 'sqlite' in self.db_uri else VARCHAR(255)
idx = Index('idx_customer_id', CustomerServiceRecord.customer_id)
idx.create(bind=self.engine)
if not any(index['name'] == 'idx_issue_type' for index in indexes):
# 创建问题类型索引
idx = Index('idx_issue_type', CustomerServiceRecord.issue_type)
idx.create(bind=self.engine)
finally:
session.close()
def get_customer_history(self, customer_id: str) -> List[Dict[str, Any]]:
"""获取特定客户的历史记录"""
session = self.Session()
try:
records = session.query(CustomerServiceRecord).filter(
CustomerServiceRecord.customer_id == customer_id
).order_by(CustomerServiceRecord.timestamp.desc()).all()
return [
{
"input": record.input_text,
"output": record.output_text,
"timestamp": record.timestamp,
"issue_type": record.issue_type
}
for record in records
]
finally:
session.close()
def get_issue_statistics(self) -> Dict[str, int]:
"""获取问题类型统计"""
session = self.Session()
try:
# 统计每种问题类型的数量
statistics = session.query(
CustomerServiceRecord.issue_type,
func.count(CustomerServiceRecord.id)
).group_by(CustomerServiceRecord.issue_type).all()
return {issue_type: count for issue_type, count in statistics}
finally:
session.close()
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""扩展保存上下文,添加客户服务特定字段"""
# 提取客户ID和问题类型
customer_id = inputs.get("customer_id", "unknown")
issue_type = inputs.get("issue_type", "general")
# 创建客户服务记录
record = CustomerServiceRecord(
customer_id=customer_id,
input_text=inputs.get("input"),
output_text=outputs.get("output"),
issue_type=issue_type,
timestamp=datetime.utcnow()
)
session = self.Session()
try:
session.add(record)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
这个自定义记忆存储专为智能客服系统设计,添加了客户ID和问题类型等特定字段,并创建了相应索引,提高查询效率。
13.2 教育辅助系统中的记忆存储
在教育辅助系统中,需要跟踪学生的学习进度、知识点掌握情况和历史问答记录。
class EducationMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "education"):
super().__init__(db_uri, table_name)
self.subject_embeddings = {}
def save_learning_progress(self, student_id: str, subject: str, progress: float, knowledge_points: List[str]) -> None:
"""保存学习进度"""
session = self.Session()
try:
# 检查是否已有该学生的学习记录
record = session.query(EducationRecord).filter(
EducationRecord.student_id == student_id,
EducationRecord.subject == subject
).first()
if record:
# 更新现有记录
record.progress = progress
record.knowledge_points = json.dumps(knowledge_points)
record.updated_at = datetime.utcnow()
else:
# 创建新记录
record = EducationRecord(
student_id=student_id,
subject=subject,
progress=progress,
knowledge_points=json.dumps(knowledge_points),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(record)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def get_learning_progress(self, student_id: str, subject: str = None) -> Dict[str, Any]:
"""获取学习进度"""
session = self.Session()
try:
if subject:
# 获取特定科目的学习进度
record = session.query(EducationRecord).filter(
EducationRecord.student_id == student_id,
EducationRecord.subject == subject
).first()
if record:
return {
"student_id": record.student_id,
"subject": record.subject,
"progress": record.progress,
"knowledge_points": json.loads(record.knowledge_points),
"updated_at": record.updated_at
}
else:
return None
else:
# 获取所有科目的学习进度
records = session.query(EducationRecord).filter(
EducationRecord.student_id == student_id
).all()
return [
{
"student_id": record.student_id,
"subject": record.subject,
"progress": record.progress,
"knowledge_points": json.loads(record.knowledge_points),
"updated_at": record.updated_at
}
for record in records
]
finally:
session.close()
def recommend_resources(self, student_id: str, subject: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""推荐学习资源"""
# 1. 获取学生当前的知识状态
progress = self.get_learning_progress(student_id, subject)
if not progress:
return []
# 2. 基于知识状态生成资源推荐
# 这里简化为随机推荐,实际应基于知识点掌握情况和资源库匹配
resources = [
{"title": f"{subject}基础教程", "type": "book", "difficulty": "beginner"},
{"title": f"{subject}进阶练习", "type": "exercise", "difficulty": "intermediate"},
{"title": f"{subject}高级理论", "type": "article", "difficulty": "advanced"},
{"title": f"{subject}实例分析", "type": "case", "difficulty": "intermediate"},
{"title": f"{subject}测验", "type": "quiz", "difficulty": "beginner"}
]
# 根据学习进度调整推荐难度
if progress["progress"] < 0.3:
# 初学者推荐基础资源
return [r for r in resources if r["difficulty"] == "beginner"][:max_results]
elif progress["progress"] < 0.7:
# 中级学习者推荐中级资源
return [r for r in resources if r["difficulty"] == "intermediate"][:max_results]
else:
# 高级学习者推荐高级资源
return [r for r in resources if r["difficulty"] == "advanced"][:max_results]
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
"""扩展保存上下文,添加教育特定信息"""
# 提取学生ID和科目
student_id = inputs.get("student_id", "unknown")
subject = inputs.get("subject", "general")
# 保存常规对话上下文
super().save_context(inputs, outputs)
# 更新主题嵌入(用于资源推荐)
self._update_subject_embedding(subject, inputs.get("input", ""))
def _update_subject_embedding(self, subject: str, text: str) -> None:
"""更新主题嵌入向量"""
if not text:
return
# 简化实现,实际应使用NLP模型生成嵌入向量
if subject not in self.subject_embeddings:
self.subject_embeddings[subject] = []
# 这里只是示例,实际应使用文本嵌入模型
self.subject_embeddings[subject].append(text)
# 限制嵌入向量数量,避免内存溢出
if len(self.subject_embeddings[subject]) > 100:
self.subject_embeddings[subject].pop(0)
这个教育记忆存储不仅保存对话历史,还跟踪学生的学习进度和知识点掌握情况,支持个性化学习资源推荐。
13.3 金融投资助手的记忆存储
在金融投资助手中,需要保存用户的投资组合、交易历史和市场分析记录。
class FinancialMemory(CustomMemory):
def __init__(self, db_uri: str, table_name: str = "financial"):
super().__init__(db_uri, table_name)
self.portfolio_cache = {}
def save_portfolio(self, user_id: str, portfolio: Dict[str, Any]) -> None:
"""保存投资组合"""
session = self.Session()
try:
# 检查是否已有该用户的投资组合
record = session.query(FinancialRecord).filter(
FinancialRecord.user_id == user_id,
FinancialRecord.record_type == "portfolio"
).first()
if record:
# 更新现有记录
record.content = json.dumps(portfolio)
record.updated_at = datetime.utcnow()
else:
# 创建新记录
record = FinancialRecord(
user_id=user_id,
record_type="portfolio",
content=json.dumps(portfolio),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(record)
session.commit()
# 更新缓存
self.portfolio_cache[user_id] = portfolio
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def get_portfolio(self, user_id: str) -> Dict[str, Any]:
"""获取投资组合"""
# 先检查缓存
if user_id in self.portfolio_cache:
return self.portfolio_cache[user_id]
session = self.Session()
try:
record = session.query(FinancialRecord).filter(
FinancialRecord.user_id == user_id,
FinancialRecord.record_type == "portfolio"
).first()
if record:
portfolio = json.loads(record.content)
# 更新缓存
self.portfolio_cache[user_id] = portfolio
return portfolio
else:
return {"assets": [], "cash": 0.0}
finally:
session.close()
def save_transaction(self, user_id: str, transaction: Dict[str, Any]) -> None:
"""保存交易记录"""
session = self.Session()
try:
# 创建交易记录
record = FinancialRecord(
user_id=user_id,
record_type="transaction",
content=json.dumps(transaction),
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
session.add(record)
session.commit()
# 更新投资组合
self._update_portfolio_after_transaction(user_id, transaction)
except Exception as e:
session.rollback()
raise e
finally:
session.close()
def get_transaction_history(self, user_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""获取交易历史"""
session = self.Session()
try:
records = session.query(FinancialRecord).filter(
FinancialRecord.user_id == user_id,
FinancialRecord.record_type == "transaction"
).order_by(FinancialRecord.created_at.desc()).limit(limit).all()
return [json.loads(record.content) for record in records]
finally:
session.close()
def _update_portfolio_after_transaction(self, user_id: str, transaction: Dict[str, Any]) -> None:
"""交易后更新投资组合"""
portfolio = self.get_portfolio(user_id)
assets = portfolio.get("assets", [])
cash = portfolio.get("cash", 0.0)
# 处理交易
transaction_type = transaction.get("type")
asset = transaction.get("asset")
amount = transaction.get("amount", 0)
price = transaction.get("price", 0)
cost = amount * price
if transaction_type == "buy":
# 购买资产
cash -= cost
# 检查是否已有该资产
existing_asset = next((a for a in assets if a["symbol"] == asset), None)
if existing_asset:
existing_asset["amount"] += amount
else:
assets.append({"symbol": asset, "amount": amount})
elif transaction_type == "sell":
# 卖出资产
cash += cost
# 更新资产数量
for a in assets:
if a["symbol"] == asset:
a["amount"] -= amount
if a["amount"] <= 0:
assets.remove(a)
break
# 更新投资组合
portfolio["assets"] = assets
portfolio["cash"] = cash
self.save_portfolio(user_id, portfolio)
def analyze_portfolio_risk(self, user_id: str) -> Dict[str, Any]:
"""分析投资组合风险"""
portfolio = self.get_portfolio(user_id)
assets = portfolio.get("assets", [])
# 简化的风险分析
risk_score = 0.0
diversification = len(assets)
for asset in assets:
# 假设每种资产有一个风险系数(实际应从市场数据获取)
risk_factor = 0.5 # 默认中等风险
if asset["symbol"].startswith("BTC"):
risk_factor = 0.8 # 加密货币高风险
elif asset["symbol"].startswith("AAPL"):
risk_factor = 0.3 # 蓝筹股低风险
# 计算加权风险
asset_value = asset.get("amount", 0) * 100 # 简化,假设每股100元
total_value = sum(a.get("amount", 0) * 100 for a in assets)
if total_value > 0:
risk_score += (asset_value / total_value) * risk_factor
return {
"risk_score": risk_score,
"diversification": diversification,
"recommendation": "建议分散投资" if diversification < 3 else "投资组合多元化良好"
}
这个金融记忆存储专门为投资助手设计,支持保存和管理用户的投资组合和交易历史,并提供基本的风险分析功能。
十四、未来发展方向
14.1 大模型集成
随着大型语言模型能力的不断增强,未来的记忆存储将更加紧密地与大模型集成,实现:
- 智能摘要:自动生成对话摘要,减少存储开销
- 知识蒸馏:从历史对话中提取结构化知识
- 预测性记忆:基于用户历史行为预测未来需求
- 语义压缩:使用向量表示法更高效地存储对话内容
14.2 联邦学习与隐私保护
在敏感领域应用中,隐私保护将变得更加重要:
- 联邦学习记忆:在不共享原始数据的情况下训练记忆模型
- 差分隐私:在存储和检索过程中添加噪声保护隐私
- 零知识证明:验证记忆操作的正确性而不泄露内容
- 动态权限控制:根据上下文动态调整数据访问权限
14.3 边缘计算与分布式记忆
随着边缘设备能力的提升,记忆存储将向分布式架构发展:
- 边缘记忆缓存:在本地设备存储常用对话历史
- 分布式记忆网络:跨设备共享和同步记忆数据
- 雾计算记忆:在边缘和云端之间平衡存储和计算
- 去中心化记忆:使用区块链技术实现可信记忆存储
14.4 多模态与情境感知
未来的记忆存储将支持更丰富的模态和情境感知:
- 多感官记忆:整合文本、图像、音频、视频等多种模态
- 情感记忆:记录和分析对话中的情感状态
- 情境感知:结合时间、地点、环境等情境信息
- 物理世界映射:将记忆与现实世界的空间和物体关联
14.5 自我进化与元学习
记忆存储系统将具备自我进化能力:
- 自适应记忆结构:根据使用模式自动调整存储结构
- 元学习记忆:学习如何更好地组织和检索记忆
- 持续学习:在不遗忘旧知识的情况下学习新知识
- 记忆重组:定期重构记忆以提高检索效率和准确性
14.6 量子计算与记忆存储
量子计算技术可能为记忆存储带来革命性变化:
- 量子存储:利用量子态存储更多信息
- 量子搜索算法:加速大规模记忆检索
- 量子加密:提供更高级别的记忆安全保护
- 量子机器学习:增强记忆分析和处理能力
通过持续的技术创新,LangChain自定义记忆存储将在未来的AI系统中扮演更加核心和智能的角色,为用户提供更加个性化、高效和安全的交互体验。