AI Agent 的记忆系统架构

概述

人类的记忆系统由短期记忆、长期记忆和情景记忆协同工作。AI Agent 要实现真正的智能,也需要一套类似的记忆架构。没有记忆的 Agent 每次对话都从零开始,无法积累经验、无法个性化、无法处理跨会话的复杂任务。

本文从认知科学的记忆模型出发,设计一套完整的 Agent 记忆系统,覆盖短期记忆、长期记忆、情景记忆的存储、检索和压缩策略。

记忆类型体系

Agent 记忆系统
    |
    ├── 工作记忆 (Working Memory)
    │   ├── 当前对话上下文
    │   ├── 工具调用状态
    │   └── 任务中间结果
    |
    ├── 短期记忆 (Short-term Memory)
    │   ├── 近期对话摘要
    │   ├── 最近的用户偏好
    │   └── 临时知识(会话内有效)
    |
    ├── 长期记忆 (Long-term Memory)
    │   ├── 语义记忆:事实知识、用户画像
    │   ├── 程序记忆:技能、工具使用模式
    │   └── 元记忆:记忆的记忆(索引)
    |
    └── 情景记忆 (Episodic Memory)
        ├── 完整对话记录
        ├── 任务执行轨迹
        └── 成功/失败经验
记忆类型 保留时间 容量 访问方式 存储介质
工作记忆 当前请求 极小 (context window) 直接 LLM 上下文
短期记忆 会话级 中等 关键词/最近 N 条 内存/Redis
长期记忆 永久 语义检索 向量数据库 + KV 存储
情景记忆 永久 极大 时间+语义 向量数据库 + 时序存储

工作记忆:上下文窗口管理

滑动窗口 + 摘要压缩

from dataclasses import dataclass, field
from typing import Optional

@dataclass
class Message:
    role: str
    content: str
    timestamp: float
    token_count: int = 0
    metadata: dict = field(default_factory=dict)

class WorkingMemory:
    """管理 LLM 上下文窗口中的消息"""

    def __init__(self, max_tokens: int = 8192, reserve_tokens: int = 2048):
        self.max_tokens = max_tokens
        self.reserve_tokens = reserve_tokens  # 预留给输出
        self.messages: list[Message] = []
        self.system_prompt: Optional[str] = None
        self.summary: Optional[str] = None

    @property
    def available_tokens(self) -> int:
        used = sum(m.token_count for m in self.messages)
        if self.system_prompt:
            used += count_tokens(self.system_prompt)
        if self.summary:
            used += count_tokens(self.summary)
        return self.max_tokens - self.reserve_tokens - used

    def add_message(self, message: Message):
        self.messages.append(message)

        # 超出窗口时触发压缩
        if self.available_tokens < 0:
            self._compress()

    def _compress(self):
        """压缩旧消息为摘要"""
        # 保留最近 N 条消息不压缩
        keep_recent = 6
        if len(self.messages) <= keep_recent:
            return

        # 将较早的消息压缩为摘要
        old_messages = self.messages[:-keep_recent]
        self.messages = self.messages[-keep_recent:]

        # 生成摘要(可用小模型快速生成)
        old_text = "\n".join(
            f"{m.role}: {m.content}" for m in old_messages
        )

        new_summary = summarize_conversation(old_text, self.summary)
        self.summary = new_summary

    def build_prompt(self) -> list[dict]:
        """构建发送给 LLM 的消息列表"""
        prompt = []

        if self.system_prompt:
            prompt.append({"role": "system", "content": self.system_prompt})

        if self.summary:
            prompt.append({
                "role": "system",
                "content": f"[Earlier conversation summary]\n{self.summary}",
            })

        for msg in self.messages:
            prompt.append({"role": msg.role, "content": msg.content})

        return prompt


def summarize_conversation(text: str, previous_summary: Optional[str]) -> str:
    """递增式摘要:在已有摘要基础上追加新内容"""
    prompt = "Summarize the following conversation, preserving key facts, "
    prompt += "user preferences, and decisions made.\n\n"

    if previous_summary:
        prompt += f"Previous summary:\n{previous_summary}\n\n"

    prompt += f"New messages:\n{text}\n\nUpdated summary:"

    # 使用快速模型生成摘要
    return call_llm(prompt, model="gpt-4o-mini")

短期记忆:会话级缓存

import json
import time
from typing import Any

class ShortTermMemory:
    """会话级短期记忆,使用 Redis 存储"""

    def __init__(self, redis_client, session_ttl: int = 3600):
        self.redis = redis_client
        self.session_ttl = session_ttl  # 默认 1 小时过期

    async def remember(self, session_id: str, key: str, value: Any):
        """存入短期记忆"""
        redis_key = f"stm:{session_id}:{key}"
        await self.redis.setex(
            redis_key,
            self.session_ttl,
            json.dumps(value, ensure_ascii=False),
        )

    async def recall(self, session_id: str, key: str) -> Any:
        """从短期记忆中检索"""
        redis_key = f"stm:{session_id}:{key}"
        value = await self.redis.get(redis_key)
        return json.loads(value) if value else None

    async def remember_preference(self, session_id: str, pref: dict):
        """记住用户本次会话的偏好"""
        existing = await self.recall(session_id, "preferences") or {}
        existing.update(pref)
        await self.remember(session_id, "preferences", existing)

    async def get_session_context(self, session_id: str) -> dict:
        """获取会话的完整短期记忆"""
        pattern = f"stm:{session_id}:*"
        keys = []
        async for key in self.redis.scan_iter(match=pattern):
            keys.append(key)

        context = {}
        if keys:
            values = await self.redis.mget(keys)
            for key, value in zip(keys, values):
                short_key = key.decode().split(":")[-1]
                context[short_key] = json.loads(value) if value else None

        return context

长期记忆:向量数据库 + KV 存储

语义记忆(事实与知识)

from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Filter, FieldCondition, MatchValue
import uuid

class SemanticMemory:
    """基于向量数据库的语义长期记忆"""

    def __init__(self, qdrant: QdrantClient, embedding_model):
        self.qdrant = qdrant
        self.embed = embedding_model
        self.collection = "agent_semantic_memory"

    async def store(self, user_id: str, content: str,
                    category: str, importance: float = 0.5):
        """存储一条语义记忆"""
        vector = await self.embed.encode(content)

        point = PointStruct(
            id=str(uuid.uuid4()),
            vector=vector,
            payload={
                "user_id": user_id,
                "content": content,
                "category": category,       # fact / preference / skill
                "importance": importance,    # 0-1
                "access_count": 0,
                "created_at": time.time(),
                "last_accessed": time.time(),
            },
        )

        await self.qdrant.upsert(
            collection_name=self.collection,
            points=[point],
        )

    async def recall(self, user_id: str, query: str,
                     limit: int = 5, category: str = None) -> list[dict]:
        """语义检索相关记忆"""
        vector = await self.embed.encode(query)

        filters = [FieldCondition(key="user_id", match=MatchValue(value=user_id))]
        if category:
            filters.append(
                FieldCondition(key="category", match=MatchValue(value=category))
            )

        results = await self.qdrant.query_points(
            collection_name=self.collection,
            query=vector,
            query_filter=Filter(must=filters),
            limit=limit,
        )

        memories = []
        for point in results.points:
            # 更新访问计数和时间
            await self._update_access(point.id)

            memories.append({
                "content": point.payload["content"],
                "category": point.payload["category"],
                "importance": point.payload["importance"],
                "relevance": point.score,
            })

        return memories

    async def _update_access(self, point_id: str):
        """更新记忆的访问统计(用于遗忘曲线)"""
        await self.qdrant.set_payload(
            collection_name=self.collection,
            payload={"last_accessed": time.time()},
            points=[point_id],
        )

    async def forget(self, user_id: str, min_age_days: int = 90,
                     max_importance: float = 0.3):
        """遗忘机制:删除老旧且不重要的记忆"""
        cutoff = time.time() - min_age_days * 86400

        # 查找候选遗忘记忆
        results = await self.qdrant.scroll(
            collection_name=self.collection,
            scroll_filter=Filter(must=[
                FieldCondition(key="user_id", match=MatchValue(value=user_id)),
            ]),
            limit=1000,
        )

        to_delete = []
        for point in results[0]:
            payload = point.payload
            if (payload["last_accessed"] < cutoff
                    and payload["importance"] < max_importance
                    and payload["access_count"] < 3):
                to_delete.append(point.id)

        if to_delete:
            await self.qdrant.delete(
                collection_name=self.collection,
                points_selector=to_delete,
            )

        return len(to_delete)

用户画像记忆

class UserProfile:
    """用户画像的长期记忆"""

    def __init__(self, kv_store):
        self.store = kv_store

    async def update_from_conversation(self, user_id: str, messages: list[dict]):
        """从对话中自动提取用户画像信息"""
        profile = await self.get(user_id) or {}

        # 用 LLM 提取用户特征
        extraction_prompt = f"""分析以下对话,提取用户特征信息。
仅输出 JSON,包含以下字段(仅填写能确定的字段):

{{
    "language_preference": "用户偏好的语言",
    "expertise_level": "beginner/intermediate/expert",
    "interests": ["兴趣领域列表"],
    "communication_style": "formal/casual/technical",
    "timezone_hint": "可能的时区",
    "frequently_asked_topics": ["常问的话题"]
}}

对话:
{json.dumps(messages[-10:], ensure_ascii=False)}
"""

        extracted = await call_llm_json(extraction_prompt)

        # 合并到已有画像(新信息覆盖旧信息)
        for key, value in extracted.items():
            if value is not None:
                if isinstance(value, list) and key in profile:
                    # 列表类型:合并去重
                    existing = set(profile.get(key, []))
                    existing.update(value)
                    profile[key] = list(existing)
                else:
                    profile[key] = value

        profile["last_updated"] = time.time()
        await self.store.set(f"profile:{user_id}", json.dumps(profile))

        return profile

    async def get(self, user_id: str) -> dict:
        data = await self.store.get(f"profile:{user_id}")
        return json.loads(data) if data else {}

情景记忆:经验学习

@dataclass
class Episode:
    """一次完整的任务执行经历"""
    episode_id: str
    user_id: str
    task_description: str
    messages: list[dict]
    tool_calls: list[dict]
    outcome: str             # success / failure / partial
    duration_seconds: float
    key_decisions: list[str]
    lessons_learned: str
    timestamp: float

class EpisodicMemory:
    """情景记忆:记录和检索历史经验"""

    def __init__(self, qdrant: QdrantClient, embedding_model):
        self.qdrant = qdrant
        self.embed = embedding_model
        self.collection = "agent_episodes"

    async def record_episode(self, episode: Episode):
        """记录一次完整的任务经历"""
        # 生成情景摘要向量
        summary = (
            f"Task: {episode.task_description}\n"
            f"Outcome: {episode.outcome}\n"
            f"Decisions: {'; '.join(episode.key_decisions)}\n"
            f"Lessons: {episode.lessons_learned}"
        )
        vector = await self.embed.encode(summary)

        point = PointStruct(
            id=episode.episode_id,
            vector=vector,
            payload={
                "user_id": episode.user_id,
                "task": episode.task_description,
                "outcome": episode.outcome,
                "decisions": episode.key_decisions,
                "lessons": episode.lessons_learned,
                "tool_calls": [tc["name"] for tc in episode.tool_calls],
                "duration": episode.duration_seconds,
                "timestamp": episode.timestamp,
            },
        )

        await self.qdrant.upsert(
            collection_name=self.collection,
            points=[point],
        )

    async def recall_similar_experiences(
        self, task_description: str, user_id: str, limit: int = 3
    ) -> list[dict]:
        """检索类似任务的历史经验"""
        vector = await self.embed.encode(task_description)

        results = await self.qdrant.query_points(
            collection_name=self.collection,
            query=vector,
            query_filter=Filter(must=[
                FieldCondition(key="user_id", match=MatchValue(value=user_id)),
            ]),
            limit=limit,
        )

        experiences = []
        for point in results.points:
            experiences.append({
                "task": point.payload["task"],
                "outcome": point.payload["outcome"],
                "decisions": point.payload["decisions"],
                "lessons": point.payload["lessons"],
                "relevance": point.score,
            })

        return experiences

    async def generate_experience_prompt(
        self, task: str, user_id: str
    ) -> str:
        """根据历史经验生成辅助提示"""
        experiences = await self.recall_similar_experiences(task, user_id)

        if not experiences:
            return ""

        prompt = "\n## Relevant Past Experiences\n"
        for i, exp in enumerate(experiences, 1):
            prompt += f"\n### Experience {i} (relevance: {exp['relevance']:.2f})\n"
            prompt += f"- Task: {exp['task']}\n"
            prompt += f"- Outcome: {exp['outcome']}\n"
            prompt += f"- Key decisions: {', '.join(exp['decisions'])}\n"
            prompt += f"- Lessons: {exp['lessons']}\n"

        prompt += (
            "\nUse these past experiences to inform your approach. "
            "Avoid repeating past mistakes and leverage successful strategies.\n"
        )

        return prompt

记忆编排层

class MemoryOrchestrator:
    """统一编排所有记忆系统"""

    def __init__(self, working, short_term, semantic, episodic, profile):
        self.working = working
        self.short_term = short_term
        self.semantic = semantic
        self.episodic = episodic
        self.profile = profile

    async def prepare_context(
        self, user_id: str, session_id: str, user_message: str
    ) -> list[dict]:
        """为每次 LLM 调用准备完整的记忆上下文"""

        # 1. 检索相关长期记忆
        semantic_memories = await self.semantic.recall(
            user_id, user_message, limit=5
        )

        # 2. 检索相关历史经验
        experience_prompt = await self.episodic.generate_experience_prompt(
            user_message, user_id
        )

        # 3. 获取用户画像
        user_profile = await self.profile.get(user_id)

        # 4. 获取会话短期记忆
        session_context = await self.short_term.get_session_context(session_id)

        # 5. 组装系统提示
        system_parts = [self.working.system_prompt or "You are a helpful assistant."]

        if user_profile:
            system_parts.append(
                f"\n## User Profile\n{json.dumps(user_profile, ensure_ascii=False)}"
            )

        if semantic_memories:
            memory_text = "\n".join(
                f"- [{m['category']}] {m['content']}" for m in semantic_memories
            )
            system_parts.append(f"\n## Relevant Knowledge\n{memory_text}")

        if experience_prompt:
            system_parts.append(experience_prompt)

        if session_context:
            system_parts.append(
                f"\n## Session Context\n"
                f"{json.dumps(session_context, ensure_ascii=False)}"
            )

        # 6. 构建完整 prompt
        self.working.system_prompt = "\n".join(system_parts)
        self.working.add_message(Message(
            role="user",
            content=user_message,
            timestamp=time.time(),
            token_count=count_tokens(user_message),
        ))

        return self.working.build_prompt()

    async def post_response(
        self, user_id: str, session_id: str,
        user_message: str, assistant_response: str
    ):
        """响应后更新记忆系统"""

        # 更新工作记忆
        self.working.add_message(Message(
            role="assistant",
            content=assistant_response,
            timestamp=time.time(),
            token_count=count_tokens(assistant_response),
        ))

        # 提取并存储重要信息到长期记忆
        await self._extract_and_store(user_id, user_message, assistant_response)

        # 更新用户画像
        messages = [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": assistant_response},
        ]
        await self.profile.update_from_conversation(user_id, messages)

    async def _extract_and_store(self, user_id, user_msg, assistant_msg):
        """提取值得长期记忆的信息"""
        extraction_prompt = f"""分析以下对话,判断是否包含值得长期记忆的信息。
仅当包含用户的个人偏好、重要事实、明确决策时才输出。

用户: {user_msg}
助手: {assistant_msg}

若有值得记忆的信息,输出 JSON 数组:
[{{"content": "记忆内容", "category": "fact/preference/decision", "importance": 0.0-1.0}}]
若无值得记忆的信息,输出空数组 []
"""
        memories = await call_llm_json(extraction_prompt)

        for mem in memories:
            if mem.get("importance", 0) > 0.3:
                await self.semantic.store(
                    user_id=user_id,
                    content=mem["content"],
                    category=mem["category"],
                    importance=mem["importance"],
                )

记忆压缩与遗忘

递进式压缩

class MemoryCompressor:
    """随时间推移递进压缩记忆"""

    async def compress_daily(self, user_id: str, day_messages: list):
        """日级压缩:保留关键事实"""
        summary = await call_llm(
            f"Summarize the key facts and decisions from today's conversations. "
            f"Focus on actionable information.\n\n"
            f"{format_messages(day_messages)}",
            model="gpt-4o-mini",
        )
        return summary

    async def compress_weekly(self, user_id: str, daily_summaries: list):
        """周级压缩:提取趋势和模式"""
        summary = await call_llm(
            f"Analyze these daily summaries and extract:\n"
            f"1. Recurring themes\n"
            f"2. Evolving preferences\n"
            f"3. Key milestones\n\n"
            f"{chr(10).join(daily_summaries)}",
            model="gpt-4o-mini",
        )
        return summary

    async def compress_monthly(self, user_id: str, weekly_summaries: list):
        """月级压缩:提炼为用户画像更新"""
        update = await call_llm(
            f"Based on these weekly summaries, extract stable user characteristics "
            f"and long-term patterns:\n\n"
            f"{chr(10).join(weekly_summaries)}",
            model="gpt-4o-mini",
        )
        return update

总结

Agent 记忆系统设计的核心原则:

  1. 分层存储:工作记忆(上下文窗口)、短期记忆(Redis)、长期记忆(向量数据库)各司其职
  2. 智能检索:语义检索而非简单的时序读取,确保相关记忆优先浮现
  3. 主动遗忘:不重要且长期不访问的记忆需要清理,防止噪声积累
  4. 经验学习:情景记忆让 Agent 从历史成败中学习,避免重复犯错
  5. 隐私保护:记忆按用户隔离,敏感信息加密存储,支持用户删除

Maurice | maurice_wen@proton.me