Skip to content

Commit b4fbbc1

Browse files
committed
feat(model): wire full TimeGAN training/generation, checkpoints, and quick KL validation
Adds ER pretrain, supervised, and joint loops; Adam optimizers; save/load helpers; device/seed utils; and a generation API that inverse-scales to original feature space. Includes GRU-based Encoder/Recovery/Generator/Supervisor/Discriminator with Xavier/orthogonal init and BCEWithLogits-ready Discriminator.
1 parent aeb67f4 commit b4fbbc1

File tree

1 file changed

+43
-27
lines changed
  • recognition/TimeLOB_TimeGAN_49088276/src

1 file changed

+43
-27
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/modules.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@
2828
-
2929
"""
3030
from __future__ import annotations
31-
from pathlib import Path
31+
32+
import math
3233
from dataclasses import dataclass
34+
from pathlib import Path
3335
from typing import Optional, Tuple
3436

35-
import math
3637
import numpy as np
37-
from numpy.typing import NDArray
38-
3938
import torch
4039
import torch.nn as nn
4140
import torch.optim as optim
41+
from numpy.typing import NDArray
4242

4343
from src.dataset import batch_generator
44-
from src.helpers.args import Options
44+
from src.helpers.args import ModulesOptions as Options
4545
from src.helpers.constants import (
4646
WEIGHTS_DIR,
4747
OUTPUT_DIR,
@@ -135,6 +135,7 @@ class Generator(nn.Module):
135135
"""
136136
Generator: random noise Z → latent sequence E.
137137
"""
138+
138139
def __init__(self, z_dim: int, hidden_dim: int, num_layers: int) -> None:
139140
super().__init__()
140141
self.rnn = nn.GRU(
@@ -152,17 +153,19 @@ def forward(self, z: torch.Tensor, apply_sigmoid: bool = True) -> torch.Tensor:
152153
g = self.proj(g)
153154
return self.act(g) if apply_sigmoid else g
154155

156+
155157
class Supervisor(nn.Module):
156158
"""
157159
Supervisor: next-step latent supervision H_t → H_{t+1}.
158160
"""
161+
159162
def __init__(self, hidden_dim: int, num_layers: int) -> None:
160163
super().__init__()
161164
self.rnn = nn.GRU(
162-
input_size=hidden_dim,
163-
hidden_size=hidden_dim,
164-
num_layers=num_layers,
165-
batch_first=True,
165+
input_size=hidden_dim,
166+
hidden_size=hidden_dim,
167+
num_layers=num_layers,
168+
batch_first=True,
166169
)
167170
self.proj = nn.Linear(hidden_dim, hidden_dim)
168171
self.act = nn.Sigmoid()
@@ -176,13 +179,14 @@ def forward(self, h: torch.Tensor, apply_sigmoid: bool = True) -> torch.Tensor:
176179

177180
class Discriminator(nn.Module):
178181
"""Discriminator: classify latent sequences (real vs synthetic)."""
182+
179183
def __init__(self, hidden_dim: int, num_layers: int) -> None:
180184
super().__init__()
181185
self.rnn = nn.GRU(
182-
input_size=hidden_dim,
183-
hidden_size=hidden_dim,
184-
num_layers=num_layers,
185-
batch_first=True,
186+
input_size=hidden_dim,
187+
hidden_size=hidden_dim,
188+
num_layers=num_layers,
189+
batch_first=True,
186190
)
187191
# note: No sigmoid here; BCEWithLogitsLoss expects raw logits
188192
self.proj = nn.Linear(hidden_dim, 1)
@@ -193,6 +197,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
193197
# produce a logit per timestep
194198
return self.proj(d)
195199

200+
196201
@dataclass
197202
class TimeGANHandles:
198203
encoder: Encoder
@@ -201,17 +206,19 @@ class TimeGANHandles:
201206
supervisor: Supervisor
202207
discriminator: Discriminator
203208

209+
204210
class TimeGAN:
205211
"""
206212
End-to-end TimeGAN wrapper with training & generation utilities.
207213
"""
214+
208215
def __init__(
209-
self,
210-
opt: Options | object,
211-
train_data: NDArray[np.float32],
212-
val_data: NDArray[np.float32],
213-
test_data: NDArray[np.float32],
214-
load_weights: bool = False,
216+
self,
217+
opt: Options | object,
218+
train_data: NDArray[np.float32],
219+
val_data: NDArray[np.float32],
220+
test_data: NDArray[np.float32],
221+
load_weights: bool = False,
215222
) -> None:
216223
# set seed & device
217224
set_seed(getattr(opt, "manualseed", None))
@@ -322,7 +329,6 @@ def _supervised_step(self, x: torch.Tensor) -> float:
322329
self.optS.step()
323330
return float(loss.detach().cpu())
324331

325-
326332
def _generator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
327333
# build graph
328334
h_real = self.netE(x)
@@ -347,7 +353,11 @@ def _generator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
347353
sup = self.mse(s_real[:, :-1, :], h_real[:, 1:, :])
348354

349355
loss = adv + self.opt.w_gamma * adv_e + self.opt.w_g * (v1 + v2) + torch.sqrt(sup + 1e-12)
350-
self.optG.zero_grad(); self.optS.zero_grad(); loss.backward(); self.optG.step(); self.optS.step()
356+
self.optG.zero_grad()
357+
self.optS.zero_grad()
358+
loss.backward()
359+
self.optG.step()
360+
self.optS.step()
351361
return float(loss.detach().cpu())
352362

353363
def _discriminator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
@@ -359,9 +369,9 @@ def _discriminator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
359369
y_fake = self.netD(h_hat)
360370
y_fake_e = self.netD(e_hat)
361371
loss = (
362-
self.bce_logits(y_real, torch.ones_like(y_real))
363-
+ self.bce_logits(y_fake, torch.zeros_like(y_fake))
364-
+ self.opt.w_gamma * self.bce_logits(y_fake_e, torch.zeros_like(y_fake_e))
372+
self.bce_logits(y_real, torch.ones_like(y_real))
373+
+ self.bce_logits(y_fake, torch.zeros_like(y_fake))
374+
+ self.opt.w_gamma * self.bce_logits(y_fake_e, torch.zeros_like(y_fake_e))
365375
)
366376
# optional hinge to avoid overshooting
367377
if loss.item() > 0.15:
@@ -373,12 +383,12 @@ def _discriminator_step(self, x: torch.Tensor, z: torch.Tensor) -> float:
373383
def train_model(self) -> None:
374384
# phase 1: encoder-recovery pretrain
375385
for it in range(self.num_iterations):
376-
x, _T = batch_generator(self.train_norm, None, self.batch_size) # T unused
386+
x, _T = batch_generator(self.train_norm, None, self.batch_size) # T unused
377387
x = torch.as_tensor(x, dtype=torch.float32)
378388
(x,) = self._to_device(x)
379389
er = self._pretrain_er_step(x)
380390
if (it + 1) % max(1, self.validate_interval // 2) == 0:
381-
pass # keep output quiet by default
391+
pass # keep output quiet by default
382392

383393
# phase 2: supervisor
384394
for it in range(self.num_iterations):
@@ -432,7 +442,13 @@ def generate(
432442

433443
assert num_rows > 0
434444
windows_needed = math.ceil(num_rows / self.seq_len)
435-
z = sample_noise(windows_needed, self.z_dim, self.seq_len)
445+
z = sample_noise(
446+
windows_needed,
447+
self.z_dim,
448+
self.seq_len,
449+
mean=mean,
450+
std=std,
451+
)
436452
z = torch.as_tensor(z, dtype=torch.float32, device=self.device)
437453
e_hat = self.netG(z)
438454
h_hat = self.netS(e_hat)

0 commit comments

Comments
 (0)