66import textwrap
77from collections .abc import Mapping
88from itertools import chain
9- from typing import Any , Dict , List , Literal , NamedTuple , Union
9+
10+ from typing import Any , Dict , Literal , NamedTuple
1011
1112import pydantic
1213from pydantic import TypeAdapter
1718from dspy .signatures .field import OutputField
1819from dspy .signatures .signature import Signature , SignatureMeta
1920from dspy .signatures .utils import get_dspy_field_type
21+ from dspy .adapters .image_utils import try_expand_image_tags
2022
2123field_header_pattern = re .compile (r"\[\[ ## (\w+) ## \]\]" )
2224
@@ -50,12 +52,12 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
5052
5153 prepared_instructions = prepare_instructions (signature )
5254 messages .append ({"role" : "system" , "content" : prepared_instructions })
53-
5455 for demo in demos :
5556 messages .append (format_turn (signature , demo , role = "user" , incomplete = demo in incomplete_demos ))
5657 messages .append (format_turn (signature , demo , role = "assistant" , incomplete = demo in incomplete_demos ))
5758
5859 messages .append (format_turn (signature , inputs , role = "user" ))
60+ messages = try_expand_image_tags (messages )
5961 return messages
6062
6163 def parse (self , signature , completion ):
@@ -110,11 +112,10 @@ def format_fields(self, signature, values, role):
110112 for field_name , field_info in signature .fields .items ()
111113 if field_name in values
112114 }
113-
114115 return format_fields (fields_with_values )
115116
116117
117- def format_fields (fields_with_values : Dict [FieldInfoWithName , Any ], assume_text = True ) -> Union [ str , List [ dict ]] :
118+ def format_fields (fields_with_values : Dict [FieldInfoWithName , Any ]) -> str :
118119 """
119120 Formats the values of the specified fields according to the field's DSPy type (input or output),
120121 annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
@@ -124,23 +125,14 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
124125 fields_with_values: A dictionary mapping information about a field to its corresponding
125126 value.
126127 Returns:
127- The joined formatted values of the fields, represented as a string or a list of dicts
128+ The joined formatted values of the fields, represented as a string
128129 """
129130 output = []
130131 for field , field_value in fields_with_values .items ():
131- formatted_field_value = format_field_value (field_info = field .info , value = field_value , assume_text = assume_text )
132- if assume_text :
133- output .append (f"[[ ## { field .name } ## ]]\n { formatted_field_value } " )
134- else :
135- output .append ({"type" : "text" , "text" : f"[[ ## { field .name } ## ]]\n " })
136- if isinstance (formatted_field_value , dict ) and formatted_field_value .get ("type" ) == "image_url" :
137- output .append (formatted_field_value )
138- else :
139- output [- 1 ]["text" ] += formatted_field_value ["text" ]
140- if assume_text :
141- return "\n \n " .join (output ).strip ()
142- else :
143- return output
132+ formatted_field_value = format_field_value (field_info = field .info , value = field_value )
133+ output .append (f"[[ ## { field .name } ## ]]\n { formatted_field_value } " )
134+
135+ return "\n \n " .join (output ).strip ()
144136
145137
146138def parse_value (value , annotation ):
@@ -180,92 +172,43 @@ def format_turn(signature, values, role, incomplete=False):
180172 A chat message that can be appended to a chat thread. The message contains two string fields:
181173 ``role`` ("user" or "assistant") and ``content`` (the message text).
182174 """
183- fields_to_collapse = []
184- content = []
185-
186175 if role == "user" :
187176 fields = signature .input_fields
188- if incomplete :
189- fields_to_collapse .append (
190- {
191- "type" : "text" ,
192- "text" : "This is an example of the task, though some input or output fields are not supplied." ,
193- }
194- )
177+ message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
195178 else :
196- fields = signature .output_fields
197- # Add the built-in field indicating that the chat turn has been completed
198- fields [BuiltInCompletedOutputFieldInfo .name ] = BuiltInCompletedOutputFieldInfo .info
179+ # Add the completed field for the assistant turn
180+ fields = {** signature .output_fields , BuiltInCompletedOutputFieldInfo .name : BuiltInCompletedOutputFieldInfo .info }
199181 values = {** values , BuiltInCompletedOutputFieldInfo .name : "" }
200- field_names = fields .keys ()
201- if not incomplete :
202- if not set (values ).issuperset (set (field_names )):
203- raise ValueError (f"Expected { field_names } but got { values .keys ()} " )
182+ message_prefix = ""
204183
205- fields_to_collapse .extend (
206- format_fields (
207- fields_with_values = {
208- FieldInfoWithName (name = field_name , info = field_info ): values .get (
209- field_name , "Not supplied for this particular example."
210- )
211- for field_name , field_info in fields .items ()
212- },
213- assume_text = False ,
214- )
215- )
216-
217- if role == "user" :
218- output_fields = list (signature .output_fields .keys ())
184+ if not incomplete and not set (values ).issuperset (fields .keys ()):
185+ raise ValueError (f"Expected { fields .keys ()} but got { values .keys ()} " )
219186
220- def type_info (v ):
221- return (
222- f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )"
223- if v .annotation is not str
224- else ""
225- )
187+ messages = []
188+ if message_prefix :
189+ messages .append (message_prefix )
226190
227- if output_fields :
228- fields_to_collapse .append (
229- {
230- "type" : "text" ,
231- "text" : "Respond with the corresponding output fields, starting with the field "
232- + ", then " .join (f"`[[ ## { f } ## ]]`{ type_info (v )} " for f , v in signature .output_fields .items ())
233- + ", and then ending with the marker for `[[ ## completed ## ]]`." ,
234- }
235- )
236-
237- # flatmap the list if any items are lists otherwise keep the item
238- flattened_list = list (
239- chain .from_iterable (item if isinstance (item , list ) else [item ] for item in fields_to_collapse )
191+ field_messages = format_fields (
192+ {FieldInfoWithName (name = k , info = v ): values .get (k , "Not supplied for this particular example." )
193+ for k , v in fields .items ()},
240194 )
241-
242- if all (message .get ("type" , None ) == "text" for message in flattened_list ):
243- content = "\n \n " .join (message .get ("text" ) for message in flattened_list )
244- return {"role" : role , "content" : content }
245-
246- # Collapse all consecutive text messages into a single message.
247- collapsed_messages = []
248- for item in flattened_list :
249- # First item is always added
250- if not collapsed_messages :
251- collapsed_messages .append (item )
252- continue
253-
254- # If the current item is image, add to collapsed_messages
255- if item .get ("type" ) == "image_url" :
256- if collapsed_messages [- 1 ].get ("type" ) == "text" :
257- collapsed_messages [- 1 ]["text" ] += "\n "
258- collapsed_messages .append (item )
259- # If the previous item is text and current item is text, append to the previous item
260- elif collapsed_messages [- 1 ].get ("type" ) == "text" :
261- collapsed_messages [- 1 ]["text" ] += "\n \n " + item ["text" ]
262- # If the previous item is not text(aka image), add the current item as a new item
263- else :
264- item ["text" ] = "\n \n " + item ["text" ]
265- collapsed_messages .append (item )
266-
267- return {"role" : role , "content" : collapsed_messages }
268-
195+ messages .append (field_messages )
196+
197+ # Add output field instructions for user messages
198+ if role == "user" and signature .output_fields :
199+ type_info = lambda v : f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )" if v .annotation is not str else ""
200+ field_instructions = "Respond with the corresponding output fields, starting with the field " + \
201+ ", then " .join (f"`[[ ## { f } ## ]]`{ type_info (v )} " for f , v in signature .output_fields .items ()) + \
202+ ", and then ending with the marker for `[[ ## completed ## ]]`."
203+ messages .append (field_instructions )
204+ joined_messages = "\n \n " .join (msg for msg in messages )
205+ return {"role" : role , "content" : joined_messages }
206+
207+ def flatten_messages (messages ):
208+ """Flatten nested message lists."""
209+ return list (chain .from_iterable (
210+ item if isinstance (item , list ) else [item ] for item in messages
211+ ))
269212
270213def enumerate_fields (fields : dict ) -> str :
271214 parts = []
@@ -328,12 +271,11 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
328271 FieldInfoWithName (name = field_name , info = field_info ): field_metadata (field_name , field_info )
329272 for field_name , field_info in fields .items ()
330273 },
331- assume_text = True ,
332274 )
333275
334276 parts .append (format_signature_fields_for_instructions (signature .input_fields ))
335277 parts .append (format_signature_fields_for_instructions (signature .output_fields ))
336- parts .append (format_fields ({BuiltInCompletedOutputFieldInfo : "" }, assume_text = True ))
278+ parts .append (format_fields ({BuiltInCompletedOutputFieldInfo : "" }))
337279 instructions = textwrap .dedent (signature .instructions )
338280 objective = ("\n " + " " * 8 ).join (["" ] + instructions .splitlines ())
339281 parts .append (f"In adhering to this structure, your objective is: { objective } " )
0 commit comments