Skip to content

Commit 16ceaba

Browse files
authored
Merge pull request #1556 from chenmoneygithub/fix-saving
Fix signature saving at Predict saving
2 parents 56a29f6 + ff53b20 commit 16ceaba

File tree

5 files changed

+142
-34
lines changed

5 files changed

+142
-34
lines changed

dspy/predict/predict.py

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
import logging
12
import random
3+
from functools import lru_cache
24

35
from pydantic import BaseModel
46

57
import dsp
68
from dspy.predict.parameter import Parameter
7-
from dspy.primitives.program import Module
8-
99
from dspy.primitives.prediction import Prediction
10+
from dspy.primitives.program import Module
1011
from dspy.signatures.signature import ensure_signature, signature_to_template
1112

12-
import logging
13-
from functools import lru_cache
1413

1514
@lru_cache(maxsize=None)
1615
def warn_once(msg: str):
@@ -31,7 +30,13 @@ def reset(self):
3130
self.train = []
3231
self.demos = []
3332

34-
def dump_state(self, save_verbose=False):
33+
def dump_state(self, save_verbose=None):
34+
if save_verbose:
35+
logging.warning(
36+
"`save_verbose` is deprecated and will be removed in DSPy 2.6.0 release. Currently `save_verbose` "
37+
"does nothing."
38+
)
39+
3540
state_keys = ["lm", "traces", "train"]
3641
state = {k: getattr(self, k) for k in state_keys}
3742

@@ -45,33 +50,47 @@ def dump_state(self, save_verbose=False):
4550

4651
state["demos"].append(demo)
4752

48-
# If `save_verbose` save all field metadata as well.
49-
if save_verbose:
50-
fields = []
51-
for field_key in self.signature.fields.keys():
52-
field_metadata = self.signature.fields[field_key]
53-
fields.append({
54-
"name": field_key,
55-
"field_type": field_metadata.json_schema_extra["__dspy_field_type"],
56-
"description": field_metadata.json_schema_extra["desc"],
57-
"prefix": field_metadata.json_schema_extra["prefix"]
58-
})
59-
state["fields"] = fields
60-
61-
# Cache the signature instructions and the last field's name.
62-
*_, last_key = self.signature.fields.keys()
63-
state["signature_instructions"] = self.signature.instructions
64-
state["signature_prefix"] = self.signature.fields[last_key].json_schema_extra["prefix"]
65-
66-
# Some special stuff for CoT.
53+
state["signature"] = self.signature.dump_state()
54+
# `extended_signature` is a special field for `Predict`s like CoT.
6755
if hasattr(self, "extended_signature"):
68-
# Cache the signature instructions and the last field's name.
69-
state["extended_signature_instructions"] = self.extended_signature.instructions
70-
state["extended_signature_prefix"] = self.extended_signature.fields[last_key].json_schema_extra['prefix']
56+
state["extended_signature"] = self.extended_signature.dump_state()
7157

7258
return state
7359

74-
def load_state(self, state):
60+
def load_state(self, state, use_legacy_loading=False):
61+
"""Load the saved state of a `Predict` object.
62+
63+
Args:
64+
state (dict): The saved state of a `Predict` object.
65+
use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a
66+
saved state from a version of DSPy prior to v2.5.3.
67+
"""
68+
if use_legacy_loading:
69+
self._load_state_legacy(state)
70+
return
71+
if "signature" not in state:
72+
# Check if the state is from a version of DSPy prior to v2.5.3.
73+
raise ValueError(
74+
"The saved state is from a version of DSPy prior to v2.5.3. Please use `use_legacy_loading=True` to "
75+
"load the state."
76+
)
77+
78+
excluded_keys = ["signature", "extended_signature"]
79+
for name, value in state.items():
80+
# `excluded_keys` are fields that go through special handling.
81+
if name not in excluded_keys:
82+
setattr(self, name, value)
83+
84+
self.signature = self.signature.load_state(state["signature"])
85+
86+
if "extended_signature" in state:
87+
self.extended_signature.load_state(state["extended_signature"])
88+
89+
def _load_state_legacy(self, state):
90+
"""Legacy state loading for backwards compatibility.
91+
92+
This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3.
93+
"""
7594
for name, value in state.items():
7695
setattr(self, name, value)
7796

@@ -84,7 +103,7 @@ def load_state(self, state):
84103
prefix = state["signature_prefix"]
85104
*_, last_key = self.signature.fields.keys()
86105
self.signature = self.signature.with_updated_fields(last_key, prefix=prefix)
87-
106+
88107
# Some special stuff for CoT.
89108
if "extended_signature_instructions" in state:
90109
instructions = state["extended_signature_instructions"]
@@ -95,6 +114,7 @@ def load_state(self, state):
95114
*_, last_key = self.extended_signature.fields.keys()
96115
self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix)
97116

117+
98118
def __call__(self, **kwargs):
99119
return self.forward(**kwargs)
100120

dspy/primitives/module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,17 @@ def dump_state(self, save_verbose):
152152
print(self.named_parameters())
153153
return {name: param.dump_state(save_verbose) for name, param in self.named_parameters()}
154154

155-
def load_state(self, state):
155+
def load_state(self, state, use_legacy_loading=False):
156156
for name, param in self.named_parameters():
157-
param.load_state(state[name])
157+
param.load_state(state[name], use_legacy_loading=use_legacy_loading)
158158

159159
def save(self, path, save_field_meta=False):
160160
with open(path, "w") as f:
161161
f.write(ujson.dumps(self.dump_state(save_field_meta), indent=2))
162162

163-
def load(self, path):
163+
def load(self, path, use_legacy_loading=False):
164164
with open(path) as f:
165-
self.load_state(ujson.loads(f.read()))
165+
self.load_state(ujson.loads(f.read()), use_legacy_loading=use_legacy_loading)
166166

167167

168168
def postprocess_parameter_name(name, value):

dspy/signatures/signature.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ def signature(cls) -> str:
9696
def instructions(cls) -> str:
9797
return inspect.cleandoc(getattr(cls, "__doc__", ""))
9898

99+
@instructions.setter
100+
def instructions(cls, instructions: str) -> None:
101+
setattr(cls, "__doc__", instructions)
102+
99103
def with_instructions(cls, instructions: str) -> Type["Signature"]:
100104
return Signature(cls.fields, instructions)
101105

@@ -159,6 +163,28 @@ def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signa
159163
new_fields = dict(input_fields + output_fields)
160164
return Signature(new_fields, cls.instructions)
161165

166+
def dump_state(cls):
167+
state = {"instructions": cls.instructions, "fields": []}
168+
for field in cls.fields:
169+
state["fields"].append(
170+
{
171+
"prefix": cls.fields[field].json_schema_extra["prefix"],
172+
"description": cls.fields[field].json_schema_extra["desc"],
173+
}
174+
)
175+
176+
return state
177+
178+
def load_state(cls, state):
179+
signature_copy = Signature(deepcopy(cls.fields), cls.instructions)
180+
181+
signature_copy.instructions = state["instructions"]
182+
for field, saved_field in zip(signature_copy.fields.values(), state["fields"]):
183+
field.json_schema_extra["prefix"] = saved_field["prefix"]
184+
field.json_schema_extra["desc"] = saved_field["description"]
185+
186+
return signature_copy
187+
162188
def equals(cls, other) -> bool:
163189
"""Compare the JSON schema of two Pydantic models."""
164190
if not isinstance(other, type) or not issubclass(other, BaseModel):
@@ -423,4 +449,4 @@ def infer_prefix(attribute_name: str) -> str:
423449
else:
424450
title_cased_words.append(word.capitalize())
425451

426-
return " ".join(title_cased_words)
452+
return " ".join(title_cased_words)

tests/predict/test_predict.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ class Output(pydantic.BaseModel):
126126
# Demos don't need to keep the same types after saving and loading the state.
127127
assert new_instance.demos[0]["input"] == original_instance.demos[0].input.model_dump_json()
128128

129+
def test_signature_fields_after_dump_and_load_state(tmp_path):
130+
class CustomSignature(dspy.Signature):
131+
"""I am just an instruction."""
132+
sentence = dspy.InputField(desc="I am an innocent input!")
133+
sentiment = dspy.OutputField()
134+
135+
file_path = tmp_path / "tmp.json"
136+
original_instance = Predict(CustomSignature)
137+
original_instance.save(file_path)
138+
139+
class CustomSignature2(dspy.Signature):
140+
"""I am not a pure instruction."""
141+
sentence = dspy.InputField(desc="I am a malicious input!")
142+
sentiment = dspy.OutputField(desc="I am a malicious output!", prefix="I am a prefix!")
143+
144+
new_instance = Predict(CustomSignature2)
145+
assert new_instance.signature.dump_state() != original_instance.signature.dump_state()
146+
# After loading, the fields should be the same.
147+
new_instance.load(file_path)
148+
assert new_instance.signature.dump_state() == original_instance.signature.dump_state()
149+
129150

130151
def test_forward_method():
131152
program = Predict("question -> answer")

tests/signatures/test_signature.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,44 @@ class SignatureFour(Signature):
238238
assert "input4" in SignatureTwo.input_fields
239239
assert "input1" in SignatureOne.input_fields
240240
assert "input2" in SignatureTwo.input_fields
241+
242+
243+
def test_dump_and_load_state():
244+
class CustomSignature(dspy.Signature):
245+
"""I am just an instruction."""
246+
247+
sentence = dspy.InputField(desc="I am an innocent input!")
248+
sentiment = dspy.OutputField()
249+
250+
state = CustomSignature.dump_state()
251+
expected = {
252+
"instructions": "I am just an instruction.",
253+
"fields": [
254+
{
255+
"prefix": "Sentence:",
256+
"description": "I am an innocent input!",
257+
},
258+
{
259+
"prefix": "Sentiment:",
260+
"description": "${sentiment}",
261+
},
262+
],
263+
}
264+
assert state == expected
265+
266+
class CustomSignature2(dspy.Signature):
267+
"""I am a malicious instruction."""
268+
269+
sentence = dspy.InputField(desc="I am an malicious input!")
270+
sentiment = dspy.OutputField()
271+
272+
assert CustomSignature2.dump_state() != expected
273+
# Overwrite the state with the state of CustomSignature.
274+
loaded_signature = CustomSignature2.load_state(state)
275+
assert loaded_signature.instructions == "I am just an instruction."
276+
# After `load_state`, the state should be the same as CustomSignature.
277+
assert loaded_signature.dump_state() == expected
278+
# CustomSignature2 should not have been modified.
279+
assert CustomSignature2.instructions == "I am a malicious instruction."
280+
assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!"
281+
assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:"

0 commit comments

Comments
 (0)