8
8
import contextlib
9
9
import re
10
10
from collections .abc import AsyncIterator
11
- from datetime import datetime
12
- from typing import Any , Literal
11
+ from typing import TYPE_CHECKING , Any , Literal
12
+
13
+ if TYPE_CHECKING :
14
+ from typing_extensions import Self
13
15
14
16
import httpx
15
17
import ulid
@@ -71,15 +73,15 @@ def __init__(self, config: MemoryClientConfig):
71
73
timeout = config .timeout ,
72
74
)
73
75
74
- async def close (self ):
76
+ async def close (self ) -> None :
75
77
"""Close the underlying HTTP client."""
76
78
await self ._client .aclose ()
77
79
78
- async def __aenter__ (self ):
80
+ async def __aenter__ (self ) -> "Self" :
79
81
"""Support using the client as an async context manager."""
80
82
return self
81
83
82
- async def __aexit__ (self , exc_type , exc_val , exc_tb ) :
84
+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
83
85
"""Close the client when exiting the context manager."""
84
86
await self .close ()
85
87
@@ -176,13 +178,13 @@ async def get_session_memory(
176
178
params ["namespace" ] = self .config .default_namespace
177
179
178
180
if window_size is not None :
179
- params ["window_size" ] = window_size
181
+ params ["window_size" ] = str ( window_size )
180
182
181
183
if model_name is not None :
182
184
params ["model_name" ] = model_name
183
185
184
186
if context_window_max is not None :
185
- params ["context_window_max" ] = context_window_max
187
+ params ["context_window_max" ] = str ( context_window_max )
186
188
187
189
try :
188
190
response = await self ._client .get (
@@ -861,31 +863,11 @@ def validate_memory_record(self, memory: ClientMemoryRecord | MemoryRecord) -> N
861
863
if memory .id and not self ._is_valid_ulid (memory .id ):
862
864
raise MemoryValidationError (f"Invalid ID format: { memory .id } " )
863
865
864
- if (
865
- hasattr (memory , "created_at" )
866
- and memory .created_at
867
- and not isinstance (memory .created_at , datetime )
868
- ):
869
- try :
870
- datetime .fromisoformat (str (memory .created_at ))
871
- except ValueError as e :
872
- raise MemoryValidationError (
873
- f"Invalid created_at format: { memory .created_at } "
874
- ) from e
875
-
876
- if (
877
- hasattr (memory , "last_accessed" )
878
- and memory .last_accessed
879
- and not isinstance (memory .last_accessed , datetime )
880
- ):
881
- try :
882
- datetime .fromisoformat (str (memory .last_accessed ))
883
- except ValueError as e :
884
- raise MemoryValidationError (
885
- f"Invalid last_accessed format: { memory .last_accessed } "
886
- ) from e
866
+ # created_at is validated by Pydantic
887
867
888
- def validate_search_filters (self , ** filters ) -> None :
868
+ # last_accessed is validated by Pydantic
869
+
870
+ def validate_search_filters (self , ** filters : Any ) -> None :
889
871
"""Validate search filter parameters before API call."""
890
872
valid_filter_keys = {
891
873
"session_id" ,
@@ -1022,7 +1004,10 @@ async def append_messages_to_working_memory(
1022
1004
{"role" : msg .role , "content" : msg .content }
1023
1005
)
1024
1006
else :
1025
- converted_existing_messages .append (msg )
1007
+ # Fallback for any other message type
1008
+ converted_existing_messages .append (
1009
+ {"role" : "user" , "content" : str (msg )}
1010
+ )
1026
1011
1027
1012
# Convert new messages to dict format if they're objects
1028
1013
new_messages = []
@@ -1074,21 +1059,21 @@ async def memory_prompt(
1074
1059
Returns:
1075
1060
Dict with messages hydrated with relevant memory context
1076
1061
"""
1077
- payload = {"query" : query }
1062
+ payload : dict [ str , Any ] = {"query" : query }
1078
1063
1079
1064
# Add session parameters if provided
1080
1065
if session_id is not None :
1081
- session_params = {"session_id" : session_id }
1066
+ session_params : dict [ str , Any ] = {"session_id" : session_id }
1082
1067
if namespace is not None :
1083
1068
session_params ["namespace" ] = namespace
1084
1069
elif self .config .default_namespace is not None :
1085
1070
session_params ["namespace" ] = self .config .default_namespace
1086
1071
if window_size is not None :
1087
- session_params ["window_size" ] = window_size
1072
+ session_params ["window_size" ] = str ( window_size )
1088
1073
if model_name is not None :
1089
1074
session_params ["model_name" ] = model_name
1090
1075
if context_window_max is not None :
1091
- session_params ["context_window_max" ] = context_window_max
1076
+ session_params ["context_window_max" ] = str ( context_window_max )
1092
1077
payload ["session" ] = session_params
1093
1078
1094
1079
# Add long-term search parameters if provided
@@ -1101,7 +1086,10 @@ async def memory_prompt(
1101
1086
json = payload ,
1102
1087
)
1103
1088
response .raise_for_status ()
1104
- return response .json ()
1089
+ result = response .json ()
1090
+ if isinstance (result , dict ):
1091
+ return result
1092
+ return {"response" : result }
1105
1093
except httpx .HTTPStatusError as e :
1106
1094
self ._handle_http_error (e .response )
1107
1095
raise
@@ -1143,7 +1131,7 @@ async def hydrate_memory_prompt(
1143
1131
Dict with messages hydrated with relevant long-term memories
1144
1132
"""
1145
1133
# Build long-term search parameters
1146
- long_term_search = {"limit" : limit }
1134
+ long_term_search : dict [ str , Any ] = {"limit" : limit }
1147
1135
1148
1136
if session_id is not None :
1149
1137
long_term_search ["session_id" ] = session_id
@@ -1171,7 +1159,9 @@ async def hydrate_memory_prompt(
1171
1159
long_term_search = long_term_search ,
1172
1160
)
1173
1161
1174
- def _deep_merge_dicts (self , base : dict , updates : dict ) -> dict :
1162
+ def _deep_merge_dicts (
1163
+ self , base : dict [str , Any ], updates : dict [str , Any ]
1164
+ ) -> dict [str , Any ]:
1175
1165
"""Recursively merge two dictionaries."""
1176
1166
result = base .copy ()
1177
1167
for key , value in updates .items ():
0 commit comments