@@ -59,6 +59,25 @@ def simulate_backprop(self, unet):
5959 unet .load_state_dict (updated_state_dict )
6060 return unet
6161
62+ def test_from_pretrained (self ):
63+ # Save the model parameters to a temporary directory
64+ unet , ema_unet = self .get_models ()
65+ with tempfile .TemporaryDirectory () as tmpdir :
66+ ema_unet .save_pretrained (tmpdir )
67+
68+ # Load the EMA model from the saved directory
69+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = False )
70+
71+ # Check that the shadow parameters of the loaded model match the original EMA model
72+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
73+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
74+
75+ # Verify that the optimization step is also preserved
76+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
77+
78+ # Check the decay value
79+ assert loaded_ema_unet .decay == ema_unet .decay
80+
6281 def test_optimization_steps_updated (self ):
6382 unet , ema_unet = self .get_models ()
6483 # Take the first (hypothetical) EMA step.
@@ -194,6 +213,25 @@ def simulate_backprop(self, unet):
194213 unet .load_state_dict (updated_state_dict )
195214 return unet
196215
216+ def test_from_pretrained (self ):
217+ # Save the model parameters to a temporary directory
218+ unet , ema_unet = self .get_models ()
219+ with tempfile .TemporaryDirectory () as tmpdir :
220+ ema_unet .save_pretrained (tmpdir )
221+
222+ # Load the EMA model from the saved directory
223+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = True )
224+
225+ # Check that the shadow parameters of the loaded model match the original EMA model
226+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
227+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
228+
229+ # Verify that the optimization step is also preserved
230+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
231+
232+ # Check the decay value
233+ assert loaded_ema_unet .decay == ema_unet .decay
234+
197235 def test_optimization_steps_updated (self ):
198236 unet , ema_unet = self .get_models ()
199237 # Take the first (hypothetical) EMA step.
0 commit comments