Skip to content

Commit cd20016

Browse files
committed
add auto token for vikingdb memory
1 parent 3666a87 commit cd20016

File tree

1 file changed

+32
-15
lines changed

1 file changed

+32
-15
lines changed

veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import os
1617
import re
1718
import time
1819
import uuid
@@ -22,7 +23,7 @@
2223
from typing_extensions import override
2324

2425
import veadk.config # noqa E401
25-
from veadk.config import getenv
26+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
2627
from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
2728
VikingDBMemoryClient,
2829
)
@@ -35,14 +36,16 @@
3536

3637

3738
class VikingDBLTMBackend(BaseLongTermMemoryBackend):
38-
volcengine_access_key: str = Field(
39-
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
39+
volcengine_access_key: str | None = Field(
40+
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
4041
)
4142

42-
volcengine_secret_key: str = Field(
43-
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
43+
volcengine_secret_key: str | None = Field(
44+
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
4445
)
4546

47+
session_token: str = ""
48+
4649
region: str = "cn-beijing"
4750
"""VikingDB memory region"""
4851

@@ -57,31 +60,41 @@ def precheck_index_naming(self):
5760
)
5861

5962
def model_post_init(self, __context: Any) -> None:
60-
self._client = VikingDBMemoryClient(
61-
ak=self.volcengine_access_key,
62-
sk=self.volcengine_secret_key,
63-
region=self.region,
64-
)
65-
6663
# check whether collection exist, if not, create it
6764
if not self._collection_exist():
6865
self._create_collection()
6966

7067
def _collection_exist(self) -> bool:
7168
try:
72-
self._client.get_collection(collection_name=self.index)
69+
client = self._get_client()
70+
client.get_collection(collection_name=self.index)
7371
return True
7472
except Exception:
7573
return False
7674

7775
def _create_collection(self) -> None:
78-
response = self._client.create_collection(
76+
client = self._get_client()
77+
response = client.create_collection(
7978
collection_name=self.index,
8079
description="Created by Volcengine Agent Development Kit VeADK",
8180
builtin_event_types=["sys_event_v1"],
8281
)
8382
return response
8483

84+
def _get_client(self) -> VikingDBMemoryClient:
85+
if not (self.volcengine_access_key and self.volcengine_secret_key):
86+
cred = get_credential_from_vefaas_iam()
87+
self.volcengine_access_key = cred.access_key_id
88+
self.volcengine_secret_key = cred.secret_access_key
89+
self.session_token = cred.session_token
90+
91+
return VikingDBMemoryClient(
92+
ak=self.volcengine_access_key,
93+
sk=self.volcengine_secret_key,
94+
sts_token=self.session_token,
95+
region=self.region,
96+
)
97+
8598
@override
8699
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
87100
user_id = kwargs.get("user_id")
@@ -101,7 +114,9 @@ def save_memory(self, event_strings: list[str], **kwargs) -> bool:
101114
"default_assistant_id": "assistant",
102115
"time": int(time.time() * 1000),
103116
}
104-
response = self._client.add_messages(
117+
118+
client = self._get_client()
119+
response = client.add_messages(
105120
collection_name=self.index,
106121
messages=messages,
107122
metadata=metadata,
@@ -122,7 +137,9 @@ def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
122137
"user_id": user_id,
123138
"memory_type": ["sys_event_v1"],
124139
}
125-
response = self._client.search_memory(
140+
141+
client = self._get_client()
142+
response = client.search_memory(
126143
collection_name=self.index, query=query, filter=filter, limit=top_k
127144
)
128145

0 commit comments

Comments
 (0)