1- import re
21import ast
32import json
3+ import re
44import textwrap
5+ from typing import get_args , get_origin
56
6- from pydantic import TypeAdapter
77import pydantic
8+ from pydantic import TypeAdapter
9+
810from .base import Adapter
9- from typing import get_origin , get_args
1011
11- field_header_pattern = re .compile (r' \[\[ ## (\w+) ## \]\]' )
12+ field_header_pattern = re .compile (r" \[\[ ## (\w+) ## \]\]" )
1213
1314
1415class ChatAdapter (Adapter ):
@@ -21,9 +22,11 @@ def format(self, signature, demos, inputs):
2122 # Extract demos where some of the output_fields are not filled in.
2223 incomplete_demos = [demo for demo in demos if not all (k in demo for k in signature .fields )]
2324 complete_demos = [demo for demo in demos if demo not in incomplete_demos ]
24- incomplete_demos = [demo for demo in incomplete_demos \
25- if any (k in demo for k in signature .input_fields ) and \
26- any (k in demo for k in signature .output_fields )]
25+ incomplete_demos = [
26+ demo
27+ for demo in incomplete_demos
28+ if any (k in demo for k in signature .input_fields ) and any (k in demo for k in signature .output_fields )
29+ ]
2730
2831 demos = incomplete_demos + complete_demos
2932
@@ -32,44 +35,52 @@ def format(self, signature, demos, inputs):
3235 for demo in demos :
3336 messages .append (format_turn (signature , demo , role = "user" , incomplete = demo in incomplete_demos ))
3437 messages .append (format_turn (signature , demo , role = "assistant" , incomplete = demo in incomplete_demos ))
35-
38+
3639 messages .append (format_turn (signature , inputs , role = "user" ))
3740
3841 return messages
39-
42+
4043 def parse (self , signature , completion , _parse_values = True ):
4144 sections = [(None , [])]
4245
4346 for line in completion .splitlines ():
4447 match = field_header_pattern .match (line .strip ())
45- if match : sections .append ((match .group (1 ), []))
46- else : sections [- 1 ][1 ].append (line )
48+ if match :
49+ sections .append ((match .group (1 ), []))
50+ else :
51+ sections [- 1 ][1 ].append (line )
4752
48- sections = [(k , ' \n ' .join (v ).strip ()) for k , v in sections ]
53+ sections = [(k , " \n " .join (v ).strip ()) for k , v in sections ]
4954
5055 fields = {}
5156 for k , v in sections :
5257 if (k not in fields ) and (k in signature .output_fields ):
5358 try :
5459 fields [k ] = parse_value (v , signature .output_fields [k ].annotation ) if _parse_values else v
5560 except Exception as e :
56- raise ValueError (f"Error parsing field { k } : { e } .\n \n \t \t On attempting to parse the value\n ```\n { v } \n ```" )
61+ raise ValueError (
62+ f"Error parsing field { k } : { e } .\n \n \t \t On attempting to parse the value\n ```\n { v } \n ```"
63+ )
5764
5865 if fields .keys () != signature .output_fields .keys ():
5966 raise ValueError (f"Expected { signature .output_fields .keys ()} but got { fields .keys ()} " )
6067
6168 return fields
6269
70+
6371def format_blob (blob ):
64- if '\n ' not in blob and "«" not in blob and "»" not in blob : return f"«{ blob } »"
72+ if "\n " not in blob and "«" not in blob and "»" not in blob :
73+ return f"«{ blob } »"
6574
66- modified_blob = blob .replace (' \n ' , ' \n ' )
75+ modified_blob = blob .replace (" \n " , " \n " )
6776 return f"«««\n { modified_blob } \n »»»"
6877
6978
7079def format_list (items ):
71- if len (items ) == 0 : return "N/A"
72- if len (items ) == 1 : return format_blob (items [0 ])
80+ if len (items ) == 0 :
81+ return "N/A"
82+ if len (items ) == 1 :
83+ return format_blob (items [0 ])
7384
7485 return "\n " .join ([f"[{ idx + 1 } ] { format_blob (txt )} " for idx , txt in enumerate (items )])
7586
@@ -89,82 +100,90 @@ def format_fields(fields):
89100 v = _format_field_value (v )
90101 output .append (f"[[ ## { k } ## ]]\n { v } " )
91102
92- return ' \n \n ' .join (output ).strip ()
93-
103+ return " \n \n " .join (output ).strip ()
104+
94105
95106def parse_value (value , annotation ):
96- if annotation is str : return str (value )
107+ if annotation is str :
108+ return str (value )
97109 parsed_value = value
98110 if isinstance (value , str ):
99- try : parsed_value = json .loads (value )
111+ try :
112+ parsed_value = json .loads (value )
100113 except json .JSONDecodeError :
101- try : parsed_value = ast .literal_eval (value )
102- except (ValueError , SyntaxError ): parsed_value = value
114+ try :
115+ parsed_value = ast .literal_eval (value )
116+ except (ValueError , SyntaxError ):
117+ parsed_value = value
103118 return TypeAdapter (annotation ).validate_python (parsed_value )
104119
105120
106- def format_turn (signature , values , role , incomplete = False ):
121+ def format_turn (signature , values , role , incomplete = False ):
107122 content = []
108123
109124 if role == "user" :
110125 field_names = signature .input_fields .keys ()
111126 if incomplete :
112127 content .append ("This is an example of the task, though some input or output fields are not supplied." )
113128 else :
114- field_names , values = list (signature .output_fields .keys ()) + [' completed' ], {** values , ' completed' : '' }
129+ field_names , values = list (signature .output_fields .keys ()) + [" completed" ], {** values , " completed" : "" }
115130
116131 if not incomplete :
117132 if not set (values ).issuperset (set (field_names )):
118133 raise ValueError (f"Expected { field_names } but got { values .keys ()} " )
119-
134+
120135 content .append (format_fields ({k : values .get (k , "Not supplied for this particular example." ) for k in field_names }))
121136
122137 if role == "user" :
123- content .append ("Respond with the corresponding output fields, starting with the field " +
124- ", then " .join (f"`{ f } `" for f in signature .output_fields ) +
125- ", and then ending with the marker for `completed`." )
138+ content .append (
139+ "Respond with the corresponding output fields, starting with the field "
140+ + ", then " .join (f"`{ f } `" for f in signature .output_fields )
141+ + ", and then ending with the marker for `completed`."
142+ )
126143
127- return {"role" : role , "content" : ' \n \n ' .join (content ).strip ()}
144+ return {"role" : role , "content" : " \n \n " .join (content ).strip ()}
128145
129146
130147def get_annotation_name (annotation ):
131148 origin = get_origin (annotation )
132149 args = get_args (annotation )
133150 if origin is None :
134- if hasattr (annotation , ' __name__' ):
151+ if hasattr (annotation , " __name__" ):
135152 return annotation .__name__
136153 else :
137154 return str (annotation )
138155 else :
139- args_str = ', ' .join (get_annotation_name (arg ) for arg in args )
140- return f"{ origin .__name__ } [{ args_str } ]"
156+ args_str = ", " .join (get_annotation_name (arg ) for arg in args )
157+ return f"{ get_annotation_name (origin )} [{ args_str } ]"
158+
141159
142160def enumerate_fields (fields ):
143161 parts = []
144162 for idx , (k , v ) in enumerate (fields .items ()):
145163 parts .append (f"{ idx + 1 } . `{ k } `" )
146164 parts [- 1 ] += f" ({ get_annotation_name (v .annotation )} )"
147- parts [- 1 ] += f": { v .json_schema_extra ['desc' ]} " if v .json_schema_extra ['desc' ] != f'${{{ k } }}' else ''
165+ parts [- 1 ] += f": { v .json_schema_extra ['desc' ]} " if v .json_schema_extra ["desc" ] != f"${{{ k } }}" else ""
166+
167+ return "\n " .join (parts ).strip ()
148168
149- return '\n ' .join (parts ).strip ()
150169
151170def prepare_instructions (signature ):
152171 parts = []
153172 parts .append ("Your input fields are:\n " + enumerate_fields (signature .input_fields ))
154173 parts .append ("Your output fields are:\n " + enumerate_fields (signature .output_fields ))
155174 parts .append ("All interactions will be structured in the following way, with the appropriate values filled in." )
156175
157- parts .append (format_fields ({f : f"{{{ f } }}" for f in signature .input_fields }))
158- parts .append (format_fields ({f : f"{{{ f } }}" for f in signature .output_fields }))
159- parts .append (format_fields ({' completed' : "" }))
176+ parts .append (format_fields ({f : f"{{{ f } }}" for f in signature .input_fields }))
177+ parts .append (format_fields ({f : f"{{{ f } }}" for f in signature .output_fields }))
178+ parts .append (format_fields ({" completed" : "" }))
160179
161180 instructions = textwrap .dedent (signature .instructions )
162- objective = (' \n ' + ' ' * 8 ).join (['' ] + instructions .splitlines ())
181+ objective = (" \n " + " " * 8 ).join (["" ] + instructions .splitlines ())
163182 parts .append (f"In adhering to this structure, your objective is: { objective } " )
164183
165184 # parts.append("You will receive some input fields in each interaction. " +
166185 # "Respond only with the corresponding output fields, starting with the field " +
167186 # ", then ".join(f"`{f}`" for f in signature.output_fields) +
168187 # ", and then ending with the marker for `completed`.")
169188
170- return ' \n \n ' .join (parts ).strip ()
189+ return " \n \n " .join (parts ).strip ()
0 commit comments