2020 ResidualBlock ,
2121 ResidualBlockTS ,
2222)
23- from qolmat .imputations .diffusions .utils import get_num_params
2423
2524logging .basicConfig (
2625 format = "%(asctime)s %(levelname)-8s %(message)s" ,
@@ -176,8 +175,8 @@ def _q_sample(
176175 epsilon = torch .randn_like (x , device = self .device )
177176 return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon , epsilon
178177
179- def _set_eps_model (self ) -> None :
180- self . _eps_model = AutoEncoder (
178+ def _get_eps_model (self ) -> AutoEncoder :
179+ model = AutoEncoder (
181180 num_noise_steps = self .num_noise_steps ,
182181 dim_input = self .dim_input ,
183182 residual_block = ResidualBlock (
@@ -186,12 +185,35 @@ def _set_eps_model(self) -> None:
186185 dim_embedding = self .dim_embedding ,
187186 num_blocks = self .num_blocks ,
188187 p_dropout = self .p_dropout ,
189- ).to (self .device )
188+ )
189+ return model
190+
191+ def _set_eps_model (self ) -> None :
192+ model = self ._get_eps_model ()
193+ self ._eps_model = model .to (self .device )
190194
191195 self .optimiser = torch .optim .Adam (
192196 self ._eps_model .parameters (), lr = self .lr
193197 )
194198
199+ def get_num_params (self ) -> int :
200+ """Compute the number of parameters of the underlying model.
201+
202+ Returns
203+ -------
204+ int: Number of parameters if the model has been fitted,
205+ 0 otherwise.
206+
207+ """
208+ if hasattr (self , "_eps_model" ):
209+ model_parameters = filter (
210+ lambda p : p .requires_grad , self ._eps_model .parameters ()
211+ )
212+ params = sum ([np .prod (p .size ()) for p in model_parameters ])
213+ return int (params )
214+ else :
215+ return 0
216+
195217 def _print_valid (self , epoch : int , time_duration : float ) -> None :
196218 """Print model performance on validation data.
197219
@@ -206,8 +228,9 @@ def _print_valid(self, epoch: int, time_duration: float) -> None:
206228 self .time_durations .append (time_duration )
207229 print_step = 1 if int (self .epochs / 10 ) == 0 else int (self .epochs / 10 )
208230 if self .print_valid and epoch == 0 :
231+ n_params = self .get_num_params ()
209232 logging .info (
210- f"Num params of { self .__class__ .__name__ } : { self . num_params } "
233+ f"Num params of { self .__class__ .__name__ } : { n_params } "
211234 )
212235 if self .print_valid and epoch % print_step == 0 :
213236 string_valid = f"Epoch { epoch } : "
@@ -526,7 +549,6 @@ def fit(
526549 )
527550
528551 self ._set_eps_model ()
529- self .num_params : int = get_num_params (self ._eps_model )
530552 self .summary : Dict [str , List ] = {
531553 "epoch_loss" : [],
532554 }
0 commit comments