1- from collections import defaultdict
21import logging
2+ from collections import defaultdict
33from typing import Any , Callable , Dict , List , Optional , Union
44
55import dspy
1212from dspy .primitives .program import Program
1313from dspy .teleprompt .teleprompt import Teleprompter
1414
15-
1615logger = logging .getLogger (__name__ )
1716
1817
1918class FinetuneTeleprompter (Teleprompter ):
20-
2119 def __init__ (
2220 self ,
2321 train_kwargs : Optional [Union [Dict [str , Any ], Dict [LM , Dict [str , Any ]]]] = None ,
@@ -41,23 +39,25 @@ def __init__(
4139 train_kwargs : Optional [Union [Dict [str , Any ], Dict [LM , Dict [str , Any ]]]] = None ,
4240 adapter : Optional [Union [Adapter , Dict [LM , Adapter ]]] = None ,
4341 exclude_demos : bool = False ,
44- num_threads : int = 6
42+ num_threads : int = 6 ,
4543 ):
4644 # TODO(feature): Inputs train_kwargs (a dict with string keys) and
4745 # adapter (Adapter) can depend on the LM they are used with. We are
48- # takingthese as parameters for the time being. However, they can be
46+ # takingthese as parameters for the time being. However, they can be
4947 # attached to LMs themselves -- an LM could know which adapter it should
5048 # be used with along with the train_kwargs. This will lead the only
5149 # required argument for LM.finetune() to be the train dataset.
52-
50+
5351 super ().__init__ (train_kwargs = train_kwargs )
5452 self .metric = metric
5553 self .multitask = multitask
5654 self .adapter : Dict [LM , Adapter ] = self .convert_to_lm_dict (adapter )
5755 self .exclude_demos = exclude_demos
5856 self .num_threads = num_threads
59-
60- def compile (self , student : Program , trainset : List [Example ], teacher : Optional [Union [Program , List [Program ]]] = None ) -> Program :
57+
58+ def compile (
59+ self , student : Program , trainset : List [Example ], teacher : Optional [Union [Program , List [Program ]]] = None
60+ ) -> Program :
6161 # TODO: Print statements can be converted to logger.info if we ensure
6262 # that the default DSPy logger logs info level messages in notebook
6363 # environments.
@@ -71,24 +71,41 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
7171 teachers = [prepare_teacher (student , t ) for t in teachers ]
7272 for t in teachers :
7373 set_missing_predictor_lms (t )
74- trace_data += bootstrap_trace_data (program = t , dataset = trainset , metric = self .metric , num_threads = self .num_threads )
74+ trace_data += bootstrap_trace_data (
75+ program = t , dataset = trainset , metric = self .metric , num_threads = self .num_threads
76+ )
7577
7678 logger .info ("Preparing the train data..." )
7779 key_to_data = {}
7880 for pred_ind , pred in enumerate (student .predictors ()):
7981 data_pred_ind = None if self .multitask else pred_ind
8082 training_key = (pred .lm , data_pred_ind )
8183 if training_key not in key_to_data :
82- train_data , data_format = self ._prepare_finetune_data (trace_data = trace_data , lm = pred .lm , pred_ind = data_pred_ind )
84+ train_data , data_format = self ._prepare_finetune_data (
85+ trace_data = trace_data , lm = pred .lm , pred_ind = data_pred_ind
86+ )
8387 logger .info (f"Using { len (train_data )} data points for fine-tuning the model: { pred .lm .model } " )
84- finetune_kwargs = dict (lm = pred .lm , train_data = train_data , train_data_format = data_format , train_kwargs = self .train_kwargs [pred .lm ])
88+ finetune_kwargs = dict (
89+ lm = pred .lm ,
90+ train_data = train_data ,
91+ train_data_format = data_format ,
92+ train_kwargs = self .train_kwargs [pred .lm ],
93+ )
8594 key_to_data [training_key ] = finetune_kwargs
86-
95+
8796 logger .info ("Starting LM fine-tuning..." )
8897 # TODO(feature): We could run batches of fine-tuning jobs in sequence
8998 # to avoid exceeding the number of threads.
90- err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are { len (key_to_data )} fine-tuning jobs to start, but the number of threads is: { self .num_threads } ! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors."
91- assert len (key_to_data ) <= self .num_threads , err
99+ if len (key_to_data ) > self .num_threads :
100+ raise ValueError (
101+ "BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning "
102+ f"jobs. There are { len (key_to_data )} fine-tuning jobs to start, but the number of threads is: "
103+ f"{ self .num_threads } ! If the `multitask` flag is set to False, the number of fine-tuning jobs will "
104+ "be equal to the number of predictors in the student program. If the `multitask` flag is set to True, "
105+ "the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of "
106+ "unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning "
107+ "jobs will be less than or equal to the number of predictors."
108+ )
92109 logger .info (f"{ len (key_to_data )} fine-tuning job(s) to start" )
93110 key_to_lm = self .finetune_lms (key_to_data )
94111
@@ -98,10 +115,10 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
98115 training_key = (pred .lm , data_pred_ind )
99116 pred .lm = key_to_lm [training_key ]
100117 # TODO: What should the correct behavior be here? Should
101- # BootstrapFinetune modify the prompt demos according to the
118+ # BootstrapFinetune modify the prompt demos according to the
102119 # train data?
103120 pred .demos = [] if self .exclude_demos else pred .demos
104-
121+
105122 logger .info ("BootstrapFinetune has finished compiling the student program" )
106123 student ._compiled = True
107124 return student
@@ -120,10 +137,13 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:
120137 # up resources for fine-tuning. This might mean introducing a new
121138 # provider method (e.g. prepare_for_finetune) that can be called
122139 # before fine-tuning is started.
123- logger .info ("Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the LM is not running." )
140+ logger .info (
141+ "Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the "
142+ "LM is not running."
143+ )
124144 lm .kill ()
125145 key_to_job [key ] = lm .finetune (** finetune_kwargs )
126-
146+
127147 key_to_lm = {}
128148 for ind , (key , job ) in enumerate (key_to_job .items ()):
129149 key_to_lm [key ] = job .result ()
@@ -143,13 +163,16 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
143163 adapter = self .adapter [lm ] or lm .infer_adapter ()
144164 data_format = infer_data_format (adapter )
145165 for item in trace_data :
146- for pred_ind , _ in enumerate (item [' trace' ]):
166+ for pred_ind , _ in enumerate (item [" trace" ]):
147167 include_data = pred_ind is None or pred_ind == pred_ind
148168 if include_data :
149- call_data = build_call_data_from_trace (trace = item ['trace' ], pred_ind = pred_ind , adapter = adapter , exclude_demos = self .exclude_demos )
169+ call_data = build_call_data_from_trace (
170+ trace = item ["trace" ], pred_ind = pred_ind , adapter = adapter , exclude_demos = self .exclude_demos
171+ )
150172 data .append (call_data )
151173
152174 import random
175+
153176 random .Random (0 ).shuffle (data )
154177
155178 return data , data_format
@@ -189,8 +212,11 @@ def bootstrap_trace_data(
189212 # Return a list of dicts with the following keys:
190213 # example_ind, example, prediction, trace, and score (if metric != None)
191214 evaluator = Evaluate (
192- devset = dataset , num_threads = num_threads , display_progress = True , return_outputs = True ,
193- provide_traceback = True # TODO(check with team)
215+ devset = dataset ,
216+ num_threads = num_threads ,
217+ display_progress = True ,
218+ return_outputs = True ,
219+ provide_traceback = True , # TODO(check with team)
194220 )
195221
196222 def wrapped_metric (example , prediction , trace = None ):
@@ -286,11 +312,10 @@ def assert_structural_equivalency(program1: object, program2: object):
286312
287313 pzip = zip (program1 .named_predictors (), program2 .named_predictors ())
288314 for ind , ((name1 , pred1 ), (name2 , pred2 )) in enumerate (pzip ):
289- err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index { ind } : '{ name1 } ' != '{ name2 } '"
315+ err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index { ind } : '{ name1 } ' != '{ name2 } '"
290316 assert name1 == name2 , err
291317 assert isinstance (pred1 , Predict )
292318 assert isinstance (pred2 , Predict )
293- # assert pred1.signature.equals(pred2.signature)
294319
295320
296321def assert_no_shared_predictor (program1 : Program , program2 : Program ):
@@ -303,17 +328,18 @@ def assert_no_shared_predictor(program1: Program, program2: Program):
303328 assert not shared_ids , err
304329
305330
306- def get_unique_lms (program : Program ) -> List [LM ]:
307- lms = [pred .lm for pred in program .predictors ()]
308- lms = list (set (lms ))
309- return lms
331+ def get_unique_lms (program : Program ) -> List [LM ]:
332+ lms = [pred .lm for pred in program .predictors ()]
333+ return list (set (lms ))
334+
310335
311336def launch_lms (program : Program ):
312337 lms = get_unique_lms (program )
313338 for lm in lms :
314339 lm .launch ()
315340
316- def kill_lms (program : Program ):
317- lms = get_unique_lms (program )
318- for lm in lms :
341+
342+ def kill_lms (program : Program ):
343+ lms = get_unique_lms (program )
344+ for lm in lms :
319345 lm .kill ()
0 commit comments