Skip to content

Commit bf7a9fa

Browse files
feat: add vikingmem project and region env (#384)
* feat: add vikingmem project and region env * revert version * fix project * add collection type
1 parent 693cf04 commit bf7a9fa

File tree

2 files changed

+71
-21
lines changed

2 files changed

+71
-21
lines changed

veadk/integrations/ve_viking_db_memory/ve_viking_db_memory.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
# limitations under the License.
1414

1515
import json
16+
import os
1617
import threading
17-
from veadk.utils.misc import getenv
18+
1819
from volcengine.ApiInfo import ApiInfo
1920
from volcengine.auth.SignerV4 import SignerV4
2021
from volcengine.base.Service import Service
2122
from volcengine.Credentials import Credentials
2223
from volcengine.ServiceInfo import ServiceInfo
2324

25+
from veadk.utils.misc import getenv
26+
2427

2528
class VikingDBMemoryException(Exception):
2629
def __init__(self, code, request_id, message=None):
@@ -56,7 +59,9 @@ def __init__(
5659
socket_timeout=30,
5760
):
5861
env_host = getenv(
59-
"DATABASE_VIKINGMEM_BASE_URL", default_value=None, allow_false_values=True
62+
"DATABASE_VIKINGMEM_BASE_URL",
63+
default_value=None,
64+
allow_false_values=True,
6065
)
6166
if env_host:
6267
if env_host.startswith("http://"):
@@ -85,7 +90,9 @@ def __init__(
8590
self.get_body("Ping", {}, json.dumps({}))
8691
except Exception as e:
8792
raise VikingDBMemoryException(
88-
1000028, "missed", "host or region is incorrect: {}".format(str(e))
93+
1000028,
94+
"missed",
95+
"host or region is incorrect: {}".format(str(e)),
8996
) from None
9097

9198
def setHeader(self, header):
@@ -118,49 +125,70 @@ def get_api_info():
118125
"/api/memory/collection/create",
119126
{},
120127
{},
121-
{"Accept": "application/json", "Content-Type": "application/json"},
128+
{
129+
"Accept": "application/json",
130+
"Content-Type": "application/json",
131+
},
122132
),
123133
"GetCollection": ApiInfo(
124134
"POST",
125135
"/api/memory/collection/info",
126136
{},
127137
{},
128-
{"Accept": "application/json", "Content-Type": "application/json"},
138+
{
139+
"Accept": "application/json",
140+
"Content-Type": "application/json",
141+
},
129142
),
130143
"DropCollection": ApiInfo(
131144
"POST",
132145
"/api/memory/collection/delete",
133146
{},
134147
{},
135-
{"Accept": "application/json", "Content-Type": "application/json"},
148+
{
149+
"Accept": "application/json",
150+
"Content-Type": "application/json",
151+
},
136152
),
137153
"UpdateCollection": ApiInfo(
138154
"POST",
139155
"/api/memory/collection/update",
140156
{},
141157
{},
142-
{"Accept": "application/json", "Content-Type": "application/json"},
158+
{
159+
"Accept": "application/json",
160+
"Content-Type": "application/json",
161+
},
143162
),
144163
"SearchMemory": ApiInfo(
145164
"POST",
146165
"/api/memory/search",
147166
{},
148167
{},
149-
{"Accept": "application/json", "Content-Type": "application/json"},
168+
{
169+
"Accept": "application/json",
170+
"Content-Type": "application/json",
171+
},
150172
),
151173
"AddMessages": ApiInfo(
152174
"POST",
153175
"/api/memory/messages/add",
154176
{},
155177
{},
156-
{"Accept": "application/json", "Content-Type": "application/json"},
178+
{
179+
"Accept": "application/json",
180+
"Content-Type": "application/json",
181+
},
157182
),
158183
"Ping": ApiInfo(
159184
"GET",
160185
"/api/memory/ping",
161186
{},
162187
{},
163-
{"Accept": "application/json", "Content-Type": "application/json"},
188+
{
189+
"Accept": "application/json",
190+
"Content-Type": "application/json",
191+
},
164192
),
165193
}
166194
return api_info
@@ -199,7 +227,9 @@ def get_body_exception(self, api, params, body):
199227
res_json = json.loads(e.args[0].decode("utf-8"))
200228
except Exception as e:
201229
raise VikingDBMemoryException(
202-
1000028, "missed", "json load res error, res:{}".format(str(e))
230+
1000028,
231+
"missed",
232+
"json load res error, res:{}".format(str(e)),
203233
) from None
204234
code = res_json.get("code", 1000028)
205235
request_id = res_json.get("request_id", 1000028)
@@ -223,7 +253,9 @@ def get_exception(self, api, params):
223253
res_json = json.loads(e.args[0].decode("utf-8"))
224254
except Exception as e:
225255
raise VikingDBMemoryException(
226-
1000028, "missed", "json load res error, res:{}".format(str(e))
256+
1000028,
257+
"missed",
258+
"json load res error, res:{}".format(str(e)),
227259
) from None
228260
code = res_json.get("code", 1000028)
229261
request_id = res_json.get("request_id", 1000028)
@@ -241,13 +273,18 @@ def create_collection(
241273
self,
242274
collection_name,
243275
description="",
276+
project="default",
244277
custom_event_type_schemas=[],
245278
custom_entity_type_schemas=[],
246279
builtin_event_types=[],
247280
builtin_entity_types=[],
248281
):
249282
params = {
250283
"CollectionName": collection_name,
284+
"ProjectName": project,
285+
"CollectionType": os.getenv(
286+
"DATABASE_VIKINGMEM_COLLECTION_TYPE", "standard"
287+
),
251288
"Description": description,
252289
"CustomEventTypeSchemas": custom_event_type_schemas,
253290
"CustomEntityTypeSchemas": custom_entity_type_schemas,
@@ -257,8 +294,8 @@ def create_collection(
257294
res = self.json("CreateCollection", {}, json.dumps(params))
258295
return json.loads(res)
259296

260-
def get_collection(self, collection_name):
261-
params = {"CollectionName": collection_name}
297+
def get_collection(self, collection_name, project="default"):
298+
params = {"CollectionName": collection_name, "ProjectName": project}
262299
res = self.json("GetCollection", {}, json.dumps(params))
263300
return json.loads(res)
264301

veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from pydantic import Field
2323
from typing_extensions import override
24+
from vikingdb import IAM
25+
from vikingdb.memory import VikingMem
2426

2527
import veadk.config # noqa E401
2628
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
@@ -30,9 +32,6 @@
3032
from veadk.memory.long_term_memory_backends.base_backend import (
3133
BaseLongTermMemoryBackend,
3234
)
33-
from vikingdb import IAM
34-
from vikingdb.memory import VikingMem
35-
3635
from veadk.utils.logger import get_logger
3736

3837
logger = get_logger(__name__)
@@ -49,9 +48,16 @@ class VikingDBLTMBackend(BaseLongTermMemoryBackend):
4948

5049
session_token: str = ""
5150

52-
region: str = "cn-beijing"
51+
region: str = Field(
52+
default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_REGION") or "cn-beijing"
53+
)
5354
"""VikingDB memory region"""
5455

56+
volcengine_project: str = Field(
57+
default_factory=lambda: os.getenv("DATABASE_VIKINGMEM_PROJECT") or "default"
58+
)
59+
"""VikingDB memory project"""
60+
5561
memory_type: list[str] = Field(default_factory=list)
5662

5763
def model_post_init(self, __context: Any) -> None:
@@ -87,7 +93,9 @@ def precheck_index_naming(self):
8793
def _collection_exist(self) -> bool:
8894
try:
8995
client = self._get_client()
90-
client.get_collection(collection_name=self.index)
96+
client.get_collection(
97+
collection_name=self.index, project=self.volcengine_project
98+
)
9199
logger.info(f"Collection {self.index} exist.")
92100
return True
93101
except Exception:
@@ -101,6 +109,7 @@ def _create_collection(self) -> None:
101109
client = self._get_client()
102110
response = client.create_collection(
103111
collection_name=self.index,
112+
project=self.volcengine_project,
104113
description="Created by Volcengine Agent Development Kit VeADK",
105114
builtin_event_types=self.memory_type,
106115
)
@@ -156,7 +165,9 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
156165
)
157166

158167
client = self._get_sdk_client()
159-
collection = client.get_collection(collection_name=self.index)
168+
collection = client.get_collection(
169+
collection_name=self.index, project_name=self.volcengine_project
170+
)
160171
response = collection.add_session(
161172
session_id=session_id,
162173
messages=messages,
@@ -181,7 +192,9 @@ def search_memory(
181192
)
182193

183194
client = self._get_sdk_client()
184-
collection = client.get_collection(collection_name=self.index)
195+
collection = client.get_collection(
196+
collection_name=self.index, project_name=self.volcengine_project
197+
)
185198
response = collection.search_memory(
186199
query=query,
187200
filter=filter,

0 commit comments

Comments
 (0)