Skip to content

Commit 9170658

Browse files
isaacbmillerokhat
andauthored
Fix broken inspect_history and broken prompt cache (#1744)
* Fix broken inspect_history and broken prompt cache * Remove errant print statement * Move global inspect history into base_lm * Remove skip parameter * Delete examples/temp.py * Minor adjustment to make adapters go back to original behavior --------- Co-authored-by: Omar Khattab <[email protected]>
1 parent 2d3ed8d commit 9170658

File tree

7 files changed

+94
-19
lines changed

7 files changed

+94
-19
lines changed

dspy/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dspy.clients import * # isort: skip
1313
from dspy.adapters import * # isort: skip
1414
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
15-
1615
settings = dsp.settings
1716

1817
configure_dspy_loggers(__name__)
@@ -70,10 +69,4 @@
7069
BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch
7170
COPRO = dspy.teleprompt.COPRO
7271
MIPROv2 = dspy.teleprompt.MIPROv2
73-
Ensemble = dspy.teleprompt.Ensemble
74-
75-
76-
# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program.
77-
def inspect_history(*args, **kwargs):
78-
from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history
79-
return _inspect_history(GLOBAL_HISTORY, *args, **kwargs)
72+
Ensemble = dspy.teleprompt.Ensemble

dspy/adapters/chat_adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
209209
else:
210210
output[-1]["text"] += formatted_field_value["text"]
211211
if assume_text:
212-
return "\n\n".join(output)
212+
return "\n\n".join(output).strip()
213213
else:
214214
return output
215215

@@ -396,7 +396,6 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
396396
parts.append(format_signature_fields_for_instructions(signature.input_fields))
397397
parts.append(format_signature_fields_for_instructions(signature.output_fields))
398398
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True))
399-
400399
instructions = textwrap.dedent(signature.instructions)
401400
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
402401
parts.append(f"In adhering to this structure, your objective is: {objective}")

dspy/clients/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .lm import LM
2-
from .base_lm import BaseLM
2+
from .base_lm import BaseLM, inspect_history

dspy/clients/base_lm.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22

3+
GLOBAL_HISTORY = []
34

45
class BaseLM(ABC):
56
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
@@ -14,7 +15,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
1415
pass
1516

1617
def inspect_history(self, n: int = 1):
17-
_inspect_history(self, n)
18+
_inspect_history(self.history, n)
19+
20+
def update_global_history(self, entry):
21+
GLOBAL_HISTORY.append(entry)
1822

1923

2024
def _green(text: str, end: str = "\n"):
@@ -24,15 +28,21 @@ def _green(text: str, end: str = "\n"):
2428
def _red(text: str, end: str = "\n"):
2529
return "\x1b[31m" + str(text) + "\x1b[0m" + end
2630

31+
def _blue(text: str, end: str = "\n"):
32+
return "\x1b[34m" + str(text) + "\x1b[0m" + end
33+
2734

28-
def _inspect_history(lm, n: int = 1):
35+
def _inspect_history(history, n: int = 1):
2936
"""Prints the last n prompts and their completions."""
3037

31-
for item in reversed(lm.history[-n:]):
38+
for item in history[-n:]:
3239
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
3340
outputs = item["outputs"]
41+
timestamp = item.get("timestamp", "Unknown time")
3442

3543
print("\n\n\n")
44+
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")
45+
3646
for msg in messages:
3747
print(_red(f"{msg['role'].capitalize()} message:"))
3848
if isinstance(msg["content"], str):
@@ -43,11 +53,13 @@ def _inspect_history(lm, n: int = 1):
4353
if c["type"] == "text":
4454
print(c["text"].strip())
4555
elif c["type"] == "image_url":
56+
image_str = ""
4657
if "base64" in c["image_url"].get("url", ""):
4758
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
48-
print(f"<{c['image_url']['url'].split('base64,')[0]}base64,<IMAGE BASE 64 ENCODED({str(len_base64)})>")
59+
image_str = f"<{c['image_url']['url'].split('base64,')[0]}base64,<IMAGE BASE 64 ENCODED({str(len_base64)})>"
4960
else:
50-
print(f"<image_url: {c['image_url']['url']}>")
61+
image_str = f"<image_url: {c['image_url']['url']}>"
62+
print(_blue(image_str.strip()))
5163
print("\n")
5264

5365
print(_red("Response:"))
@@ -58,3 +70,7 @@ def _inspect_history(lm, n: int = 1):
5870
print(_red(choices_text, end=""))
5971

6072
print("\n\n\n")
73+
74+
def inspect_history(n: int = 1):
75+
"""The global history shared across all LMs."""
76+
return _inspect_history(GLOBAL_HISTORY, n)

dspy/clients/lm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
2424
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
2525

26-
GLOBAL_HISTORY = []
27-
2826
logger = logging.getLogger(__name__)
2927

3028
class LM(BaseLM):
@@ -109,7 +107,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
109107
model_type=self.model_type,
110108
)
111109
self.history.append(entry)
112-
GLOBAL_HISTORY.append(entry)
110+
self.update_global_history(entry)
113111

114112
return outputs
115113

dspy/utils/dummies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]):
205205
entry = dict(**entry, outputs=outputs, usage=0)
206206
entry = dict(**entry, cost=0)
207207
self.history.append(entry)
208+
self.update_global_history(entry)
208209

209210
return outputs
210211

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
from dspy.utils.dummies import DummyLM
3+
from dspy.clients.base_lm import GLOBAL_HISTORY
4+
import dspy
5+
6+
@pytest.fixture(autouse=True)
7+
def clear_history():
8+
GLOBAL_HISTORY.clear()
9+
yield
10+
11+
def test_inspect_history_basic(capsys):
12+
# Configure a DummyLM with some predefined responses
13+
lm = DummyLM([{"response": "Hello"}, {"response": "How are you?"}])
14+
dspy.settings.configure(lm=lm)
15+
16+
# Make some calls to generate history
17+
predictor = dspy.Predict("query: str -> response: str")
18+
predictor(query="Hi")
19+
predictor(query="What's up?")
20+
21+
# Test inspecting all history
22+
history = GLOBAL_HISTORY
23+
print(capsys)
24+
assert len(history) > 0
25+
assert isinstance(history, list)
26+
assert all(isinstance(entry, dict) for entry in history)
27+
assert all("messages" in entry for entry in history)
28+
29+
def test_inspect_history_with_n(capsys):
30+
lm = DummyLM([{"response": "One"}, {"response": "Two"}, {"response": "Three"}])
31+
dspy.settings.configure(lm=lm)
32+
33+
# Generate some history
34+
predictor = dspy.Predict("query: str -> response: str")
35+
predictor(query="First")
36+
predictor(query="Second")
37+
predictor(query="Third")
38+
39+
dspy.inspect_history(n=2)
40+
# Test getting last 2 entries
41+
out, err = capsys.readouterr()
42+
assert not "First" in out
43+
assert "Second" in out
44+
assert "Third" in out
45+
46+
def test_inspect_empty_history(capsys):
47+
# Configure fresh DummyLM
48+
lm = DummyLM([])
49+
dspy.settings.configure(lm=lm)
50+
51+
# Test inspecting empty history
52+
dspy.inspect_history()
53+
history = GLOBAL_HISTORY
54+
assert len(history) == 0
55+
assert isinstance(history, list)
56+
57+
def test_inspect_history_n_larger_than_history(capsys):
58+
lm = DummyLM([{"response": "First"}, {"response": "Second"}])
59+
dspy.settings.configure(lm=lm)
60+
61+
predictor = dspy.Predict("query: str -> response: str")
62+
predictor(query="Query 1")
63+
predictor(query="Query 2")
64+
65+
# Request more entries than exist
66+
dspy.inspect_history(n=5)
67+
history = GLOBAL_HISTORY
68+
assert len(history) == 2 # Should return all available entries

0 commit comments

Comments
 (0)