1- from typing import List , Type
1+ from typing import List , Type , Union , Any
22import nanoid
3- import numpy as np
43from redis .asyncio import Redis
54from redis .commands .search .query import Query
65from models import (
76 MemoryMessage ,
87 OpenAIClientWrapper ,
8+ AnthropicClientWrapper ,
99 RedisearchResult ,
1010 SearchResults ,
11+ ModelProvider ,
12+ get_model_config ,
1113)
1214import logging
1315
1921async def index_messages (
2022 messages : List [MemoryMessage ],
2123 session_id : str ,
22- openai_client : OpenAIClientWrapper ,
24+ client : OpenAIClientWrapper , # Only OpenAI supports embeddings currently
2325 redis_conn : Redis ,
2426) -> None :
2527 """Index messages in Redis for vector search"""
@@ -28,7 +30,7 @@ async def index_messages(
2830 contents = [msg .content for msg in messages ]
2931
3032 # Get embeddings from OpenAI
31- embeddings = await openai_client .create_embedding (contents )
33+ embeddings = await client .create_embedding (contents )
3234
3335 # Index each message with its embedding
3436 for index , embedding in enumerate (embeddings ):
@@ -64,16 +66,18 @@ class Unset:
6466async def search_messages (
6567 query : str ,
6668 session_id : str ,
67- openai_client : OpenAIClientWrapper ,
69+ client : OpenAIClientWrapper , # Only OpenAI supports embeddings currently
6870 redis_conn : Redis ,
6971 distance_threshold : float | Type [Unset ] = Unset ,
7072 limit : int = 10 ,
7173) -> SearchResults :
7274 """Search for messages using vector similarity"""
7375 try :
7476 # Get embedding for query
75- query_embedding = await openai_client .create_embedding ([query ])
77+ query_embedding = await client .create_embedding ([query ])
7678 vector = query_embedding .tobytes ()
79+
80+ # Set up query parameters
7781 params = {"vec" : vector }
7882
7983 if distance_threshold and distance_threshold is not Unset :
@@ -85,26 +89,46 @@ async def search_messages(
8589 base_query = Query (
8690 f"@session:{{{ session_id } }}=>[KNN { limit } @vector $vec AS dist]"
8791 )
92+
8893 q = (
8994 base_query .return_fields ("role" , "content" , "dist" )
9095 .sort_by ("dist" , asc = True )
9196 .paging (0 , limit )
9297 .dialect (2 )
9398 )
9499
100+ # Execute search
95101 raw_results = await redis_conn .ft (REDIS_INDEX_NAME ).search (
96102 q ,
97103 query_params = params , # type: ignore
98104 )
99105
100- # Parse results
101- results = [
102- RedisearchResult (role = doc .role , content = doc .content , dist = doc .dist )
103- for doc in raw_results .docs
104- ]
106+ # Parse results safely
107+ results = []
108+ total_results = 0
109+
110+ # Check if raw_results has the expected attributes
111+ if hasattr (raw_results , "docs" ) and isinstance (raw_results .docs , list ):
112+ for doc in raw_results .docs :
113+ if (
114+ hasattr (doc , "role" )
115+ and hasattr (doc , "content" )
116+ and hasattr (doc , "dist" )
117+ ):
118+ results .append (
119+ RedisearchResult (
120+ role = doc .role , content = doc .content , dist = float (doc .dist )
121+ )
122+ )
123+
124+ total_results = getattr (raw_results , "total" , len (results ))
125+ else :
126+ # Handle the case where raw_results doesn't have the expected structure
127+ logger .warning ("Unexpected search result format" )
128+ total_results = 0
105129
106130 logger .info (f"Found { len (results )} results for query in session { session_id } " )
107- return SearchResults (total = raw_results . total , docs = results )
131+ return SearchResults (total = total_results , docs = results )
108132 except Exception as e :
109133 logger .error (f"Error searching messages: { e } " )
110134 raise
0 commit comments