Skip to content

Commit 2a442d5

Browse files
committed
feat: 支持使用短效apikey访问预置接入点
* feat: fix annotation warning * feat: support ep-m and model id * feat: support preset ep * feat: support preset ep * feat: support jwtoken with presetep See merge request: !942
1 parent 503fe59 commit 2a442d5

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
_DEFAULT_MANDATORY_REFRESH_TIMEOUT,
3838
_DEFAULT_STS_TIMEOUT,
3939
_DEFAULT_RESOURCE_TYPE,
40+
_PRESETENDPOINT_RESOURCE_TYPE,
4041
DEFAULT_TIMEOUT,
42+
_BOT_RESOURCE_TYPE,
4143
)
4244
from ._streaming import Stream
4345

@@ -137,12 +139,15 @@ def __init__(
137139
self.files = resources.Files(self)
138140
# self.classification = resources.Classification(self)
139141

140-
def _get_endpoint_sts_token(self, endpoint_id: str):
142+
def _get_endpoint_sts_token(self, endpoint_id: str, project_name: str = None):
141143
if self._sts_token_manager is None:
142144
if self.ak is None or self.sk is None:
143145
raise ArkAPIError("must set ak and sk before get endpoint token.")
144146
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
145-
return self._sts_token_manager.get(endpoint_id)
147+
resource_type: str = self.get_resource_type_by_endpoint_id(endpoint_id)
148+
if resource_type == _PRESETENDPOINT_RESOURCE_TYPE and (project_name is None or project_name.strip() == ""):
149+
raise ArkAPIError("must set project_name when get preset endpoint token.")
150+
return self._sts_token_manager.get(endpoint_id, resource_type=resource_type, project_name=project_name)
146151

147152
def _get_endpoint_certificate(
148153
self, endpoint_id: str
@@ -179,6 +184,16 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
179184
with self.model_breaker_lock:
180185
return self.model_breaker_map[model_name]
181186

187+
def get_resource_type_by_endpoint_id(self, endpoint_id: str) -> str:
188+
if endpoint_id.startswith("ep-m-"):
189+
return _PRESETENDPOINT_RESOURCE_TYPE
190+
if endpoint_id.startswith("ep-"):
191+
return _DEFAULT_RESOURCE_TYPE
192+
if endpoint_id.startswith("bot-"):
193+
return _BOT_RESOURCE_TYPE
194+
# for model id, default to preset endpoint
195+
return _PRESETENDPOINT_RESOURCE_TYPE
196+
182197

183198
class AsyncArk(AsyncAPIClient):
184199
beta: beta.AsyncBeta
@@ -350,6 +365,7 @@ def _protected_refresh(
350365
ttl: int = _DEFAULT_STS_TIMEOUT,
351366
is_mandatory: bool = False,
352367
resource_type: str = _DEFAULT_RESOURCE_TYPE,
368+
project_name: str = None,
353369
):
354370
if ttl < self._advisory_refresh_timeout * 2:
355371
raise ArkAPIError(
@@ -360,7 +376,7 @@ def _protected_refresh(
360376

361377
try:
362378
api_key, expired_time = self._load_api_key(
363-
ep, ttl, resource_type=resource_type
379+
ep, ttl, resource_type=resource_type, project_name=project_name
364380
)
365381
self._endpoint_sts_tokens[ep] = (api_key, expired_time)
366382
except ApiException as e:
@@ -369,7 +385,7 @@ def _protected_refresh(
369385
else:
370386
logging.error("load api key cause error: e={}".format(e))
371387

372-
def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
388+
def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE, project_name: str = None):
373389
if not self._need_refresh(ep, self._advisory_refresh_timeout):
374390
return
375391

@@ -383,7 +399,7 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
383399
)
384400

385401
self._protected_refresh(
386-
ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type
402+
ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type, project_name=project_name
387403
)
388404
return
389405
finally:
@@ -394,24 +410,27 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
394410
return
395411

396412
self._protected_refresh(
397-
ep, is_mandatory=True, resource_type=resource_type
413+
ep, is_mandatory=True, resource_type=resource_type, project_name=project_name
398414
)
399415

400-
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE) -> str:
401-
self._refresh(ep, resource_type=resource_type)
416+
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE, project_name: str = None) -> str:
417+
self._refresh(ep, resource_type=resource_type, project_name=project_name)
402418
return self._endpoint_sts_tokens[ep][0]
403419

404420
def _load_api_key(
405421
self,
406422
ep: str,
407423
duration_seconds: int,
408424
resource_type: str = _DEFAULT_RESOURCE_TYPE,
425+
project_name: str = None,
409426
) -> Tuple[str, int]:
410427
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
411428
duration_seconds=duration_seconds,
412429
resource_type=resource_type,
413430
resource_ids=[ep],
414431
)
432+
if project_name is not None and project_name.strip() != "":
433+
get_api_key_request.project_name = project_name
415434
resp: volcenginesdkark.GetApiKeyResponse = self.api_instance.get_api_key(
416435
get_api_key_request
417436
)

volcenginesdkarkruntime/_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
CLIENT_REQUEST_HEADER = "X-Client-Request-Id"
1919
SERVER_REQUEST_HEADER = "X-Request-Id"
2020
ARK_E2E_ENCRYPTION_HEADER = "x-is-encrypted"
21+
ARK_APIKEY_PROJECT_NAME = "X-Project-Name"
2122

2223
DEFAULT_TIMEOUT_SECONDS = 600.0
2324
DEFAULT_CONNECT_TIMEOUT_SECONDS = 60.0
@@ -39,3 +40,5 @@
3940
_DEFAULT_STS_TIMEOUT = 7 * 24 * 60 * 60 # 7 days
4041

4142
_DEFAULT_RESOURCE_TYPE = "endpoint"
43+
_PRESETENDPOINT_RESOURCE_TYPE = "presetendpoint"
44+
_BOT_RESOURCE_TYPE = "bot"

volcenginesdkarkruntime/_utils/_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from typing_extensions import TypeGuard
3333

3434
from .._types import NotGiven, FileTypes, NotGivenOr
35+
from .._constants import ARK_APIKEY_PROJECT_NAME
3536

3637
_T = TypeVar("_T")
3738
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
@@ -444,12 +445,12 @@ def _insert_sts_token(args, kwargs):
444445
if (
445446
ark_client.api_key is None
446447
and model
447-
and model.startswith("ep-")
448+
and model.startswith("bot-")
448449
and ark_client.ak
449450
and ark_client.sk
450451
):
451452
default_auth_header = {
452-
"Authorization": "Bearer " + ark_client._get_endpoint_sts_token(model)
453+
"Authorization": "Bearer " + ark_client._get_bot_sts_token(model)
453454
}
454455
extra_headers = (
455456
kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
@@ -458,16 +459,18 @@ def _insert_sts_token(args, kwargs):
458459
elif (
459460
ark_client.api_key is None
460461
and model
461-
and model.startswith("bot-")
462462
and ark_client.ak
463463
and ark_client.sk
464464
):
465-
default_auth_header = {
466-
"Authorization": "Bearer " + ark_client._get_bot_sts_token(model)
467-
}
468465
extra_headers = (
469466
kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
470467
)
468+
project_name: str = None
469+
if extra_headers is not None and extra_headers.get(ARK_APIKEY_PROJECT_NAME, None) is not None:
470+
project_name = extra_headers[ARK_APIKEY_PROJECT_NAME]
471+
default_auth_header = {
472+
"Authorization": "Bearer " + ark_client._get_endpoint_sts_token(model, project_name)
473+
}
471474
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}
472475

473476

0 commit comments

Comments
 (0)