Skip to content

Commit 30ecf48

Browse files
author
潘婉宁
committed
feat: support json schema
1 parent b033d1f commit 30ecf48

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+4746
-113
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
"httpx>=0.23.0, <1",
3030
"anyio>=3.5.0, <5",
3131
"cached-property; python_version < '3.8'",
32-
"cryptography>=43.0.3, <43.0.4"
32+
"cryptography>=43.0.3, <43.0.4",
33+
"jiter>=0.4.0, <1"
3334
]
3435
},
3536
)

volcenginesdkarkruntime/_base_client.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@
4242
ArkAPIResponseValidationError,
4343
)
4444
from ._models import construct_type
45+
from ._request_options import RequestOptions, ExtraRequestOptions
4546
from ._response import ArkAPIResponse, ArkAsyncAPIResponse
4647
from ._streaming import SSEDecoder, SSEBytesDecoder, Stream, AsyncStream
47-
from ._types import ResponseT, NotGiven, NOT_GIVEN
48-
from ._request_options import RequestOptions, ExtraRequestOptions
49-
from ._utils._utils import _gen_request_id
48+
from ._types import ResponseT, NotGiven, NOT_GIVEN, PostParser
49+
from ._utils._utils import _gen_request_id, is_given
5050

5151
_T = TypeVar("_T")
5252
_StreamT = TypeVar("_StreamT", bound=Stream[Any])
@@ -90,6 +90,7 @@ def make_request_options(
9090
extra_query: Dict[str, Any] | None = None,
9191
extra_body: Dict[str, Any] | None = None,
9292
timeout: float | httpx.Timeout | None = None,
93+
post_parser: PostParser | NotGiven = NOT_GIVEN,
9394
) -> ExtraRequestOptions:
9495
options: ExtraRequestOptions = {}
9596
if extra_headers is not None:
@@ -107,6 +108,10 @@ def make_request_options(
107108
if timeout:
108109
options["timeout"] = timeout
109110

111+
if is_given(post_parser):
112+
# internal
113+
options["post_parser"] = post_parser # type: ignore
114+
110115
return options
111116

112117

@@ -524,6 +529,7 @@ def _process_response(
524529
self,
525530
*,
526531
cast_to: Type[ResponseT],
532+
options: RequestOptions,
527533
response: httpx.Response,
528534
stream: bool,
529535
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
@@ -537,6 +543,7 @@ def _process_response(
537543
cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast]
538544
stream=stream,
539545
stream_cls=stream_cls,
546+
options=options,
540547
)
541548
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
542549
return cast(ResponseT, api_response)
@@ -557,6 +564,7 @@ def post(
557564
opts = RequestOptions.construct( # type: ignore
558565
method="post",
559566
url=path,
567+
files=files,
560568
body=body,
561569
**options,
562570
)
@@ -618,6 +626,7 @@ def post_without_retry(
618626
method="post",
619627
url=path,
620628
body=body,
629+
files=files,
621630
**options,
622631
)
623632

@@ -745,6 +754,7 @@ async def post(
745754
method="post",
746755
url=path,
747756
body=body,
757+
files=files,
748758
**options,
749759
)
750760

@@ -801,6 +811,7 @@ async def post_without_retry(
801811
method="post",
802812
url=path,
803813
body=body,
814+
files=files,
804815
**options,
805816
)
806817

@@ -910,6 +921,7 @@ async def _request(
910921
response=response,
911922
stream=stream,
912923
stream_cls=stream_cls,
924+
options=options,
913925
)
914926

915927
async def _retry_request(
@@ -945,6 +957,7 @@ async def _process_response(
945957
self,
946958
*,
947959
cast_to: Type[ResponseT],
960+
options: RequestOptions,
948961
response: httpx.Response,
949962
stream: bool,
950963
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
@@ -958,6 +971,7 @@ async def _process_response(
958971
cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast]
959972
stream=stream,
960973
stream_cls=stream_cls,
974+
options=options,
961975
)
962976
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
963977
return cast(ResponseT, api_response)

volcenginesdkarkruntime/_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import volcenginesdkark
1717

1818
from . import resources
19+
from .resources.beta import beta
1920
from ._base_client import SyncAPIClient, AsyncAPIClient
2021
from ._constants import (
2122
DEFAULT_MAX_RETRIES,
@@ -35,6 +36,7 @@
3536

3637

3738
class Ark(SyncAPIClient):
39+
beta: beta.Beta
3840
chat: resources.Chat
3941
bot_chat: resources.BotChat
4042
embeddings: resources.Embeddings
@@ -100,6 +102,7 @@ def __init__(
100102
self._sts_token_manager: StsTokenManager | None = None
101103
self._certificate_manager: E2ECertificateManager | None = None
102104

105+
self.beta = beta.Beta(self)
103106
self.chat = resources.Chat(self)
104107
self.bot_chat = resources.BotChat(self)
105108
self.embeddings = resources.Embeddings(self)
@@ -155,6 +158,7 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
155158

156159

157160
class AsyncArk(AsyncAPIClient):
161+
beta: beta.AsyncBeta
158162
chat: resources.AsyncChat
159163
bot_chat: resources.AsyncBotChat
160164
embeddings: resources.AsyncEmbeddings
@@ -220,6 +224,7 @@ def __init__(
220224
self._sts_token_manager: StsTokenManager | None = None
221225
self._certificate_manager: E2ECertificateManager | None = None
222226

227+
self.beta = beta.AsyncBeta(self)
223228
self.chat = resources.AsyncChat(self)
224229
self.bot_chat = resources.AsyncBotChat(self)
225230
self.embeddings = resources.AsyncEmbeddings(self)

volcenginesdkarkruntime/_compat.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
44
from datetime import date, datetime
5-
from typing_extensions import Self
6-
5+
from typing_extensions import Self, Literal
76
import pydantic
87
from pydantic.fields import FieldInfo
8+
from ._types import IncEx
99

1010
_T = TypeVar("_T")
1111
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
@@ -68,7 +68,6 @@ def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
6868
parse_datetime as parse_datetime,
6969
)
7070

71-
7271
# refactored config
7372
if TYPE_CHECKING:
7473
from pydantic import ConfigDict as ConfigDict
@@ -138,17 +137,25 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
138137
def model_dump(
139138
model: pydantic.BaseModel,
140139
*,
140+
exclude: IncEx | None = None,
141141
exclude_unset: bool = False,
142142
exclude_defaults: bool = False,
143+
warnings: bool = True,
144+
mode: Literal["json", "python"] = "python",
143145
) -> dict[str, Any]:
144-
if PYDANTIC_V2:
146+
if PYDANTIC_V2 or hasattr(model, "model_dump"):
145147
return model.model_dump(
148+
mode=mode,
149+
exclude=exclude,
146150
exclude_unset=exclude_unset,
147151
exclude_defaults=exclude_defaults,
152+
# warnings are not supported in Pydantic v1
153+
warnings=warnings if PYDANTIC_V2 else True,
148154
)
149155
return cast(
150156
"dict[str, Any]",
151157
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
158+
exclude=exclude,
152159
exclude_unset=exclude_unset,
153160
exclude_defaults=exclude_defaults,
154161
),
@@ -161,6 +168,18 @@ def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
161168
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
162169

163170

171+
def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
172+
if PYDANTIC_V2:
173+
return model.model_validate_json(data)
174+
return model.parse_raw(data) # pyright: ignore[reportDeprecated]
175+
176+
177+
def model_json_schema(model: type[_ModelT]) -> dict[str, Any]:
178+
if PYDANTIC_V2:
179+
return model.model_json_schema()
180+
return model.schema() # pyrigh
181+
182+
164183
# generic models
165184
if TYPE_CHECKING:
166185

volcenginesdkarkruntime/_exceptions.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from __future__ import annotations
44

5-
from typing import Optional
5+
from typing import TYPE_CHECKING, Optional
66
from typing_extensions import Literal
77

88
import httpx
99

10+
if TYPE_CHECKING:
11+
from .types.chat import ChatCompletion
12+
1013
__all__ = [
1114
"ArkBadRequestError",
1215
"ArkAuthenticationError",
@@ -159,3 +162,27 @@ class ArkRateLimitError(ArkAPIStatusError):
159162

160163
class ArkInternalServerError(ArkAPIStatusError):
161164
pass
165+
166+
167+
class ArkLengthFinishReasonError(ArkAPIStatusError):
168+
completion: ChatCompletion
169+
"""The completion that caused this error.
170+
171+
Note: this will *not* be a complete `ChatCompletion` object when streaming as `usage`
172+
will not be included.
173+
"""
174+
175+
def __init__(self, *, completion: ChatCompletion) -> None:
176+
msg = "Could not parse response content as the length limit was reached"
177+
if completion.usage:
178+
msg += f" - {completion.usage}"
179+
180+
super().__init__(msg)
181+
self.completion = completion
182+
183+
184+
class ArkContentFilterFinishReasonError(ArkAPIStatusError):
185+
def __init__(self) -> None:
186+
super().__init__(
187+
"Could not parse response content as the request was rejected by the content filter",
188+
)

0 commit comments

Comments
 (0)