Skip to content

Commit 2313681

Browse files
committed
add db auth
1 parent de6456e commit 2313681

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import os
30+
31+
from typing_extensions import override
32+
33+
from veadk.auth.veauth.base_veauth import BaseVeAuth
34+
from veadk.utils.logger import get_logger
35+
36+
# from veadk.utils.volcengine_sign import ve_request
37+
38+
logger = get_logger(__name__)
39+
40+
41+
class OpensearchVeAuth(BaseVeAuth):
42+
def __init__(
43+
self,
44+
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
45+
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
46+
) -> None:
47+
super().__init__(access_key, secret_key)
48+
49+
self._token: str = ""
50+
51+
@override
52+
def _fetch_token(self) -> None:
53+
logger.info("Fetching Opensearch STS token...")
54+
55+
# res = ve_request(
56+
# request_body={},
57+
# action="GetOrCreatePromptPilotAPIKeys",
58+
# ak=self.access_key,
59+
# sk=self.secret_key,
60+
# service="ark",
61+
# version="2024-01-01",
62+
# region="cn-beijing",
63+
# host="open.volcengineapi.com",
64+
# )
65+
# try:
66+
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
67+
# except KeyError:
68+
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")
69+
70+
@property
71+
def token(self) -> str:
72+
if self._token:
73+
return self._token
74+
self._fetch_token()
75+
return self._token
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import os
30+
31+
from typing_extensions import override
32+
33+
from veadk.auth.veauth.base_veauth import BaseVeAuth
34+
from veadk.utils.logger import get_logger
35+
36+
# from veadk.utils.volcengine_sign import ve_request
37+
38+
logger = get_logger(__name__)
39+
40+
41+
class PostgreSqlVeAuth(BaseVeAuth):
42+
def __init__(
43+
self,
44+
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
45+
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
46+
) -> None:
47+
super().__init__(access_key, secret_key)
48+
49+
self._token: str = ""
50+
51+
@override
52+
def _fetch_token(self) -> None:
53+
logger.info("Fetching PostgreSQL STS token...")
54+
55+
# res = ve_request(
56+
# request_body={},
57+
# action="GetOrCreatePromptPilotAPIKeys",
58+
# ak=self.access_key,
59+
# sk=self.secret_key,
60+
# service="ark",
61+
# version="2024-01-01",
62+
# region="cn-beijing",
63+
# host="open.volcengineapi.com",
64+
# )
65+
# try:
66+
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
67+
# except KeyError:
68+
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")
69+
70+
@property
71+
def token(self) -> str:
72+
if self._token:
73+
return self._token
74+
self._fetch_token()
75+
return self._token

veadk/configs/database_configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class OpensearchConfig(BaseSettings):
3232

3333
password: str = ""
3434

35+
secret_token: str = ""
36+
3537

3638
class MysqlConfig(BaseSettings):
3739
model_config = SettingsConfigDict(env_prefix="DATABASE_MYSQL_")
@@ -46,6 +48,9 @@ class MysqlConfig(BaseSettings):
4648

4749
charset: str = "utf8"
4850

51+
secret_token: str = ""
52+
"""STS token for MySQL auth, not supported yet."""
53+
4954

5055
class PostgreSqlConfig(BaseSettings):
5156
model_config = SettingsConfigDict(env_prefix="DATABASE_POSTGRESQL_")
@@ -60,6 +65,8 @@ class PostgreSqlConfig(BaseSettings):
6065

6166
database: str = ""
6267

68+
secret_token: str = ""
69+
6370

6471
class RedisConfig(BaseSettings):
6572
model_config = SettingsConfigDict(env_prefix="DATABASE_REDIS_")
@@ -72,6 +79,9 @@ class RedisConfig(BaseSettings):
7279

7380
db: int = 0
7481

82+
secret_token: str = ""
83+
"""STS token for Redis auth, not supported yet."""
84+
7585

7686
class VikingKnowledgebaseConfig(BaseSettings):
7787
model_config = SettingsConfigDict(env_prefix="DATABASE_VIKING_")

veadk/memory/short_term_memory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
)
3131
from veadk.utils.logger import get_logger
3232

33-
# from .short_term_memory_processor import ShortTermMemoryProcessor
34-
3533
logger = get_logger(__name__)
3634

3735
DEFAULT_LOCAL_DATABASE_PATH = "/tmp/veadk_local_database.db"
@@ -71,6 +69,7 @@ def model_post_init(self, __context: Any) -> None:
7169
logger.warning(
7270
"Backend `database` is deprecated, use `sqlite` to create short term memory."
7371
)
72+
self.backend = "sqlite"
7473
match self.backend:
7574
case "local":
7675
self.session_service = InMemorySessionService()
@@ -101,7 +100,7 @@ async def create_session(
101100
app_name: str,
102101
user_id: str,
103102
session_id: str,
104-
):
103+
) -> None:
105104
if isinstance(self.session_service, DatabaseSessionService):
106105
list_sessions_response = await self.session_service.list_sessions(
107106
app_name=app_name, user_id=user_id

0 commit comments

Comments
 (0)