@@ -68,6 +68,64 @@ def parse(self, argv: Optional[list | str]) -> Namespace:
6868
6969 return ns
7070
71+ class ModulesOptions :
72+ """
73+ Hyperparameters for modules & training. Designed to feel like an `opt` object.
74+
75+ Usage:
76+ mods = ModulesOptions().parse(argv_after_flag)
77+ # Access:
78+ mods.batch_size, mods.seq_len, mods.z_dim, mods.hidden_dim, mods.num_layer,
79+ mods.lr, mods.beta1, mods.w_gamma, mods.w_g
80+ """
81+ def __init__ (self ) -> None :
82+ parser = ArgumentParser (
83+ prog = "timeganlob_modules" ,
84+ description = "Module/model hyperparameters and training weights." ,
85+ )
86+ # Core shapes
87+ parser .add_argument ("--batch-size" , type = int , default = 128 )
88+ parser .add_argument ("--seq-len" , type = int , default = 128 ,
89+ help = "Sequence length (kept here for convenience to sync with data)." )
90+ parser .add_argument ("--z-dim" , type = int , default = 40 ,
91+ help = "Latent/input feature dim (e.g., LOB feature count)." )
92+ parser .add_argument ("--hidden-dim" , type = int , default = 64 ,
93+ help = "Module hidden size." )
94+ parser .add_argument ("--num-layer" , type = int , default = 3 ,
95+ help = "Number of stacked layers per RNN/TCN block." )
96+
97+ # Optimizer
98+ parser .add_argument ("--lr" , type = float , default = 1e-4 ,
99+ help = "Learning rate (generator/supervisor/discriminator if shared)." )
100+ parser .add_argument ("--beta1" , type = float , default = 0.5 ,
101+ help = "Adam beta1." )
102+
103+ # Loss weights
104+ parser .add_argument ("--w-gamma" , type = float , default = 1.0 ,
105+ help = "Supervisor loss weight (γ)." )
106+ parser .add_argument ("--w-g" , type = float , default = 1.0 ,
107+ help = "Generator adversarial loss weight (g)." )
108+
109+ self ._parser = parser
110+
111+ def parse (self , argv : Optional [list | str ]) -> Namespace :
112+ m = self ._parser .parse_args (argv )
113+
114+ # Provide both snake_case and "opt-like" names already as attributes
115+ # (so downstream code can do opt.lr, opt.beta1, opt.w_gamma, opt.w_g).
116+ ns = Namespace (
117+ batch_size = m .batch_size ,
118+ seq_len = m .seq_len ,
119+ z_dim = m .z_dim ,
120+ hidden_dim = m .hidden_dim ,
121+ num_layer = m .num_layer ,
122+ lr = m .lr ,
123+ beta1 = m .beta1 ,
124+ w_gamma = m .w_gamma ,
125+ w_g = m .w_g ,
126+ )
127+ return ns
128+
71129class Options :
72130 """
73131 Top-level options that *route* anything after `--dataset` to DatasetOptions.
@@ -92,6 +150,14 @@ def __init__(self) -> None:
92150 "Example: --dataset --seq-len 256 --no-shuffle"
93151 ),
94152 )
153+ parser .add_argument (
154+ "--modules" ,
155+ nargs = REMAINDER ,
156+ help = (
157+ "All arguments following this flag are parsed by ModulesOptions. "
158+ "Example: --modules --batch-size 256 --hidden-dim 128 --lr 3e-4"
159+ ),
160+ )
95161 self ._parser = parser
96162
97163 def parse (self , argv : Optional [list | str ] = None ) -> Namespace :
0 commit comments