Skip to content

Commit 015c649

Browse files
Merge pull request #1169 from stanfordnlp/mipro_v2
MIPRO optimizer updates for paper release
2 parents 01c8de0 + 1ee5479 commit 015c649

File tree

14 files changed

+10195
-191
lines changed

14 files changed

+10195
-191
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
/docs/downloads/
99
/docs/experiments/
1010

11+
/examples/qa/hotpot/MIPRO_notebook_cache/
12+
/examples/nli/scone/MIPRO_notebook_cache/
13+
/examples/nli/scone/ScoNe/
14+
/examples/nli/scone/compiled_program.dspy
15+
/examples/qa/hotpot/compiled_program.dspy
16+
/ScoNe/
17+
1118
# Byte-compiled / optimized / DLL files
1219
__pycache__/
1320
*.py[cod]

dsp/primitives/predict.py

Lines changed: 35 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -61,41 +61,6 @@ def _generate(template: Template, **kwargs) -> Callable:
6161

6262
generator = dsp.settings.lm
6363

64-
def extend_generation(completion: Example, field_names: list[str], stage:str, max_depth: int, original_example:Example):
65-
"""If the required fields are not present in the completion, extend the generation."""
66-
assert max_depth > 0, "Max depth exceeded - failed to complete in one pass - increase max_tokens"
67-
# remove content of last field to avoid half-completed content
68-
for field_name in get_all_fields_following_missing_field(completion, field_names):
69-
completion.pop(field_name, None)
70-
71-
# Recurse with greedy decoding and a shorter length.
72-
max_tokens = (kwargs.get("max_tokens") or
73-
kwargs.get("max_output_tokens") or
74-
dsp.settings.lm.kwargs.get("max_tokens") or
75-
dsp.settings.lm.kwargs.get('max_output_tokens'))
76-
77-
78-
if max_tokens is None:
79-
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
80-
max_tokens = min(max(75, max_tokens // 2), max_tokens)
81-
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
82-
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
83-
new_kwargs = {
84-
**kwargs,
85-
max_tokens_key: max_tokens,
86-
"n": 1,
87-
"temperature": 0.0,
88-
}
89-
90-
_, finished_completion = generate(template, **new_kwargs)(
91-
completion,
92-
stage=stage,
93-
max_depth=max_depth - 1,
94-
original_example=original_example,
95-
)
96-
return finished_completion.data[0]
97-
98-
9964
def do_generate(
10065
example: Example, stage: str, max_depth: int = 2, original_example=None,
10166
):
@@ -112,19 +77,45 @@ def do_generate(
11277
completions: list[dict[str, Any]] = generator(prompt, **kwargs)
11378
completions: list[Example] = [template.extract(example, p) for p in completions]
11479

115-
# Find the completions that are unfinished.
80+
# Find the completions that are most complete.
11681
field_names: list[str] = [field.input_variable for field in template.fields]
11782

118-
finished_completions = []
119-
for completion in completions:
120-
if all((completion.get(key, "") != "") for key in field_names):
121-
finished_completions.append(completion)
122-
continue
123-
finished_completions.append(
124-
extend_generation(completion, field_names, stage, max_depth, original_example),
83+
last_field_idx = 0
84+
for field_idx, key in enumerate(field_names):
85+
completions_ = [
86+
c for c in completions if key in c.keys() and c[key] is not None
87+
]
88+
89+
# Filter out completions that are missing fields that are present in at least one completion.
90+
if len(completions_):
91+
completions = completions_
92+
last_field_idx = field_idx + 1
93+
94+
# If none of the completions is completed (i.e., none has the final field set).
95+
if last_field_idx < len(field_names):
96+
# Pick the first completion that has gone farthest.
97+
completion = completions[0]
98+
completion[field_names[last_field_idx]] = ""
99+
100+
# Recurse with greedy decoding and a shorter length.
101+
max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"])
102+
max_tokens = min(max(75, max_tokens // 2), max_tokens)
103+
new_kwargs = {
104+
**kwargs,
105+
"max_tokens": max_tokens,
106+
"n": 1,
107+
"temperature": 0.0,
108+
}
109+
110+
assert max_depth > 0
111+
return generate(template, **new_kwargs)(
112+
completion,
113+
stage=stage,
114+
max_depth=max_depth - 1,
115+
original_example=original_example,
125116
)
126117

127-
completions = Completions(finished_completions, template=template)
118+
completions = Completions(completions, template=template)
128119
example = example.copy(completions=completions)
129120

130121
if len(completions) == 1:
@@ -161,15 +152,6 @@ def do_generate(
161152

162153
return do_generate
163154

164-
def get_all_fields_following_missing_field(completion: Example, field_names: list[str]) -> list[str]:
165-
"""Returns every field following the first missing field"""
166-
for i, field_name in enumerate(field_names):
167-
if field_name not in completion:
168-
return field_names[i:]
169-
if completion[field_name] == "":
170-
return field_names[i:]
171-
return []
172-
173155

174156
def generate_sc(
175157
example, prompt, normalize=True, extract=None, prediction_field=None, **kwargs,

dspy/propose/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .grounded_proposer import GroundedProposer
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import re
2+
3+
import dspy
4+
from dspy.propose.utils import strip_prefix
5+
6+
7+
class ObservationSummarizer(dspy.Signature):
8+
("""Given a series of observations I have made about my dataset, please summarize them into a brief 2-3 sentence summary which highlights only the most important details.""")
9+
observations = dspy.InputField(desc="Observations I have made about my dataset")
10+
summary = dspy.OutputField(desc="Two to Three sentence summary of only the most significant highlights of my observations")
11+
12+
class DatasetDescriptor(dspy.Signature):
13+
("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """
14+
"""Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """
15+
"""It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""")
16+
17+
examples = dspy.InputField(desc="Sample data points from the dataset")
18+
observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed")
19+
20+
class DatasetDescriptorWithPriorObservations(dspy.Signature):
21+
("""Given several examples from a dataset please write observations about trends that hold for most or all of the samples. """
22+
"""I will also provide you with a few observations I have already made. Please add your own observations or if you feel the observations are comprehensive say 'COMPLETE' """
23+
"""Some areas you may consider in your observations: topics, content, syntax, conciceness, etc. """
24+
"""It will be useful to make an educated guess as to the nature of the task this dataset will enable. Don't be afraid to be creative""")
25+
26+
examples = dspy.InputField(desc="Sample data points from the dataset")
27+
prior_observations = dspy.InputField(desc="Some prior observations I made about the data")
28+
observations = dspy.OutputField(desc="Somethings that holds true for most or all of the data you observed or COMPLETE if you have nothing to add")
29+
30+
def order_input_keys_in_string(unordered_repr):
31+
# Regex pattern to match the input keys structure
32+
pattern = r"input_keys=\{([^\}]+)\}"
33+
34+
# Function to reorder keys
35+
def reorder_keys(match):
36+
# Extracting the keys from the match
37+
keys_str = match.group(1)
38+
# Splitting the keys, stripping extra spaces, and sorting them
39+
keys = sorted(key.strip() for key in keys_str.split(','))
40+
# Formatting the sorted keys back into the expected structure
41+
return f"input_keys={{{', '.join(keys)}}}"
42+
43+
# Using re.sub to find all matches of the pattern and replace them using the reorder_keys function
44+
ordered_repr = re.sub(pattern, reorder_keys, unordered_repr)
45+
46+
return ordered_repr
47+
48+
def create_dataset_summary(trainset, view_data_batch_size, prompt_model, log_file=None):
49+
upper_lim = min(len(trainset), view_data_batch_size)
50+
with dspy.settings.context(lm=prompt_model):
51+
observation = dspy.Predict(DatasetDescriptor, n=1, temperature=1.0)(examples=order_input_keys_in_string(trainset[0:upper_lim].__repr__()))
52+
observations = observation["observations"]
53+
54+
if log_file:
55+
log_file.write("PRODUCING DATASET SUMMARY\n")
56+
57+
skips = 0
58+
try:
59+
max_calls = 10
60+
calls = 0
61+
for b in range(view_data_batch_size, len(trainset), view_data_batch_size):
62+
calls+=1
63+
if calls >= max_calls:
64+
break
65+
print(f"b: {b}")
66+
upper_lim = min(len(trainset), b+view_data_batch_size)
67+
with dspy.settings.context(lm=prompt_model):
68+
output = dspy.Predict(DatasetDescriptorWithPriorObservations, n=1, temperature=1.0)(prior_observations=observations, examples=order_input_keys_in_string(trainset[b:upper_lim].__repr__()))
69+
if len(output["observations"]) >= 8 and output["observations"][:8].upper() == "COMPLETE":
70+
skips += 1
71+
if skips >= 5:
72+
break
73+
continue
74+
observations += output["observations"]
75+
76+
log_file.write(f"observations {observations}\n")
77+
except Exception as e:
78+
print(f"e {e}. using observations from past round for a summary.")
79+
80+
with dspy.settings.context(lm=prompt_model):
81+
summary = dspy.Predict(ObservationSummarizer, n=1, temperature=1.0)(observations=observations)
82+
print(f"summary: {summary}")
83+
if log_file:
84+
log_file.write(f"summary: {summary}\n")
85+
86+
return strip_prefix(summary.summary)

0 commit comments

Comments
 (0)