@@ -35,8 +35,12 @@ def _load_image_from_data_url(image_url: str):
35
35
return load_image_from_base64 (image_base64 )
36
36
37
37
38
- def fetch_image (image_url : str ) -> Image .Image :
39
- """Load PIL image from a url or base64 encoded openai GPT4V format"""
38
+ def fetch_image (image_url : str , * , image_mode : str = "RGB" ) -> Image .Image :
39
+ """
40
+ Load a PIL image from a HTTP or base64 data URL.
41
+
42
+ By default, the image is converted into RGB format.
43
+ """
40
44
if image_url .startswith ('http' ):
41
45
_validate_remote_url (image_url , name = "image_url" )
42
46
@@ -53,7 +57,7 @@ def fetch_image(image_url: str) -> Image.Image:
53
57
raise ValueError ("Invalid 'image_url': A valid 'image_url' must start "
54
58
"with either 'data:image' or 'http'." )
55
59
56
- return image
60
+ return image . convert ( image_mode )
57
61
58
62
59
63
class ImageFetchAiohttp :
@@ -70,8 +74,17 @@ def get_aiohttp_client(cls) -> aiohttp.ClientSession:
70
74
return cls .aiohttp_client
71
75
72
76
@classmethod
73
- async def fetch_image (cls , image_url : str ) -> Image .Image :
74
- """Load PIL image from a url or base64 encoded openai GPT4V format"""
77
+ async def fetch_image (
78
+ cls ,
79
+ image_url : str ,
80
+ * ,
81
+ image_mode : str = "RGB" ,
82
+ ) -> Image .Image :
83
+ """
84
+ Asynchronously load a PIL image from a HTTP or base64 data URL.
85
+
86
+ By default, the image is converted into RGB format.
87
+ """
75
88
76
89
if image_url .startswith ('http' ):
77
90
_validate_remote_url (image_url , name = "image_url" )
@@ -91,20 +104,27 @@ async def fetch_image(cls, image_url: str) -> Image.Image:
91
104
"Invalid 'image_url': A valid 'image_url' must start "
92
105
"with either 'data:image' or 'http'." )
93
106
94
- return image
107
+ return image . convert ( image_mode )
95
108
96
109
97
110
async def async_get_and_parse_image (image_url : str ) -> MultiModalDataDict :
98
111
image = await ImageFetchAiohttp .fetch_image (image_url )
99
112
return {"image" : image }
100
113
101
114
102
- def encode_image_base64 (image : Image .Image , format : str = 'JPEG' ) -> str :
103
- """Encode a pillow image to base64 format."""
115
+ def encode_image_base64 (
116
+ image : Image .Image ,
117
+ * ,
118
+ image_mode : str = "RGB" ,
119
+ format : str = "JPEG" ,
120
+ ) -> str :
121
+ """
122
+ Encode a pillow image to base64 format.
104
123
124
+ By default, the image is converted into RGB format before being encoded.
125
+ """
105
126
buffered = BytesIO ()
106
- if format == 'JPEG' :
107
- image = image .convert ('RGB' )
127
+ image = image .convert (image_mode )
108
128
image .save (buffered , format )
109
129
return base64 .b64encode (buffered .getvalue ()).decode ('utf-8' )
110
130
0 commit comments