Skip to content

Commit b51fdda

Browse files
feat: restrict aksk auth on content gen methods
1 parent f90e885 commit b51fdda

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

volcenginesdkarkruntime/_utils/_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,18 @@ def _insert_sts_token(args, kwargs):
8787
elif ark_client.api_key is None and model and model.startswith("bot-") and ark_client.ak and ark_client.sk:
8888
default_auth_header = {"Authorization": "Bearer " + ark_client._get_bot_sts_token(model)}
8989
extra_headers = kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
90-
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}
90+
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}
91+
92+
93+
def disallow_aksk(func):
94+
def wrapper(*args, **kwargs):
95+
_restrict_aksk(args, kwargs)
96+
return func(*args, **kwargs)
97+
98+
return wrapper
99+
100+
def _restrict_aksk(args, kwargs):
101+
assert len(args) > 0
102+
103+
ark_client = args[0]._client
104+
assert ark_client.api_key is not None, "ak&sk authentication is currently not supported for this method, please use api key instead"

volcenginesdkarkruntime/resources/content_generation/tasks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import httpx
55

66
from ..._types import Body, Query, Headers
7+
from ..._utils._utils import disallow_aksk
78
from ...types.content_generation.content_generation_task import ContentGenerationTask
89
from ...types.content_generation.content_generation_task_id import ContentGenerationTaskID
910
from volcenginesdkarkruntime._base_client import make_request_options
@@ -13,6 +14,8 @@
1314

1415

1516
class Tasks(SyncAPIResource):
17+
18+
@disallow_aksk
1619
def create(
1720
self,
1821
*,
@@ -39,7 +42,7 @@ def create(
3942
)
4043
return resp
4144

42-
45+
@disallow_aksk
4346
def get(
4447
self,
4548
*,
@@ -61,6 +64,7 @@ def get(
6164
)
6265
return resp
6366

67+
@disallow_aksk
6468
def list(
6569
self,
6670
page_num: int | None = None,
@@ -102,6 +106,7 @@ def list(
102106
)
103107
return resp
104108

109+
@disallow_aksk
105110
def delete(
106111
self,
107112
task_id: str,

0 commit comments

Comments
 (0)