Skip to content

Commit dd3a034

Browse files
author
hexiaochun
committed
fix: add embeddings; misc
1 parent d3b2ec3 commit dd3a034

21 files changed

+569
-115
lines changed

volcenginesdkarkruntime/_base_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(self, **kwargs: Any) -> None:
4949
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
5050
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
5151
kwargs.setdefault("follow_redirects", True)
52-
kwargs.setdefault("base_url", BASE_URL)
5352
super().__init__(**kwargs)
5453

5554

volcenginesdkarkruntime/_client.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
ak: str | None = None,
4040
sk: str | None = None,
4141
api_key: str | None = None,
42+
region: str = "cn-beijing",
4243
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
4344
max_retries: int = DEFAULT_MAX_RETRIES,
4445
http_client: Client | None = None,
@@ -66,6 +67,7 @@ def __init__(
6667
self.ak = ak
6768
self.sk = sk
6869
self.api_key = api_key
70+
self.region = region
6971

7072
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
7173

@@ -81,12 +83,15 @@ def __init__(
8183
self._sts_token_manager: StsTokenManager | None = None
8284

8385
self.chat = resources.Chat(self)
86+
self.embeddings = resources.Embeddings(self)
87+
# self.tokenization = resources.Tokenization(self)
88+
# self.classification = resources.Classification(self)
8489

8590
def _get_endpoint_sts_token(self, endpoint_id: str):
8691
if self._sts_token_manager is None:
8792
if self.ak is None or self.sk is None:
8893
raise ArkAPIError("must set ak and sk before get endpoint token.")
89-
self._sts_token_manager = StsTokenManager(self.ak, self.sk)
94+
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
9095
return self._sts_token_manager.get(endpoint_id)
9196

9297
@property
@@ -105,6 +110,7 @@ def __init__(
105110
sk: str | None = None,
106111
api_key: str | None = None,
107112
base_url: str | URL = BASE_URL,
113+
region: str = "cn-beijing",
108114
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
109115
max_retries: int = DEFAULT_MAX_RETRIES,
110116
http_client: AsyncClient | None = None,
@@ -131,6 +137,7 @@ def __init__(
131137
self.ak = ak
132138
self.sk = sk
133139
self.api_key = api_key
140+
self.region = region
134141

135142
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
136143

@@ -146,12 +153,15 @@ def __init__(
146153
self._sts_token_manager: StsTokenManager | None = None
147154

148155
self.chat = resources.AsyncChat(self)
156+
self.embeddings = resources.AsyncEmbeddings(self)
157+
# self.tokenization = resources.AsyncTokenization(self)
158+
# self.classification = resources.AsyncClassification(self)
149159

150160
def _get_endpoint_sts_token(self, endpoint_id: str):
151161
if self._sts_token_manager is None:
152162
if self.ak is None or self.sk is None:
153163
raise ArkAPIError("must set ak and sk before get endpoint token.")
154-
self._sts_token_manager = StsTokenManager(self.ak, self.sk)
164+
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
155165
return self._sts_token_manager.get(endpoint_id)
156166

157167
@property
@@ -169,7 +179,7 @@ class StsTokenManager(object):
169179
# refreshed credentials.
170180
_mandatory_refresh_timeout: int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
171181

172-
def __init__(self, ak: str, sk: str):
182+
def __init__(self, ak: str, sk: str, region: str):
173183
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(lambda: ("", 0))
174184
self._refresh_lock = threading.Lock()
175185

@@ -178,7 +188,8 @@ def __init__(self, ak: str, sk: str):
178188
configuration = volcenginesdkcore.Configuration()
179189
configuration.ak = ak
180190
configuration.sk = sk
181-
configuration.region = "cn-beijing"
191+
configuration.region = region
192+
configuration.schema = "https"
182193

183194
volcenginesdkcore.Configuration.set_default(configuration)
184195
self.api_instance = volcenginesdkark.ARKApi()
@@ -190,8 +201,8 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
190201
return self._endpoint_sts_tokens[ep][1] - time.time() < refresh_in
191202

192203
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False):
193-
if ttl < _DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2:
194-
raise ArkAPIError("ttl should not be under {} seconds.".format(_DEFAULT_ADVISORY_REFRESH_TIMEOUT * 2))
204+
if ttl < self._advisory_refresh_timeout * 2:
205+
raise ArkAPIError("ttl should not be under {} seconds.".format(self._advisory_refresh_timeout * 2))
195206

196207
try:
197208
api_key, expired_time = self._load_api_key(

volcenginesdkarkruntime/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
# default timeout is 1 minutes
1313
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=60.0)
14-
DEFAULT_MAX_RETRIES = 1
14+
DEFAULT_MAX_RETRIES = 2
1515
DEFAULT_CONNECTION_LIMITS = httpx.Limits(
1616
max_connections=1000, max_keepalive_connections=100
1717
)

volcenginesdkarkruntime/_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
self.type = None
6767

6868
def __str__(self):
69-
return f"{self.message} (request_id: {self.request_id})"
69+
return f"{self.message}, request_id: {self.request_id}"
7070

7171

7272
class ArkAPIResponseValidationError(ArkAPIError):

volcenginesdkarkruntime/_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@ def __repr__(self) -> str:
6262

6363

6464
NOT_GIVEN = NotGiven()
65+
66+
Headers = Dict[str, str]
67+
Query = Dict[str, object]
68+
Body = Dict[str, object]

volcenginesdkarkruntime/_utils/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,5 @@ def _insert_sts_token(args, kwargs):
8282
model = kwargs.get("model", "")
8383
if ark_client.api_key is None and model and model.startswith("ep-") and ark_client.ak and ark_client.sk:
8484
default_auth_header = {"Authorization": "Bearer " + ark_client._get_endpoint_sts_token(model)}
85-
kwargs["extra_headers"] = {**default_auth_header, **kwargs.get("extra_headers", {})}
85+
extra_headers = kwargs.get("extra_headers") if kwargs.get("extra_headers") else {}
86+
kwargs["extra_headers"] = {**default_auth_header, **extra_headers}
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
from .chat import Chat, AsyncChat
2+
from .embeddings import Embeddings, AsyncEmbeddings
3+
from .tokenization import Tokenization, AsyncTokenization
4+
from .classification import Classification, AsyncClassification
25

3-
__all__ = ["Chat", "AsyncChat"]
6+
__all__ = [
7+
"Chat",
8+
"AsyncChat",
9+
"Embeddings",
10+
"AsyncEmbeddings",
11+
"Tokenization",
12+
"AsyncTokenization"
13+
]

volcenginesdkarkruntime/resources/chat/completions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import httpx
66
from typing_extensions import Literal
77

8+
from ..._types import Body, Query, Headers
89
from ..._utils._utils import with_sts_token, async_with_sts_token
910
from ..._base_client import make_request_options
1011
from ..._resource import SyncAPIResource, AsyncAPIResource
@@ -59,9 +60,9 @@ def create(
5960
top_logprobs: Optional[int] | None = None,
6061
top_p: Optional[float] | None = None,
6162
user: str | None = None,
62-
extra_headers: Dict[str, str] | None = None,
63-
extra_query: Dict[str, object] | None = None,
64-
extra_body: Dict[str, object] | None = None,
63+
extra_headers: Headers | None = None,
64+
extra_query: Query | None = None,
65+
extra_body: Body | None = None,
6566
timeout: float | httpx.Timeout | None = None,
6667
) -> ChatCompletion | Stream[ChatCompletionChunk]:
6768
return self._post(
@@ -127,9 +128,9 @@ async def create(
127128
top_logprobs: Optional[int] | None = None,
128129
top_p: Optional[float] | None = None,
129130
user: str | None = None,
130-
extra_headers: Dict[str, str] | None = None,
131-
extra_query: Dict[str, object] | None = None,
132-
extra_body: Dict[str, object] | None = None,
131+
extra_headers: Headers | None = None,
132+
extra_query: Query | None = None,
133+
extra_body: Body | None = None,
133134
timeout: float | httpx.Timeout | None = None,
134135
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
135136
return await self._post(
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
from typing import List
4+
5+
import httpx
6+
7+
from .._base_client import (
8+
make_request_options,
9+
)
10+
from .._compat import cached_property
11+
from .._resource import SyncAPIResource, AsyncAPIResource
12+
from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper
13+
from .._types import Body, Query, Headers
14+
from .._utils._utils import with_sts_token, async_with_sts_token
15+
from ..types.create_classification_response import CreateClassificationResponse
16+
17+
__all__ = ["Classification", "AsyncClassification"]
18+
19+
20+
class Classification(SyncAPIResource):
21+
@cached_property
22+
def with_raw_response(self) -> ClassificationWithRawResponse:
23+
return ClassificationWithRawResponse(self)
24+
25+
@with_sts_token
26+
def create(
27+
self,
28+
*,
29+
query: str,
30+
model: str,
31+
labels: List[str],
32+
user: str | None = None,
33+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
34+
# The extra values given here take precedence over values defined on the client or passed to this method.
35+
extra_headers: Headers | None = None,
36+
extra_query: Query | None = None,
37+
extra_body: Body | None = None,
38+
timeout: float | httpx.Timeout | None = None,
39+
) -> CreateClassificationResponse:
40+
return self._post(
41+
"/classification",
42+
body={
43+
"query": query,
44+
"model": model,
45+
"labels": labels
46+
},
47+
options=make_request_options(
48+
extra_headers=extra_headers,
49+
extra_query=extra_query,
50+
extra_body=extra_body,
51+
timeout=timeout,
52+
),
53+
cast_to=CreateClassificationResponse,
54+
)
55+
56+
57+
class AsyncClassification(AsyncAPIResource):
58+
@cached_property
59+
def with_raw_response(self) -> AsyncClassificationWithRawResponse:
60+
return AsyncClassificationWithRawResponse(self)
61+
62+
@async_with_sts_token
63+
async def create(
64+
self,
65+
*,
66+
query: str,
67+
model: str,
68+
labels: List[str],
69+
user: str | None = None,
70+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
71+
# The extra values given here take precedence over values defined on the client or passed to this method.
72+
extra_headers: Headers | None = None,
73+
extra_query: Query | None = None,
74+
extra_body: Body | None = None,
75+
timeout: float | httpx.Timeout | None = None,
76+
) -> CreateClassificationResponse:
77+
return await self._post(
78+
"/classification",
79+
body={
80+
"query": query,
81+
"model": model,
82+
"labels": labels
83+
},
84+
options=make_request_options(
85+
extra_headers=extra_headers,
86+
extra_query=extra_query,
87+
extra_body=extra_body,
88+
timeout=timeout,
89+
),
90+
cast_to=CreateClassificationResponse,
91+
)
92+
93+
94+
class ClassificationWithRawResponse:
95+
def __init__(self, classification: Classification) -> None:
96+
self._classification = classification
97+
98+
self.create = to_raw_response_wrapper(
99+
classification.create,
100+
)
101+
102+
103+
class AsyncClassificationWithRawResponse:
104+
def __init__(self, classification: AsyncClassification) -> None:
105+
self._classification = classification
106+
107+
self.create = async_to_raw_response_wrapper(
108+
classification.create,
109+
)

0 commit comments

Comments
 (0)