55import textwrap
66from collections .abc import Mapping
77from itertools import chain
8-
98from typing import Any , Dict , Literal , NamedTuple
109
1110import pydantic
1211from pydantic .fields import FieldInfo
1312
1413from dspy .adapters .base import Adapter
15- from dspy .adapters .utils import parse_value , format_field_value , get_annotation_name
14+ from dspy .adapters .image_utils import try_expand_image_tags
15+ from dspy .adapters .utils import format_field_value , get_annotation_name , parse_value
1616from dspy .signatures .field import OutputField
1717from dspy .signatures .signature import Signature , SignatureMeta
1818from dspy .signatures .utils import get_dspy_field_type
19- from dspy .adapters .image_utils import try_expand_image_tags
2019
2120field_header_pattern = re .compile (r"\[\[ ## (\w+) ## \]\]" )
2221
@@ -99,9 +98,6 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
9998 # Wrap the messages in a dictionary with a "messages" key
10099 return dict (messages = messages )
101100
102- def format_turn (self , signature , values , role , incomplete = False ):
103- return format_turn (signature , values , role , incomplete )
104-
105101 def format_fields (self , signature , values , role ):
106102 fields_with_values = {
107103 FieldInfoWithName (name = field_name , info = field_info ): values .get (
@@ -152,7 +148,9 @@ def format_turn(signature, values, role, incomplete=False):
152148 """
153149 if role == "user" :
154150 fields = signature .input_fields
155- message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
151+ message_prefix = (
152+ "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
153+ )
156154 else :
157155 # Add the completed field for the assistant turn
158156 fields = {** signature .output_fields , BuiltInCompletedOutputFieldInfo .name : BuiltInCompletedOutputFieldInfo .info }
@@ -167,31 +165,40 @@ def format_turn(signature, values, role, incomplete=False):
167165 messages .append (message_prefix )
168166
169167 field_messages = format_fields (
170- {FieldInfoWithName (name = k , info = v ): values .get (k , "Not supplied for this particular example." )
171- for k , v in fields .items ()},
168+ {
169+ FieldInfoWithName (name = k , info = v ): values .get (k , "Not supplied for this particular example." )
170+ for k , v in fields .items ()
171+ },
172172 )
173173 messages .append (field_messages )
174174
175+ def type_info (v ):
176+ if v .annotation is not str :
177+ return f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )"
178+ else :
179+ return ""
180+
175181 # Add output field instructions for user messages
176182 if role == "user" and signature .output_fields :
177- type_info = lambda v : f" (must be formatted as a valid Python { get_annotation_name (v .annotation )} )" if v .annotation is not str else ""
178- field_instructions = "Respond with the corresponding output fields, starting with the field " + \
179- ", then " .join (f"`[[ ## { f } ## ]]`{ type_info (v )} " for f , v in signature .output_fields .items ()) + \
180- ", and then ending with the marker for `[[ ## completed ## ]]`."
183+ field_instructions = (
184+ "Respond with the corresponding output fields, starting with the field "
185+ + ", then " .join (f"`[[ ## { f } ## ]]`{ type_info (v )} " for f , v in signature .output_fields .items ())
186+ + ", and then ending with the marker for `[[ ## completed ## ]]`."
187+ )
181188 messages .append (field_instructions )
182189 joined_messages = "\n \n " .join (msg for msg in messages )
183190 return {"role" : role , "content" : joined_messages }
184191
192+
185193def flatten_messages (messages ):
186194 """Flatten nested message lists."""
187- return list (chain .from_iterable (
188- item if isinstance (item , list ) else [item ] for item in messages
189- ))
195+ return list (chain .from_iterable (item if isinstance (item , list ) else [item ] for item in messages ))
196+
190197
191198def enumerate_fields (fields : dict ) -> str :
192199 parts = []
193200 for idx , (k , v ) in enumerate (fields .items ()):
194- parts .append (f"{ idx + 1 } . `{ k } `" )
201+ parts .append (f"{ idx + 1 } . `{ k } `" )
195202 parts [- 1 ] += f" ({ get_annotation_name (v .annotation )} )"
196203 parts [- 1 ] += f": { v .json_schema_extra ['desc' ]} " if v .json_schema_extra ["desc" ] != f"${{{ k } }}" else ""
197204
@@ -207,8 +214,8 @@ def move_type_to_front(d):
207214 return d
208215
209216
210- def prepare_schema (type_ ):
211- schema = pydantic .TypeAdapter (type_ ).json_schema ()
217+ def prepare_schema (field_type ):
218+ schema = pydantic .TypeAdapter (field_type ).json_schema ()
212219 schema = move_type_to_front (schema )
213220 return schema
214221
@@ -237,7 +244,6 @@ def field_metadata(field_name, field_info):
237244 f"must exactly match (no extra characters) one of: { '; ' .join ([str (x ) for x in field_type .__args__ ])} "
238245 )
239246 else :
240- # desc = "must be pareseable according to the following JSON schema: "
241247 desc = "must adhere to the JSON schema: "
242248 desc += json .dumps (prepare_schema (field_type ), ensure_ascii = False )
243249
0 commit comments