1313# limitations under the License.
1414
1515import json
16+ import os
1617import re
1718import time
1819import uuid
2223from typing_extensions import override
2324
2425import veadk .config # noqa E401
25- from veadk .config import getenv
26+ from veadk .auth . veauth . utils import get_credential_from_vefaas_iam
2627from veadk .integrations .ve_viking_db_memory .ve_viking_db_memory import (
2728 VikingDBMemoryClient ,
2829)
3536
3637
3738class VikingDBLTMBackend (BaseLongTermMemoryBackend ):
38- volcengine_access_key : str = Field (
39- default_factory = lambda : getenv ("VOLCENGINE_ACCESS_KEY" )
39+ volcengine_access_key : str | None = Field (
40+ default_factory = lambda : os . getenv ("VOLCENGINE_ACCESS_KEY" )
4041 )
4142
42- volcengine_secret_key : str = Field (
43- default_factory = lambda : getenv ("VOLCENGINE_SECRET_KEY" )
43+ volcengine_secret_key : str | None = Field (
44+ default_factory = lambda : os . getenv ("VOLCENGINE_SECRET_KEY" )
4445 )
4546
47+ session_token : str = ""
48+
4649 region : str = "cn-beijing"
4750 """VikingDB memory region"""
4851
@@ -57,31 +60,41 @@ def precheck_index_naming(self):
5760 )
5861
5962 def model_post_init (self , __context : Any ) -> None :
60- self ._client = VikingDBMemoryClient (
61- ak = self .volcengine_access_key ,
62- sk = self .volcengine_secret_key ,
63- region = self .region ,
64- )
65-
6663 # check whether collection exist, if not, create it
6764 if not self ._collection_exist ():
6865 self ._create_collection ()
6966
7067 def _collection_exist (self ) -> bool :
7168 try :
72- self ._client .get_collection (collection_name = self .index )
69+ client = self ._get_client ()
70+ client .get_collection (collection_name = self .index )
7371 return True
7472 except Exception :
7573 return False
7674
7775 def _create_collection (self ) -> None :
78- response = self ._client .create_collection (
76+ client = self ._get_client ()
77+ response = client .create_collection (
7978 collection_name = self .index ,
8079 description = "Created by Volcengine Agent Development Kit VeADK" ,
8180 builtin_event_types = ["sys_event_v1" ],
8281 )
8382 return response
8483
84+ def _get_client (self ) -> VikingDBMemoryClient :
85+ if not (self .volcengine_access_key and self .volcengine_secret_key ):
86+ cred = get_credential_from_vefaas_iam ()
87+ self .volcengine_access_key = cred .access_key_id
88+ self .volcengine_secret_key = cred .secret_access_key
89+ self .session_token = cred .session_token
90+
91+ return VikingDBMemoryClient (
92+ ak = self .volcengine_access_key ,
93+ sk = self .volcengine_secret_key ,
94+ sts_token = self .session_token ,
95+ region = self .region ,
96+ )
97+
8598 @override
8699 def save_memory (self , event_strings : list [str ], ** kwargs ) -> bool :
87100 user_id = kwargs .get ("user_id" )
@@ -101,7 +114,9 @@ def save_memory(self, event_strings: list[str], **kwargs) -> bool:
101114 "default_assistant_id" : "assistant" ,
102115 "time" : int (time .time () * 1000 ),
103116 }
104- response = self ._client .add_messages (
117+
118+ client = self ._get_client ()
119+ response = client .add_messages (
105120 collection_name = self .index ,
106121 messages = messages ,
107122 metadata = metadata ,
@@ -122,7 +137,9 @@ def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
122137 "user_id" : user_id ,
123138 "memory_type" : ["sys_event_v1" ],
124139 }
125- response = self ._client .search_memory (
140+
141+ client = self ._get_client ()
142+ response = client .search_memory (
126143 collection_name = self .index , query = query , filter = filter , limit = top_k
127144 )
128145
0 commit comments