@@ -37,7 +37,14 @@ class ChatAdapter(Adapter):
3737 def __init__ (self , callbacks : Optional [list [BaseCallback ]] = None ):
3838 super ().__init__ (callbacks )
3939
40- def __call__ (self , lm : LM , lm_kwargs : dict [str , Any ], signature : Type [Signature ], demos : list [dict [str , Any ]], inputs : dict [str , Any ]) -> list [dict [str , Any ]]:
40+ def __call__ (
41+ self ,
42+ lm : LM ,
43+ lm_kwargs : dict [str , Any ],
44+ signature : Type [Signature ],
45+ demos : list [dict [str , Any ]],
46+ inputs : dict [str , Any ],
47+ ) -> list [dict [str , Any ]]:
4148 try :
4249 return super ().__call__ (lm , lm_kwargs , signature , demos , inputs )
4350 except Exception as e :
@@ -46,8 +53,10 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
4653 raise e
4754 # fallback to JSONAdapter
4855 return JSONAdapter ()(lm , lm_kwargs , signature , demos , inputs )
49-
50- def format (self , signature : Type [Signature ], demos : list [dict [str , Any ]], inputs : dict [str , Any ]) -> list [dict [str , Any ]]:
56+
57+ def format (
58+ self , signature : Type [Signature ], demos : list [dict [str , Any ]], inputs : dict [str , Any ]
59+ ) -> list [dict [str , Any ]]:
5160 messages : list [dict [str , Any ]] = []
5261
5362 # Extract demos where some of the output_fields are not filled in.
@@ -88,7 +97,7 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
8897 if match :
8998 # If the header pattern is found, split the rest of the line as content
9099 header = match .group (1 )
91- remaining_content = line [match .end ():].strip ()
100+ remaining_content = line [match .end () :].strip ()
92101 sections .append ((header , [remaining_content ] if remaining_content else []))
93102 else :
94103 sections [- 1 ][1 ].append (line )
@@ -111,7 +120,9 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
111120 return fields
112121
113122 # TODO(PR): Looks ok?
114- def format_finetune_data (self , signature : Type [Signature ], demos : list [dict [str , Any ]], inputs : dict [str , Any ], outputs : dict [str , Any ]) -> dict [str , list [Any ]]:
123+ def format_finetune_data (
124+ self , signature : Type [Signature ], demos : list [dict [str , Any ]], inputs : dict [str , Any ], outputs : dict [str , Any ]
125+ ) -> dict [str , list [Any ]]:
115126 # Get system + user messages
116127 messages = self .format (signature , demos , inputs )
117128
@@ -134,7 +145,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
134145 }
135146 return format_fields (fields_with_values )
136147
137- def format_turn (self , signature : Type [Signature ], values : dict [str , Any ], role : str , incomplete : bool = False , is_conversation_history : bool = False ) -> dict [str , Any ]:
148+ def format_turn (
149+ self ,
150+ signature : Type [Signature ],
151+ values : dict [str , Any ],
152+ role : str ,
153+ incomplete : bool = False ,
154+ is_conversation_history : bool = False ,
155+ ) -> dict [str , Any ]:
138156 return format_turn (signature , values , role , incomplete , is_conversation_history )
139157
140158
@@ -158,7 +176,9 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
158176 return "\n \n " .join (output ).strip ()
159177
160178
161- def format_turn (signature : Type [Signature ], values : dict [str , Any ], role : str , incomplete = False , is_conversation_history = False ):
179+ def format_turn (
180+ signature : Type [Signature ], values : dict [str , Any ], role : str , incomplete = False , is_conversation_history = False
181+ ):
162182 """
163183 Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
164184 so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
0 commit comments