Skip to content

Commit 9739f82

Browse files
isaacbmillerokhat
andauthored
Allow for arbitrary examples containing DSPy.Images (#1801)
* Fix dataset download * WIP complex image types * WIP fixing complex images * Refactor chat_adapter somewhat working * Ruff fixes * remove print and update notebooks * Tests failing on purpose - added None support and new str repr * remove extra notebook * Tests passing * Clean comments * Allow for proper image serialization * ruff * Remove assume text * ruff * remove excess prints * fix test docstring * Fix image repr * tests * Refactor tests to be more readable * clean: ruff fix and add test * fix: change test to use model_dump instead of model_dump_json --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent 0bf4c50 commit 9739f82

File tree

10 files changed

+460
-355
lines changed

10 files changed

+460
-355
lines changed

dspy/adapters/chat_adapter.py

Lines changed: 40 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import textwrap
77
from collections.abc import Mapping
88
from itertools import chain
9-
from typing import Any, Dict, List, Literal, NamedTuple, Union
9+
10+
from typing import Any, Dict, Literal, NamedTuple
1011

1112
import pydantic
1213
from pydantic import TypeAdapter
@@ -17,6 +18,7 @@
1718
from dspy.signatures.field import OutputField
1819
from dspy.signatures.signature import Signature, SignatureMeta
1920
from dspy.signatures.utils import get_dspy_field_type
21+
from dspy.adapters.image_utils import try_expand_image_tags
2022

2123
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
2224

@@ -50,12 +52,12 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
5052

5153
prepared_instructions = prepare_instructions(signature)
5254
messages.append({"role": "system", "content": prepared_instructions})
53-
5455
for demo in demos:
5556
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
5657
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
5758

5859
messages.append(format_turn(signature, inputs, role="user"))
60+
messages = try_expand_image_tags(messages)
5961
return messages
6062

6163
def parse(self, signature, completion):
@@ -110,11 +112,10 @@ def format_fields(self, signature, values, role):
110112
for field_name, field_info in signature.fields.items()
111113
if field_name in values
112114
}
113-
114115
return format_fields(fields_with_values)
115116

116117

117-
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
118+
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
118119
"""
119120
Formats the values of the specified fields according to the field's DSPy type (input or output),
120121
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
@@ -124,23 +125,14 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
124125
fields_with_values: A dictionary mapping information about a field to its corresponding
125126
value.
126127
Returns:
127-
The joined formatted values of the fields, represented as a string or a list of dicts
128+
The joined formatted values of the fields, represented as a string
128129
"""
129130
output = []
130131
for field, field_value in fields_with_values.items():
131-
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
132-
if assume_text:
133-
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
134-
else:
135-
output.append({"type": "text", "text": f"[[ ## {field.name} ## ]]\n"})
136-
if isinstance(formatted_field_value, dict) and formatted_field_value.get("type") == "image_url":
137-
output.append(formatted_field_value)
138-
else:
139-
output[-1]["text"] += formatted_field_value["text"]
140-
if assume_text:
141-
return "\n\n".join(output).strip()
142-
else:
143-
return output
132+
formatted_field_value = format_field_value(field_info=field.info, value=field_value)
133+
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
134+
135+
return "\n\n".join(output).strip()
144136

145137

146138
def parse_value(value, annotation):
@@ -180,92 +172,43 @@ def format_turn(signature, values, role, incomplete=False):
180172
A chat message that can be appended to a chat thread. The message contains two string fields:
181173
``role`` ("user" or "assistant") and ``content`` (the message text).
182174
"""
183-
fields_to_collapse = []
184-
content = []
185-
186175
if role == "user":
187176
fields = signature.input_fields
188-
if incomplete:
189-
fields_to_collapse.append(
190-
{
191-
"type": "text",
192-
"text": "This is an example of the task, though some input or output fields are not supplied.",
193-
}
194-
)
177+
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
195178
else:
196-
fields = signature.output_fields
197-
# Add the built-in field indicating that the chat turn has been completed
198-
fields[BuiltInCompletedOutputFieldInfo.name] = BuiltInCompletedOutputFieldInfo.info
179+
# Add the completed field for the assistant turn
180+
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
199181
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
200-
field_names = fields.keys()
201-
if not incomplete:
202-
if not set(values).issuperset(set(field_names)):
203-
raise ValueError(f"Expected {field_names} but got {values.keys()}")
182+
message_prefix = ""
204183

205-
fields_to_collapse.extend(
206-
format_fields(
207-
fields_with_values={
208-
FieldInfoWithName(name=field_name, info=field_info): values.get(
209-
field_name, "Not supplied for this particular example."
210-
)
211-
for field_name, field_info in fields.items()
212-
},
213-
assume_text=False,
214-
)
215-
)
216-
217-
if role == "user":
218-
output_fields = list(signature.output_fields.keys())
184+
if not incomplete and not set(values).issuperset(fields.keys()):
185+
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")
219186

220-
def type_info(v):
221-
return (
222-
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
223-
if v.annotation is not str
224-
else ""
225-
)
187+
messages = []
188+
if message_prefix:
189+
messages.append(message_prefix)
226190

227-
if output_fields:
228-
fields_to_collapse.append(
229-
{
230-
"type": "text",
231-
"text": "Respond with the corresponding output fields, starting with the field "
232-
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
233-
+ ", and then ending with the marker for `[[ ## completed ## ]]`.",
234-
}
235-
)
236-
237-
# flatmap the list if any items are lists otherwise keep the item
238-
flattened_list = list(
239-
chain.from_iterable(item if isinstance(item, list) else [item] for item in fields_to_collapse)
191+
field_messages = format_fields(
192+
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
193+
for k, v in fields.items()},
240194
)
241-
242-
if all(message.get("type", None) == "text" for message in flattened_list):
243-
content = "\n\n".join(message.get("text") for message in flattened_list)
244-
return {"role": role, "content": content}
245-
246-
# Collapse all consecutive text messages into a single message.
247-
collapsed_messages = []
248-
for item in flattened_list:
249-
# First item is always added
250-
if not collapsed_messages:
251-
collapsed_messages.append(item)
252-
continue
253-
254-
# If the current item is image, add to collapsed_messages
255-
if item.get("type") == "image_url":
256-
if collapsed_messages[-1].get("type") == "text":
257-
collapsed_messages[-1]["text"] += "\n"
258-
collapsed_messages.append(item)
259-
# If the previous item is text and current item is text, append to the previous item
260-
elif collapsed_messages[-1].get("type") == "text":
261-
collapsed_messages[-1]["text"] += "\n\n" + item["text"]
262-
# If the previous item is not text(aka image), add the current item as a new item
263-
else:
264-
item["text"] = "\n\n" + item["text"]
265-
collapsed_messages.append(item)
266-
267-
return {"role": role, "content": collapsed_messages}
268-
195+
messages.append(field_messages)
196+
197+
# Add output field instructions for user messages
198+
if role == "user" and signature.output_fields:
199+
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
200+
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
201+
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
202+
", and then ending with the marker for `[[ ## completed ## ]]`."
203+
messages.append(field_instructions)
204+
joined_messages = "\n\n".join(msg for msg in messages)
205+
return {"role": role, "content": joined_messages}
206+
207+
def flatten_messages(messages):
208+
"""Flatten nested message lists."""
209+
return list(chain.from_iterable(
210+
item if isinstance(item, list) else [item] for item in messages
211+
))
269212

270213
def enumerate_fields(fields: dict) -> str:
271214
parts = []
@@ -328,12 +271,11 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
328271
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
329272
for field_name, field_info in fields.items()
330273
},
331-
assume_text=True,
332274
)
333275

334276
parts.append(format_signature_fields_for_instructions(signature.input_fields))
335277
parts.append(format_signature_fields_for_instructions(signature.output_fields))
336-
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True))
278+
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))
337279
instructions = textwrap.dedent(signature.instructions)
338280
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
339281
parts.append(f"In adhering to this structure, your objective is: {objective}")

dspy/adapters/image_utils.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import base64
22
import io
33
import os
4-
from typing import Union
4+
from typing import Any, Dict, List, Union
55
from urllib.parse import urlparse
6+
import re
67

78
import pydantic
89
import requests
@@ -17,13 +18,20 @@
1718

1819
class Image(pydantic.BaseModel):
1920
url: str
20-
21+
22+
model_config = {
23+
'frozen': True,
24+
'str_strip_whitespace': True,
25+
'validate_assignment': True,
26+
'extra': 'forbid',
27+
}
28+
2129
@pydantic.model_validator(mode="before")
2230
@classmethod
2331
def validate_input(cls, values):
2432
# Allow the model to accept either a URL string or a dictionary with a single 'url' key
2533
if isinstance(values, str):
26-
# if a string, assume its the URL directly and wrap it in a dict
34+
# if a string, assume it's the URL directly and wrap it in a dict
2735
return {"url": values}
2836
elif isinstance(values, dict) and set(values.keys()) == {"url"}:
2937
# if it's a dict, ensure it has only the 'url' key
@@ -44,14 +52,21 @@ def from_file(cls, file_path: str):
4452

4553
@classmethod
4654
def from_PIL(cls, pil_image):
47-
import PIL
55+
return cls(url=encode_image(pil_image))
4856

49-
return cls(url=encode_image(PIL.Image.open(pil_image)))
57+
@pydantic.model_serializer()
58+
def serialize_model(self):
59+
return "<DSPY_IMAGE_START>" + self.url + "<DSPY_IMAGE_END>"
5060

51-
def __repr__(self):
52-
len_base64 = len(self.url.split("base64,")[1])
53-
return f"Image(url = {self.url.split('base64,')[0]}base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
61+
def __str__(self):
62+
return self.serialize_model()
5463

64+
def __repr__(self):
65+
if "base64" in self.url:
66+
len_base64 = len(self.url.split("base64,")[1])
67+
image_type = self.url.split(";")[0].split("/")[-1]
68+
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
69+
return f"Image(url='{self.url}')"
5570

5671
def is_url(string: str) -> bool:
5772
"""Check if a string is a valid URL."""
@@ -95,6 +110,7 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
95110
return image
96111
else:
97112
# Unsupported string format
113+
print(f"Unsupported image string: {image}")
98114
raise ValueError(f"Unsupported image string: {image}")
99115
elif PIL_AVAILABLE and isinstance(image, PILImage.Image):
100116
# PIL Image
@@ -103,11 +119,12 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
103119
# Raw bytes
104120
if not PIL_AVAILABLE:
105121
raise ImportError("Pillow is required to process image bytes.")
106-
img = Image.open(io.BytesIO(image))
122+
img = PILImage.open(io.BytesIO(image))
107123
return _encode_pil_image(img)
108124
elif isinstance(image, Image):
109125
return image.url
110126
else:
127+
print(f"Unsupported image type: {type(image)}")
111128
raise ValueError(f"Unsupported image type: {type(image)}")
112129

113130

@@ -133,8 +150,7 @@ def _encode_image_from_url(image_url: str) -> str:
133150
encoded_image = base64.b64encode(response.content).decode("utf-8")
134151
return f"data:image/{file_extension};base64,{encoded_image}"
135152

136-
137-
def _encode_pil_image(image: "Image.Image") -> str:
153+
def _encode_pil_image(image: 'PILImage') -> str:
138154
"""Encode a PIL Image object to a base64 data URI."""
139155
buffered = io.BytesIO()
140156
file_extension = (image.format or "PNG").lower()
@@ -151,9 +167,7 @@ def _get_file_extension(path_or_url: str) -> str:
151167

152168
def is_image(obj) -> bool:
153169
"""Check if the object is an image or a valid image reference."""
154-
if PIL_AVAILABLE and isinstance(obj, Image.Image):
155-
return True
156-
if isinstance(obj, (bytes, bytearray)):
170+
if PIL_AVAILABLE and isinstance(obj, PILImage.Image):
157171
return True
158172
if isinstance(obj, str):
159173
if obj.startswith("data:image/"):
@@ -163,3 +177,52 @@ def is_image(obj) -> bool:
163177
elif is_url(obj):
164178
return True
165179
return False
180+
181+
def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
182+
"""Try to expand image tags in the messages."""
183+
for message in messages:
184+
# NOTE: Assumption that content is a string
185+
if "content" in message and "<DSPY_IMAGE_START>" in message["content"]:
186+
message["content"] = expand_image_tags(message["content"])
187+
return messages
188+
189+
def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
190+
"""Expand image tags in the text. If there are any image tags,
191+
turn it from a content string into a content list of texts and image urls.
192+
193+
Args:
194+
text: The text content that may contain image tags
195+
196+
Returns:
197+
Either the original string if no image tags, or a list of content dicts
198+
with text and image_url entries
199+
"""
200+
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'
201+
202+
# If no image tags, return original text
203+
if not re.search(image_tag_regex, text):
204+
return text
205+
206+
final_list = []
207+
remaining_text = text
208+
209+
while remaining_text:
210+
match = re.search(image_tag_regex, remaining_text)
211+
if not match:
212+
if remaining_text.strip():
213+
final_list.append({"type": "text", "text": remaining_text.strip()})
214+
break
215+
216+
# Get text before the image tag
217+
prefix = remaining_text[:match.start()].strip()
218+
if prefix:
219+
final_list.append({"type": "text", "text": prefix})
220+
221+
# Add the image
222+
image_url = match.group(1)
223+
final_list.append({"type": "image_url", "image_url": {"url": image_url}})
224+
225+
# Update remaining text
226+
remaining_text = remaining_text[match.end():].strip()
227+
228+
return final_list

dspy/adapters/json_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
151151
Returns:
152152
The formatted value of the field, represented as a string.
153153
"""
154+
# TODO: Wasnt this easy to fix?
154155
if field_info.annotation is Image:
155156
raise NotImplementedError("Images are not yet supported in JSON mode.")
156157

157-
return format_field_value(field_info=field_info, value=value, assume_text=True)
158+
return format_field_value(field_info=field_info, value=value)
158159

159160

160161
def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:

0 commit comments

Comments
 (0)