diff --git a/assets/long-term-memory.png b/assets/long-term-memory.png new file mode 100644 index 00000000..309ed22c Binary files /dev/null and b/assets/long-term-memory.png differ diff --git a/assets/memory-agents.png b/assets/memory-agents.png new file mode 100644 index 00000000..7d0249f4 Binary files /dev/null and b/assets/memory-agents.png differ diff --git a/assets/short-term-memory.png b/assets/short-term-memory.png new file mode 100644 index 00000000..41759488 Binary files /dev/null and b/assets/short-term-memory.png differ diff --git a/python-recipes/agents/03_memory_agent.ipynb b/python-recipes/agents/03_memory_agent.ipynb new file mode 100644 index 00000000..c701e6c8 --- /dev/null +++ b/python-recipes/agents/03_memory_agent.ipynb @@ -0,0 +1,1607 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)\n", + "\n", + "# Agent Memory Using Redis and LangGraph\n", + "This notebook demonstrates how to manage short-term and long-term agent memory using Redis and LangGraph. We'll explore:\n", + "\n", + "1. Short-term memory management using LangGraph's checkpointer\n", + "2. Long-term memory storage and retrieval using RedisVL\n", + "3. Managing long-term memory manually vs. exposing tool access (AKA function-calling)\n", + "4. Managing conversation history size with summarization\n", + "5. Memory consolidation\n", + "\n", + "\n", + "## What We'll Build\n", + "\n", + "We're going to build two versions of a travel agent, one that manages long-term\n", + "memory manually and one that does so using tools the LLM calls.\n", + "\n", + "Here are two diagrams showing the components used in both agents:\n", + "\n", + "![diagram](../../assets/memory-agents.png)\n", + "\n", + "## Let's Begin!\n", + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "### Packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -q langchain-openai langgraph-checkpoint langgraph-checkpoint-redis \"langchain-community>=0.2.11\" tavily-python langchain-redis pydantic ulid" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Required API Keys\n", + "\n", + "You must add an OpenAI API key with billing information for this lesson. You will also need\n", + "a Tavily API key. Tavily API keys come with free credits at the time of this writing." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# NBVAL_SKIP\n", + "import getpass\n", + "import os\n", + "\n", + "\n", + "def _set_env(key: str):\n", + " if key not in os.environ:\n", + " os.environ[key] = getpass.getpass(f\"{key}:\")\n", + "\n", + "\n", + "_set_env(\"OPENAI_API_KEY\")\n", + "\n", + "# Uncomment this if you have a Tavily API key and want to\n", + "# use the web search tool.\n", + "# _set_env(\"TAVILY_API_KEY\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run redis\n", + "\n", + "### For colab\n", + "\n", + "Convert the following cell to Python to run it in Colab." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "%%sh\n", + "# Exit if this is not running in Colab\n", + "if [ -z \"$COLAB_RELEASE_TAG\" ]; then\n", + " exit 0\n", + "fi\n", + "\n", + "curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg\n", + "echo \"deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main\" | sudo tee /etc/apt/sources.list.d/redis.list\n", + "sudo apt-get update > /dev/null 2>&1\n", + "sudo apt-get install redis-stack-server > /dev/null 2>&1\n", + "redis-stack-server --daemonize yes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### For Alternative Environments\n", + "There are many ways to get the necessary redis-stack instance running\n", + "1. On cloud, deploy a [FREE instance of Redis in the cloud](https://redis.com/try-free/). Or, if you have your\n", + "own version of Redis Enterprise running, that works too!\n", + "2. Per OS, [see the docs](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/)\n", + "3. With docker: `docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest`\n", + "\n", + "## Test connection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from redis import Redis\n", + "\n", + "# Use the environment variable if set, otherwise default to localhost\n", + "REDIS_URL = os.getenv(\"REDIS_URL\", \"redis://localhost:6379\")\n", + "\n", + "redis_client = Redis.from_url(REDIS_URL)\n", + "redis_client.ping()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Short-Term vs. Long-Term Memory\n", + "\n", + "The agent uses **short-term memory** and **long-term memory**. The implementations\n", + "of short-term and long-term memory differ, as does how the agent uses them. Let's\n", + "dig into the details. We'll return to code soon!\n", + "\n", + "### Short-Term Memory\n", + "\n", + "For short-term memory, the agent keeps track of conversation history with Redis.\n", + "Because this is a LangGraph agent, we use the `RedisSaver` class to achieve\n", + "this. `RedisSaver` is what LangGraph refers to as a _checkpointer_. You can read\n", + "more about checkpointers in the [LangGraph\n", + "documentation](https://langchain-ai.github.io/langgraph/concepts/persistence/).\n", + "In short, they store state for each node in the graph, which for this agent\n", + "includes conversation history.\n", + "\n", + "Here's a diagram showing how the agent uses Redis for short-term memory. Each node\n", + "in the graph (Retrieve Users, Respond, Summarize Conversation) persists its \"state\"\n", + "to Redis. The state object contains the agent's message conversation history for\n", + "the current thread.\n", + "\n", + "\n", + "\n", + "If Redis persistence is on, then Redis will persist short-term memory to\n", + "disk. This means if you quit the agent and return with the same thread ID and\n", + "user ID, you'll resume the same conversation.\n", + "\n", + "Conversation histories can grow long and pollute an LLM's context window. To manage\n", + "this, after every \"turn\" of a conversation, the agent summarizes messages when the\n", + "conversation grows past a configurable threshold. Checkpointers do not do this by\n", + "default, so we've created a node in the graph for summarization.\n", + "\n", + "**NOTE**: We'll see example code for the summarization node later in this notebook.\n", + "\n", + "### Long-Term Memory\n", + "\n", + "Aside from conversation history, the agent stores long-term memories in a search\n", + "index in Redis, using [RedisVL](https://docs.redisvl.com/en/latest/). Here's a\n", + "diagram showing the components involved:\n", + "\n", + "\n", + "\n", + "The agent tracks two types of long-term memories:\n", + "\n", + "- **Episodic**: User-specific experiences and preferences\n", + "- **Semantic**: General knowledge about travel destinations and requirements\n", + "\n", + "**NOTE** If you're familiar with the [CoALA\n", + "paper](https://arxiv.org/abs/2309.02427), the terms \"episodic\" and \"semantic\"\n", + "here map to the same concepts in the paper. CoALA discusses a third type of\n", + "memory, _procedural_. In our example, we consider logic encoded in Python in the\n", + "agent codebase to be its procedural memory.\n", + "\n", + "### Representing Long-Term Memory in Python\n", + "We use a couple of Pydantic models to represent long-term memories, both before\n", + "and after they're stored in Redis:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from enum import Enum\n", + "from typing import List, Optional\n", + "\n", + "from pydantic import BaseModel, Field\n", + "import ulid\n", + "\n", + "\n", + "class MemoryType(str, Enum):\n", + " \"\"\"\n", + " The type of a long-term memory.\n", + "\n", + " EPISODIC: User specific experiences and preferences\n", + "\n", + " SEMANTIC: General knowledge on top of the user's preferences and LLM's\n", + " training data.\n", + " \"\"\"\n", + "\n", + " EPISODIC = \"episodic\"\n", + " SEMANTIC = \"semantic\"\n", + "\n", + "\n", + "class Memory(BaseModel):\n", + " \"\"\"Represents a single long-term memory.\"\"\"\n", + "\n", + " content: str\n", + " memory_type: MemoryType\n", + " metadata: str\n", + " \n", + " \n", + "class Memories(BaseModel):\n", + " \"\"\"\n", + " A list of memories extracted from a conversation by an LLM.\n", + "\n", + " NOTE: OpenAI's structured output requires us to wrap the list in an object.\n", + " \"\"\"\n", + "\n", + " memories: List[Memory]\n", + "\n", + "\n", + "class StoredMemory(Memory):\n", + " \"\"\"A stored long-term memory\"\"\"\n", + "\n", + " id: str # The redis key\n", + " memory_id: ulid.ULID = Field(default_factory=lambda: ulid.ULID())\n", + " created_at: datetime = Field(default_factory=datetime.now)\n", + " user_id: Optional[str] = None\n", + " thread_id: Optional[str] = None\n", + " memory_type: Optional[MemoryType] = None\n", + " \n", + " \n", + "class MemoryStrategy(str, Enum):\n", + " \"\"\"\n", + " Supported strategies for managing long-term memory.\n", + " \n", + " This notebook supports two strategies for working with long-term memory:\n", + "\n", + " TOOLS: The LLM decides when to store and retrieve long-term memories, using\n", + " tools (AKA, function-calling) to do so.\n", + "\n", + " MANUAL: The agent manually retrieves long-term memories relevant to the\n", + " current conversation before sending every message and analyzes every\n", + " response to extract memories to store.\n", + "\n", + " NOTE: In both cases, the agent runs a background thread to consolidate\n", + " memories, and a workflow step to summarize conversations after the history\n", + " grows past a threshold.\n", + " \"\"\"\n", + "\n", + " TOOLS = \"tools\"\n", + " MANUAL = \"manual\"\n", + " \n", + " \n", + "# By default, we'll use the manual strategy\n", + "memory_strategy = MemoryStrategy.MANUAL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll return to these models soon to see them in action!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Short-Term Memory Storage and Retrieval\n", + "\n", + "The `RedisSaver` class handles the basics of short-term memory storage for us,\n", + "so we don't need to do anything here." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Long-Term Memory Storage and Retrieval\n", + "\n", + "We use RedisVL to store and retrieve long-term memories with vector embeddings.\n", + "This allows for semantic search of past experiences and knowledge.\n", + "\n", + "Let's set up a new search index to store and query memories:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.index import SearchIndex\n", + "from redisvl.schema.schema import IndexSchema\n", + "\n", + "# Define schema for long-term memory index\n", + "memory_schema = IndexSchema.from_dict({\n", + " \"index\": {\n", + " \"name\": \"agent_memories\",\n", + " \"prefix\": \"memory:\",\n", + " \"key_separator\": \":\",\n", + " \"storage_type\": \"json\",\n", + " },\n", + " \"fields\": [\n", + " {\"name\": \"content\", \"type\": \"text\"},\n", + " {\"name\": \"memory_type\", \"type\": \"tag\"},\n", + " {\"name\": \"metadata\", \"type\": \"text\"},\n", + " {\"name\": \"created_at\", \"type\": \"text\"},\n", + " {\"name\": \"user_id\", \"type\": \"tag\"},\n", + " {\"name\": \"memory_id\", \"type\": \"tag\"},\n", + " {\n", + " \"name\": \"embedding\",\n", + " \"type\": \"vector\",\n", + " \"attrs\": {\n", + " \"algorithm\": \"flat\",\n", + " \"dims\": 1536, # OpenAI embedding dimension\n", + " \"distance_metric\": \"cosine\",\n", + " \"datatype\": \"float32\",\n", + " },\n", + " },\n", + " ],\n", + " }\n", + ")\n", + "\n", + "# Create search index\n", + "try:\n", + " long_term_memory_index = SearchIndex(\n", + " schema=memory_schema, redis_client=redis_client, overwrite=True\n", + " )\n", + " long_term_memory_index.create()\n", + " print(\"Long-term memory index ready\")\n", + "except Exception as e:\n", + " print(f\"Error creating index: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Storage and Retrieval Functions\n", + "\n", + "Now that we have a search index in Redis, we can write functions to store and\n", + "retrieve memories. We can use RedisVL to write these.\n", + "\n", + "First, we'll write a utility function to check if a memory similar to a given\n", + "memory already exists in the index. Later, we can use this to avoid storing\n", + "duplicate memories.\n", + "\n", + "#### Checking for Similar Memories" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "from redisvl.query import VectorRangeQuery\n", + "from redisvl.query.filter import Tag\n", + "from redisvl.utils.vectorize.text.openai import OpenAITextVectorizer\n", + "\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "# If we have any memories that aren't associated with a user, we'll use this ID.\n", + "SYSTEM_USER_ID = \"system\"\n", + "\n", + "openai_embed = OpenAITextVectorizer(model=\"text-embedding-ada-002\")\n", + "\n", + "# Change this to MemoryStrategy.TOOLS to use function-calling to store and\n", + "# retrieve memories.\n", + "memory_strategy = MemoryStrategy.MANUAL\n", + "\n", + "\n", + "def similar_memory_exists(\n", + " content: str,\n", + " memory_type: MemoryType,\n", + " user_id: str = SYSTEM_USER_ID,\n", + " thread_id: Optional[str] = None,\n", + " distance_threshold: float = 0.1,\n", + ") -> bool:\n", + " \"\"\"Check if a similar long-term memory already exists in Redis.\"\"\"\n", + " query_embedding = openai_embed.embed(content)\n", + " filters = (Tag(\"user_id\") == user_id) & (Tag(\"memory_type\") == memory_type)\n", + " if thread_id:\n", + " filters = filters & (Tag(\"thread_id\") == thread_id)\n", + "\n", + " # Search for similar memories\n", + " vector_query = VectorRangeQuery(\n", + " vector=query_embedding,\n", + " num_results=1,\n", + " vector_field_name=\"embedding\",\n", + " filter_expression=filters,\n", + " distance_threshold=distance_threshold,\n", + " return_fields=[\"id\"],\n", + " )\n", + " results = long_term_memory_index.query(vector_query)\n", + " logger.debug(f\"Similar memory search results: {results}\")\n", + "\n", + " if results:\n", + " logger.debug(\n", + " f\"{len(results)} similar {'memory' if results.count == 1 else 'memories'} found. First: \"\n", + " f\"{results[0]['id']}. Skipping storage.\"\n", + " )\n", + " return True\n", + "\n", + " return False\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Storing and Retrieving Long-Term Memories" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll use the `similar_memory_exists()` function when we store memories:" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "from datetime import datetime\n", + "from typing import List, Optional, Union\n", + "\n", + "import ulid\n", + "\n", + "\n", + "def store_memory(\n", + " content: str,\n", + " memory_type: MemoryType,\n", + " user_id: str = SYSTEM_USER_ID,\n", + " thread_id: Optional[str] = None,\n", + " metadata: Optional[str] = None,\n", + "):\n", + " \"\"\"Store a long-term memory in Redis, avoiding duplicates.\"\"\"\n", + " if metadata is None:\n", + " metadata = \"{}\"\n", + "\n", + " logger.info(f\"Preparing to store memory: {content}\")\n", + "\n", + " if similar_memory_exists(content, memory_type, user_id, thread_id):\n", + " logger.info(\"Similar memory found, skipping storage\")\n", + " return\n", + "\n", + " embedding = openai_embed.embed(content)\n", + "\n", + " memory_data = {\n", + " \"user_id\": user_id or SYSTEM_USER_ID,\n", + " \"content\": content,\n", + " \"memory_type\": memory_type.value,\n", + " \"metadata\": metadata,\n", + " \"created_at\": datetime.now().isoformat(),\n", + " \"embedding\": embedding,\n", + " \"memory_id\": str(ulid.ULID()),\n", + " \"thread_id\": thread_id,\n", + " }\n", + "\n", + " try:\n", + " long_term_memory_index.load([memory_data])\n", + " except Exception as e:\n", + " logger.error(f\"Error storing memory: {e}\")\n", + " return\n", + "\n", + " logger.info(f\"Stored {memory_type} memory: {content}\")\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now that we're storing memories, we can retrieve them:" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve_memories(\n", + " query: str,\n", + " memory_type: Union[Optional[MemoryType], List[MemoryType]] = None,\n", + " user_id: str = SYSTEM_USER_ID,\n", + " thread_id: Optional[str] = None,\n", + " distance_threshold: float = 0.1,\n", + " limit: int = 5,\n", + ") -> List[StoredMemory]:\n", + " \"\"\"Retrieve relevant memories from Redis\"\"\"\n", + " # Create vector query\n", + " logger.debug(f\"Retrieving memories for query: {query}\")\n", + " vector_query = VectorRangeQuery(\n", + " vector=openai_embed.embed(query),\n", + " return_fields=[\n", + " \"content\",\n", + " \"memory_type\",\n", + " \"metadata\",\n", + " \"created_at\",\n", + " \"memory_id\",\n", + " \"thread_id\",\n", + " \"user_id\",\n", + " ],\n", + " num_results=limit,\n", + " vector_field_name=\"embedding\",\n", + " dialect=2,\n", + " distance_threshold=distance_threshold,\n", + " )\n", + "\n", + " base_filters = [f\"@user_id:{{{user_id or SYSTEM_USER_ID}}}\"]\n", + "\n", + " if memory_type:\n", + " if isinstance(memory_type, list):\n", + " base_filters.append(f\"@memory_type:{{{'|'.join(memory_type)}}}\")\n", + " else:\n", + " base_filters.append(f\"@memory_type:{{{memory_type.value}}}\")\n", + "\n", + " if thread_id:\n", + " base_filters.append(f\"@thread_id:{{{thread_id}}}\")\n", + "\n", + " vector_query.set_filter(\" \".join(base_filters))\n", + "\n", + " # Execute search\n", + " results = long_term_memory_index.query(vector_query)\n", + "\n", + " # Parse results\n", + " memories = []\n", + " for doc in results:\n", + " try:\n", + " memory = StoredMemory(\n", + " id=doc[\"id\"],\n", + " memory_id=doc[\"memory_id\"],\n", + " user_id=doc[\"user_id\"],\n", + " thread_id=doc.get(\"thread_id\", None),\n", + " memory_type=MemoryType(doc[\"memory_type\"]),\n", + " content=doc[\"content\"],\n", + " created_at=doc[\"created_at\"],\n", + " metadata=doc[\"metadata\"],\n", + " )\n", + " memories.append(memory)\n", + " except Exception as e:\n", + " logger.error(f\"Error parsing memory: {e}\")\n", + " continue\n", + " return memories" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Managing Long-Term Memory Manually vs. Calling Tools\n", + "\n", + "While making LLM queries, agents can store and retrieve relevant long-term\n", + "memories in one of two ways (and more, but these are the two we'll discuss):\n", + "\n", + "1. Expose memory retrieval and storage as \"tools\" that the LLM can decide to call contextually.\n", + "2. Manually augment prompts with relevant memories, and manually extract and store relevant memories.\n", + "\n", + "These approaches both have tradeoffs.\n", + "\n", + "**Tool-calling** leaves the decision to store a memory or find relevant memories\n", + "up to the LLM. This can add latency to requests. It will generally result in\n", + "fewer calls to Redis but will also sometimes miss out on retrieving potentially\n", + "relevant context and/or extracting relevant memories from a conversation.\n", + "\n", + "**Manual memory management** will result in more calls to Redis but will produce\n", + "fewer round-trip LLM requests, reducing latency. Manually extracting memories\n", + "will generally extract more memories than tool calls, which will store more data\n", + "in Redis and should result in more context added to LLM requests. More context\n", + "means more contextual awareness but also higher token spend.\n", + "\n", + "You can test both approaches with this agent by changing the `memory_strategy`\n", + "variable.\n", + "\n", + "## Managing Memory Manually\n", + "With the manual memory management strategy, we're going to extract memories after\n", + "every interaction between the user and the agent. We're then going to retrieve\n", + "those memories during future interactions before we send the query.\n", + "\n", + "### Extracting Memories\n", + "We'll call this `extract_memories` function manually after each interaction:" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import HumanMessage\n", + "from langchain_core.runnables.config import RunnableConfig\n", + "from langchain_openai import ChatOpenAI\n", + "from langgraph.graph.message import MessagesState\n", + "\n", + "\n", + "class RuntimeState(MessagesState):\n", + " \"\"\"Agent state (just messages for now)\"\"\"\n", + "\n", + " pass\n", + "\n", + "\n", + "memory_llm = ChatOpenAI(model=\"gpt-4o\", temperature=0.3).with_structured_output(\n", + " Memories\n", + ")\n", + "\n", + "\n", + "def extract_memories(\n", + " last_processed_message_id: Optional[str],\n", + " state: RuntimeState,\n", + " config: RunnableConfig,\n", + ") -> Optional[str]:\n", + " \"\"\"Extract and store memories in long-term memory\"\"\"\n", + " logger.debug(f\"Last message ID is: {last_processed_message_id}\")\n", + "\n", + " if len(state[\"messages\"]) < 3: # Need at least a user message and agent response\n", + " logger.debug(\"Not enough messages to extract memories\")\n", + " return last_processed_message_id\n", + "\n", + " user_id = config.get(\"configurable\", {}).get(\"user_id\", None)\n", + " if not user_id:\n", + " logger.warning(\"No user ID found in config when extracting memories\")\n", + " return last_processed_message_id\n", + "\n", + " # Get the messages\n", + " messages = state[\"messages\"]\n", + "\n", + " # Find the newest message ID (or None if no IDs)\n", + " newest_message_id = None\n", + " for msg in reversed(messages):\n", + " if hasattr(msg, \"id\") and msg.id:\n", + " newest_message_id = msg.id\n", + " break\n", + "\n", + " logger.debug(f\"Newest message ID is: {newest_message_id}\")\n", + "\n", + " # If we've already processed up to this message ID, skip\n", + " if (\n", + " last_processed_message_id\n", + " and newest_message_id\n", + " and last_processed_message_id == newest_message_id\n", + " ):\n", + " logger.debug(f\"Already processed messages up to ID {newest_message_id}\")\n", + " return last_processed_message_id\n", + "\n", + " # Find the index of the message with last_processed_message_id\n", + " start_index = 0\n", + " if last_processed_message_id:\n", + " for i, msg in enumerate(messages):\n", + " if hasattr(msg, \"id\") and msg.id == last_processed_message_id:\n", + " start_index = i + 1 # Start processing from the next message\n", + " break\n", + "\n", + " # Check if there are messages to process\n", + " if start_index >= len(messages):\n", + " logger.debug(\"No new messages to process since last processed message\")\n", + " return newest_message_id\n", + "\n", + " # Get only the messages after the last processed message\n", + " messages_to_process = messages[start_index:]\n", + "\n", + " # If there are not enough messages to process, include some context\n", + " if len(messages_to_process) < 3 and start_index > 0:\n", + " # Include up to 3 messages before the start_index for context\n", + " context_start = max(0, start_index - 3)\n", + " messages_to_process = messages[context_start:]\n", + "\n", + " # Format messages for the memory agent\n", + " message_history = \"\\n\".join(\n", + " [\n", + " f\"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}\"\n", + " for msg in messages_to_process\n", + " ]\n", + " )\n", + "\n", + " prompt = f\"\"\"\n", + " You are a long-memory manager. Your job is to analyze this message history\n", + " and extract information that might be useful in future conversations.\n", + " \n", + " Extract two types of memories:\n", + " 1. EPISODIC: Personal experiences and preferences specific to this user\n", + " Example: \"User prefers window seats\" or \"User had a bad experience in Paris\"\n", + " \n", + " 2. SEMANTIC: General facts and knowledge about travel that could be useful\n", + " Example: \"The best time to visit Japan is during cherry blossom season in April\"\n", + " \n", + " For each memory, provide:\n", + " - Type: The memory type (EPISODIC/SEMANTIC)\n", + " - Content: The actual information to store\n", + " - Metadata: Relevant tags and context (as JSON)\n", + " \n", + " IMPORTANT RULES:\n", + " 1. Only extract information that would be genuinely useful for future interactions.\n", + " 2. Do not extract procedural knowledge - that is handled by the system's built-in tools and prompts.\n", + " 3. You are a large language model, not a human - do not extract facts that you already know.\n", + " \n", + " Message history:\n", + " {message_history}\n", + " \n", + " Extracted memories:\n", + " \"\"\"\n", + "\n", + " memories_to_store: Memories = memory_llm.invoke([HumanMessage(content=prompt)]) # type: ignore\n", + "\n", + " # Store each extracted memory\n", + " for memory_data in memories_to_store.memories:\n", + " store_memory(\n", + " content=memory_data.content,\n", + " memory_type=memory_data.memory_type,\n", + " user_id=user_id,\n", + " metadata=memory_data.metadata,\n", + " )\n", + "\n", + " # Return data with the newest processed message ID\n", + " return newest_message_id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll use this function in a background thread. We'll start the thread in manual\n", + "memory mode but not in tool mode, and we'll run it as a worker that pulls\n", + "message histories from a `Queue` to process:" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from queue import Queue\n", + "\n", + "\n", + "DEFAULT_MEMORY_WORKER_INTERVAL = 5 * 60 # 5 minutes\n", + "DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL = 10 * 60 # 10 minutes\n", + "\n", + "\n", + "def memory_worker(\n", + " memory_queue: Queue,\n", + " user_id: str,\n", + " interval: int = DEFAULT_MEMORY_WORKER_INTERVAL,\n", + " backoff_interval: int = DEFAULT_MEMORY_WORKER_BACKOFF_INTERVAL,\n", + "):\n", + " \"\"\"Worker function that processes long-term memory extraction requests\"\"\"\n", + " key = f\"memory_worker:{user_id}:last_processed_message_id\"\n", + "\n", + " last_processed_message_id = redis_client.get(key)\n", + " logger.debug(f\"Last processed message ID: {last_processed_message_id}\")\n", + " last_processed_message_id = (\n", + " str(last_processed_message_id) if last_processed_message_id else None\n", + " )\n", + "\n", + " while True:\n", + " try:\n", + " # Get the next state and config from the queue (blocks until an item is available)\n", + " state, config = memory_queue.get()\n", + "\n", + " # Extract long-term memories from the conversation history\n", + " last_processed_message_id = extract_memories(\n", + " last_processed_message_id, state, config\n", + " )\n", + " logger.debug(\n", + " f\"Memory worker extracted memories. Last processed message ID: {last_processed_message_id}\"\n", + " )\n", + "\n", + " if last_processed_message_id:\n", + " logger.debug(\n", + " f\"Setting last processed message ID: {last_processed_message_id}\"\n", + " )\n", + " redis_client.set(key, last_processed_message_id)\n", + "\n", + " # Mark the task as done\n", + " memory_queue.task_done()\n", + " logger.debug(\"Memory extraction completed for queue item\")\n", + " # Wait before processing next item\n", + " time.sleep(interval)\n", + " except Exception as e:\n", + " # Wait before processing next item after an error\n", + " logger.exception(f\"Error in memory worker thread: {e}\")\n", + " time.sleep(backoff_interval)\n", + "\n", + "\n", + "# NOTE: We'll actually start the worker thread later, in the main loop." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Augmenting Queries with Relevant Memories\n", + "\n", + "For every user interaction with the agent, we'll query for relevant memories and\n", + "add them to the LLM prompt with `retrieve_relevant_memories()`.\n", + "\n", + "**NOTE:** We only run this node in the \"manual\" memory management strategy. If\n", + "using \"tools,\" the LLM will decide when to retrieve memories." + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "def retrieve_relevant_memories(\n", + " state: RuntimeState, config: RunnableConfig\n", + ") -> RuntimeState:\n", + " \"\"\"Retrieve relevant memories based on the current conversation.\"\"\"\n", + " if not state[\"messages\"]:\n", + " logger.debug(\"No messages in state\")\n", + " return state\n", + "\n", + " latest_message = state[\"messages\"][-1]\n", + " if not isinstance(latest_message, HumanMessage):\n", + " logger.debug(\"Latest message is not a HumanMessage: \", latest_message)\n", + " return state\n", + "\n", + " user_id = config.get(\"configurable\", {}).get(\"user_id\", SYSTEM_USER_ID)\n", + "\n", + " query = str(latest_message.content)\n", + " relevant_memories = retrieve_memories(\n", + " query=query,\n", + " memory_type=[MemoryType.EPISODIC, MemoryType.SEMANTIC],\n", + " limit=5,\n", + " user_id=user_id,\n", + " distance_threshold=0.3,\n", + " )\n", + "\n", + " logger.debug(f\"All relevant memories: {relevant_memories}\")\n", + "\n", + " # We'll augment the latest human message with the relevant memories.\n", + " if relevant_memories:\n", + " memory_context = \"\\n\\n### Relevant memories from previous conversations:\\n\"\n", + "\n", + " # Group by memory type\n", + " memory_types = {\n", + " MemoryType.EPISODIC: \"User Preferences & History\",\n", + " MemoryType.SEMANTIC: \"Travel Knowledge\",\n", + " }\n", + "\n", + " for mem_type, type_label in memory_types.items():\n", + " memories_of_type = [\n", + " m for m in relevant_memories if m.memory_type == mem_type\n", + " ]\n", + " if memories_of_type:\n", + " memory_context += f\"\\n**{type_label}**:\\n\"\n", + " for mem in memories_of_type:\n", + " memory_context += f\"- {mem.content}\\n\"\n", + "\n", + " augmented_message = HumanMessage(content=f\"{query}\\n{memory_context}\")\n", + " state[\"messages\"][-1] = augmented_message\n", + "\n", + " logger.debug(f\"Augmented message: {augmented_message.content}\")\n", + "\n", + " return state.copy()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is the first function we've seen that represents a **node** in the LangGraph\n", + "graph we'll build. As a node representation, this function receives a `state`\n", + "object containing the runtime state of the graph, which is where conversation\n", + "history resides. Its `config` parameter contains data like the user and thread\n", + "IDs.\n", + "\n", + "This will be the starting node in the graph we'll assemble later. When a user\n", + "invokes the graph with a message, the first thing we'll do (when using the\n", + "\"manual\" memory strategy) is augment that message with potentially related\n", + "memories." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining Tools\n", + "\n", + "Now that we have our storage functions defined, we can create **tools**. We'll\n", + "need these to set up our agent in a moment. These tools will only be used when\n", + "the agent is operating in \"tools\" memory management mode." + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.tools import tool\n", + "from typing import Dict, Optional\n", + "\n", + "\n", + "@tool\n", + "def store_memory_tool(\n", + " content: str,\n", + " memory_type: MemoryType,\n", + " metadata: Optional[Dict[str, str]] = None,\n", + " config: Optional[RunnableConfig] = None,\n", + ") -> str:\n", + " \"\"\"\n", + " Store a long-term memory in the system.\n", + "\n", + " Use this tool to save important information about user preferences,\n", + " experiences, or general knowledge that might be useful in future\n", + " interactions.\n", + " \"\"\"\n", + " config = config or RunnableConfig()\n", + " user_id = config.get(\"user_id\", SYSTEM_USER_ID)\n", + " thread_id = config.get(\"thread_id\")\n", + "\n", + " try:\n", + " # Store in long-term memory\n", + " store_memory(\n", + " content=content,\n", + " memory_type=memory_type,\n", + " user_id=user_id,\n", + " thread_id=thread_id,\n", + " metadata=str(metadata) if metadata else None,\n", + " )\n", + "\n", + " return f\"Successfully stored {memory_type} memory: {content}\"\n", + " except Exception as e:\n", + " return f\"Error storing memory: {str(e)}\"\n", + "\n", + "\n", + "@tool\n", + "def retrieve_memories_tool(\n", + " query: str,\n", + " memory_type: List[MemoryType],\n", + " limit: int = 5,\n", + " config: Optional[RunnableConfig] = None,\n", + ") -> str:\n", + " \"\"\"\n", + " Retrieve long-term memories relevant to the query.\n", + "\n", + " Use this tool to access previously stored information about user\n", + " preferences, experiences, or general knowledge.\n", + " \"\"\"\n", + " config = config or RunnableConfig()\n", + " user_id = config.get(\"user_id\", SYSTEM_USER_ID)\n", + "\n", + " try:\n", + " # Get long-term memories\n", + " stored_memories = retrieve_memories(\n", + " query=query,\n", + " memory_type=memory_type,\n", + " user_id=user_id,\n", + " limit=limit,\n", + " distance_threshold=0.3,\n", + " )\n", + "\n", + " # Format the response\n", + " response = []\n", + "\n", + " if stored_memories:\n", + " response.append(\"Long-term memories:\")\n", + " for memory in stored_memories:\n", + " response.append(f\"- [{memory.memory_type}] {memory.content}\")\n", + "\n", + " return \"\\n\".join(response) if response else \"No relevant memories found.\"\n", + "\n", + " except Exception as e:\n", + " return f\"Error retrieving memories: {str(e)}\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating the Agent\n", + "\n", + "Because we're using different LLM objects configured for different purposes and\n", + "a prebuilt ReAct agent, we need a node that invokes the agent and returns the\n", + "response. But before we can invoke the agent, we need to set it up. This will\n", + "involve defining the tools the agent will need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Dict, List, Optional, Tuple, Union\n", + "\n", + "from langchain_community.tools.tavily_search import TavilySearchResults\n", + "from langchain_core.callbacks.manager import CallbackManagerForToolRun\n", + "from langchain_core.messages import AIMessage, AIMessageChunk, SystemMessage\n", + "from langgraph.prebuilt.chat_agent_executor import create_react_agent\n", + "from langgraph.checkpoint.redis import RedisSaver\n", + "\n", + "\n", + "class CachingTavilySearchResults(TavilySearchResults):\n", + " \"\"\"\n", + " An interface to Tavily search that caches results in Redis.\n", + " \n", + " Caching the results of the web search allows us to avoid rate limiting,\n", + " improve latency, and reduce costs.\n", + " \"\"\"\n", + "\n", + " def _run(\n", + " self,\n", + " query: str,\n", + " run_manager: Optional[CallbackManagerForToolRun] = None,\n", + " ) -> Tuple[Union[List[Dict[str, str]], str], Dict]:\n", + " \"\"\"Use the tool.\"\"\"\n", + " cache_key = f\"tavily_search:{query}\"\n", + " cached_result: Optional[str] = redis_client.get(cache_key) # type: ignore\n", + " if cached_result:\n", + " return json.loads(cached_result), {}\n", + " else:\n", + " result, raw_results = super()._run(query, run_manager)\n", + " redis_client.set(cache_key, json.dumps(result), ex=60 * 60)\n", + " return result, raw_results\n", + "\n", + "\n", + "# Create a checkpoint saver for short-term memory. This keeps track of the\n", + "# conversation history for each thread. Later, we'll continually summarize the\n", + "# conversation history to keep the context window manageable, while we also\n", + "# extract long-term memories from the conversation history to store in the\n", + "# long-term memory index.\n", + "redis_saver = RedisSaver(redis_client=redis_client)\n", + "redis_saver.setup()\n", + "\n", + "# Configure an LLM for the agent with a more creative temperature.\n", + "llm = ChatOpenAI(model=\"gpt-4o\", temperature=0.7)\n", + "\n", + "\n", + "# Uncomment these lines if you have a Tavily API key and want to use the web\n", + "# search tool. The agent is much more useful with this tool.\n", + "# web_search_tool = CachingTavilySearchResults(max_results=2)\n", + "# base_tools = [web_search_tool]\n", + "base_tools = []\n", + "\n", + "if memory_strategy == MemoryStrategy.TOOLS:\n", + " tools = base_tools + [store_memory_tool, retrieve_memories_tool]\n", + "elif memory_strategy == MemoryStrategy.MANUAL:\n", + " tools = base_tools\n", + "\n", + "\n", + "travel_agent = create_react_agent(\n", + " model=llm,\n", + " tools=tools,\n", + " checkpointer=redis_saver, # Short-term memory: the conversation history\n", + " prompt=SystemMessage(\n", + " content=\"\"\"\n", + " You are a travel assistant helping users plan their trips. You remember user preferences\n", + " and provide personalized recommendations based on past interactions.\n", + " \n", + " You have access to the following types of memory:\n", + " 1. Short-term memory: The current conversation thread\n", + " 2. Long-term memory: \n", + " - Episodic: User preferences and past trip experiences (e.g., \"User prefers window seats\")\n", + " - Semantic: General knowledge about travel destinations and requirements\n", + " \n", + " Your procedural knowledge (how to search, book flights, etc.) is built into your tools and prompts.\n", + " \n", + " Always be helpful, personal, and context-aware in your responses.\n", + " \"\"\"\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Responding to the User\n", + "\n", + "Now we can write our node that invokes the agent and responds to the user:" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [ + "def respond_to_user(state: RuntimeState, config: RunnableConfig) -> RuntimeState:\n", + " \"\"\"Invoke the travel agent to generate a response.\"\"\"\n", + " human_messages = [m for m in state[\"messages\"] if isinstance(m, HumanMessage)]\n", + " if not human_messages:\n", + " logger.warning(\"No HumanMessage found in state\")\n", + " return state\n", + "\n", + " try:\n", + " for result in travel_agent.stream(\n", + " {\"messages\": state[\"messages\"]}, config=config, stream_mode=\"messages\"\n", + " ):\n", + " result_messages = result.get(\"messages\", [])\n", + "\n", + " ai_messages = [\n", + " m\n", + " for m in result_messages\n", + " if isinstance(m, AIMessage) or isinstance(m, AIMessageChunk)\n", + " ]\n", + " if ai_messages:\n", + " agent_response = ai_messages[-1]\n", + " # Append only the agent's response to the original state\n", + " state[\"messages\"].append(agent_response)\n", + "\n", + " except Exception as e:\n", + " logger.error(f\"Error invoking travel agent: {e}\")\n", + " agent_response = AIMessage(\n", + " content=\"I'm sorry, I encountered an error processing your request.\"\n", + " )\n", + " return state" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summarizing Conversation History\n", + "\n", + "We've been focusing on long-term memory, but let's bounce back to short-term\n", + "memory for a moment. With `RedisSaver`, LangGraph will manage our message\n", + "history automatically. Still, the message history will continue to grow\n", + "indefinitely, until it overwhelms the LLM's token context window.\n", + "\n", + "To solve this problem, we'll add a node to the graph that summarizes the\n", + "conversation if it's grown past a threshold." + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import RemoveMessage\n", + "\n", + "# An LLM configured for summarization.\n", + "summarizer = ChatOpenAI(model=\"gpt-4o\", temperature=0.3)\n", + "\n", + "# The number of messages after which we'll summarize the conversation.\n", + "MESSAGE_SUMMARIZATION_THRESHOLD = 10\n", + "\n", + "\n", + "def summarize_conversation(\n", + " state: RuntimeState, config: RunnableConfig\n", + ") -> Optional[RuntimeState]:\n", + " \"\"\"\n", + " Summarize a list of messages into a concise summary to reduce context length\n", + " while preserving important information.\n", + " \"\"\"\n", + " messages = state[\"messages\"]\n", + " current_message_count = len(messages)\n", + " if current_message_count < MESSAGE_SUMMARIZATION_THRESHOLD:\n", + " logger.debug(f\"Not summarizing conversation: {current_message_count}\")\n", + " return state\n", + "\n", + " system_prompt = \"\"\"\n", + " You are a conversation summarizer. Create a concise summary of the previous\n", + " conversation between a user and a travel assistant.\n", + " \n", + " The summary should:\n", + " 1. Highlight key topics, preferences, and decisions\n", + " 2. Include any specific trip details (destinations, dates, preferences)\n", + " 3. Note any outstanding questions or topics that need follow-up\n", + " 4. Be concise but informative\n", + " \n", + " Format your summary as a brief narrative paragraph.\n", + " \"\"\"\n", + "\n", + " message_content = \"\\n\".join(\n", + " [\n", + " f\"{'User' if isinstance(msg, HumanMessage) else 'Assistant'}: {msg.content}\"\n", + " for msg in messages\n", + " ]\n", + " )\n", + "\n", + " # Invoke the summarizer\n", + " summary_messages = [\n", + " SystemMessage(content=system_prompt),\n", + " HumanMessage(\n", + " content=f\"Please summarize this conversation:\\n\\n{message_content}\"\n", + " ),\n", + " ]\n", + "\n", + " summary_response = summarizer.invoke(summary_messages)\n", + "\n", + " logger.info(f\"Summarized {len(messages)} messages into a conversation summary\")\n", + "\n", + " summary_message = SystemMessage(\n", + " content=f\"\"\"\n", + " Summary of the conversation so far:\n", + " \n", + " {summary_response.content}\n", + " \n", + " Please continue the conversation based on this summary and the recent messages.\n", + " \"\"\"\n", + " )\n", + " remove_messages = [\n", + " RemoveMessage(id=msg.id) for msg in messages if msg.id is not None\n", + " ]\n", + "\n", + " state[\"messages\"] = [ # type: ignore\n", + " *remove_messages,\n", + " summary_message,\n", + " state[\"messages\"][-1],\n", + " ]\n", + "\n", + " return state.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Assembling the Graph\n", + "\n", + "It's time to assemble our graph!" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "from langgraph.graph import StateGraph, END, START\n", + "\n", + "\n", + "workflow = StateGraph(RuntimeState)\n", + "\n", + "workflow.add_node(\"respond\", respond_to_user)\n", + "workflow.add_node(\"summarize_conversation\", summarize_conversation)\n", + "\n", + "if memory_strategy == MemoryStrategy.MANUAL:\n", + " # In manual memory mode, we'll retrieve relevant memories before\n", + " # responding to the user, and then augment the user's message with the\n", + " # relevant memories.\n", + " workflow.add_node(\"retrieve_memories\", retrieve_relevant_memories)\n", + " workflow.add_edge(START, \"retrieve_memories\")\n", + " workflow.add_edge(\"retrieve_memories\", \"respond\")\n", + "else:\n", + " # In tool-calling mode, we'll respond to the user and let the LLM\n", + " # decide when to retrieve and store memories, using tool calls.\n", + " workflow.add_edge(START, \"respond\")\n", + "\n", + "# Regardless of memory strategy, we'll summarize the conversation after\n", + "# responding to the user, to keep the context window manageable.\n", + "workflow.add_edge(\"respond\", \"summarize_conversation\")\n", + "workflow.add_edge(\"summarize_conversation\", END)\n", + "\n", + "# Finally, compile the graph.\n", + "graph = workflow.compile(checkpointer=redis_saver)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Consolidating Memories in a Background Thread\n", + "\n", + "We're almost ready to create the main loop that runs our graph. First, though,\n", + "let's create a worker that consolidates similar memories on a regular schedule,\n", + "using semantic search. We'll run the worker in a background thread later, in the\n", + "main loop." + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.query import FilterQuery\n", + "\n", + "\n", + "def consolidate_memories(user_id: str, batch_size: int = 10):\n", + " \"\"\"\n", + " Periodically merge similar long-term memories for a user.\n", + " \"\"\"\n", + " logger.info(f\"Starting memory consolidation for user {user_id}\")\n", + " \n", + " # For each memory type, consolidate separately\n", + "\n", + " for memory_type in MemoryType:\n", + " all_memories = []\n", + "\n", + " # Get all memories of this type for the user\n", + " of_type_for_user = (Tag(\"user_id\") == user_id) & (\n", + " Tag(\"memory_type\") == memory_type\n", + " )\n", + " filter_query = FilterQuery(filter_expression=of_type_for_user)\n", + " \n", + " for batch in long_term_memory_index.paginate(filter_query, page_size=batch_size):\n", + " all_memories.extend(batch)\n", + " \n", + " all_memories = long_term_memory_index.query(filter_query)\n", + " if not all_memories:\n", + " continue\n", + "\n", + " # Group similar memories\n", + " processed_ids = set()\n", + " for memory in all_memories:\n", + " if memory[\"id\"] in processed_ids:\n", + " continue\n", + "\n", + " memory_embedding = memory[\"embedding\"]\n", + " vector_query = VectorRangeQuery(\n", + " vector=memory_embedding,\n", + " num_results=10,\n", + " vector_field_name=\"embedding\",\n", + " filter_expression=of_type_for_user\n", + " & (Tag(\"memory_id\") != memory[\"memory_id\"]),\n", + " distance_threshold=0.1,\n", + " return_fields=[\n", + " \"content\",\n", + " \"metadata\",\n", + " ],\n", + " )\n", + " similar_memories = long_term_memory_index.query(vector_query)\n", + "\n", + " # If we found similar memories, consolidate them\n", + " if similar_memories:\n", + " combined_content = memory[\"content\"]\n", + " combined_metadata = memory[\"metadata\"]\n", + "\n", + " if combined_metadata:\n", + " try:\n", + " combined_metadata = json.loads(combined_metadata)\n", + " except Exception as e:\n", + " logger.error(f\"Error parsing metadata: {e}\")\n", + " combined_metadata = {}\n", + "\n", + " for similar in similar_memories:\n", + " # Merge the content of similar memories\n", + " combined_content += f\" {similar['content']}\"\n", + "\n", + " if similar[\"metadata\"]:\n", + " try:\n", + " similar_metadata = json.loads(similar[\"metadata\"])\n", + " except Exception as e:\n", + " logger.error(f\"Error parsing metadata: {e}\")\n", + " similar_metadata = {}\n", + "\n", + " combined_metadata = {**combined_metadata, **similar_metadata}\n", + "\n", + " # Create a consolidated memory\n", + " new_metadata = {\n", + " \"consolidated\": True,\n", + " \"source_count\": len(similar_memories) + 1,\n", + " **combined_metadata,\n", + " }\n", + " consolidated_memory = {\n", + " \"content\": summarize_memories(combined_content, memory_type),\n", + " \"memory_type\": memory_type.value,\n", + " \"metadata\": json.dumps(new_metadata),\n", + " \"user_id\": user_id,\n", + " }\n", + "\n", + " # Delete the old memories\n", + " delete_memory(memory[\"id\"])\n", + " for similar in similar_memories:\n", + " delete_memory(similar[\"id\"])\n", + "\n", + " # Store the new consolidated memory\n", + " store_memory(\n", + " content=consolidated_memory[\"content\"],\n", + " memory_type=memory_type,\n", + " user_id=user_id,\n", + " metadata=consolidated_memory[\"metadata\"],\n", + " )\n", + "\n", + " logger.info(\n", + " f\"Consolidated {len(similar_memories) + 1} memories into one\"\n", + " )\n", + "\n", + "\n", + "def delete_memory(memory_id: str):\n", + " \"\"\"Delete a memory from Redis\"\"\"\n", + " try:\n", + " result = long_term_memory_index.drop_keys([memory_id])\n", + " except Exception as e:\n", + " logger.error(f\"Deleting memory {memory_id} failed: {e}\")\n", + " if result == 0:\n", + " logger.debug(f\"Deleting memory {memory_id} failed: memory not found\")\n", + " else:\n", + " logger.info(f\"Deleted memory {memory_id}\")\n", + "\n", + "\n", + "def summarize_memories(combined_content: str, memory_type: MemoryType) -> str:\n", + " \"\"\"Use the LLM to create a concise summary of similar memories\"\"\"\n", + " try:\n", + " system_prompt = f\"\"\"\n", + " You are a memory consolidation assistant. Your task is to create a single, \n", + " concise memory from these similar memory fragments. The new memory should\n", + " be a {memory_type.value} memory.\n", + " \n", + " Combine the information without repetition while preserving all important details.\n", + " \"\"\"\n", + "\n", + " messages = [\n", + " SystemMessage(content=system_prompt),\n", + " HumanMessage(\n", + " content=f\"Consolidate these similar memories into one:\\n\\n{combined_content}\"\n", + " ),\n", + " ]\n", + "\n", + " response = summarizer.invoke(messages)\n", + " return str(response.content)\n", + " except Exception as e:\n", + " logger.error(f\"Error summarizing memories: {e}\")\n", + " # Fall back to just using the combined content\n", + " return combined_content\n", + "\n", + "\n", + "def memory_consolidation_worker(user_id: str):\n", + " \"\"\"\n", + " Worker that periodically consolidates memories for the active user.\n", + "\n", + " NOTE: In production, this would probably use a background task framework, such\n", + " as rq or Celery, and run on a schedule.\n", + " \"\"\"\n", + " while True:\n", + " try:\n", + " consolidate_memories(user_id)\n", + " # Run every 10 minutes\n", + " time.sleep(10 * 60)\n", + " except Exception as e:\n", + " logger.exception(f\"Error in memory consolidation worker: {e}\")\n", + " # If there's an error, wait an hour and try again\n", + " time.sleep(60 * 60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Main Loop\n", + "\n", + "Now we can put everything together and run the main loop.\n", + "\n", + "Running this cell should ask for your OpenAI and Tavily keys, then a username\n", + "and thread ID. You'll enter a loop in which you can enter queries and see\n", + "responses from the agent printed below the following cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import threading\n", + "\n", + "\n", + "def main(thread_id: str = \"book_flight\", user_id: str = \"demo_user\"):\n", + " \"\"\"Main interaction loop for the travel agent\"\"\"\n", + " print(\"Welcome to the Travel Assistant! (Type 'exit' to quit)\")\n", + "\n", + " config = RunnableConfig(configurable={\"thread_id\": thread_id, \"user_id\": user_id})\n", + " state = RuntimeState(messages=[])\n", + "\n", + " # If we're using the manual memory strategy, we need to create a queue for\n", + " # memory processing and start a worker thread. After every 'round' of a\n", + " # conversation, the main loop will add the current state and config to the\n", + " # queue for memory processing.\n", + " if memory_strategy == MemoryStrategy.MANUAL:\n", + " # Create a queue for memory processing\n", + " memory_queue = Queue()\n", + "\n", + " # Start a worker thread that will process memory extraction tasks\n", + " memory_thread = threading.Thread(\n", + " target=memory_worker, args=(memory_queue, user_id), daemon=True\n", + " )\n", + " memory_thread.start()\n", + "\n", + " # We always run consolidation in the background, regardless of memory strategy.\n", + " consolidation_thread = threading.Thread(\n", + " target=memory_consolidation_worker, args=(user_id,), daemon=True\n", + " )\n", + " consolidation_thread.start()\n", + "\n", + " while True:\n", + " user_input = input(\"\\nYou (type 'quit' to quit): \")\n", + "\n", + " if not user_input:\n", + " continue\n", + "\n", + " if user_input.lower() in [\"exit\", \"quit\"]:\n", + " print(\"Thank you for using the Travel Assistant. Goodbye!\")\n", + " break\n", + "\n", + " state[\"messages\"].append(HumanMessage(content=user_input))\n", + "\n", + " try:\n", + " # Process user input through the graph\n", + " for result in graph.stream(state, config=config, stream_mode=\"values\"):\n", + " state = RuntimeState(**result)\n", + "\n", + " logger.debug(f\"# of messages after run: {len(state['messages'])}\")\n", + "\n", + " # Find the most recent AI message, so we can print the response\n", + " ai_messages = [m for m in state[\"messages\"] if isinstance(m, AIMessage)]\n", + " if ai_messages:\n", + " message = ai_messages[-1].content\n", + " else:\n", + " logger.error(\"No AI messages after run\")\n", + " message = \"I'm sorry, I couldn't process your request properly.\"\n", + " # Add the error message to the state\n", + " state[\"messages\"].append(AIMessage(content=message))\n", + "\n", + " print(f\"\\nAssistant: {message}\")\n", + "\n", + " # Add the current state to the memory processing queue\n", + " if memory_strategy == MemoryStrategy.MANUAL:\n", + " memory_queue.put((state.copy(), config))\n", + "\n", + " except Exception as e:\n", + " logger.exception(f\"Error processing request: {e}\")\n", + " error_message = \"I'm sorry, I encountered an error processing your request.\"\n", + " print(f\"\\nAssistant: {error_message}\")\n", + " # Add the error message to the state\n", + " state[\"messages\"].append(AIMessage(content=error_message))\n", + "\n", + "\n", + "try:\n", + " user_id = input(\"Enter a user ID: \") or \"demo_user\"\n", + " thread_id = input(\"Enter a thread ID: \") or \"demo_thread\"\n", + "except Exception:\n", + " # If we're running in CI, we don't have a terminal to input from, so just exit\n", + " exit()\n", + "else:\n", + " main(thread_id, user_id)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## That's a Wrap!\n", + "\n", + "Want to make your own agent? Try the [LangGraph Quickstart](https://langchain-ai.github.io/langgraph/tutorials/introduction/). Then add our [Redis checkpointer](https://github.com/redis-developer/langgraph-redis) to make give your agent fast, persistent memory!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}