3434logger = get_logger (__name__ )
3535
3636
37- # FIXME
3837class VikingMemConfig (BaseModel ):
3938 volcengine_ak : Optional [str ] = Field (
4039 default = getenv ("VOLCENGINE_ACCESS_KEY" ),
@@ -54,8 +53,8 @@ class VikingMemConfig(BaseModel):
5453 )
5554
5655
57- # ======= adapted from ... =======
58- class VikingDBMemoryException (Exception ):
56+ # ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py =======
57+ class VikingMemoryException (Exception ):
5958 def __init__ (self , code , request_id , message = None ):
6059 self .code = code
6160 self .request_id = request_id
@@ -67,15 +66,15 @@ def __str__(self):
6766 return self .message
6867
6968
70- class VikingDBMemoryService (Service ):
69+ class VikingMemoryService (Service ):
7170 _instance_lock = threading .Lock ()
7271
7372 def __new__ (cls , * args , ** kwargs ):
74- if not hasattr (VikingDBMemoryService , "_instance" ):
75- with VikingDBMemoryService ._instance_lock :
76- if not hasattr (VikingDBMemoryService , "_instance" ):
77- VikingDBMemoryService ._instance = object .__new__ (cls )
78- return VikingDBMemoryService ._instance
73+ if not hasattr (VikingMemoryService , "_instance" ):
74+ with VikingMemoryService ._instance_lock :
75+ if not hasattr (VikingMemoryService , "_instance" ):
76+ VikingMemoryService ._instance = object .__new__ (cls )
77+ return VikingMemoryService ._instance
7978
8079 def __init__ (
8180 self ,
@@ -88,11 +87,11 @@ def __init__(
8887 connection_timeout = 30 ,
8988 socket_timeout = 30 ,
9089 ):
91- self .service_info = VikingDBMemoryService .get_service_info (
90+ self .service_info = VikingMemoryService .get_service_info (
9291 host , region , scheme , connection_timeout , socket_timeout
9392 )
94- self .api_info = VikingDBMemoryService .get_api_info ()
95- super (VikingDBMemoryService , self ).__init__ (self .service_info , self .api_info )
93+ self .api_info = VikingMemoryService .get_api_info ()
94+ super (VikingMemoryService , self ).__init__ (self .service_info , self .api_info )
9695 if ak :
9796 self .set_ak (ak )
9897 if sk :
@@ -102,12 +101,12 @@ def __init__(
102101 try :
103102 self .get_body ("Ping" , {}, json .dumps ({}))
104103 except Exception as e :
105- raise VikingDBMemoryException (
104+ raise VikingMemoryException (
106105 1000028 , "missed" , "host or region is incorrect: {}" .format (str (e ))
107106 ) from None
108107
109108 def setHeader (self , header ):
110- api_info = VikingDBMemoryService .get_api_info ()
109+ api_info = VikingMemoryService .get_api_info ()
111110 for key in api_info :
112111 for item in header :
113112 api_info [key ].header [item ] = header [item ]
@@ -213,17 +212,17 @@ def get_body_exception(self, api, params, body):
213212 try :
214213 res_json = json .loads (e .args [0 ].decode ("utf-8" ))
215214 except Exception :
216- raise VikingDBMemoryException (
215+ raise VikingMemoryException (
217216 1000028 , "missed" , "json load res error, res:{}" .format (str (e ))
218217 ) from None
219218 code = res_json .get ("code" , 1000028 )
220219 request_id = res_json .get ("request_id" , 1000028 )
221220 message = res_json .get ("message" , None )
222221
223- raise VikingDBMemoryException (code , request_id , message )
222+ raise VikingMemoryException (code , request_id , message )
224223
225224 if res == "" :
226- raise VikingDBMemoryException (
225+ raise VikingMemoryException (
227226 1000028 ,
228227 "missed" ,
229228 "empty response due to unknown error, please contact customer service" ,
@@ -237,15 +236,15 @@ def get_exception(self, api, params):
237236 try :
238237 res_json = json .loads (e .args [0 ].decode ("utf-8" ))
239238 except Exception :
240- raise VikingDBMemoryException (
239+ raise VikingMemoryException (
241240 1000028 , "missed" , "json load res error, res:{}" .format (str (e ))
242241 ) from None
243242 code = res_json .get ("code" , 1000028 )
244243 request_id = res_json .get ("request_id" , 1000028 )
245244 message = res_json .get ("message" , None )
246- raise VikingDBMemoryException (code , request_id , message )
245+ raise VikingMemoryException (code , request_id , message )
247246 if res == "" :
248- raise VikingDBMemoryException (
247+ raise VikingMemoryException (
249248 1000028 ,
250249 "missed" ,
251250 "empty response due to unknown error, please contact customer service" ,
@@ -365,7 +364,7 @@ def format_milliseconds(timestamp_ms):
365364 return dt .strftime ("%Y%m%d %H:%M:%S" )
366365
367366
368- # ======= adapted from ... =======
367+ # ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py =======
369368
370369
371370class VikingMemoryDatabase (BaseModel , BaseDatabase ):
@@ -375,7 +374,7 @@ class VikingMemoryDatabase(BaseModel, BaseDatabase):
375374 )
376375
377376 def model_post_init (self , context : Any , / ) -> None :
378- self ._vm = VikingDBMemoryService (
377+ self ._vm = VikingMemoryService (
379378 ak = self .config .volcengine_ak , sk = self .config .volcengine_sk
380379 )
381380
@@ -516,8 +515,8 @@ def query(self, query: str, **kwargs: Any) -> list[str]:
516515 assert collection_name is not None , "collection_name is required"
517516 user_id = kwargs .get ("user_id" )
518517 assert user_id is not None , "user_id is required"
519-
520- resp = self .search_memory (collection_name , query , user_id = user_id )
518+ top_k = kwargs . get ( "top_k" , 5 )
519+ resp = self .search_memory (collection_name , query , user_id = user_id , top_k = top_k )
521520 return resp
522521
523522 def delete (self , ** kwargs : Any ):
0 commit comments