Skip to content

Commit 691ab10

Browse files
backward compatibility
1 parent 557e126 commit 691ab10

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

dspy/predict/predict.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,17 @@ def dump_state(self, save_verbose=None):
5757

5858
return state
5959

60-
def load_state(self, state):
60+
def load_state(self, state, use_legacy_loading=False):
61+
"""Load the saved state of a `Predict` object.
62+
63+
Args:
64+
state (dict): The saved state of a `Predict` object.
65+
use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a
66+
saved state from a version of DSPy prior to v2.5.3.
67+
"""
68+
if use_legacy_loading:
69+
self._load_state_legacy(state)
70+
return
6171
excluded_keys = ["signature", "extended_signature"]
6272
for name, value in state.items():
6373
# `excluded_keys` are fields that go through special handling.
@@ -69,6 +79,34 @@ def load_state(self, state):
6979
if "extended_signature" in state:
7080
self.extended_signature.load_state(state["extended_signature"])
7181

82+
def _load_state_legacy(self, state):
83+
"""Legacy state loading for backwards compatibility.
84+
85+
This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3.
86+
"""
87+
for name, value in state.items():
88+
setattr(self, name, value)
89+
90+
# Reconstruct the signature.
91+
if "signature_instructions" in state:
92+
instructions = state["signature_instructions"]
93+
self.signature = self.signature.with_instructions(instructions)
94+
95+
if "signature_prefix" in state:
96+
prefix = state["signature_prefix"]
97+
*_, last_key = self.signature.fields.keys()
98+
self.signature = self.signature.with_updated_fields(last_key, prefix=prefix)
99+
100+
# Some special stuff for CoT.
101+
if "extended_signature_instructions" in state:
102+
instructions = state["extended_signature_instructions"]
103+
self.extended_signature = self.extended_signature.with_instructions(instructions)
104+
105+
if "extended_signature_prefix" in state:
106+
prefix = state["extended_signature_prefix"]
107+
*_, last_key = self.extended_signature.fields.keys()
108+
self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix)
109+
72110

73111
def __call__(self, **kwargs):
74112
return self.forward(**kwargs)

0 commit comments

Comments
 (0)