Skip to content

Commit 24d3089

Browse files
committed
FEAT: OpenAI image edit api support (#4110)
1 parent 37728dc commit 24d3089

File tree

8 files changed

+796
-15
lines changed

8 files changed

+796
-15
lines changed

.github/workflows/python.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
sortPaths: "xinference"
4646
configuration: "--check-only --diff --sp setup.cfg"
4747
- name: mypy
48-
run: pip install 'mypy<1.16.0' && mypy --install-types --non-interactive xinference
48+
run: pip install 'mypy==1.18.1' && mypy --install-types --non-interactive xinference
4949
- name: codespell
5050
run: pip install codespell && codespell --ignore-words-list thirdparty xinference
5151
- name: Set up Node.js

xinference/api/restful_api.py

Lines changed: 462 additions & 3 deletions
Large diffs are not rendered by default.

xinference/client/restful/async_restful_client.py

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(self, model_uid: str, base_url: str, auth_headers: Dict):
7474
self._model_uid = model_uid
7575
self._base_url = base_url
7676
self.auth_headers = auth_headers
77+
self.timeout = aiohttp.ClientTimeout(total=1800)
7778
self.session = aiohttp.ClientSession(
7879
connector=aiohttp.TCPConnector(force_close=True)
7980
)
@@ -356,7 +357,7 @@ async def image_to_image(
356357
else:
357358
# Single image
358359
files.append(("image", ("image", image, "application/octet-stream")))
359-
response = await self.session.post(url, files=files, headers=self.auth_headers)
360+
response = await self.session.post(url, data=files, headers=self.auth_headers)
360361
if response.status != 200:
361362
raise RuntimeError(
362363
f"Failed to variants the images, detail: {await _get_error_string(response)}"
@@ -366,6 +367,157 @@ async def image_to_image(
366367
await _release_response(response)
367368
return response_data
368369

370+
async def image_edit(
371+
self,
372+
image: Union[Union[str, bytes], List[Union[str, bytes]]],
373+
prompt: str,
374+
mask: Optional[Union[str, bytes]] = None,
375+
n: int = 1,
376+
size: Optional[str] = None,
377+
response_format: str = "url",
378+
**kwargs,
379+
) -> "ImageList":
380+
"""
381+
Edit image(s) by the input text and optional mask.
382+
383+
Parameters
384+
----------
385+
image: `Union[Union[str, bytes], List[Union[str, bytes]]]`
386+
The input image(s) to edit. Can be:
387+
- Single image: file path, URL, or binary image data
388+
- Multiple images: list of file paths, URLs, or binary image data
389+
When multiple images are provided, the first image is used as the primary image
390+
and subsequent images are used as reference images for better editing results.
391+
prompt: `str`
392+
The prompt or prompts to guide image editing. If not defined, you need to pass `prompt_embeds`.
393+
mask: `Optional[Union[str, bytes]]`, optional
394+
An optional mask image. White pixels in the mask are repainted while black pixels are preserved.
395+
If provided, this will trigger inpainting mode. If not provided, this will trigger image-to-image mode.
396+
n: `int`, defaults to 1
397+
The number of images to generate per prompt. Must be between 1 and 10.
398+
size: `Optional[str]`, optional
399+
The width*height in pixels of the generated image. If not specified, uses the original image size.
400+
response_format: `str`, defaults to `url`
401+
The format in which the generated images are returned. Must be one of url or b64_json.
402+
**kwargs
403+
Additional parameters to pass to the model.
404+
405+
Returns
406+
-------
407+
ImageList
408+
A list of edited image objects.
409+
410+
Raises
411+
------
412+
RuntimeError
413+
If the image editing request fails.
414+
415+
Examples
416+
--------
417+
# Single image editing
418+
result = await model.image_edit(
419+
image="path/to/image.png",
420+
prompt="make this image look like a painting"
421+
)
422+
423+
# Multiple image editing with reference images
424+
result = await model.image_edit(
425+
image=["primary_image.png", "reference1.jpg", "reference2.png"],
426+
prompt="edit the main image using the style from reference images"
427+
)
428+
"""
429+
url = f"{self._base_url}/v1/images/edits"
430+
params = {
431+
"model": self._model_uid,
432+
"prompt": prompt,
433+
"n": n,
434+
"size": size,
435+
"response_format": response_format,
436+
"kwargs": json.dumps(kwargs),
437+
}
438+
params = _filter_params(params)
439+
files: List[Any] = []
440+
for key, value in params.items():
441+
files.append((key, (None, value)))
442+
443+
# Handle single image or multiple images
444+
import aiohttp
445+
446+
data = aiohttp.FormData()
447+
448+
# Add all parameters as form fields
449+
for key, value in params.items():
450+
if value is not None:
451+
data.add_field(key, str(value))
452+
453+
# Handle single image or multiple images
454+
if isinstance(image, list):
455+
# Validate image list is not empty
456+
if len(image) == 0:
457+
raise ValueError("Image list cannot be empty")
458+
# Multiple images - send as image[] array
459+
for i, img in enumerate(image):
460+
if isinstance(img, str):
461+
# File path - read file content
462+
with open(img, "rb") as f:
463+
content = f.read()
464+
data.add_field(
465+
f"image[]",
466+
content,
467+
filename=f"image_{i}.png",
468+
content_type="image/png",
469+
)
470+
else:
471+
# Binary data
472+
data.add_field(
473+
f"image[]",
474+
img,
475+
filename=f"image_{i}.png",
476+
content_type="image/png",
477+
)
478+
else:
479+
# Single image
480+
if isinstance(image, str):
481+
# File path - read file content
482+
with open(image, "rb") as f:
483+
content = f.read()
484+
data.add_field(
485+
"image", content, filename="image.png", content_type="image/png"
486+
)
487+
else:
488+
# Binary data
489+
data.add_field(
490+
"image", image, filename="image.png", content_type="image/png"
491+
)
492+
493+
if mask is not None:
494+
if isinstance(mask, str):
495+
# File path - read file content
496+
with open(mask, "rb") as f:
497+
content = f.read()
498+
data.add_field(
499+
"mask", content, filename="mask.png", content_type="image/png"
500+
)
501+
else:
502+
# Binary data
503+
data.add_field(
504+
"mask", mask, filename="mask.png", content_type="image/png"
505+
)
506+
507+
try:
508+
response = await self.session.post(
509+
url, data=data, headers=self.auth_headers
510+
)
511+
if response.status != 200:
512+
raise RuntimeError(
513+
f"Failed to edit the images, detail: {await _get_error_string(response)}"
514+
)
515+
516+
response_data = await response.json()
517+
return response_data
518+
finally:
519+
await _release_response(response) if "response" in locals() else None
520+
369521
async def inpainting(
370522
self,
371523
image: Union[str, bytes],
@@ -436,7 +588,7 @@ async def inpainting(
436588
("mask_image", mask_image, "application/octet-stream"),
437589
)
438590
)
439-
response = await self.session.post(url, files=files, headers=self.auth_headers)
591+
response = await self.session.post(url, data=files, headers=self.auth_headers)
440592
if response.status != 200:
441593
raise RuntimeError(
442594
f"Failed to inpaint the images, detail: {await _get_error_string(response)}"
@@ -457,7 +609,7 @@ async def ocr(self, image: Union[str, bytes], **kwargs):
457609
for key, value in params.items():
458610
files.append((key, (None, value)))
459611
files.append(("image", ("image", image, "application/octet-stream")))
460-
response = await self.session.post(url, files=files, headers=self.auth_headers)
612+
response = await self.session.post(url, data=files, headers=self.auth_headers)
461613
if response.status != 200:
462614
raise RuntimeError(
463615
f"Failed to ocr the images, detail: {await _get_error_string(response)}"
@@ -547,7 +699,7 @@ async def image_to_video(
547699
for key, value in params.items():
548700
files.append((key, (None, value)))
549701
files.append(("image", ("image", image, "application/octet-stream")))
550-
response = await self.session.post(url, files=files, headers=self.auth_headers)
702+
response = await self.session.post(url, data=files, headers=self.auth_headers)
551703
if response.status != 200:
552704
raise RuntimeError(
553705
f"Failed to create the video from image, detail: {await _get_error_string(response)}"
@@ -987,8 +1139,9 @@ def __init__(self, base_url, api_key: Optional[str] = None):
9871139
self.base_url = base_url
9881140
self._headers: Dict[str, str] = {}
9891141
self._cluster_authed = False
1142+
self.timeout = aiohttp.ClientTimeout(total=1800)
9901143
self.session = aiohttp.ClientSession(
991-
connector=aiohttp.TCPConnector(force_close=True)
1144+
connector=aiohttp.TCPConnector(force_close=True), timeout=self.timeout
9921145
)
9931146
self._check_cluster_authenticated()
9941147
if api_key is not None and self._cluster_authed:

xinference/client/restful/restful_client.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,137 @@ def image_to_image(
329329
response_data = response.json()
330330
return response_data
331331

332+
def image_edit(
333+
self,
334+
image: Union[Union[str, bytes], List[Union[str, bytes]]],
335+
prompt: str,
336+
mask: Optional[Union[str, bytes]] = None,
337+
n: int = 1,
338+
size: Optional[str] = None,
339+
response_format: str = "url",
340+
**kwargs,
341+
) -> "ImageList":
342+
"""
343+
Edit image(s) by the input text and optional mask.
344+
345+
Parameters
346+
----------
347+
image: `Union[Union[str, bytes], List[Union[str, bytes]]]`
348+
The input image(s) to edit. Can be:
349+
- Single image: file path, URL, or binary image data
350+
- Multiple images: list of file paths, URLs, or binary image data
351+
When multiple images are provided, the first image is used as the primary image
352+
and subsequent images are used as reference images for better editing results.
353+
prompt: `str`
354+
The prompt or prompts to guide image editing. If not defined, you need to pass `prompt_embeds`.
355+
mask: `Optional[Union[str, bytes]]`, optional
356+
An optional mask image. White pixels in the mask are repainted while black pixels are preserved.
357+
If provided, this will trigger inpainting mode. If not provided, this will trigger image-to-image mode.
358+
n: `int`, defaults to 1
359+
The number of images to generate per prompt. Must be between 1 and 10.
360+
size: `Optional[str]`, optional
361+
The width*height in pixels of the generated image. If not specified, uses the original image size.
362+
response_format: `str`, defaults to `url`
363+
The format in which the generated images are returned. Must be one of url or b64_json.
364+
**kwargs
365+
Additional parameters to pass to the model.
366+
367+
Returns
368+
-------
369+
ImageList
370+
A list of edited image objects.
371+
372+
Raises
373+
------
374+
RuntimeError
375+
If the image editing request fails.
376+
377+
Examples
378+
--------
379+
# Single image editing
380+
result = model.image_edit(
381+
image="path/to/image.png",
382+
prompt="make this image look like a painting"
383+
)
384+
385+
# Multiple image editing with reference images
386+
result = model.image_edit(
387+
image=["primary_image.png", "reference1.jpg", "reference2.png"],
388+
prompt="edit the main image using the style from reference images"
389+
)
390+
"""
391+
url = f"{self._base_url}/v1/images/edits"
392+
params = {
393+
"model": self._model_uid,
394+
"prompt": prompt,
395+
"n": n,
396+
"size": size,
397+
"response_format": response_format,
398+
"kwargs": json.dumps(kwargs),
399+
}
400+
files: List[Any] = []
401+
for key, value in params.items():
402+
if value is not None:
403+
files.append((key, (None, value)))
404+
405+
# Handle single image or multiple images using requests format
406+
if isinstance(image, list):
407+
# Validate image list is not empty
408+
if len(image) == 0:
409+
raise ValueError("Image list cannot be empty")
410+
# Multiple images - send as image[] array
411+
for i, img in enumerate(image):
412+
if isinstance(img, str):
413+
# File path - open file
414+
f = open(img, "rb")
415+
files.append(
416+
(f"image[]", (f"image_{i}", f, "application/octet-stream"))
417+
)
418+
else:
419+
# Binary data
420+
files.append(
421+
(f"image[]", (f"image_{i}", img, "application/octet-stream"))
422+
)
423+
else:
424+
# Single image
425+
if isinstance(image, str):
426+
# File path - open file
427+
f = open(image, "rb")
428+
files.append(("image", ("image", f, "application/octet-stream")))
429+
else:
430+
# Binary data
431+
files.append(("image", ("image", image, "application/octet-stream")))
432+
433+
if mask is not None:
434+
if isinstance(mask, str):
435+
# File path - open file
436+
f = open(mask, "rb")
437+
files.append(("mask", ("mask", f, "application/octet-stream")))
438+
else:
439+
# Binary data
440+
files.append(("mask", ("mask", mask, "application/octet-stream")))
441+
442+
try:
443+
response = self.session.post(url, files=files, headers=self.auth_headers)
444+
if response.status_code != 200:
445+
raise RuntimeError(
446+
f"Failed to edit the images, detail: {_get_error_string(response)}"
447+
)
448+
449+
response_data = response.json()
450+
return response_data
451+
finally:
452+
# Close all opened files
453+
for file_item in files:
454+
if (
455+
len(file_item) >= 2
456+
and hasattr(file_item[1], "__len__")
457+
and len(file_item[1]) >= 2
458+
):
459+
file_obj = file_item[1][1]
460+
if hasattr(file_obj, "close"):
461+
file_obj.close()
462+
332463
def inpainting(
333464
self,
334465
image: Union[str, bytes],

xinference/core/supervisor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,6 +1661,9 @@ async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
16611661
if isinstance(worker_ref, list):
16621662
# get first worker to fetch information if model across workers
16631663
worker_ref = worker_ref[0]
1664+
assert not isinstance(
1665+
worker_ref, (list, tuple)
1666+
), "worker_ref must be a single worker"
16641667
return await worker_ref.get_model(model_uid=replica_model_uid)
16651668

16661669
@log_async(logger=logger)
@@ -1673,6 +1676,9 @@ async def get_model_status(self, replica_model_uid: str):
16731676
if isinstance(worker_ref, list):
16741677
# get status from first shard if model has multiple shards across workers
16751678
worker_ref = worker_ref[0]
1679+
assert not isinstance(
1680+
worker_ref, (list, tuple)
1681+
), "worker_ref must be a single worker"
16761682
return await worker_ref.get_model_status(replica_model_uid)
16771683

16781684
@log_async(logger=logger)
@@ -1691,6 +1697,9 @@ async def describe_model(self, model_uid: str) -> Dict[str, Any]:
16911697
if isinstance(worker_ref, list):
16921698
# get status from first shard if model has multiple shards across workers
16931699
worker_ref = worker_ref[0]
1700+
assert not isinstance(
1701+
worker_ref, (list, tuple)
1702+
), "worker_ref must be a single worker"
16941703
info = await worker_ref.describe_model(model_uid=replica_model_uid)
16951704
info["replica"] = replica_info.replica
16961705
return info
@@ -1766,6 +1775,9 @@ async def abort_request(
17661775
if isinstance(worker_ref, list):
17671776
# get status from first shard if model has multiple shards across workers
17681777
worker_ref = worker_ref[0]
1778+
assert not isinstance(
1779+
worker_ref, (list, tuple)
1780+
), "worker_ref must be a single worker"
17691781
model_ref = await worker_ref.get_model(model_uid=rep_mid)
17701782
result_info = await model_ref.abort_request(request_id, block_duration)
17711783
res["msg"] = result_info

0 commit comments

Comments
 (0)