Skip to content

Commit f3d7f26

Browse files
style fix for optimizers (#7839)
1 parent 2f3834a commit f3d7f26

File tree

3 files changed

+167
-153
lines changed

3 files changed

+167
-153
lines changed

dspy/teleprompt/bootstrap.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,37 +33,36 @@
3333

3434
logger = logging.getLogger(__name__)
3535

36+
3637
class BootstrapFewShot(Teleprompter):
3738
def __init__(
3839
self,
3940
metric=None,
4041
metric_threshold=None,
41-
teacher_settings: Optional[Dict]=None,
42+
teacher_settings: Optional[Dict] = None,
4243
max_bootstrapped_demos=4,
4344
max_labeled_demos=16,
4445
max_rounds=1,
4546
max_errors=5,
4647
):
47-
"""
48-
A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
48+
"""A Teleprompter class that composes a set of demos/examples to go into a predictor's prompt.
4949
These demos come from a combination of labeled examples in the training set, and bootstrapped demos.
5050
5151
Args:
52-
metric: Callable
53-
A function that compares an expected value and predicted value, outputting the result of that comparison.
54-
metric_threshold: optional float, default `None`
55-
If the metric yields a numerical value, then check it against this threshold when
56-
deciding whether or not to accept a bootstrap example.
57-
teacher_settings: dict, optional
58-
Settings for the `teacher` model.
59-
max_bootstrapped_demos: int, default 4
60-
Maximum number of bootstrapped demonstrations to include
61-
max_labeled_demos: int, default 16
62-
Maximum number of labeled demonstrations to include.
63-
max_rounds: int, default 1
64-
Number of iterations to attempt generating the required bootstrap examples. If unsuccessful after `max_rounds`, the program ends.
65-
max_errors: int, default 5
66-
Maximum number of errors until program ends.
52+
metric (Callable): A function that compares an expected value and predicted value,
53+
outputting the result of that comparison.
54+
metric_threshold (float, optional): If the metric yields a numerical value, then check it
55+
against this threshold when deciding whether or not to accept a bootstrap example.
56+
Defaults to None.
57+
teacher_settings (dict, optional): Settings for the `teacher` model.
58+
Defaults to None.
59+
max_bootstrapped_demos (int): Maximum number of bootstrapped demonstrations to include.
60+
Defaults to 4.
61+
max_labeled_demos (int): Maximum number of labeled demonstrations to include.
62+
Defaults to 16.
63+
max_rounds (int): Number of iterations to attempt generating the required bootstrap
64+
examples. If unsuccessful after `max_rounds`, the program ends. Defaults to 1.
65+
max_errors (int): Maximum number of errors until program ends. Defaults to 5.
6766
"""
6867
self.metric = metric
6968
self.metric_threshold = metric_threshold
@@ -117,9 +116,10 @@ def _prepare_predictor_mappings(self):
117116
if hasattr(predictor1.signature, "equals"):
118117
assert predictor1.signature.equals(
119118
predictor2.signature,
120-
), (f"Student and teacher must have the same signatures. "
119+
), (
120+
f"Student and teacher must have the same signatures. "
121121
f"{type(predictor1.signature)} != {type(predictor2.signature)}"
122-
)
122+
)
123123
else:
124124
# fallback in case if .equals is not implemented (e.g. dsp.Prompt)
125125
assert predictor1.signature == predictor2.signature, (
@@ -149,7 +149,8 @@ def _bootstrap(self, *, max_bootstraps=None):
149149
self.name2traces = {name: [] for name in self.name2predictor}
150150

151151
for example_idx, example in enumerate(tqdm.tqdm(self.trainset)):
152-
if len(bootstrapped) >= max_bootstraps: break
152+
if len(bootstrapped) >= max_bootstraps:
153+
break
153154

154155
for round_idx in range(self.max_rounds):
155156
bootstrap_attempts += 1
@@ -175,8 +176,8 @@ def _bootstrap(self, *, max_bootstraps=None):
175176
# score = evaluate(self.metric, display_table=False, display_progress=True)
176177

177178
def _bootstrap_one_example(self, example, round_idx=0):
178-
name2traces = {} #self.name2traces
179-
teacher = self.teacher # .deepcopy()
179+
name2traces = {}
180+
teacher = self.teacher
180181
predictor_cache = {}
181182

182183
try:
@@ -235,10 +236,11 @@ def _bootstrap_one_example(self, example, round_idx=0):
235236

236237
name2traces[predictor_name] = name2traces.get(predictor_name, [])
237238
name2traces[predictor_name].append(demo)
238-
239+
239240
# Update the traces
240241
for name, demos in name2traces.items():
241242
from datasets.fingerprint import Hasher
243+
242244
# If there are multiple traces for the same predictor in the sample example,
243245
# sample 50/50 from the first N-1 traces or the last trace.
244246
if len(demos) > 1:

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from collections import defaultdict
21
import logging
2+
from collections import defaultdict
33
from typing import Any, Callable, Dict, List, Optional, Union
44

55
import dspy
@@ -12,12 +12,10 @@
1212
from dspy.primitives.program import Program
1313
from dspy.teleprompt.teleprompt import Teleprompter
1414

15-
1615
logger = logging.getLogger(__name__)
1716

1817

1918
class FinetuneTeleprompter(Teleprompter):
20-
2119
def __init__(
2220
self,
2321
train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None,
@@ -41,23 +39,25 @@ def __init__(
4139
train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None,
4240
adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None,
4341
exclude_demos: bool = False,
44-
num_threads: int = 6
42+
num_threads: int = 6,
4543
):
4644
# TODO(feature): Inputs train_kwargs (a dict with string keys) and
4745
# adapter (Adapter) can depend on the LM they are used with. We are
48-
# takingthese as parameters for the time being. However, they can be
46+
# takingthese as parameters for the time being. However, they can be
4947
# attached to LMs themselves -- an LM could know which adapter it should
5048
# be used with along with the train_kwargs. This will lead the only
5149
# required argument for LM.finetune() to be the train dataset.
52-
50+
5351
super().__init__(train_kwargs=train_kwargs)
5452
self.metric = metric
5553
self.multitask = multitask
5654
self.adapter: Dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
5755
self.exclude_demos = exclude_demos
5856
self.num_threads = num_threads
59-
60-
def compile(self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None) -> Program:
57+
58+
def compile(
59+
self, student: Program, trainset: List[Example], teacher: Optional[Union[Program, List[Program]]] = None
60+
) -> Program:
6161
# TODO: Print statements can be converted to logger.info if we ensure
6262
# that the default DSPy logger logs info level messages in notebook
6363
# environments.
@@ -71,24 +71,41 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
7171
teachers = [prepare_teacher(student, t) for t in teachers]
7272
for t in teachers:
7373
set_missing_predictor_lms(t)
74-
trace_data += bootstrap_trace_data(program=t, dataset=trainset, metric=self.metric, num_threads=self.num_threads)
74+
trace_data += bootstrap_trace_data(
75+
program=t, dataset=trainset, metric=self.metric, num_threads=self.num_threads
76+
)
7577

7678
logger.info("Preparing the train data...")
7779
key_to_data = {}
7880
for pred_ind, pred in enumerate(student.predictors()):
7981
data_pred_ind = None if self.multitask else pred_ind
8082
training_key = (pred.lm, data_pred_ind)
8183
if training_key not in key_to_data:
82-
train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
84+
train_data, data_format = self._prepare_finetune_data(
85+
trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind
86+
)
8387
logger.info(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
84-
finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_data_format=data_format, train_kwargs=self.train_kwargs[pred.lm])
88+
finetune_kwargs = dict(
89+
lm=pred.lm,
90+
train_data=train_data,
91+
train_data_format=data_format,
92+
train_kwargs=self.train_kwargs[pred.lm],
93+
)
8594
key_to_data[training_key] = finetune_kwargs
86-
95+
8796
logger.info("Starting LM fine-tuning...")
8897
# TODO(feature): We could run batches of fine-tuning jobs in sequence
8998
# to avoid exceeding the number of threads.
90-
err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: {self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors."
91-
assert len(key_to_data) <= self.num_threads, err
99+
if len(key_to_data) > self.num_threads:
100+
raise ValueError(
101+
"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning "
102+
f"jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: "
103+
f"{self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will "
104+
"be equal to the number of predictors in the student program. If the `multitask` flag is set to True, "
105+
"the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of "
106+
"unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning "
107+
"jobs will be less than or equal to the number of predictors."
108+
)
92109
logger.info(f"{len(key_to_data)} fine-tuning job(s) to start")
93110
key_to_lm = self.finetune_lms(key_to_data)
94111

@@ -98,10 +115,10 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[U
98115
training_key = (pred.lm, data_pred_ind)
99116
pred.lm = key_to_lm[training_key]
100117
# TODO: What should the correct behavior be here? Should
101-
# BootstrapFinetune modify the prompt demos according to the
118+
# BootstrapFinetune modify the prompt demos according to the
102119
# train data?
103120
pred.demos = [] if self.exclude_demos else pred.demos
104-
121+
105122
logger.info("BootstrapFinetune has finished compiling the student program")
106123
student._compiled = True
107124
return student
@@ -120,10 +137,13 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:
120137
# up resources for fine-tuning. This might mean introducing a new
121138
# provider method (e.g. prepare_for_finetune) that can be called
122139
# before fine-tuning is started.
123-
logger.info("Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the LM is not running.")
140+
logger.info(
141+
"Calling lm.kill() on the LM to be fine-tuned to free up resources. This won't have any effect if the "
142+
"LM is not running."
143+
)
124144
lm.kill()
125145
key_to_job[key] = lm.finetune(**finetune_kwargs)
126-
146+
127147
key_to_lm = {}
128148
for ind, (key, job) in enumerate(key_to_job.items()):
129149
key_to_lm[key] = job.result()
@@ -143,13 +163,16 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
143163
adapter = self.adapter[lm] or lm.infer_adapter()
144164
data_format = infer_data_format(adapter)
145165
for item in trace_data:
146-
for pred_ind, _ in enumerate(item['trace']):
166+
for pred_ind, _ in enumerate(item["trace"]):
147167
include_data = pred_ind is None or pred_ind == pred_ind
148168
if include_data:
149-
call_data = build_call_data_from_trace(trace=item['trace'], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos)
169+
call_data = build_call_data_from_trace(
170+
trace=item["trace"], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos
171+
)
150172
data.append(call_data)
151173

152174
import random
175+
153176
random.Random(0).shuffle(data)
154177

155178
return data, data_format
@@ -189,8 +212,11 @@ def bootstrap_trace_data(
189212
# Return a list of dicts with the following keys:
190213
# example_ind, example, prediction, trace, and score (if metric != None)
191214
evaluator = Evaluate(
192-
devset=dataset, num_threads=num_threads, display_progress=True, return_outputs=True,
193-
provide_traceback=True # TODO(check with team)
215+
devset=dataset,
216+
num_threads=num_threads,
217+
display_progress=True,
218+
return_outputs=True,
219+
provide_traceback=True, # TODO(check with team)
194220
)
195221

196222
def wrapped_metric(example, prediction, trace=None):
@@ -286,11 +312,10 @@ def assert_structural_equivalency(program1: object, program2: object):
286312

287313
pzip = zip(program1.named_predictors(), program2.named_predictors())
288314
for ind, ((name1, pred1), (name2, pred2)) in enumerate(pzip):
289-
err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'"
315+
err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'"
290316
assert name1 == name2, err
291317
assert isinstance(pred1, Predict)
292318
assert isinstance(pred2, Predict)
293-
# assert pred1.signature.equals(pred2.signature)
294319

295320

296321
def assert_no_shared_predictor(program1: Program, program2: Program):
@@ -303,17 +328,18 @@ def assert_no_shared_predictor(program1: Program, program2: Program):
303328
assert not shared_ids, err
304329

305330

306-
def get_unique_lms(program: Program) -> List[LM]:
307-
lms = [pred.lm for pred in program.predictors()]
308-
lms = list(set(lms))
309-
return lms
331+
def get_unique_lms(program: Program) -> List[LM]:
332+
lms = [pred.lm for pred in program.predictors()]
333+
return list(set(lms))
334+
310335

311336
def launch_lms(program: Program):
312337
lms = get_unique_lms(program)
313338
for lm in lms:
314339
lm.launch()
315340

316-
def kill_lms(program: Program):
317-
lms = get_unique_lms(program)
318-
for lm in lms:
341+
342+
def kill_lms(program: Program):
343+
lms = get_unique_lms(program)
344+
for lm in lms:
319345
lm.kill()

0 commit comments

Comments
 (0)