Skip to content
Merged
8 changes: 6 additions & 2 deletions src/zai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_jwt_token,
)


class BaseClient(HttpClient):
"""
Main client for interacting with the ZAI API
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self.base_url = base_url

from ._version import __version__

super().__init__(
version=__version__,
base_url=base_url,
Expand All @@ -104,7 +106,7 @@ def __init__(

@property
def default_base_url(self):
raise NotImplementedError("Subclasses must define default_base_url")
raise NotImplementedError('Subclasses must define default_base_url')

@cached_property
def chat(self) -> Chat:
Expand Down Expand Up @@ -204,12 +206,14 @@ def __del__(self) -> None:

self.close()


class ZaiClient(BaseClient):
@property
def default_base_url(self):
return 'https://api.z.ai/api/paas/v4'


class ZhipuAiClient(BaseClient):
@property
def default_base_url(self):
return 'https://open.bigmodel.cn/api/paas/v4'
return 'https://open.bigmodel.cn/api/paas/v4'
7 changes: 4 additions & 3 deletions src/zai/api_resource/audio/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
class Audio(BaseAPI):
"""
API resource for audio operations

Attributes:
transcriptions (Transcriptions): Audio transcription operations
"""

@cached_property
def transcriptions(self) -> Transcriptions:
return Transcriptions(self._client)
Expand All @@ -57,7 +58,7 @@ def speech(
) -> HttpxBinaryResponseContent:
"""
Generate speech audio from text input

Arguments:
model (str): The model to use for speech generation
input (str): The text to convert to speech
Expand Down Expand Up @@ -105,7 +106,7 @@ def customization(
) -> HttpxBinaryResponseContent:
"""
Generate customized speech audio with voice cloning

Arguments:
model (str): The model to use for speech generation
input (str): The text to convert to speech
Expand Down
4 changes: 2 additions & 2 deletions src/zai/api_resource/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from zai._client import ZaiClient



class Transcriptions(BaseAPI):
"""
API resource for audio transcription operations
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand All @@ -54,7 +54,7 @@ def create(
) -> Completion | StreamResponse[ChatCompletionChunk]:
"""
Transcribe audio files to text

Arguments:
file (FileTypes): Audio file to transcribe
model (str): The model to use for transcription
Expand Down
2 changes: 1 addition & 1 deletion src/zai/api_resource/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .batches import Batches

__all__ = ["Batches"]
__all__ = ['Batches']
2 changes: 1 addition & 1 deletion src/zai/api_resource/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
'AsyncCompletions',
'Chat',
'Completions',
]
]
9 changes: 5 additions & 4 deletions src/zai/api_resource/chat/async_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class AsyncCompletions(BaseAPI):

Provides access to asynchronous chat completion operations.
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand Down Expand Up @@ -61,7 +62,7 @@ def create(
) -> AsyncTaskStatus:
"""
Create an asynchronous chat completion task

Arguments:
model (str): Model name to use for completion
request_id (Optional[str]): Request identifier
Expand Down Expand Up @@ -125,8 +126,8 @@ def create(
'tool_choice': tool_choice,
'meta': meta,
'extra': maybe_transform(extra, code_geex_params.CodeGeexExtra),
"response_format": response_format,
"thinking": thinking
'response_format': response_format,
'thinking': thinking,
}
return self._post(
'/async/chat/completions',
Expand All @@ -145,7 +146,7 @@ def retrieve_completion_result(
) -> Union[AsyncCompletion, AsyncTaskStatus]:
"""
Retrieve the result of an asynchronous chat completion task

Arguments:
id (str): The task ID to retrieve results for
extra_headers (Headers): Additional HTTP headers
Expand Down
1 change: 1 addition & 0 deletions src/zai/api_resource/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Chat(BaseAPI):

Provides access to chat completions and async completions.
"""

@cached_property
def completions(self) -> Completions:
return Completions(self._client)
Expand Down
3 changes: 2 additions & 1 deletion src/zai/api_resource/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Completions(BaseAPI):
Attributes:
client (ZaiClient): The ZAI client instance
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand Down Expand Up @@ -133,7 +134,7 @@ def create(
'meta': meta,
'extra': maybe_transform(extra, code_geex_params.CodeGeexExtra),
'response_format': response_format,
"thinking": thinking
'thinking': thinking,
}
)
return self._post(
Expand Down
2 changes: 1 addition & 1 deletion src/zai/api_resource/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .embeddings import Embeddings

__all__ = ['Embeddings']
__all__ = ['Embeddings']
1 change: 1 addition & 0 deletions src/zai/api_resource/embeddings/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Embeddings(BaseAPI):
Attributes:
client (ZaiClient): The ZAI client instance
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand Down
2 changes: 1 addition & 1 deletion src/zai/api_resource/files/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .files import Files, FilesWithRawResponse

__all__ = ['Files', 'FilesWithRawResponse']
__all__ = ['Files', 'FilesWithRawResponse']
2 changes: 1 addition & 1 deletion src/zai/api_resource/images/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .images import Images

__all__ = ['Images']
__all__ = ['Images']
3 changes: 2 additions & 1 deletion src/zai/api_resource/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Images(BaseAPI):
"""
API resource for image generation operations
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand All @@ -40,7 +41,7 @@ def generations(
) -> ImagesResponded:
"""
Generate images from text prompts

Arguments:
prompt (str): Text description of the desired image
model (str): The model to use for image generation
Expand Down
3 changes: 2 additions & 1 deletion src/zai/api_resource/moderations/moderations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Moderations(BaseAPI):
"""
API resource for content moderation operations
"""

def __init__(self, client: ZaiClient) -> None:
super().__init__(client)

Expand All @@ -27,7 +28,7 @@ def create(
) -> Completion:
"""
Moderate content for safety and compliance

Arguments:
model (str): The moderation model to use
input (Union[str, List[str], Dict]): Content to moderate
Expand Down
3 changes: 2 additions & 1 deletion src/zai/api_resource/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Tools(BaseAPI):

Provides access to various tool operations including web search.
"""

def __init__(self, client: 'ZaiClient') -> None:
super().__init__(client)

Expand All @@ -50,7 +51,7 @@ def web_search(
) -> WebSearch | StreamResponse[WebSearchChunk]:
"""
Perform web search using AI models

Arguments:
model (str): The model to use for web search
request_id (Optional[str]): Unique identifier for the request
Expand Down
2 changes: 1 addition & 1 deletion src/zai/api_resource/videos/videos.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, List
from typing import TYPE_CHECKING, List, Optional

import httpx

Expand Down
1 change: 1 addition & 0 deletions src/zai/core/_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class BaseAPI:
Attributes:
_client (ZaiClient): The client instance for making API requests
"""

_client: ZaiClient

def __init__(self, client: ZaiClient) -> None:
Expand Down
4 changes: 1 addition & 3 deletions src/zai/core/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(


class APIConnectionError(APIResponseError):
def __init__(
self, *, message: str = 'Connection error.', request: httpx.Request
) -> None:
def __init__(self, *, message: str = 'Connection error.', request: httpx.Request) -> None:
super().__init__(message, request, json_data=None)


Expand Down
19 changes: 4 additions & 15 deletions src/zai/core/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,13 @@

def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes)
or isinstance(obj, tuple)
or isinstance(obj, io.IOBase)
or isinstance(obj, os.PathLike)
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
)


def assert_is_file_content(obj: object, *, key: str | None = None) -> None:
if not is_file_content(obj):
prefix = (
f'Expected entry at `{key}`'
if key is not None
else f'Expected file input `{obj!r}`'
)
prefix = f'Expected entry at `{key}`' if key is not None else f'Expected file input `{obj!r}`'
raise RuntimeError(
f'{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead. See https://github.com/openai/openai-python/tree/main#file-uploads'
) from None
Expand All @@ -56,9 +49,7 @@ def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
elif is_sequence_t(files):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(
f'Unexpected file type input {type(files)}, expected mapping or sequence'
)
raise TypeError(f'Unexpected file type input {type(files)}, expected mapping or sequence')

return files

Expand All @@ -74,9 +65,7 @@ def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_tuple_t(file):
return (file[0], _read_file_content(file[1]), *file[2:])

raise TypeError(
'Expected file types input to be a FileContent type or to be a tuple'
)
raise TypeError('Expected file types input to be a FileContent type or to be a tuple')


def _read_file_content(file: FileContent) -> HttpxFileContent:
Expand Down
2 changes: 1 addition & 1 deletion src/zai/core/_jwt_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def generate_token(apikey: str):
try:
api_key, secret = apikey.split('.')
except Exception as e:
raise Exception('invalid api_key', e)
raise Exception('Invalid API key', e)

payload = {
'api_key': api_key,
Expand Down
11 changes: 3 additions & 8 deletions src/zai/core/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class LoggerNameFilter(logging.Filter):

Currently allows all log records to pass through.
"""

def filter(self, record):
"""
Determine if the specified record is to be logged.
Expand Down Expand Up @@ -38,15 +39,9 @@ def get_log_file(log_path: str, sub_dir: str):
return os.path.join(log_dir, 'zai.log')


def get_config_dict(
log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int
) -> dict:
def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict:
# for windows, the path should be a raw string.
log_file_path = (
log_file_path.encode('unicode-escape').decode()
if os.name == 'nt'
else log_file_path
)
log_file_path = log_file_path.encode('unicode-escape').decode() if os.name == 'nt' else log_file_path
log_level = log_level.upper()
config_dict = {
'version': 1,
Expand Down
16 changes: 8 additions & 8 deletions src/zai/types/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from .chat_completions_create_param import Reference

__all__ = [
'AgentsCompletion',
'AgentsCompletionUsage',
'AgentsCompletionChunkUsage',
'AgentsCompletionChunk',
'AgentsChoice',
'AgentsChoiceDelta',
'AgentsError',
'Reference',
'AgentsCompletion',
'AgentsCompletionUsage',
'AgentsCompletionChunkUsage',
'AgentsCompletionChunk',
'AgentsChoice',
'AgentsChoiceDelta',
'AgentsError',
'Reference',
]
Loading