2828-
2929"""
3030from __future__ import annotations
31- from pathlib import Path
31+
32+ import math
3233from dataclasses import dataclass
34+ from pathlib import Path
3335from typing import Optional , Tuple
3436
35- import math
3637import numpy as np
37- from numpy .typing import NDArray
38-
3938import torch
4039import torch .nn as nn
4140import torch .optim as optim
41+ from numpy .typing import NDArray
4242
4343from src .dataset import batch_generator
44- from src .helpers .args import Options
44+ from src .helpers .args import ModulesOptions as Options
4545from src .helpers .constants import (
4646 WEIGHTS_DIR ,
4747 OUTPUT_DIR ,
@@ -135,6 +135,7 @@ class Generator(nn.Module):
135135 """
136136 Generator: random noise Z → latent sequence E.
137137 """
138+
138139 def __init__ (self , z_dim : int , hidden_dim : int , num_layers : int ) -> None :
139140 super ().__init__ ()
140141 self .rnn = nn .GRU (
@@ -152,17 +153,19 @@ def forward(self, z: torch.Tensor, apply_sigmoid: bool = True) -> torch.Tensor:
152153 g = self .proj (g )
153154 return self .act (g ) if apply_sigmoid else g
154155
156+
155157class Supervisor (nn .Module ):
156158 """
157159 Supervisor: next-step latent supervision H_t → H_{t+1}.
158160 """
161+
159162 def __init__ (self , hidden_dim : int , num_layers : int ) -> None :
160163 super ().__init__ ()
161164 self .rnn = nn .GRU (
162- input_size = hidden_dim ,
163- hidden_size = hidden_dim ,
164- num_layers = num_layers ,
165- batch_first = True ,
165+ input_size = hidden_dim ,
166+ hidden_size = hidden_dim ,
167+ num_layers = num_layers ,
168+ batch_first = True ,
166169 )
167170 self .proj = nn .Linear (hidden_dim , hidden_dim )
168171 self .act = nn .Sigmoid ()
@@ -176,13 +179,14 @@ def forward(self, h: torch.Tensor, apply_sigmoid: bool = True) -> torch.Tensor:
176179
177180class Discriminator (nn .Module ):
178181 """Discriminator: classify latent sequences (real vs synthetic)."""
182+
179183 def __init__ (self , hidden_dim : int , num_layers : int ) -> None :
180184 super ().__init__ ()
181185 self .rnn = nn .GRU (
182- input_size = hidden_dim ,
183- hidden_size = hidden_dim ,
184- num_layers = num_layers ,
185- batch_first = True ,
186+ input_size = hidden_dim ,
187+ hidden_size = hidden_dim ,
188+ num_layers = num_layers ,
189+ batch_first = True ,
186190 )
187191 # note: No sigmoid here; BCEWithLogitsLoss expects raw logits
188192 self .proj = nn .Linear (hidden_dim , 1 )
@@ -193,6 +197,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
193197 # produce a logit per timestep
194198 return self .proj (d )
195199
200+
196201@dataclass
197202class TimeGANHandles :
198203 encoder : Encoder
@@ -201,17 +206,19 @@ class TimeGANHandles:
201206 supervisor : Supervisor
202207 discriminator : Discriminator
203208
209+
204210class TimeGAN :
205211 """
206212 End-to-end TimeGAN wrapper with training & generation utilities.
207213 """
214+
208215 def __init__ (
209- self ,
210- opt : Options | object ,
211- train_data : NDArray [np .float32 ],
212- val_data : NDArray [np .float32 ],
213- test_data : NDArray [np .float32 ],
214- load_weights : bool = False ,
216+ self ,
217+ opt : Options | object ,
218+ train_data : NDArray [np .float32 ],
219+ val_data : NDArray [np .float32 ],
220+ test_data : NDArray [np .float32 ],
221+ load_weights : bool = False ,
215222 ) -> None :
216223 # set seed & device
217224 set_seed (getattr (opt , "manualseed" , None ))
@@ -322,7 +329,6 @@ def _supervised_step(self, x: torch.Tensor) -> float:
322329 self .optS .step ()
323330 return float (loss .detach ().cpu ())
324331
325-
326332 def _generator_step (self , x : torch .Tensor , z : torch .Tensor ) -> float :
327333 # build graph
328334 h_real = self .netE (x )
@@ -347,7 +353,11 @@ def _generator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
347353 sup = self .mse (s_real [:, :- 1 , :], h_real [:, 1 :, :])
348354
349355 loss = adv + self .opt .w_gamma * adv_e + self .opt .w_g * (v1 + v2 ) + torch .sqrt (sup + 1e-12 )
350- self .optG .zero_grad (); self .optS .zero_grad (); loss .backward (); self .optG .step (); self .optS .step ()
356+ self .optG .zero_grad ()
357+ self .optS .zero_grad ()
358+ loss .backward ()
359+ self .optG .step ()
360+ self .optS .step ()
351361 return float (loss .detach ().cpu ())
352362
353363 def _discriminator_step (self , x : torch .Tensor , z : torch .Tensor ) -> float :
@@ -359,9 +369,9 @@ def _discriminator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
359369 y_fake = self .netD (h_hat )
360370 y_fake_e = self .netD (e_hat )
361371 loss = (
362- self .bce_logits (y_real , torch .ones_like (y_real ))
363- + self .bce_logits (y_fake , torch .zeros_like (y_fake ))
364- + self .opt .w_gamma * self .bce_logits (y_fake_e , torch .zeros_like (y_fake_e ))
372+ self .bce_logits (y_real , torch .ones_like (y_real ))
373+ + self .bce_logits (y_fake , torch .zeros_like (y_fake ))
374+ + self .opt .w_gamma * self .bce_logits (y_fake_e , torch .zeros_like (y_fake_e ))
365375 )
366376 # optional hinge to avoid overshooting
367377 if loss .item () > 0.15 :
@@ -373,12 +383,12 @@ def _discriminator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
373383 def train_model (self ) -> None :
374384 # phase 1: encoder-recovery pretrain
375385 for it in range (self .num_iterations ):
376- x , _T = batch_generator (self .train_norm , None , self .batch_size ) # T unused
386+ x , _T = batch_generator (self .train_norm , None , self .batch_size ) # T unused
377387 x = torch .as_tensor (x , dtype = torch .float32 )
378388 (x ,) = self ._to_device (x )
379389 er = self ._pretrain_er_step (x )
380390 if (it + 1 ) % max (1 , self .validate_interval // 2 ) == 0 :
381- pass # keep output quiet by default
391+ pass # keep output quiet by default
382392
383393 # phase 2: supervisor
384394 for it in range (self .num_iterations ):
@@ -432,7 +442,13 @@ def generate(
432442
433443 assert num_rows > 0
434444 windows_needed = math .ceil (num_rows / self .seq_len )
435- z = sample_noise (windows_needed , self .z_dim , self .seq_len )
445+ z = sample_noise (
446+ windows_needed ,
447+ self .z_dim ,
448+ self .seq_len ,
449+ mean = mean ,
450+ std = std ,
451+ )
436452 z = torch .as_tensor (z , dtype = torch .float32 , device = self .device )
437453 e_hat = self .netG (z )
438454 h_hat = self .netS (e_hat )
0 commit comments