88# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from
99# named_sub_modules for the time being.
1010
11+
1112class BaseModule :
1213 def __init__ (self ):
1314 pass
@@ -29,7 +30,7 @@ def add_parameter(param_name, param_value):
2930 visited .add (id (param_value ))
3031 param_name = postprocess_parameter_name (param_name , param_value )
3132 named_parameters .append ((param_name , param_value ))
32-
33+
3334 elif isinstance (param_value , dspy .Module ):
3435 # When a sub-module is pre-compiled, keep it frozen.
3536 if not getattr (param_value , "_compiled" , False ):
@@ -42,7 +43,7 @@ def add_parameter(param_name, param_value):
4243 for name , value in self .__dict__ .items ():
4344 if isinstance (value , Parameter ):
4445 add_parameter (name , value )
45-
46+
4647 elif isinstance (value , dspy .Module ):
4748 # When a sub-module is pre-compiled, keep it frozen.
4849 if not getattr (value , "_compiled" , False ):
@@ -153,7 +154,11 @@ def dump_state(self, save_verbose):
153154
154155 def load_state (self , state , use_legacy_loading = False ):
155156 for name , param in self .named_parameters ():
156- param .load_state (state [name ], use_legacy_loading = use_legacy_loading )
157+ if isinstance (param , BaseModule ):
158+ param .load_state (state [name ], use_legacy_loading = use_legacy_loading )
159+ else :
160+ # `use_legacy_loading` is only applicable for BaseModule instances.
161+ param .load_state (state [name ])
157162
158163 def save (self , path , save_field_meta = False ):
159164 with open (path , "w" ) as f :
@@ -168,11 +173,11 @@ def postprocess_parameter_name(name, value):
168173 # For ChainOfThought backward compatibility, remove ending ._predict if it's there
169174 if name .endswith ("._predict" ):
170175 name = name [:- 9 ]
171-
176+
172177 if name .endswith (".self" ):
173178 name = name [:- 5 ]
174-
179+
175180 if name == "_predict" :
176181 return "self"
177-
178- return name
182+
183+ return name
0 commit comments