Skip to content

Commit aeb67f4

Browse files
committed
feat(model): extend TimeGAN with training loop, ckpt I/O, KL check, and generation API
Adds full wrapper (optimizers, ER pretrain, supervised, joint phases), checkpoint save/load, quick KL(spread) validation, and deterministic helpers. Integrates dataset batcher and utils (minmax, noise). Exposes encoder/recovery/generator/supervisor/discriminator and device/seed utilities.
1 parent 259567c commit aeb67f4

File tree

3 files changed

+348
-10
lines changed

3 files changed

+348
-10
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/helpers/args.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,64 @@ def parse(self, argv: Optional[list | str]) -> Namespace:
6868

6969
return ns
7070

71+
class ModulesOptions:
72+
"""
73+
Hyperparameters for modules & training. Designed to feel like an `opt` object.
74+
75+
Usage:
76+
mods = ModulesOptions().parse(argv_after_flag)
77+
# Access:
78+
mods.batch_size, mods.seq_len, mods.z_dim, mods.hidden_dim, mods.num_layer,
79+
mods.lr, mods.beta1, mods.w_gamma, mods.w_g
80+
"""
81+
def __init__(self) -> None:
82+
parser = ArgumentParser(
83+
prog="timeganlob_modules",
84+
description="Module/model hyperparameters and training weights.",
85+
)
86+
# Core shapes
87+
parser.add_argument("--batch-size", type=int, default=128)
88+
parser.add_argument("--seq-len", type=int, default=128,
89+
help="Sequence length (kept here for convenience to sync with data).")
90+
parser.add_argument("--z-dim", type=int, default=40,
91+
help="Latent/input feature dim (e.g., LOB feature count).")
92+
parser.add_argument("--hidden-dim", type=int, default=64,
93+
help="Module hidden size.")
94+
parser.add_argument("--num-layer", type=int, default=3,
95+
help="Number of stacked layers per RNN/TCN block.")
96+
97+
# Optimizer
98+
parser.add_argument("--lr", type=float, default=1e-4,
99+
help="Learning rate (generator/supervisor/discriminator if shared).")
100+
parser.add_argument("--beta1", type=float, default=0.5,
101+
help="Adam beta1.")
102+
103+
# Loss weights
104+
parser.add_argument("--w-gamma", type=float, default=1.0,
105+
help="Supervisor loss weight (γ).")
106+
parser.add_argument("--w-g", type=float, default=1.0,
107+
help="Generator adversarial loss weight (g).")
108+
109+
self._parser = parser
110+
111+
def parse(self, argv: Optional[list | str]) -> Namespace:
112+
m = self._parser.parse_args(argv)
113+
114+
# Provide both snake_case and "opt-like" names already as attributes
115+
# (so downstream code can do opt.lr, opt.beta1, opt.w_gamma, opt.w_g).
116+
ns = Namespace(
117+
batch_size=m.batch_size,
118+
seq_len=m.seq_len,
119+
z_dim=m.z_dim,
120+
hidden_dim=m.hidden_dim,
121+
num_layer=m.num_layer,
122+
lr=m.lr,
123+
beta1=m.beta1,
124+
w_gamma=m.w_gamma,
125+
w_g=m.w_g,
126+
)
127+
return ns
128+
71129
class Options:
72130
"""
73131
Top-level options that *route* anything after `--dataset` to DatasetOptions.
@@ -92,6 +150,14 @@ def __init__(self) -> None:
92150
"Example: --dataset --seq-len 256 --no-shuffle"
93151
),
94152
)
153+
parser.add_argument(
154+
"--modules",
155+
nargs=REMAINDER,
156+
help=(
157+
"All arguments following this flag are parsed by ModulesOptions. "
158+
"Example: --modules --batch-size 256 --hidden-dim 128 --lr 3e-4"
159+
),
160+
)
95161
self._parser = parser
96162

97163
def parse(self, argv: Optional[list | str] = None) -> Namespace:

recognition/TimeLOB_TimeGAN_49088276/src/helpers/constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
"""
44
from math import isclose
55
from typing import Literal
6+
67
OUTPUT_DIR = "outs"
8+
WEIGHTS_DIR = "weights"
9+
DATA_DIR = "data"
10+
11+
ORDERBOOK_FILENAME = "AMZN_2012-06-21_34200000_57600000_orderbook_10.csv"
712

813
# Training hyperparameters for TimeGAN
914
NUM_TRAINING_ITERATIONS = 25_000
@@ -16,8 +21,3 @@
1621
), (
1722
f"TRAIN_TEST_SPLIT must sum to 1.0 (got {sum(TRAIN_TEST_SPLIT):.8f})"
1823
)
19-
20-
DATA_DIR = "data"
21-
ORDERBOOK_FILENAME = "AMZN_2012-06-21_34200000_57600000_orderbook_10.csv"
22-
23-
DATANAME = Literal["message", "orderbook"]

0 commit comments

Comments
 (0)