Skip to content

Commit 1905366

Browse files
feat: ark image generation sdk
1 parent b27f80a commit 1905366

File tree

8 files changed

+194
-1
lines changed

8 files changed

+194
-1
lines changed

volcenginesdkarkruntime/_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Ark(SyncAPIClient):
4242
context: resources.Context
4343
multimodal_embeddings: resources.MultimodalEmbeddings
4444
content_generation: resources.ContentGeneration
45+
images: resources.Images
4546
batch_chat: resources.BatchChat
4647
model_breaker_map: dict[str, ModelBreaker]
4748
model_breaker_lock: threading.Lock
@@ -71,7 +72,6 @@ def __init__(
7172
Returns:
7273
ark client
7374
"""
74-
7575
if ak is None:
7676
ak = os.environ.get("VOLC_ACCESSKEY")
7777
if sk is None:
@@ -107,6 +107,7 @@ def __init__(
107107
self.context = resources.Context(self)
108108
self.multimodal_embeddings = resources.MultimodalEmbeddings(self)
109109
self.content_generation = resources.ContentGeneration(self)
110+
self.images = resources.Images(self)
110111
self.batch_chat = resources.BatchChat(self)
111112
self.model_breaker_map = defaultdict(ModelBreaker)
112113
self.model_breaker_lock = threading.Lock()
@@ -161,6 +162,7 @@ class AsyncArk(AsyncAPIClient):
161162
context: resources.AsyncContext
162163
multimodal_embeddings: resources.AsyncMultimodalEmbeddings
163164
content_generation: resources.AsyncContentGeneration
165+
images: resources.AsyncImages
164166
batch_chat: resources.AsyncBatchChat
165167
model_breaker_map: dict[str, ModelBreaker]
166168
model_breaker_lock: asyncio.Lock
@@ -225,6 +227,7 @@ def __init__(
225227
self.context = resources.AsyncContext(self)
226228
self.multimodal_embeddings = resources.AsyncMultimodalEmbeddings(self)
227229
self.content_generation = resources.AsyncContentGeneration(self)
230+
self.images = resources.AsyncImages(self)
228231
self.batch_chat = resources.AsyncBatchChat(self)
229232
self.model_breaker_map = defaultdict(ModelBreaker)
230233
self.model_breaker_lock = asyncio.Lock()

volcenginesdkarkruntime/resources/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .context import Context, AsyncContext
77
from .multimodal_embeddings import MultimodalEmbeddings, AsyncMultimodalEmbeddings
88
from .content_generation import ContentGeneration, AsyncContentGeneration
9+
from .images import Images, AsyncImages
910
from .batch_chat import BatchChat, AsyncBatchChat
1011

1112
__all__ = [
@@ -24,6 +25,8 @@
2425
"AsyncContext",
2526
"ContentGeneration",
2627
"AsyncContentGeneration",
28+
"Images",
29+
"AsyncImages",
2730
"BatchChat",
2831
"AsyncBatchChat",
2932
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from volcenginesdkarkruntime.resources.images.images import (
2+
Images,
3+
AsyncImages,
4+
)
5+
6+
__all__ = ["Images", "AsyncImages"]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import httpx
4+
5+
from ..._base_client import make_request_options
6+
from ..._utils._utils import with_sts_token, async_with_sts_token
7+
from ..._resource import SyncAPIResource, AsyncAPIResource
8+
from ...types.images import ImagesResponse
9+
from ..._types import Body, Query, Headers
10+
11+
12+
class Images(SyncAPIResource):
13+
@with_sts_token
14+
def generate(
15+
self,
16+
*,
17+
model: str,
18+
prompt: str,
19+
response_format: str | None = None,
20+
size: str | None = None,
21+
seed: int | None = None,
22+
guidance_scale: float | None = None,
23+
watermark: bool | None = None,
24+
extra_headers: Headers | None = None,
25+
extra_query: Query | None = None,
26+
extra_body: Body | None = None,
27+
timeout: float | httpx.Timeout | None = None,
28+
) -> ImagesResponse:
29+
resp = self._post(
30+
"/images/generations",
31+
body={
32+
"model": model,
33+
"prompt": prompt,
34+
"response_format": response_format,
35+
"size": size,
36+
"seed": seed,
37+
"guidance_scale": guidance_scale,
38+
"watermark": watermark,
39+
},
40+
options=make_request_options(
41+
extra_headers=extra_headers,
42+
extra_query=extra_query,
43+
extra_body=extra_body,
44+
timeout=timeout,
45+
),
46+
cast_to=ImagesResponse,
47+
)
48+
49+
return resp
50+
51+
52+
class AsyncImages(AsyncAPIResource):
53+
@async_with_sts_token
54+
async def generate(
55+
self,
56+
*,
57+
model: str,
58+
prompt: str,
59+
response_format: str | None = None,
60+
size: str | None = None,
61+
seed: int | None = None,
62+
guidance_scale: float | None = None,
63+
watermark: bool | None = None,
64+
extra_headers: Headers | None = None,
65+
extra_query: Query | None = None,
66+
extra_body: Body | None = None,
67+
timeout: float | httpx.Timeout | None = None,
68+
) -> ImagesResponse:
69+
return await self._post(
70+
"/images/generations",
71+
body={
72+
"model": model,
73+
"prompt": prompt,
74+
"response_format": response_format,
75+
"size": size,
76+
"seed": seed,
77+
"guidance_scale": guidance_scale,
78+
"watermark": watermark,
79+
},
80+
options=make_request_options(
81+
extra_headers=extra_headers,
82+
extra_query=extra_query,
83+
extra_body=extra_body,
84+
timeout=timeout,
85+
),
86+
cast_to=ImagesResponse,
87+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .images import ImagesResponse
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List
2+
3+
from volcenginesdkarkruntime._models import BaseModel
4+
5+
__all__ = ["ImagesResponse"]
6+
7+
8+
class Usage(BaseModel):
9+
generated_images: int
10+
"""The number of images generated."""
11+
12+
13+
class Image(BaseModel):
14+
url: str
15+
"""The URL of the generated image, if any."""
16+
17+
b64_json: str
18+
"""The Base 64 encoded string of the generated image, if any."""
19+
20+
21+
class Error(BaseModel):
22+
message: str
23+
"""The reason for failed image generation"""
24+
25+
code: str
26+
"""The error code for failed image generation"""
27+
28+
29+
class ImagesResponse(BaseModel):
30+
model: str
31+
"""The model used to generated the images."""
32+
33+
data: List[Image]
34+
"""The generated images."""
35+
36+
error: Error
37+
"""The error body, if applicable."""
38+
39+
usage: Usage
40+
"""The usage information for the generation of images."""
41+
42+
created_at: int
43+
"""The Unix timestamp when the image was generated."""
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import asyncio
2+
3+
from volcenginesdkarkruntime import AsyncArk
4+
5+
# Authentication
6+
# 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
7+
# or specify api key by Ark(api_key="${YOUR_API_KEY}").
8+
# Note: If you use an API key, this API key will not be refreshed.
9+
# To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
10+
client = AsyncArk()
11+
12+
13+
async def main():
14+
print("----- async generate images -----")
15+
16+
result = client.images.generate(
17+
model="${YOUR_ENDPOINT_ID}",
18+
prompt="龙与地下城女骑士背景是起伏的平原,目光从镜头转向平原",
19+
seed=1234567890,
20+
watermark=True,
21+
size="512x512",
22+
guidance_scale=2.5,
23+
)
24+
25+
print(await result)
26+
27+
if __name__ == "__main__":
28+
asyncio.run(main())
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from volcenginesdkarkruntime import Ark
2+
3+
# Authentication
4+
# 1.If you authorize your endpoint using an API key, you can set your api key to environment variable "ARK_API_KEY"
5+
# or specify api key by Ark(api_key="${YOUR_API_KEY}").
6+
# Note: If you use an API key, this API key will not be refreshed.
7+
# To prevent the API from expiring and failing after some time, choose an API key with no expiration date.
8+
client = Ark()
9+
10+
if __name__ == "__main__":
11+
print("----- generate images -----")
12+
13+
result = client.images.generate(
14+
model="${YOUR_ENDPOINT_ID}",
15+
prompt="龙与地下城女骑士背景是起伏的平原,目光从镜头转向平原",
16+
seed=1234567890,
17+
watermark=True,
18+
size="512x512",
19+
guidance_scale=2.5,
20+
)
21+
22+
print(result)

0 commit comments

Comments
 (0)