从零到一搭建 AI Agent 记忆系统:九种策略全景实战(含注释代码)
这篇文章展示的是记忆系统,聚焦“可复用、可落地”的记忆架构。你将看到九种常见记忆策略的原理与权衡,并附带带注释的关键实现代码,便于直接复刻。
0. 统一接口与 Agent 框架(所有策略的地基)
在构建任何记忆策略之前,先定义统一接口与 Agent 调度流程,这样策略可以“即插即用”。
import abc
# 记忆策略的统一接口(抽象基类)
class BaseMemoryStrategy(abc.ABC):
@abc.abstractmethod
def add_message(self, user_input: str, ai_response: str):
"""将一轮对话写入记忆"""
pass
@abc.abstractmethod
def get_context(self, query: str) -> str:
"""根据当前问题提取上下文"""
pass
@abc.abstractmethod
def clear(self):
"""清空记忆"""
pass
class AIAgent:
"""统一的 Agent 逻辑:取记忆 -> 构造提示词 -> 调用 LLM -> 更新记忆"""
def __init__(self, memory_strategy: BaseMemoryStrategy, system_prompt: str = "You are a helpful AI assistant."):
self.memory = memory_strategy
self.system_prompt = system_prompt
def chat(self, user_input: str) -> str:
# 1) 获取记忆上下文
context = self.memory.get_context(query=user_input)
# 2) 拼接提示词
full_user_prompt = f"### MEMORY CONTEXT\n{context}\n\n### CURRENT REQUEST\n{user_input}"
# 3) 调用 LLM(generate_text 来自你的工具函数)
ai_response = generate_text(self.system_prompt, full_user_prompt)
# 4) 写回记忆
self.memory.add_message(user_input, ai_response)
return ai_response
1) 顺序记忆(Sequential / Keep-It-All)
特点:把所有对话完整保存并拼接。优点是“记得全”,缺点是上下文无限增长。
class SequentialMemory(BaseMemoryStrategy):
def __init__(self):
# 使用列表存所有对话
self.history = []
def add_message(self, user_input: str, ai_response: str):
# 依次写入用户与助手消息
self.history.append({"role": "user", "content": user_input})
self.history.append({"role": "assistant", "content": ai_response})
def get_context(self, query: str) -> str:
# 将历史对话拼接成一段文本
return "\n".join([f"{t['role'].capitalize()}: {t['content']}" for t in self.history])
def clear(self):
# 清空历史
self.history = []
2) 滑动窗口记忆(Sliding Window)
特点:只保留最近 N 轮对话,成本稳定但会遗忘。
from collections import deque
class SlidingWindowMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 4):
# deque 自动维护长度上限
self.history = deque(maxlen=window_size)
def add_message(self, user_input: str, ai_response: str):
# 一轮对话作为一个“turn”写入
self.history.append([
{"role": "user", "content": user_input},
{"role": "assistant", "content": ai_response}
])
def get_context(self, query: str) -> str:
# 展开 deque 生成上下文
ctx = []
for turn in self.history:
for msg in turn:
ctx.append(f"{msg['role'].capitalize()}: {msg['content']}")
return "\n".join(ctx)
def clear(self):
self.history.clear()
3) 总结记忆(Summarization)
特点:对话到阈值后,让 LLM 生成摘要并合并。适合长对话。
class SummarizationMemory(BaseMemoryStrategy):
def __init__(self, summary_threshold: int = 4):
self.running_summary = ""
self.buffer = []
self.summary_threshold = summary_threshold
def add_message(self, user_input: str, ai_response: str):
# 先把消息放进缓冲区
self.buffer.append({"role": "user", "content": user_input})
self.buffer.append({"role": "assistant", "content": ai_response})
# 达到阈值则触发总结
if len(self.buffer) >= self.summary_threshold:
self._consolidate_memory()
def _consolidate_memory(self):
# 将缓冲区文本拼接
buffer_text = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
# 构造总结提示词
prompt = (
"You are a summarization expert.\n"
f"### Previous Summary:\n{self.running_summary}\n\n"
f"### New Conversation:\n{buffer_text}\n\n"
"### Updated Summary:"
)
# 调用 LLM 生成摘要
self.running_summary = generate_text("You are a summarization engine.", prompt)
# 清空缓冲区
self.buffer = []
def get_context(self, query: str) -> str:
buffer_text = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in self.buffer])
return f"### Summary:\n{self.running_summary}\n\n### Recent:\n{buffer_text}"
def clear(self):
self.running_summary = ""
self.buffer = []
4) 检索记忆(Retrieval / RAG)
特点:用向量检索长程相关信息,最常用的“长期记忆”方案。
import numpy as np
import faiss
class RetrievalMemory(BaseMemoryStrategy):
def __init__(self, k: int = 2, embedding_dim: int | None = None):
self.k = k
self.embedding_dim = embedding_dim
self.documents = []
# embedding_dim 未知时,先不初始化 index
self.index = faiss.IndexFlatL2(embedding_dim) if embedding_dim else None
def _ensure_index(self, embedding: list):
# 首次写入时用向量长度确定维度
if self.embedding_dim is None:
self.embedding_dim = len(embedding)
self.index = faiss.IndexFlatL2(self.embedding_dim)
# 若维度不匹配,直接抛错
elif len(embedding) != self.embedding_dim:
raise ValueError(f"Embedding dim {len(embedding)} != index dim {self.embedding_dim}")
def add_message(self, user_input: str, ai_response: str):
docs = [f"User said: {user_input}", f"AI responded: {ai_response}"]
for doc in docs:
emb = get_embedding(doc)
if emb:
self._ensure_index(emb)
self.documents.append(doc)
self.index.add(np.array([emb], dtype="float32"))
def get_context(self, query: str) -> str:
if self.index is None or self.index.ntotal == 0:
return "No information in memory yet."
q = get_embedding(query)
if not q:
return "Could not process query for retrieval."
if len(q) != self.embedding_dim:
return "Query embedding dimension mismatch with index."
D, I = self.index.search(np.array([q], dtype="float32"), self.k)
retrieved = [self.documents[i] for i in I[0] if i != -1]
return "### Retrieved:\n" + "\n---\n".join(retrieved)
def clear(self):
self.documents = []
if self.index is not None:
self.index.reset()
5) 记忆增强(Memory-Augmented Simulation)
特点:让 LLM 识别“关键事实”,生成长期“记忆 token”。
class MemoryAugmentedMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 2):
self.recent_memory = SlidingWindowMemory(window_size=window_size)
self.memory_tokens = []
def add_message(self, user_input: str, ai_response: str):
# 先写入短期记忆
self.recent_memory.add_message(user_input, ai_response)
# 让 LLM 抽取关键事实
prompt = (
"Analyze the following turn and extract any long-term fact.\n"
f"User: {user_input}\nAI: {ai_response}\n"
"If none, reply 'No important fact.'"
)
fact = generate_text("You are a fact-extraction expert.", prompt)
if "no important fact" not in fact.lower():
self.memory_tokens.append(fact)
def get_context(self, query: str) -> str:
recent = self.recent_memory.get_context(query)
tokens = "\n".join([f"- {t}" for t in self.memory_tokens])
return f"### Memory Tokens:\n{tokens}\n\n### Recent:\n{recent}"
def clear(self):
self.recent_memory.clear()
self.memory_tokens = []
6) 分层记忆(Hierarchical)
特点:短期用滑窗,长期用检索,触发关键词时晋升。
class HierarchicalMemory(BaseMemoryStrategy):
def __init__(self, window_size: int = 2, k: int = 2, embedding_dim: int = 4096):
self.working_memory = SlidingWindowMemory(window_size=window_size)
self.long_term_memory = RetrievalMemory(k=k, embedding_dim=embedding_dim)
self.promotion_keywords = ["remember", "rule", "preference", "always", "never", "allergic"]
def add_message(self, user_input: str, ai_response: str):
self.working_memory.add_message(user_input, ai_response)
# 触发关键词则进入长期记忆
if any(k in user_input.lower() for k in self.promotion_keywords):
self.long_term_memory.add_message(user_input, ai_response)
def get_context(self, query: str) -> str:
working = self.working_memory.get_context(query)
long_term = self.long_term_memory.get_context(query)
return f"### Long-Term:\n{long_term}\n\n### Working:\n{working}"
def clear(self):
self.working_memory.clear()
self.long_term_memory.clear()
7) 图谱记忆(Graph Memory)
特点:抽取三元组构建知识图谱,适合关系推理。
import networkx as nx
import re
class GraphMemory(BaseMemoryStrategy):
def __init__(self):
self.graph = nx.DiGraph()
def _extract_triples(self, text: str):
prompt = (
"Extract Subject-Relation-Object triples as Python tuples.\n"
f"Text:\n{text}"
)
response = generate_text("You are a KG extractor.", prompt)
return re.findall(r"\(['\"](.*?)['\"],\s*['\"](.*?)['\"],\s*['\"](.*?)['\"]\)", response)
def add_message(self, user_input: str, ai_response: str):
triples = self._extract_triples(f"User: {user_input}\nAI: {ai_response}")
for s, r, o in triples:
self.graph.add_edge(s.strip(), o.strip(), relation=r.strip())
def get_context(self, query: str) -> str:
if not self.graph.nodes:
return "The knowledge graph is empty."
entities = [w.capitalize() for w in query.replace("?", "").split() if w.capitalize() in self.graph.nodes]
if not entities:
return "No relevant entities from your query were found in the knowledge graph."
facts = []
for e in set(entities):
for u, v, d in self.graph.out_edges(e, data=True):
facts.append(f"{u} --[{d['relation']}]--> {v}")
return "### Facts Retrieved from Knowledge Graph:\n" + "\n".join(sorted(set(facts)))
def clear(self):
self.graph.clear()
8) 压缩记忆(Compression)
特点:把每轮对话压缩为“极简事实”,超省 token。
class CompressionMemory(BaseMemoryStrategy):
def __init__(self):
self.compressed_facts = []
def add_message(self, user_input: str, ai_response: str):
prompt = (
"Compress the following into its most essential factual statement.\n"
f"User: {user_input}\nAI: {ai_response}"
)
fact = generate_text("You are a data compressor.", prompt)
self.compressed_facts.append(fact)
def get_context(self, query: str) -> str:
if not self.compressed_facts:
return "No compressed facts in memory."
return "### Compressed Facts:\n- " + "\n- ".join(self.compressed_facts)
def clear(self):
self.compressed_facts = []
9) OS 类记忆(OS-Like)
特点:模拟“内存/硬盘”分页,按需调入旧信息。
class OSMemory(BaseMemoryStrategy):
def __init__(self, ram_size: int = 2):
self.ram_size = ram_size
self.active_memory = deque()
self.passive_memory = {}
self.turn_count = 0
def add_message(self, user_input: str, ai_response: str):
turn_id = self.turn_count
turn_data = f"User: {user_input}\nAI: {ai_response}"
# RAM 满则页面换出
if len(self.active_memory) >= self.ram_size:
lru_id, lru_data = self.active_memory.popleft()
self.passive_memory[lru_id] = lru_data
# 新页面写入 RAM
self.active_memory.append((turn_id, turn_data))
self.turn_count += 1
def get_context(self, query: str) -> str:
active = "\n".join([d for _, d in self.active_memory])
# 简化版“缺页”逻辑:关键词命中则调入
paged_in = ""
for tid, data in self.passive_memory.items():
if any(w in data.lower() for w in query.lower().split() if len(w) > 3):
paged_in += f"\n(Paged in Turn {tid}): {data}"
return f"### RAM:\n{active}\n\n### Disk:\n{paged_in}"
def clear(self):
self.active_memory.clear()
self.passive_memory = {}
self.turn_count = 0
结语:记忆策略不是“选一个”,而是“组合搭配”
- 短对话:顺序或滑窗即可。
- 长对话:总结/压缩能有效控制成本。
- 长期记忆:检索是主流做法。
- 复杂关系:图谱能做结构化推理。
- 大规模系统:分层或 OS 化管理更稳。