Skip to content

Commit 92cadb0

Browse files
author
BitsAdmin
committed
Merge branch 'feat/add_file_sdk' into 'integration_2025-11-13_1083916009218'
feat: [development task] ark runtime (1828603) See merge request iaasng/volcengine-python-sdk!918
2 parents 084685a + 9428ca3 commit 92cadb0

File tree

226 files changed

+1563
-408
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

226 files changed

+1563
-408
lines changed

volcenginesdkarkruntime/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0

volcenginesdkarkruntime/_base_client.py

Lines changed: 149 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0
@@ -19,19 +18,20 @@
1918
from random import random
2019
from types import TracebackType
2120
from typing import (
22-
Type,
23-
Dict,
24-
TypeVar,
25-
Any,
26-
Optional,
27-
cast,
2821
TYPE_CHECKING,
22+
Any,
23+
Dict,
24+
Type,
2925
Union,
3026
Generic,
27+
Mapping,
28+
TypeVar,
3129
Iterable,
32-
AsyncIterator,
3330
Iterator,
34-
Generator
31+
Optional,
32+
Generator,
33+
AsyncIterator,
34+
cast,
3535
)
3636
from typing_extensions import override
3737

@@ -44,6 +44,7 @@
4444
from httpx._types import RequestFiles
4545

4646
from . import _exceptions # type: ignore
47+
from ._qs import Querystring
4748
from ._constants import (
4849
DEFAULT_MAX_RETRIES,
4950
DEFAULT_TIMEOUT,
@@ -60,6 +61,7 @@
6061
ArkAPIStatusError,
6162
ArkAPIResponseValidationError,
6263
)
64+
from ._files import to_httpx_files, async_to_httpx_files
6365
from ._models import construct_type, GenericModel
6466
from ._request_options import RequestOptions, ExtraRequestOptions
6567
from ._response import ArkAPIResponse, ArkAsyncAPIResponse
@@ -71,14 +73,17 @@
7173
PostParser,
7274
Body,
7375
Query,
76+
HttpxRequestFiles,
7477
)
75-
from ._utils._utils import _gen_request_id, is_given, is_mapping
76-
from ._compat import model_copy, PYDANTIC_V2
78+
from ._utils._utils import _gen_request_id, is_given, is_mapping, is_dict, is_list
79+
from ._compat import model_copy, PYDANTIC_V2, model_dump
7780

7881
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
7982
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
8083

8184
_T = TypeVar("_T")
85+
_T_co = TypeVar("_T_co", covariant=True)
86+
8287
_StreamT = TypeVar("_StreamT", bound=Stream[Any])
8388
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
8489

@@ -177,6 +182,10 @@ def __init__(
177182
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number`"
178183
)
179184

185+
@property
186+
def qs(self) -> Querystring:
187+
return Querystring()
188+
180189
@property
181190
def auth_headers(self) -> dict[str, str]:
182191
return {}
@@ -219,9 +228,13 @@ def _should_stream_response_body(self, request: httpx.Request) -> bool:
219228
def _build_request(
220229
self,
221230
options: RequestOptions,
231+
*,
232+
retries_taken: int = 0,
222233
) -> httpx.Request:
223234
if log.isEnabledFor(logging.DEBUG):
224-
log.debug("Request options: %s", options.model_dump(exclude_unset=True))
235+
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
236+
237+
kwargs: dict[str, Any] = {}
225238

226239
body = options.body
227240
if options.extra_body is not None:
@@ -236,16 +249,105 @@ def _build_request(
236249

237250
headers = self._build_headers(options)
238251
params = options.params
252+
content_type = headers.get("Content-Type")
253+
files = options.files
254+
255+
# If the given Content-Type header is multipart/form-data then it
256+
# has to be removed so that httpx can generate the header with
257+
# additional information for us as it has to be in this form
258+
# for the server to be able to correctly parse the request:
259+
# multipart/form-data; boundary=---abc--
260+
if content_type is not None and content_type.startswith("multipart/form-data"):
261+
if "boundary" not in content_type:
262+
# only remove the header if the boundary hasn't been explicitly set
263+
# as the caller doesn't want httpx to come up with their own boundary
264+
headers.pop("Content-Type")
265+
266+
# As we are now sending multipart/form-data instead of application/json
267+
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding
268+
if body:
269+
if not is_dict(body):
270+
raise TypeError(
271+
f"Expected query input to be a dictionary for multipart requests but got {type(body)} instead."
272+
)
273+
kwargs["data"] = self._serialize_multipartform(body)
274+
275+
# httpx determines whether or not to send a "multipart/form-data"
276+
# request based on the truthiness of the "files" argument.
277+
# This gets around that issue by generating a dict value that
278+
# evaluates to true.
279+
#
280+
# https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186
281+
if not files:
282+
files = cast(HttpxRequestFiles, ForceMultipartDict())
283+
284+
prepared_url = self._prepare_url(options.url)
285+
if "_" in prepared_url.host:
286+
# work around https://github.com/encode/httpx/discussions/2880
287+
kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")}
288+
289+
is_body_allowed = options.method.lower() != "get"
290+
291+
if is_body_allowed:
292+
if isinstance(body, bytes):
293+
kwargs["content"] = body
294+
else:
295+
kwargs["json"] = body if is_given(body) else None
296+
kwargs["files"] = files
297+
else:
298+
headers.pop("Content-Type", None)
299+
kwargs.pop("data", None)
239300

301+
# TODO: report this error to httpx
240302
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
241303
headers=headers,
242-
timeout=options.timeout if options.timeout else self.timeout,
304+
timeout=self.timeout
305+
if isinstance(options.timeout, NotGiven)
306+
else options.timeout,
243307
method=options.method,
244-
url=self._prepare_url(options.url),
245-
params=params, # type: ignore
246-
json=body,
308+
url=prepared_url,
309+
# the `Query` type that we use is incompatible with qs'
310+
# `Params` type as it needs to be typed as `Mapping[str, object]`
311+
# so that passing a `TypedDict` doesn't cause an error.
312+
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
313+
params=self.qs.stringify(cast(Mapping[str, Any], params))
314+
if params
315+
else None,
316+
**kwargs,
247317
)
248318

319+
def _serialize_multipartform(
320+
self, data: Mapping[object, object]
321+
) -> dict[str, object]:
322+
items = self.qs.stringify_items(
323+
# TODO: type ignore is required as stringify_items is well typed but we can't be
324+
# well typed without heavy validation.
325+
data, # type: ignore
326+
array_format="brackets",
327+
)
328+
serialized: dict[str, object] = {}
329+
for key, value in items:
330+
existing = serialized.get(key)
331+
332+
if not existing:
333+
serialized[key] = value
334+
continue
335+
336+
# If a value has already been set for this key then that
337+
# means we're sending data like `array[]=[1, 2, 3]` and we
338+
# need to tell httpx that we want to send multiple values with
339+
# the same key which is done by using a list or a tuple.
340+
#
341+
# Note: 2d arrays should never result in the same key at both
342+
# levels so it's safe to assume that if the value is a list,
343+
# it was because we changed it to be a list.
344+
if is_list(existing):
345+
existing.append(value)
346+
else:
347+
serialized[key] = [existing, value]
348+
349+
return serialized
350+
249351
def _calculate_retry_timeout(
250352
self,
251353
remaining_retries: int,
@@ -595,7 +697,7 @@ def post(
595697
opts = RequestOptions.construct( # type: ignore
596698
method="post",
597699
url=path,
598-
files=files,
700+
files=to_httpx_files(files),
599701
body=body,
600702
**options,
601703
)
@@ -678,7 +780,9 @@ def get_api_list(
678780
options: ExtraRequestOptions = {},
679781
method: str = "get",
680782
) -> AsyncPageT:
681-
opts = RequestOptions.construct(method=method, url=path, json_data=body, **options)
783+
opts = RequestOptions.construct(
784+
method=method, url=path, json_data=body, **options
785+
)
682786
return self._request_api_list(model, page, opts)
683787

684788
def _request_api_list(
@@ -815,7 +919,7 @@ async def post(
815919
method="post",
816920
url=path,
817921
body=body,
818-
files=files,
922+
files=await async_to_httpx_files(files),
819923
**options,
820924
)
821925

@@ -890,7 +994,9 @@ async def get_api_list(
890994
options: ExtraRequestOptions = {},
891995
method: str = "get",
892996
) -> AsyncPageT:
893-
opts = RequestOptions.construct(method=method, url=path, json_data=body, **options)
997+
opts = RequestOptions.construct(
998+
method=method, url=path, json_data=body, **options
999+
)
8941000
return await self._request_api_list(model, page, opts)
8951001

8961002
async def _request_api_list(
@@ -1229,7 +1335,9 @@ def get_next_page(self: SyncPageT) -> SyncPageT:
12291335
)
12301336

12311337
options = self._info_to_options(info)
1232-
return self._client._request_api_list(self._model, page=self.__class__, options=options)
1338+
return self._client._request_api_list(
1339+
self._model, page=self.__class__, options=options
1340+
)
12331341

12341342

12351343
class AsyncPaginator(Generic[_T, AsyncPageT]):
@@ -1309,4 +1417,23 @@ async def get_next_page(self: AsyncPageT) -> AsyncPageT:
13091417
)
13101418

13111419
options = self._info_to_options(info)
1312-
return await self._client._request_api_list(self._model, page=self.__class__, options=options)
1420+
return await self._client._request_api_list(
1421+
self._model, page=self.__class__, options=options
1422+
)
1423+
1424+
1425+
class ForceMultipartDict(Dict[str, None]):
1426+
def __bool__(self) -> bool:
1427+
return True
1428+
1429+
1430+
def _merge_mappings(
1431+
obj1: Mapping[_T_co, Union[_T, None]],
1432+
obj2: Mapping[_T_co, Union[_T, None]],
1433+
) -> Dict[_T_co, _T]:
1434+
"""Merge two mappings of the same type, removing any values that are instances of `Omit`.
1435+
1436+
In cases with duplicate keys the second mapping takes precedence.
1437+
"""
1438+
merged = {**obj1, **obj2}
1439+
return {key: value for key, value in merged.items() if value is not None}

volcenginesdkarkruntime/_client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0
@@ -65,6 +64,7 @@ class Ark(SyncAPIClient):
6564
batch: batch.Batch
6665
model_breaker_map: dict[str, ModelBreaker]
6766
model_breaker_lock: threading.Lock
67+
files: resources.Files
6868

6969
def __init__(
7070
self,
@@ -134,6 +134,7 @@ def __init__(
134134
self.batch = batch.Batch(self)
135135
self.model_breaker_map = defaultdict(ModelBreaker)
136136
self.model_breaker_lock = threading.Lock()
137+
self.files = resources.Files(self)
137138
# self.classification = resources.Classification(self)
138139

139140
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -143,7 +144,9 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
143144
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
144145
return self._sts_token_manager.get(endpoint_id)
145146

146-
def _get_endpoint_certificate(self, endpoint_id: str) -> Tuple[key_agreement_client, str, str, float]:
147+
def _get_endpoint_certificate(
148+
self, endpoint_id: str
149+
) -> Tuple[key_agreement_client, str, str, float]:
147150
if self._certificate_manager is None:
148151
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
149152
if (
@@ -194,6 +197,7 @@ class AsyncArk(AsyncAPIClient):
194197
batch: batch.AsyncBatch
195198
model_breaker_map: dict[str, ModelBreaker]
196199
model_breaker_lock: asyncio.Lock
200+
files: resources.AsyncFiles
197201

198202
def __init__(
199203
self,
@@ -263,6 +267,7 @@ def __init__(
263267
self.batch = batch.AsyncBatch(self)
264268
self.model_breaker_map = defaultdict(ModelBreaker)
265269
self.model_breaker_lock = asyncio.Lock()
270+
self.files = resources.AsyncFiles(self)
266271
# self.classification = resources.AsyncClassification(self)
267272

268273
def _get_endpoint_sts_token(self, endpoint_id: str):
@@ -279,7 +284,9 @@ def _get_bot_sts_token(self, bot_id: str):
279284
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
280285
return self._sts_token_manager.get(bot_id, resource_type="bot")
281286

282-
def _get_endpoint_certificate(self, endpoint_id: str) -> Tuple[key_agreement_client, str, str, float]:
287+
def _get_endpoint_certificate(
288+
self, endpoint_id: str
289+
) -> Tuple[key_agreement_client, str, str, float]:
283290
if self._certificate_manager is None:
284291
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
285292
if (
@@ -429,7 +436,9 @@ def __init__(
429436
base_url: str | URL = BASE_URL,
430437
api_key: str | None = None,
431438
):
432-
self._certificate_manager: Dict[str, Tuple[key_agreement_client, str, str, float]] = {}
439+
self._certificate_manager: Dict[
440+
str, Tuple[key_agreement_client, str, str, float]
441+
] = {}
433442

434443
# local cache prepare
435444
self._init_local_cert_cache()
@@ -542,7 +551,8 @@ def _init_local_cert_cache(self):
542551
pass
543552
except Exception as e:
544553
raise ArkAPIError(
545-
"failed to create certificate directory %s: %s\n" % (self._cert_storage_path, e)
554+
"failed to create certificate directory %s: %s\n"
555+
% (self._cert_storage_path, e)
546556
)
547557

548558
def get(self, ep: str) -> Tuple[key_agreement_client, str, str, float]:
@@ -558,9 +568,7 @@ def get(self, ep: str) -> Tuple[key_agreement_client, str, str, float]:
558568
self._save_cert_to_file(ep, cert_pem)
559569
ring, key, exp_time = get_cert_info(cert_pem)
560570
self._certificate_manager[ep] = (
561-
key_agreement_client(
562-
certificate_pem_string=cert_pem
563-
),
571+
key_agreement_client(certificate_pem_string=cert_pem),
564572
ring,
565573
key,
566574
exp_time,

volcenginesdkarkruntime/_compat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0

volcenginesdkarkruntime/_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0

volcenginesdkarkruntime/_exceptions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0

volcenginesdkarkruntime/_files.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# Copyright (c) [2025] [OpenAI]
32
# Copyright (c) [2025] [ByteDance Ltd. and/or its affiliates.]
43
# SPDX-License-Identifier: Apache-2.0

0 commit comments

Comments
 (0)