Skip to content

Commit 35a3f67

Browse files
author
hexiaochun
committed
feat: 增加refresh
1 parent 8b5c278 commit 35a3f67

File tree

13 files changed

+400
-200
lines changed

13 files changed

+400
-200
lines changed

volcenginesdkarkruntime/_base_client.py

Lines changed: 140 additions & 138 deletions
Large diffs are not rendered by default.

volcenginesdkarkruntime/_client.py

Lines changed: 165 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
4-
from typing import Dict
5+
import threading
6+
import time
7+
from collections import defaultdict
58

6-
import httpx
7-
from httpx import Timeout, URL
9+
from httpx import Timeout, URL, Client, AsyncClient
10+
from typing import Dict, Tuple
811

9-
from . import resources, _exceptions
12+
from volcenginesdkcore.rest import ApiException
13+
from ._exceptions import ArkAPIError
14+
15+
import volcenginesdkark
16+
17+
from . import resources
1018
from ._base_client import SyncAPIClient, AsyncAPIClient
11-
from ._constants import DEFAULT_MAX_RETRIES, BASE_URL
12-
from ._exceptions import ArkError, ArkAPIStatusError
19+
from ._constants import (
20+
DEFAULT_MAX_RETRIES,
21+
BASE_URL,
22+
_DEFAULT_ADVISORY_REFRESH_TIMEOUT,
23+
_DEFAULT_MANDATORY_REFRESH_TIMEOUT,
24+
_DEFAULT_STS_TIMEOUT,
25+
DEFAULT_TIMEOUT
26+
)
1327
from ._streaming import Stream
1428

1529
__all__ = ["Ark", "AsyncArk"]
@@ -22,32 +36,51 @@ def __init__(
2236
self,
2337
*,
2438
base_url: str | URL = BASE_URL,
25-
api_key: str | None = None,
26-
timeout: float | Timeout | None = None,
39+
ak: str | None = None,
40+
sk: str | None = None,
41+
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
2742
max_retries: int = DEFAULT_MAX_RETRIES,
28-
default_query: Dict[str, object] | None = None,
29-
http_client: httpx.Client | None = None,
43+
http_client: Client | None = None,
3044
) -> None:
31-
if api_key is None:
32-
api_key = os.environ.get("ARK_API_KEY")
33-
self.api_key = api_key
45+
"""init ark client, this client is thread unsafe. If need to use in multi thread, init a new `Ark` client in
46+
each thread
47+
48+
Args:
49+
ak: access key id
50+
sk: secret access key
51+
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
52+
max_retries: times of retry when request failed. default 1
53+
http_client: specify customized http_client
54+
Returns:
55+
ark client
56+
"""
57+
58+
if ak is None:
59+
ak = os.environ.get("VOLC_ACCESSKEY")
60+
if sk is None:
61+
sk = os.environ.get("VOLC_SECRETKEY")
62+
self.ak = ak
63+
self.sk = sk
3464

3565
super().__init__(
3666
base_url=base_url,
3767
max_retries=max_retries,
3868
timeout=timeout,
3969
http_client=http_client,
40-
custom_query=default_query,
70+
custom_query=None,
4171
)
4272

4373
self._default_stream_cls = Stream
74+
self._sts_token_manager: StsTokenManager | None = None
4475

4576
self.chat = resources.Chat(self)
4677

47-
@property
48-
def auth_headers(self) -> dict[str, str]:
49-
api_key = self.api_key
50-
return {"Authorization": f"Bearer {api_key}"}
78+
def _get_endpoint_sts_token(self, endpoint_id: str):
79+
if self._sts_token_manager is None:
80+
if self.ak is None or self.sk is None:
81+
raise ArkAPIError("must set ak and sk before get endpoint token.")
82+
self._sts_token_manager = StsTokenManager(self.ak, self.sk)
83+
return self._sts_token_manager.get(endpoint_id)
5184

5285

5386
class AsyncArk(AsyncAPIClient):
@@ -56,31 +89,133 @@ class AsyncArk(AsyncAPIClient):
5689
def __init__(
5790
self,
5891
*,
92+
ak: str | None = None,
93+
sk: str | None = None,
5994
base_url: str | URL = BASE_URL,
60-
api_key: str | None = None,
61-
timeout: float | Timeout | None = None,
95+
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
6296
max_retries: int = DEFAULT_MAX_RETRIES,
63-
default_query: Dict[str, object] | None = None,
64-
http_client: httpx.Client | None = None,
97+
http_client: AsyncClient | None = None,
6598
) -> None:
66-
if api_key is None:
67-
api_key = os.environ.get("ARK_API_KEY")
68-
self.api_key = api_key
99+
"""init async ark client, this client is thread unsafe
100+
101+
Args:
102+
ak: access key id
103+
sk: secret access key
104+
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
105+
max_retries: times of retry when request failed. default 1
106+
http_client: specify customized http_client
107+
Returns:
108+
async ark client
109+
"""
110+
111+
if ak is None:
112+
ak = os.environ.get("VOLC_ACCESSKEY")
113+
if sk is None:
114+
sk = os.environ.get("VOLC_SECRETKEY")
115+
self.ak = ak
116+
self.sk = sk
69117

70118
super().__init__(
71119
base_url=base_url,
72120
max_retries=max_retries,
73121
timeout=timeout,
74122
http_client=http_client,
75-
custom_query=default_query,
123+
custom_query=None,
76124
)
77125

78126
self._default_stream_cls = Stream
127+
self._sts_token_manager: StsTokenManager | None = None
79128

80129
self.chat = resources.AsyncChat(self)
81130

131+
def _get_endpoint_sts_token(self, endpoint_id: str):
132+
if self._sts_token_manager is None:
133+
if self.ak is None or self.sk is None:
134+
raise ArkAPIError("must set ak and sk before get endpoint token.")
135+
self._sts_token_manager = StsTokenManager(self.ak, self.sk)
136+
return self._sts_token_manager.get(endpoint_id)
137+
138+
139+
class StsTokenManager(object):
140+
141+
# The time at which we'll attempt to refresh, but not
142+
# block if someone else is refreshing.
143+
_advisory_refresh_timeout: int = _DEFAULT_ADVISORY_REFRESH_TIMEOUT
144+
# The time at which all threads will block waiting for
145+
# refreshed credentials.
146+
_mandatory_refresh_timeout: int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
147+
148+
def __init__(self, ak: str, sk: str):
149+
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(lambda: ("", 0))
150+
self._refresh_lock = threading.Lock()
151+
152+
import volcenginesdkcore
153+
154+
configuration = volcenginesdkcore.Configuration()
155+
configuration.ak = ak
156+
configuration.sk = sk
157+
configuration.region = "cn-beijing"
158+
159+
volcenginesdkcore.Configuration.set_default(configuration)
160+
self.api_instance = volcenginesdkark.ARKApi()
161+
162+
def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
163+
if refresh_in is None:
164+
refresh_in = self._advisory_refresh_timeout
165+
166+
return self._endpoint_sts_tokens[ep][1] - time.time() < refresh_in
167+
168+
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False):
169+
if ttl < _DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2:
170+
raise ArkAPIError("ttl should not be under {} seconds.".format(_DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2))
171+
172+
try:
173+
api_key, expired_time = self._load_api_key(
174+
ep, ttl
175+
)
176+
self._endpoint_sts_tokens[ep] = (api_key, expired_time)
177+
except ApiException as e:
178+
if is_mandatory:
179+
raise ArkAPIError("load api key cause error: e={}".format(e))
180+
else:
181+
logging.error("load api key cause error: e={}".format(e))
182+
183+
def _refresh(self, ep: str):
184+
if not self._need_refresh(ep, self._advisory_refresh_timeout):
185+
return
186+
187+
if self._refresh_lock.acquire(False):
188+
if not self._need_refresh(ep, self._advisory_refresh_timeout):
189+
return
190+
191+
try:
192+
is_mandatory_refresh = self._need_refresh(
193+
ep, self._mandatory_refresh_timeout
194+
)
195+
196+
self._protected_refresh(ep, is_mandatory=is_mandatory_refresh)
197+
return
198+
finally:
199+
self._refresh_lock.release()
200+
elif self._need_refresh(ep, self._mandatory_refresh_timeout):
201+
with self._refresh_lock:
202+
if not self._need_refresh(ep, self._mandatory_refresh_timeout):
203+
return
204+
205+
self._protected_refresh(ep, is_mandatory=True)
206+
207+
def get(self, ep: str) -> str:
208+
self._refresh(ep)
209+
return self._endpoint_sts_tokens[ep][0]
210+
211+
def _load_api_key(self, ep: str, duration_seconds: int) -> Tuple[str, int]:
212+
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
213+
duration_seconds=duration_seconds,
214+
resource_type="endpoint",
215+
resource_ids=[ep],
216+
)
217+
resp: volcenginesdkark.GetApiKeyResponse = self.api_instance.get_api_key(
218+
get_api_key_request
219+
)
82220

83-
@property
84-
def auth_headers(self) -> dict[str, str]:
85-
api_key = self.api_key
86-
return {"Authorization": f"Bearer {api_key}"}
221+
return resp.api_key, resp.expired_time

volcenginesdkarkruntime/_constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@
1818

1919
INITIAL_RETRY_DELAY = 0.5
2020
MAX_RETRY_DELAY = 8.0
21+
22+
_DEFAULT_MANDATORY_REFRESH_TIMEOUT = 10 * 60 # 10 min
23+
_DEFAULT_ADVISORY_REFRESH_TIMEOUT = 30 * 60 # 30 min
24+
_DEFAULT_STS_TIMEOUT = 7 * 24 * 60 * 60 # 7 days

volcenginesdkarkruntime/_exceptions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class ArkError(Exception):
2525

2626
class ArkAPIError(ArkError):
2727
message: str
28-
request: httpx.Request
29-
body: object | None
30-
request_id: str
28+
request: Optional[httpx.Request] = None
29+
body: Optional[object] = None
30+
request_id: Optional[str] = None
3131
"""The API response body.
3232
3333
If the API responded with a valid JSON structure then this property will be the
@@ -45,10 +45,10 @@ class ArkAPIError(ArkError):
4545
def __init__(
4646
self,
4747
message: str,
48-
request: httpx.Request,
48+
request: Optional[httpx.Request] = None,
4949
*,
50-
body: object | None,
51-
request_id: str,
50+
body: Optional[object] = None,
51+
request_id: Optional[str] = None,
5252
) -> None:
5353
super().__init__(message)
5454
self.request = request

volcenginesdkarkruntime/_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pydantic
99
import pydantic.generics
1010
from pydantic.fields import FieldInfo
11-
from pydantic.v1.typing import get_origin
1211
from typing_extensions import (
1312
Literal,
1413
ClassVar,
@@ -28,6 +27,7 @@
2827
is_literal_type,
2928
GenericModel as BaseGenericModel,
3029
parse_obj,
30+
get_origin,
3131
ConfigDict,
3232
)
3333
from ._types import ModelT
@@ -37,7 +37,7 @@
3737
from ._utils._utils import is_mapping, is_list, coerce_boolean, lru_cache
3838

3939
if TYPE_CHECKING:
40-
from pydantic_core.core_schema import ModelField
40+
from pydantic_core.core_schema import ModelField, ModelFieldsSchema
4141

4242
__all__ = ["BaseModel", "GenericModel"]
4343

volcenginesdkarkruntime/_response.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,21 @@
1919
)
2020

2121
import httpx
22-
import pydantic
2322
from typing_extensions import ParamSpec, override, get_origin
2423

25-
from ._constants import CLIENT_REQUEST_HEADER, RAW_RESPONSE_HEADER
26-
from ._exceptions import ArkError, ArkAPIResponseValidationError
24+
from ._constants import CLIENT_REQUEST_HEADER, RAW_RESPONSE_HEADER # type: ignore
25+
from ._exceptions import ArkError
2726
from ._streaming import Stream, AsyncStream
28-
from ._types import ResponseT
29-
from ._utils import extract_type_arg, is_annotated_type
27+
from ._utils import extract_type_arg, is_annotated_type # type: ignore
3028

3129
if TYPE_CHECKING:
3230
from ._base_client import BaseClient
3331

3432
P = ParamSpec("P")
3533
R = TypeVar("R")
3634
_T = TypeVar("_T")
37-
_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]")
38-
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]")
35+
_APIResponseT = TypeVar("_APIResponseT", bound="ArkAPIResponse[Any]")
36+
_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="ArkAsyncAPIResponse[Any]")
3937
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
4038

4139
log: logging.Logger = logging.getLogger(__name__)
@@ -118,7 +116,7 @@ def _parse(self) -> R:
118116
self._stream_cls(
119117
cast_to=extract_type_arg(self._stream_cls),
120118
response=self.http_response,
121-
client=self._client,
119+
client=self._client, # type: ignore
122120
),
123121
)
124122

@@ -166,7 +164,7 @@ def _parse(self) -> R:
166164

167165
class ArkAPIResponse(BaseAPIResponse[R]):
168166
@property
169-
def request_id(self) -> str | None:
167+
def request_id(self) -> str:
170168
return self.http_response.headers.get(CLIENT_REQUEST_HEADER, "")
171169

172170
def parse(self) -> R:
@@ -206,7 +204,7 @@ def iter_lines(self) -> Iterator[str]:
206204

207205
class ArkAsyncAPIResponse(BaseAPIResponse[R]):
208206
@property
209-
def request_id(self) -> str | None:
207+
def request_id(self) -> str:
210208
return self.http_response.headers.get(CLIENT_REQUEST_HEADER, "")
211209

212210
async def parse(self) -> R:

0 commit comments

Comments
 (0)