2424### OPTIMIZER TRAINING UTILS ###
2525
2626
27- def create_minibatch (trainset , batch_size = 50 ):
27+ def create_minibatch (trainset , batch_size = 50 , rng = None ):
2828 """Create a minibatch from the trainset."""
2929
3030 # Ensure batch_size isn't larger than the size of the dataset
3131 batch_size = min (batch_size , len (trainset ))
3232
33- # Randomly sample indices for the mini-batch
34- sampled_indices = random .sample (range (len (trainset )), batch_size )
33+ # If no RNG is provided, fall back to the global random instance
34+ rng = rng or random
35+
36+ # Randomly sample indices for the mini-batch using the provided rng
37+ sampled_indices = rng .sample (range (len (trainset )), batch_size )
3538
3639 # Create the mini-batch using the sampled indices
3740 minibatch = [trainset [i ] for i in sampled_indices ]
3841
3942 return minibatch
4043
4144
42- def eval_candidate_program (batch_size , trainset , candidate_program , evaluate ):
45+ def eval_candidate_program (batch_size , trainset , candidate_program , evaluate , rng = None ):
4346 """Evaluate a candidate program on the trainset, using the specified batch size."""
4447 # Evaluate on the full trainset
4548 if batch_size >= len (trainset ):
@@ -48,7 +51,7 @@ def eval_candidate_program(batch_size, trainset, candidate_program, evaluate):
4851 else :
4952 score = evaluate (
5053 candidate_program ,
51- devset = create_minibatch (trainset , batch_size ),
54+ devset = create_minibatch (trainset , batch_size , rng ),
5255 )
5356
5457 return score
@@ -279,6 +282,7 @@ def create_n_fewshot_demo_sets(
279282 teacher = None ,
280283 include_non_bootstrapped = True ,
281284 seed = 0 ,
285+ rng = None
282286):
283287 """
284288 This function is copied from random_search.py, and creates fewshot examples in the same way that random search does.
@@ -292,17 +296,15 @@ def create_n_fewshot_demo_sets(
292296 # Initialize demo_candidates dictionary
293297 for i , _ in enumerate (student .predictors ()):
294298 demo_candidates [i ] = []
295-
296- starter_seed = seed
297- # Shuffle the trainset with the starter seed
298- random .Random (starter_seed ).shuffle (trainset )
299+
300+ rng = rng or random .Random (seed )
299301
300302 # Go through and create each candidate set
301303 for seed in range (- 3 , num_candidate_sets ):
302304
303305 print (f"Bootstrapping set { seed + 4 } /{ num_candidate_sets + 3 } " )
304306
305- trainset2 = list (trainset )
307+ trainset_copy = list (trainset )
306308
307309 if seed == - 3 and include_non_bootstrapped :
308310 # zero-shot
@@ -316,7 +318,7 @@ def create_n_fewshot_demo_sets(
316318 # labels only
317319 teleprompter = LabeledFewShot (k = max_labeled_demos )
318320 program2 = teleprompter .compile (
319- student , trainset = trainset2 , sample = labeled_sample ,
321+ student , trainset = trainset_copy , sample = labeled_sample ,
320322 )
321323
322324 elif seed == - 1 :
@@ -329,12 +331,12 @@ def create_n_fewshot_demo_sets(
329331 teacher_settings = teacher_settings ,
330332 max_rounds = max_rounds ,
331333 )
332- program2 = program .compile (student , teacher = teacher , trainset = trainset2 )
334+ program2 = program .compile (student , teacher = teacher , trainset = trainset_copy )
333335
334336 else :
335337 # shuffled few-shot
336- random . Random ( seed ). shuffle (trainset2 )
337- size = random . Random ( seed ) .randint (min_num_samples , max_bootstrapped_demos )
338+ rng . shuffle (trainset_copy )
339+ size = rng .randint (min_num_samples , max_bootstrapped_demos )
338340
339341 teleprompter = BootstrapFewShot (
340342 metric = metric ,
@@ -347,7 +349,7 @@ def create_n_fewshot_demo_sets(
347349 )
348350
349351 program2 = teleprompter .compile (
350- student , teacher = teacher , trainset = trainset2 ,
352+ student , teacher = teacher , trainset = trainset_copy ,
351353 )
352354
353355 for i , _ in enumerate (student .predictors ()):
0 commit comments