Skip to content

Commit 9e5967d

Browse files
authored
Treat None rollout_id as absent (#8745)
1 parent 61619c1 commit 9e5967d

File tree

14 files changed

+126
-40
lines changed

14 files changed

+126
-40
lines changed

docs/docs/cheatsheet.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ dspy.configure_cache(
463463
464464
### BestofN
465465

466-
Runs a module up to `N` times with different temperatures and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`.
466+
Runs a module up to `N` times with different rollout IDs (bypassing cache) and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`.
467467

468468
```python
469469
import dspy
@@ -478,7 +478,7 @@ best_of_3(question="What is the capital of Belgium?").answer
478478

479479
### Refine
480480

481-
Refines a module by running it up to `N` times with different temperatures and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`. After each attempt (except the final one), `Refine` automatically generates detailed feedback about the module's performance and uses this feedback as hints for subsequent runs, creating an iterative refinement process.
481+
Refines a module by running it up to `N` times with different rollout IDs (bypassing cache) and returns the best prediction, as defined by the `reward_fn`, or the first prediction that passes the `threshold`. After each attempt (except the final one), `Refine` automatically generates detailed feedback about the module's performance and uses this feedback as hints for subsequent runs, creating an iterative refinement process.
482482

483483
```python
484484
import dspy

docs/docs/learn/programming/language_models.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', temperature=0.9, max_tokens=3000, st
166166

167167
By default LMs in DSPy are cached. If you repeat the same call, you will get the same outputs. But you can turn off caching by setting `cache=False`.
168168

169+
If you want to keep caching enabled but force a new request (for example, to obtain diverse outputs),
170+
pass a unique `rollout_id` in your call. Different values ensure a different cache entry while
171+
still caching future calls with the same inputs and `rollout_id`.
172+
173+
```python linenums="1"
174+
lm("Say this is a test!", rollout_id=1)
175+
```
176+
169177

170178
## Inspecting output and usage metadata.
171179

docs/docs/tutorials/output_refinement/best-of-n-and-refine.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Output Refinement: BestOfN and Refine
22

3-
Both `BestOfN` and `Refine` are DSPy modules designed to improve the reliability and quality of predictions by making multiple `LM` calls with different parameter settings. Both modules stop when they have reached `N` attempts or when the `reward_fn` returns an award above the `threshold`.
3+
Both `BestOfN` and `Refine` are DSPy modules designed to improve the reliability and quality of predictions by making multiple `LM` calls with different rollout IDs to bypass caching. Both modules stop when they have reached `N` attempts or when the `reward_fn` returns an award above the `threshold`.
44

55
## BestOfN
66

7-
`BestOfN` is a module that runs the provided module multiple times (up to `N`) with different temperature settings. It returns either the first prediction that passes a specified threshold or the one with the highest reward if none meets the threshold.
7+
`BestOfN` is a module that runs the provided module multiple times (up to `N`) with different rollout IDs. It returns either the first prediction that passes a specified threshold or the one with the highest reward if none meets the threshold.
88

99
### Basic Usage
1010

11-
Lets say we wanted to have the best chance of getting a one word answer from the model. We could use `BestOfN` to try multiple temperature settings and return the best result.
11+
Lets say we wanted to have the best chance of getting a one word answer from the model. We could use `BestOfN` to try multiple rollout IDs and return the best result.
1212

1313
```python
1414
import dspy
@@ -86,7 +86,7 @@ refine = dspy.Refine(
8686

8787
Both modules serve similar purposes but differ in their approach:
8888

89-
- `BestOfN` simply tries different temperature settings and selects the best resulting prediction as defined by the `reward_fn`.
89+
- `BestOfN` simply tries different rollout IDs and selects the best resulting prediction as defined by the `reward_fn`.
9090
- `Refine` adds an feedback loop, using the lm to generate a detailed feedback about the module's own performance using the previous prediction and the code in the `reward_fn`. This feedback is then used as hints for subsequent runs.
9191

9292
## Practical Examples

dspy/clients/base_lm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
110110
raise NotImplementedError("Subclasses must implement this method.")
111111

112112
def copy(self, **kwargs):
113-
"""Returns a copy of the language model with possibly updated parameters."""
113+
"""Returns a copy of the language model with possibly updated parameters.
114+
115+
Any provided keyword arguments update the corresponding attributes or LM kwargs of
116+
the copy. For example, ``lm.copy(rollout_id=1)`` returns an LM whose requests use a
117+
different rollout ID to bypass cache collisions.
118+
"""
114119

115120
import copy
116121

@@ -121,7 +126,10 @@ def copy(self, **kwargs):
121126
if hasattr(self, key):
122127
setattr(new_instance, key, value)
123128
if (key in self.kwargs) or (not hasattr(self, key)):
124-
new_instance.kwargs[key] = value
129+
if value is None:
130+
new_instance.kwargs.pop(key, None)
131+
else:
132+
new_instance.kwargs[key] = value
125133

126134
return new_instance
127135

dspy/clients/lm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(
5959
provider: The provider to use. If not specified, the provider will be inferred from the model.
6060
finetuning_model: The model to finetune. In some providers, the models available for finetuning is different
6161
from the models available for inference.
62+
rollout_id: Optional integer used to differentiate cache entries for otherwise
63+
identical requests. Different values bypass DSPy's caches while still caching
64+
future calls with the same inputs and rollout ID. This argument is stripped
65+
before sending requests to the provider.
6266
"""
6367
# Remember to update LM.copy() if you modify the constructor!
6468
self.model = model
@@ -85,8 +89,12 @@ def __init__(
8589
"`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=20000)"
8690
)
8791
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
92+
if self.kwargs.get("rollout_id") is None:
93+
self.kwargs.pop("rollout_id", None)
8894
else:
8995
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
96+
if self.kwargs.get("rollout_id") is None:
97+
self.kwargs.pop("rollout_id", None)
9098

9199
def _get_cached_completion_fn(self, completion_fn, cache):
92100
ignored_args_for_cache_key = ["api_key", "api_base", "base_url"]
@@ -102,10 +110,13 @@ def _get_cached_completion_fn(self, completion_fn, cache):
102110

103111
def forward(self, prompt=None, messages=None, **kwargs):
104112
# Build the request.
113+
kwargs = dict(kwargs)
105114
cache = kwargs.pop("cache", self.cache)
106115

107116
messages = messages or [{"role": "user", "content": prompt}]
108117
kwargs = {**self.kwargs, **kwargs}
118+
if kwargs.get("rollout_id") is None:
119+
kwargs.pop("rollout_id", None)
109120

110121
if self.model_type == "chat":
111122
completion = litellm_completion
@@ -129,10 +140,13 @@ def forward(self, prompt=None, messages=None, **kwargs):
129140

130141
async def aforward(self, prompt=None, messages=None, **kwargs):
131142
# Build the request.
143+
kwargs = dict(kwargs)
132144
cache = kwargs.pop("cache", self.cache)
133145

134146
messages = messages or [{"role": "user", "content": prompt}]
135147
kwargs = {**self.kwargs, **kwargs}
148+
if kwargs.get("rollout_id") is None:
149+
kwargs.pop("rollout_id", None)
136150

137151
if self.model_type == "chat":
138152
completion = alitellm_completion
@@ -296,6 +310,8 @@ async def async_stream_completion():
296310

297311
def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
298312
cache = cache or {"no-cache": True, "no-store": True}
313+
request = dict(request)
314+
request.pop("rollout_id", None)
299315
stream_completion = _get_stream_completion_fn(request, cache, sync=True)
300316
if stream_completion is None:
301317
return litellm.completion(
@@ -310,6 +326,8 @@ def litellm_completion(request: dict[str, Any], num_retries: int, cache: dict[st
310326

311327
def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
312328
cache = cache or {"no-cache": True, "no-store": True}
329+
request = dict(request)
330+
request.pop("rollout_id", None)
313331
# Extract the provider and model from the model string.
314332
# TODO: Not all the models are in the format of "provider/model"
315333
model = request.pop("model").split("/", 1)
@@ -336,6 +354,8 @@ def litellm_text_completion(request: dict[str, Any], num_retries: int, cache: di
336354

337355
async def alitellm_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
338356
cache = cache or {"no-cache": True, "no-store": True}
357+
request = dict(request)
358+
request.pop("rollout_id", None)
339359
stream_completion = _get_stream_completion_fn(request, cache, sync=False)
340360
if stream_completion is None:
341361
return await litellm.acompletion(
@@ -350,6 +370,8 @@ async def alitellm_completion(request: dict[str, Any], num_retries: int, cache:
350370

351371
async def alitellm_text_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
352372
cache = cache or {"no-cache": True, "no-store": True}
373+
request = dict(request)
374+
request.pop("rollout_id", None)
353375
model = request.pop("model").split("/", 1)
354376
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
355377

@@ -373,6 +395,8 @@ async def alitellm_text_completion(request: dict[str, Any], num_retries: int, ca
373395

374396
def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
375397
cache = cache or {"no-cache": True, "no-store": True}
398+
request = dict(request)
399+
request.pop("rollout_id", None)
376400
request = _convert_chat_request_to_responses_request(request)
377401

378402
return litellm.responses(
@@ -385,6 +409,8 @@ def litellm_responses_completion(request: dict[str, Any], num_retries: int, cach
385409

386410
async def alitellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
387411
cache = cache or {"no-cache": True, "no-store": True}
412+
request = dict(request)
413+
request.pop("rollout_id", None)
388414
request = _convert_chat_request_to_responses_request(request)
389415

390416
return await litellm.aresponses(
@@ -395,6 +421,7 @@ async def alitellm_responses_completion(request: dict[str, Any], num_retries: in
395421
)
396422

397423
def _convert_chat_request_to_responses_request(request: dict[str, Any]):
424+
request = dict(request)
398425
if "messages" in request:
399426
content_blocks = []
400427
for msg in request.pop("messages"):

dspy/predict/best_of_n.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
fail_count: int | None = None,
1515
):
1616
"""
17-
Runs a module up to `N` times with different temperatures and returns the best prediction
17+
Runs a module up to `N` times with different rollout IDs and returns the best prediction
1818
out of `N` attempts or the first prediction that passes the `threshold`.
1919
2020
Args:
@@ -53,12 +53,14 @@ def one_word_answer(args, pred):
5353

5454
def forward(self, **kwargs):
5555
lm = self.module.get_lm() or dspy.settings.lm
56-
temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
57-
temps = list(dict.fromkeys(temps))[: self.N]
56+
base_rollout = lm.kwargs.get("rollout_id")
57+
start = 0 if base_rollout is None else base_rollout
58+
rollout_ids = [start + i for i in range(self.N)]
59+
rollout_ids = list(dict.fromkeys(rollout_ids))[: self.N]
5860
best_pred, best_trace, best_reward = None, None, -float("inf")
5961

60-
for idx, t in enumerate(temps):
61-
lm_ = lm.copy(temperature=t)
62+
for idx, rid in enumerate(rollout_ids):
63+
lm_ = lm.copy(rollout_id=rid)
6264
mod = self.module.deepcopy()
6365
mod.set_lm(lm_)
6466

@@ -77,7 +79,7 @@ def forward(self, **kwargs):
7779
break
7880

7981
except Exception as e:
80-
print(f"BestOfN: Attempt {idx + 1} failed with temperature {t}: {e}")
82+
print(f"BestOfN: Attempt {idx + 1} failed with rollout id {rid}: {e}")
8183
if idx > self.fail_count:
8284
raise e
8385
self.fail_count -= 1

dspy/predict/refine.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def __init__(
4848
fail_count: int | None = None,
4949
):
5050
"""
51-
Refines a module by running it up to N times with different temperatures and returns the best prediction.
51+
Refines a module by running it up to N times with different rollout IDs and returns the best prediction.
5252
53-
This module runs the provided module multiple times with varying temperature settings and selects
53+
This module runs the provided module multiple times with varying rollout identifiers and selects
5454
either the first prediction that exceeds the specified threshold or the one with the highest reward.
5555
If no prediction meets the threshold, it automatically generates feedback to improve future predictions.
5656
@@ -96,14 +96,16 @@ def one_word_answer(args, pred):
9696

9797
def forward(self, **kwargs):
9898
lm = self.module.get_lm() or dspy.settings.lm
99-
temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / self.N) for i in range(self.N)]
100-
temps = list(dict.fromkeys(temps))[: self.N]
99+
base_rollout = lm.kwargs.get("rollout_id")
100+
start = 0 if base_rollout is None else base_rollout
101+
rollout_ids = [start + i for i in range(self.N)]
102+
rollout_ids = list(dict.fromkeys(rollout_ids))[: self.N]
101103
best_pred, best_trace, best_reward = None, None, -float("inf")
102104
advice = None
103105
adapter = dspy.settings.adapter or dspy.ChatAdapter()
104106

105-
for idx, t in enumerate(temps):
106-
lm_ = lm.copy(temperature=t)
107+
for idx, rid in enumerate(rollout_ids):
108+
lm_ = lm.copy(rollout_id=rid)
107109
mod = self.module.deepcopy()
108110
mod.set_lm(lm_)
109111

@@ -167,7 +169,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
167169
# print(f"Advice for each module: {advice}")
168170

169171
except Exception as e:
170-
print(f"Refine: Attempt failed with temperature {t}: {e}")
172+
print(f"Refine: Attempt failed with rollout id {rid}: {e}")
171173
if idx > self.fail_count:
172174
raise e
173175
self.fail_count -= 1

dspy/propose/grounded_proposer.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,6 @@ def propose_instructions_for_program(
330330
demo_candidates,
331331
trial_logs,
332332
N, # noqa: N803
333-
T, # noqa: N803
334333
) -> list[str]:
335334
"""This method is responsible for returning the full set of new instructions for our program, given the specified criteria."""
336335

@@ -375,7 +374,6 @@ def propose_instructions_for_program(
375374
program=program,
376375
predictor=predictor,
377376
pred_i=pred_i,
378-
T=T,
379377
demo_candidates=demo_candidates,
380378
demo_set_i=demo_set_i,
381379
trial_logs=trial_logs,
@@ -390,7 +388,6 @@ def propose_instruction_for_predictor(
390388
program,
391389
predictor,
392390
pred_i,
393-
T, # noqa: N803
394391
demo_candidates,
395392
demo_set_i,
396393
trial_logs,
@@ -414,14 +411,10 @@ def propose_instruction_for_predictor(
414411
verbose=self.verbose
415412
)
416413

417-
# Generate a new instruction for our predictor, using the temperature specified for this round
418-
original_temp = self.prompt_model.kwargs["temperature"]
414+
# Generate a new instruction for our predictor using a unique rollout id to bypass cache
415+
rollout_lm = self.prompt_model.copy(rollout_id=self.rng.randint(0, 10**9))
419416

420-
epsilon = self.rng.uniform(0.01, 0.05)
421-
modified_temp = T + epsilon
422-
423-
with dspy.settings.context(lm=self.prompt_model):
424-
self.prompt_model.kwargs["temperature"] = modified_temp
417+
with dspy.settings.context(lm=rollout_lm):
425418
proposed_instruction = instruction_generator(
426419
demo_candidates=demo_candidates,
427420
pred_i=pred_i,
@@ -432,7 +425,6 @@ def propose_instruction_for_predictor(
432425
num_demos_in_context = self.num_demos_in_context,
433426
tip=tip,
434427
).proposed_instruction
435-
self.prompt_model.kwargs["temperature"] = original_temp
436428

437429
# Log the trace used to generate the new instruction, along with the new instruction itself
438430
if self.verbose:

dspy/teleprompt/bootstrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _bootstrap_one_example(self, example, round_idx=0):
181181
try:
182182
with dspy.settings.context(trace=[], **self.teacher_settings):
183183
lm = dspy.settings.lm
184-
lm = lm.copy(temperature=0.7 + 0.001 * round_idx) if round_idx > 0 else lm
184+
lm = lm.copy(rollout_id=round_idx) if round_idx > 0 else lm
185185
new_settings = {"lm": lm} if round_idx > 0 else {}
186186

187187
with dspy.settings.context(**new_settings):

dspy/teleprompt/infer_rules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class CustomRulesInduction(dspy.Signature):
143143

144144
def forward(self, examples_text):
145145
with dspy.settings.context(**self.teacher_settings):
146-
lm = dspy.settings.lm.copy(temperature=self.rng.uniform(0.9, 1.0))
146+
lm = dspy.settings.lm.copy(rollout_id=self.rng.randint(0, 10**9))
147147
with dspy.settings.context(lm=lm):
148148
rules = self.rules_induction(examples_text=examples_text).natural_language_rules
149149

0 commit comments

Comments
 (0)