|
99 | 99 |
|
100 | 100 | logger = logging.getLogger(__name__)
|
101 | 101 |
|
| 102 | +# Debounce configuration for thread-aware extraction |
| 103 | +EXTRACTION_DEBOUNCE_TTL = 300 # 5 minutes |
| 104 | +EXTRACTION_DEBOUNCE_KEY_PREFIX = "extraction_debounce" |
| 105 | + |
| 106 | + |
| 107 | +async def should_extract_session_thread(session_id: str, redis: Redis) -> bool: |
| 108 | + """ |
| 109 | + Check if enough time has passed since last thread-aware extraction for this session. |
| 110 | +
|
| 111 | + This implements a debounce mechanism to avoid constantly re-extracting memories |
| 112 | + from the same conversation thread as new messages arrive. |
| 113 | +
|
| 114 | + Args: |
| 115 | + session_id: The session ID to check |
| 116 | + redis: Redis client |
| 117 | +
|
| 118 | + Returns: |
| 119 | + True if extraction should proceed, False if debounced |
| 120 | + """ |
| 121 | + |
| 122 | + debounce_key = f"{EXTRACTION_DEBOUNCE_KEY_PREFIX}:{session_id}" |
| 123 | + |
| 124 | + # Check if debounce key exists |
| 125 | + exists = await redis.exists(debounce_key) |
| 126 | + if not exists: |
| 127 | + # Set debounce key with TTL to prevent extraction for the next period |
| 128 | + await redis.setex(debounce_key, EXTRACTION_DEBOUNCE_TTL, "extracting") |
| 129 | + logger.info( |
| 130 | + f"Starting thread-aware extraction for session {session_id} (debounce set for {EXTRACTION_DEBOUNCE_TTL}s)" |
| 131 | + ) |
| 132 | + return True |
| 133 | + |
| 134 | + remaining_ttl = await redis.ttl(debounce_key) |
| 135 | + logger.info( |
| 136 | + f"Skipping thread-aware extraction for session {session_id} (debounced, {remaining_ttl}s remaining)" |
| 137 | + ) |
| 138 | + return False |
| 139 | + |
| 140 | + |
| 141 | +async def extract_memories_from_session_thread( |
| 142 | + session_id: str, |
| 143 | + namespace: str | None = None, |
| 144 | + user_id: str | None = None, |
| 145 | + llm_client: OpenAIClientWrapper | AnthropicClientWrapper | None = None, |
| 146 | +) -> list[MemoryRecord]: |
| 147 | + """ |
| 148 | + Extract memories from the entire conversation thread in working memory. |
| 149 | +
|
| 150 | + This provides full conversational context for proper contextual grounding, |
| 151 | + allowing pronouns and references to be resolved across the entire thread. |
| 152 | +
|
| 153 | + Args: |
| 154 | + session_id: The session ID to extract memories from |
| 155 | + namespace: Optional namespace for the memories |
| 156 | + user_id: Optional user ID for the memories |
| 157 | + llm_client: Optional LLM client for extraction |
| 158 | +
|
| 159 | + Returns: |
| 160 | + List of extracted memory records with proper contextual grounding |
| 161 | + """ |
| 162 | + from agent_memory_server.working_memory import get_working_memory |
| 163 | + |
| 164 | + # Get the complete working memory thread |
| 165 | + working_memory = await get_working_memory( |
| 166 | + session_id=session_id, namespace=namespace, user_id=user_id |
| 167 | + ) |
| 168 | + |
| 169 | + if not working_memory or not working_memory.messages: |
| 170 | + logger.info(f"No working memory messages found for session {session_id}") |
| 171 | + return [] |
| 172 | + |
| 173 | + # Build full conversation context from all messages |
| 174 | + conversation_messages = [] |
| 175 | + for msg in working_memory.messages: |
| 176 | + # Include role and content for better context |
| 177 | + role_prefix = ( |
| 178 | + f"[{msg.role.upper()}]: " if hasattr(msg, "role") and msg.role else "" |
| 179 | + ) |
| 180 | + conversation_messages.append(f"{role_prefix}{msg.content}") |
| 181 | + |
| 182 | + full_conversation = "\n".join(conversation_messages) |
| 183 | + |
| 184 | + logger.info( |
| 185 | + f"Extracting memories from {len(working_memory.messages)} messages in session {session_id}" |
| 186 | + ) |
| 187 | + logger.debug( |
| 188 | + f"Full conversation context length: {len(full_conversation)} characters" |
| 189 | + ) |
| 190 | + |
| 191 | + # Use the enhanced extraction prompt with contextual grounding |
| 192 | + from agent_memory_server.extraction import DISCRETE_EXTRACTION_PROMPT |
| 193 | + |
| 194 | + client = llm_client or await get_model_client(settings.generation_model) |
| 195 | + |
| 196 | + try: |
| 197 | + response = await client.create_chat_completion( |
| 198 | + model=settings.generation_model, |
| 199 | + prompt=DISCRETE_EXTRACTION_PROMPT.format( |
| 200 | + message=full_conversation, |
| 201 | + top_k_topics=settings.top_k_topics, |
| 202 | + current_datetime=datetime.now().strftime( |
| 203 | + "%A, %B %d, %Y at %I:%M %p %Z" |
| 204 | + ), |
| 205 | + ), |
| 206 | + response_format={"type": "json_object"}, |
| 207 | + ) |
| 208 | + |
| 209 | + extraction_result = json.loads(response.choices[0].message.content) |
| 210 | + memories_data = extraction_result.get("memories", []) |
| 211 | + |
| 212 | + logger.info( |
| 213 | + f"Extracted {len(memories_data)} memories from session thread {session_id}" |
| 214 | + ) |
| 215 | + |
| 216 | + # Convert to MemoryRecord objects |
| 217 | + extracted_memories = [] |
| 218 | + for memory_data in memories_data: |
| 219 | + memory = MemoryRecord( |
| 220 | + id=str(ULID()), |
| 221 | + text=memory_data["text"], |
| 222 | + memory_type=memory_data.get("type", "semantic"), |
| 223 | + topics=memory_data.get("topics", []), |
| 224 | + entities=memory_data.get("entities", []), |
| 225 | + session_id=session_id, |
| 226 | + namespace=namespace, |
| 227 | + user_id=user_id, |
| 228 | + discrete_memory_extracted="t", # Mark as extracted |
| 229 | + ) |
| 230 | + extracted_memories.append(memory) |
| 231 | + |
| 232 | + return extracted_memories |
| 233 | + |
| 234 | + except Exception as e: |
| 235 | + logger.error(f"Error extracting memories from session thread {session_id}: {e}") |
| 236 | + return [] |
| 237 | + |
102 | 238 |
|
103 | 239 | async def extract_memory_structure(memory: MemoryRecord):
|
104 | 240 | redis = await get_redis_conn()
|
@@ -1131,23 +1267,32 @@ async def promote_working_memory_to_long_term(
|
1131 | 1267 | updated_memories = []
|
1132 | 1268 | extracted_memories = []
|
1133 | 1269 |
|
1134 |
| - # Find messages that haven't been extracted yet for discrete memory extraction |
| 1270 | + # Thread-aware discrete memory extraction with debouncing |
1135 | 1271 | unextracted_messages = [
|
1136 | 1272 | message
|
1137 | 1273 | for message in current_working_memory.messages
|
1138 | 1274 | if message.discrete_memory_extracted == "f"
|
1139 | 1275 | ]
|
1140 | 1276 |
|
1141 | 1277 | if settings.enable_discrete_memory_extraction and unextracted_messages:
|
1142 |
| - logger.info(f"Extracting memories from {len(unextracted_messages)} messages") |
1143 |
| - extracted_memories = await extract_memories_from_messages( |
1144 |
| - messages=unextracted_messages, |
1145 |
| - session_id=session_id, |
1146 |
| - user_id=user_id, |
1147 |
| - namespace=namespace, |
1148 |
| - ) |
1149 |
| - for message in unextracted_messages: |
1150 |
| - message.discrete_memory_extracted = "t" |
| 1278 | + # Check if we should run thread-aware extraction (debounced) |
| 1279 | + if await should_extract_session_thread(session_id, redis): |
| 1280 | + logger.info( |
| 1281 | + f"Running thread-aware extraction from {len(current_working_memory.messages)} total messages in session {session_id}" |
| 1282 | + ) |
| 1283 | + extracted_memories = await extract_memories_from_session_thread( |
| 1284 | + session_id=session_id, |
| 1285 | + namespace=namespace, |
| 1286 | + user_id=user_id, |
| 1287 | + ) |
| 1288 | + |
| 1289 | + # Mark ALL messages in the session as extracted since we processed the full thread |
| 1290 | + for message in current_working_memory.messages: |
| 1291 | + message.discrete_memory_extracted = "t" |
| 1292 | + |
| 1293 | + else: |
| 1294 | + logger.info(f"Skipping extraction for session {session_id} - debounced") |
| 1295 | + extracted_memories = [] |
1151 | 1296 |
|
1152 | 1297 | for memory in current_working_memory.memories:
|
1153 | 1298 | if memory.persisted_at is None:
|
|
0 commit comments