@@ -57,7 +57,17 @@ def dump_state(self, save_verbose=None):
5757
5858 return state
5959
60- def load_state (self , state ):
60+ def load_state (self , state , use_legacy_loading = False ):
61+ """Load the saved state of a `Predict` object.
62+
63+ Args:
64+ state (dict): The saved state of a `Predict` object.
65+ use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a
66+ saved state from a version of DSPy prior to v2.5.3.
67+ """
68+ if use_legacy_loading :
69+ self ._load_state_legacy (state )
70+ return
6171 excluded_keys = ["signature" , "extended_signature" ]
6272 for name , value in state .items ():
6373 # `excluded_keys` are fields that go through special handling.
@@ -69,6 +79,34 @@ def load_state(self, state):
6979 if "extended_signature" in state :
7080 self .extended_signature .load_state (state ["extended_signature" ])
7181
82+ def _load_state_legacy (self , state ):
83+ """Legacy state loading for backwards compatibility.
84+
85+ This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3.
86+ """
87+ for name , value in state .items ():
88+ setattr (self , name , value )
89+
90+ # Reconstruct the signature.
91+ if "signature_instructions" in state :
92+ instructions = state ["signature_instructions" ]
93+ self .signature = self .signature .with_instructions (instructions )
94+
95+ if "signature_prefix" in state :
96+ prefix = state ["signature_prefix" ]
97+ * _ , last_key = self .signature .fields .keys ()
98+ self .signature = self .signature .with_updated_fields (last_key , prefix = prefix )
99+
100+ # Some special stuff for CoT.
101+ if "extended_signature_instructions" in state :
102+ instructions = state ["extended_signature_instructions" ]
103+ self .extended_signature = self .extended_signature .with_instructions (instructions )
104+
105+ if "extended_signature_prefix" in state :
106+ prefix = state ["extended_signature_prefix" ]
107+ * _ , last_key = self .extended_signature .fields .keys ()
108+ self .extended_signature = self .extended_signature .with_updated_fields (last_key , prefix = prefix )
109+
72110
73111 def __call__ (self , ** kwargs ):
74112 return self .forward (** kwargs )
0 commit comments