1+ import re
12import ast
23import json
3- import re
4+ import enum
5+ import inspect
6+ import pydantic
47import textwrap
5- from typing import Any , Dict , KeysView , List , Literal , NamedTuple , get_args , get_origin
68
7- import pydantic
89from pydantic import TypeAdapter
910from pydantic .fields import FieldInfo
11+ from typing import Any , Dict , KeysView , List , Literal , NamedTuple , get_args , get_origin
1012
13+ from dspy .adapters .base import Adapter
1114from ..signatures .field import OutputField
1215from ..signatures .signature import SignatureMeta
1316from ..signatures .utils import get_dspy_field_type
14- from .base import Adapter
1517
1618field_header_pattern = re .compile (r"\[\[ ## (\w+) ## \]\]" )
1719
1820
1921class FieldInfoWithName (NamedTuple ):
20- """
21- A tuple containing a field name and its corresponding FieldInfo object.
22- """
23-
2422 name : str
2523 info : FieldInfo
2624
2725
28- # Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat
29- # thread) has been completed.
26+ # Built-in field indicating that a chat turn has been completed.
3027BuiltInCompletedOutputFieldInfo = FieldInfoWithName (name = "completed" , info = OutputField ())
3128
3229
@@ -114,6 +111,16 @@ def format_input_list_field_value(value: List[Any]) -> str:
114111 return "\n " .join ([f"[{ idx + 1 } ] { format_blob (txt )} " for idx , txt in enumerate (value )])
115112
116113
114+ def _serialize_for_json (value ):
115+ if isinstance (value , pydantic .BaseModel ):
116+ return value .model_dump ()
117+ elif isinstance (value , list ):
118+ return [_serialize_for_json (item ) for item in value ]
119+ elif isinstance (value , dict ):
120+ return {key : _serialize_for_json (val ) for key , val in value .items ()}
121+ else :
122+ return value
123+
117124def _format_field_value (field_info : FieldInfo , value : Any ) -> str :
118125 """
119126 Formats the value of the specified field according to the field's DSPy type (input or output),
@@ -125,24 +132,17 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
125132 Returns:
126133 The formatted value of the field, represented as a string.
127134 """
128- dspy_field_type : Literal ["input" , "output" ] = get_dspy_field_type (field_info )
129- if isinstance (value , list ):
130- if dspy_field_type == "input" or field_info .annotation is str :
131- # If the field is an input field or has no special type requirements, format it as
132- # numbered list so that it's organized in a way suitable for presenting long context
133- # to an LLM (i.e. not JSON)
134- return format_input_list_field_value (value )
135- else :
136- # If the field is an output field that has strict parsing requirements, format the
137- # value as a stringified JSON Array. This ensures that downstream routines can parse
138- # the field value correctly using methods from the `ujson` or `json` packages.
139- return json .dumps (value )
140- elif isinstance (value , pydantic .BaseModel ):
141- return value .model_dump_json ()
135+
136+ if isinstance (value , list ) and field_info .annotation is str :
137+ # If the field has no special type requirements, format it as a nice numbere list for the LM.
138+ return format_input_list_field_value (value )
139+ elif isinstance (value , pydantic .BaseModel ) or isinstance (value , dict ) or isinstance (value , list ):
140+ return json .dumps (_serialize_for_json (value ))
142141 else :
143142 return str (value )
144143
145144
145+
146146def format_fields (fields_with_values : Dict [FieldInfoWithName , Any ]) -> str :
147147 """
148148 Formats the values of the specified fields according to the field's DSPy type (input or output),
@@ -166,15 +166,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
166166def parse_value (value , annotation ):
167167 if annotation is str :
168168 return str (value )
169+
169170 parsed_value = value
170- if isinstance (value , str ):
171+
172+ if isinstance (annotation , enum .EnumMeta ):
173+ parsed_value = annotation [value ]
174+ elif isinstance (value , str ):
171175 try :
172176 parsed_value = json .loads (value )
173177 except json .JSONDecodeError :
174178 try :
175179 parsed_value = ast .literal_eval (value )
176180 except (ValueError , SyntaxError ):
177181 parsed_value = value
182+
178183 return TypeAdapter (annotation ).validate_python (parsed_value )
179184
180185
@@ -222,6 +227,16 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
222227 content .append (formatted_fields )
223228
224229 if role == "user" :
230+ # def type_info(v):
231+ # return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
232+ # if v.annotation is not str else ""
233+ #
234+ # content.append(
235+ # "Respond with the corresponding output fields, starting with the field "
236+ # + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
237+ # + ", and then ending with the marker for `[[ ## completed ## ]]`."
238+ # )
239+
225240 content .append (
226241 "Respond with the corresponding output fields, starting with the field "
227242 + ", then " .join (f"`{ f } `" for f in signature .output_fields )
@@ -260,10 +275,30 @@ def prepare_instructions(signature: SignatureMeta):
260275 parts .append ("Your output fields are:\n " + enumerate_fields (signature .output_fields ))
261276 parts .append ("All interactions will be structured in the following way, with the appropriate values filled in." )
262277
278+ def field_metadata (field_name , field_info ):
279+ type_ = field_info .annotation
280+
281+ if get_dspy_field_type (field_info ) == 'input' or type_ is str :
282+ desc = ""
283+ elif type_ is bool :
284+ desc = "must be True or False"
285+ elif type_ in (int , float ):
286+ desc = f"must be a single { type_ .__name__ } value"
287+ elif inspect .isclass (type_ ) and issubclass (type_ , enum .Enum ):
288+ desc = f"must be one of: { '; ' .join (type_ .__members__ )} "
289+ elif hasattr (type_ , '__origin__' ) and type_ .__origin__ is Literal :
290+ desc = f"must be one of: { '; ' .join ([str (x ) for x in type_ .__args__ ])} "
291+ else :
292+ desc = "must be pareseable according to the following JSON schema: "
293+ desc += json .dumps (pydantic .TypeAdapter (type_ ).json_schema ())
294+
295+ desc = (" " * 8 ) + f"# note: the value you produce { desc } " if desc else ""
296+ return f"{{{ field_name } }}{ desc } "
297+
263298 def format_signature_fields_for_instructions (fields : Dict [str , FieldInfo ]):
264299 return format_fields (
265300 fields_with_values = {
266- FieldInfoWithName (name = field_name , info = field_info ): f"{{ { field_name } }}"
301+ FieldInfoWithName (name = field_name , info = field_info ): field_metadata ( field_name , field_info )
267302 for field_name , field_info in fields .items ()
268303 }
269304 )
0 commit comments