1+ import logging
12import random
3+ from functools import lru_cache
24
35from pydantic import BaseModel
46
57import dsp
68from dspy .predict .parameter import Parameter
7- from dspy .primitives .program import Module
8-
99from dspy .primitives .prediction import Prediction
10+ from dspy .primitives .program import Module
1011from dspy .signatures .signature import ensure_signature , signature_to_template
1112
12- import logging
13- from functools import lru_cache
1413
1514@lru_cache (maxsize = None )
1615def warn_once (msg : str ):
@@ -31,7 +30,13 @@ def reset(self):
3130 self .train = []
3231 self .demos = []
3332
34- def dump_state (self , save_verbose = False ):
33+ def dump_state (self , save_verbose = None ):
34+ if save_verbose :
35+ logging .warning (
36+ "`save_verbose` is deprecated and will be removed in DSPy 2.6.0 release. Currently `save_verbose` "
37+ "does nothing."
38+ )
39+
3540 state_keys = ["lm" , "traces" , "train" ]
3641 state = {k : getattr (self , k ) for k in state_keys }
3742
@@ -45,33 +50,47 @@ def dump_state(self, save_verbose=False):
4550
4651 state ["demos" ].append (demo )
4752
48- # If `save_verbose` save all field metadata as well.
49- if save_verbose :
50- fields = []
51- for field_key in self .signature .fields .keys ():
52- field_metadata = self .signature .fields [field_key ]
53- fields .append ({
54- "name" : field_key ,
55- "field_type" : field_metadata .json_schema_extra ["__dspy_field_type" ],
56- "description" : field_metadata .json_schema_extra ["desc" ],
57- "prefix" : field_metadata .json_schema_extra ["prefix" ]
58- })
59- state ["fields" ] = fields
60-
61- # Cache the signature instructions and the last field's name.
62- * _ , last_key = self .signature .fields .keys ()
63- state ["signature_instructions" ] = self .signature .instructions
64- state ["signature_prefix" ] = self .signature .fields [last_key ].json_schema_extra ["prefix" ]
65-
66- # Some special stuff for CoT.
53+ state ["signature" ] = self .signature .dump_state ()
54+ # `extended_signature` is a special field for `Predict`s like CoT.
6755 if hasattr (self , "extended_signature" ):
68- # Cache the signature instructions and the last field's name.
69- state ["extended_signature_instructions" ] = self .extended_signature .instructions
70- state ["extended_signature_prefix" ] = self .extended_signature .fields [last_key ].json_schema_extra ['prefix' ]
56+ state ["extended_signature" ] = self .extended_signature .dump_state ()
7157
7258 return state
7359
74- def load_state (self , state ):
60+ def load_state (self , state , use_legacy_loading = False ):
61+ """Load the saved state of a `Predict` object.
62+
63+ Args:
64+ state (dict): The saved state of a `Predict` object.
65+ use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a
66+ saved state from a version of DSPy prior to v2.5.3.
67+ """
68+ if use_legacy_loading :
69+ self ._load_state_legacy (state )
70+ return
71+ if "signature" not in state :
72+ # Check if the state is from a version of DSPy prior to v2.5.3.
73+ raise ValueError (
74+ "The saved state is from a version of DSPy prior to v2.5.3. Please use `use_legacy_loading=True` to "
75+ "load the state."
76+ )
77+
78+ excluded_keys = ["signature" , "extended_signature" ]
79+ for name , value in state .items ():
80+ # `excluded_keys` are fields that go through special handling.
81+ if name not in excluded_keys :
82+ setattr (self , name , value )
83+
84+ self .signature = self .signature .load_state (state ["signature" ])
85+
86+ if "extended_signature" in state :
87+ self .extended_signature .load_state (state ["extended_signature" ])
88+
89+ def _load_state_legacy (self , state ):
90+ """Legacy state loading for backwards compatibility.
91+
92+ This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3.
93+ """
7594 for name , value in state .items ():
7695 setattr (self , name , value )
7796
@@ -84,7 +103,7 @@ def load_state(self, state):
84103 prefix = state ["signature_prefix" ]
85104 * _ , last_key = self .signature .fields .keys ()
86105 self .signature = self .signature .with_updated_fields (last_key , prefix = prefix )
87-
106+
88107 # Some special stuff for CoT.
89108 if "extended_signature_instructions" in state :
90109 instructions = state ["extended_signature_instructions" ]
@@ -95,6 +114,7 @@ def load_state(self, state):
95114 * _ , last_key = self .extended_signature .fields .keys ()
96115 self .extended_signature = self .extended_signature .with_updated_fields (last_key , prefix = prefix )
97116
117+
98118 def __call__ (self , ** kwargs ):
99119 return self .forward (** kwargs )
100120
0 commit comments