|
7 | 7 | and saves model checkpoints and plots. The model is imported from ``modules.py`` |
8 | 8 | and data loaders from ``dataset.py``. |
9 | 9 |
|
10 | | -Typical Usage: |
11 | | - python3 -m predict --ckpt checkpoints/best.pt --n 8 --seq_len 120 --out outputs/predictions |
12 | | -
|
13 | 10 | Created By: Radhesh Goel (Keys-I) |
14 | 11 | ID: s49088276 |
15 | 12 |
|
16 | 13 | References: |
17 | 14 | - |
18 | 15 | """ |
19 | | -from __future__ import annotations |
20 | | -import os, json, math, time, argparse, random |
21 | | -from dataclasses import asdict |
22 | | -from typing import Tuple, Optional |
23 | | - |
24 | | -import numpy as np |
25 | | -import torch |
26 | | -from torch.utils.data import TensorDataset, DataLoader |
27 | | - |
28 | | -# local imports |
29 | | -from dataset import LOBSTERData |
30 | | -from modules import ( |
31 | | - TimeGAN, sample_noise, make_optim, |
32 | | - timegan_autoencoder_step, timegan_supervisor_step, timegan_joint_step, |
33 | | - LossWeights |
34 | | -) |
35 | | - |
36 | | -# ------------------------- |
37 | | -# utils |
38 | | -# ------------------------- |
39 | | -def set_seed(seed: int = 1337): |
40 | | - random.seed(seed); np.random.seed(seed) |
41 | | - torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) |
42 | | - |
43 | | -def shape_from_npz(npz_path: str) -> Tuple[int,int,int]: |
44 | | - d = np.load(npz_path) |
45 | | - w = d["train"] |
46 | | - return tuple(w.shape) # num_seq, seq_len, x_dim |
47 | | - |
48 | | -def build_loaders_from_npz(npz_path: str, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader, int, int]: |
49 | | - d = np.load(npz_path) |
50 | | - W_train = torch.from_numpy(d["train"]).float() |
51 | | - W_val = torch.from_numpy(d["val"]).float() |
52 | | - W_test = torch.from_numpy(d["test"]).float() |
53 | | - T = W_train.size(1); D = W_train.size(2) |
54 | | - train_dl = DataLoader(TensorDataset(W_train), batch_size=batch_size, shuffle=True, drop_last=True) |
55 | | - val_dl = DataLoader(TensorDataset(W_val), batch_size=batch_size, shuffle=False) |
56 | | - test_dl = DataLoader(TensorDataset(W_test), batch_size=batch_size, shuffle=False) |
57 | | - return train_dl, val_dl, test_dl, T, D |
| 16 | +from dataset import load_data |
| 17 | +from modules import TimeGAN |
| 18 | +from src.helpers.args import Options |
58 | 19 |
|
59 | | -def build_loaders_from_csv(args, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader, int, int]: |
60 | | - ds = LOBSTERData( |
61 | | - data_dir=args.data_dir, |
62 | | - message_file=args.message, |
63 | | - orderbook_file=args.orderbook, |
64 | | - feature_set=args.feature_set, |
65 | | - seq_len=args.seq_len, |
66 | | - stride=args.stride, |
67 | | - splits=tuple(args.splits), |
68 | | - scaler=args.scaler, |
69 | | - headerless_message=args.headerless_message, |
70 | | - headerless_orderbook=args.headerless_orderbook, |
71 | | - # optional whitening & aug flags if you want them in training too: |
72 | | - whiten=args.whiten, pca_var=args.pca_var, |
73 | | - aug_prob=args.aug_prob, aug_jitter_std=args.aug_jitter_std, |
74 | | - aug_scaling_std=args.aug_scaling_std, aug_timewarp_max=args.aug_timewarp_max, |
75 | | - save_dir=args.save_dir, |
76 | | - ) |
77 | | - W_train, W_val, W_test = ds.load_arrays() |
78 | | - T = W_train.shape[1]; D = W_train.shape[2] |
79 | | - train_dl = DataLoader(TensorDataset(torch.from_numpy(W_train).float()), batch_size=batch_size, shuffle=True, drop_last=True) |
80 | | - val_dl = DataLoader(TensorDataset(torch.from_numpy(W_val).float()), batch_size=batch_size, shuffle=False) |
81 | | - test_dl = DataLoader(TensorDataset(torch.from_numpy(W_test).float()), batch_size=batch_size, shuffle=False) |
82 | | - # Persist meta if saving: |
83 | | - if args.save_dir: |
84 | | - meta = ds.get_meta() |
85 | | - with open(os.path.join(args.save_dir, "meta.train.json"), "w") as f: |
86 | | - json.dump(meta, f, indent=2) |
87 | | - return train_dl, val_dl, test_dl, T, D |
88 | 20 |
|
89 | | -def save_ckpt(path: str, model: TimeGAN, opt_gs, opt_d, step: int, args, extra=None): |
90 | | - os.makedirs(os.path.dirname(path), exist_ok=True) |
91 | | - payload = { |
92 | | - "step": step, |
93 | | - "args": vars(args), |
94 | | - "embedder": model.embedder.state_dict(), |
95 | | - "recovery": model.recovery.state_dict(), |
96 | | - "generator": model.generator.state_dict(), |
97 | | - "supervisor": model.supervisor.state_dict(), |
98 | | - "discriminator": model.discriminator.state_dict(), |
99 | | - "opt_gs": opt_gs.state_dict(), |
100 | | - "opt_d": opt_d.state_dict(), |
101 | | - "extra": extra or {}, |
102 | | - } |
103 | | - torch.save(payload, path) |
| 21 | +def train() -> None: |
| 22 | + # parse cli args as before |
| 23 | + opt = Options().parse() |
104 | 24 |
|
105 | | -# ------------------------- |
106 | | -# train loops |
107 | | -# ------------------------- |
108 | | -def run_autoencoder_phase(model, train_dl, device, opt_gs, epochs: int, amp: bool, clip: Optional[float]): |
109 | | - scaler = torch.amp.GradScaler('cuda', enabled=amp) |
110 | | - for ep in range(1, epochs+1): |
111 | | - t0 = time.time() |
112 | | - logs = [] |
113 | | - for (xb,) in train_dl: |
114 | | - xb = xb.to(device, non_blocking=True) |
115 | | - opt_gs.zero_grad(set_to_none=True) |
116 | | - if amp: |
117 | | - with torch.amp.autocast('cuda'): |
118 | | - out = timegan_autoencoder_step(model, xb, opt_gs) |
119 | | - else: |
120 | | - out = timegan_autoencoder_step(model, xb, opt_gs) |
121 | | - # timegan_autoencoder_step already steps opt; clip if needed |
122 | | - if clip is not None: |
123 | | - torch.nn.utils.clip_grad_norm_(model.embedder.parameters(), clip) |
124 | | - torch.nn.utils.clip_grad_norm_(model.recovery.parameters(), clip) |
125 | | - logs.append(out["recon"]) |
126 | | - dt = time.time()-t0 |
127 | | - print(f"[AE] epoch {ep}/{epochs} recon={np.mean(logs):.6f} ({dt:.1f}s)") |
| 25 | + # train_data: [N, T, F]; val/test should be 2D [T, F] for quick metrics |
| 26 | + train_data, val_data, test_data = load_data(opt) |
| 27 | + # if val/test come windowed [N, T, F], flatten to [T', F] |
| 28 | + if getattr(val_data, "ndim", None) == 3: |
| 29 | + val_data = val_data.reshape(-1, val_data.shape[-1]) |
| 30 | + if getattr(test_data, "ndim", None) == 3: |
| 31 | + test_data = test_data.reshape(-1, test_data.shape[-1]) |
128 | 32 |
|
129 | | -def run_supervisor_phase(model, train_dl, device, opt_gs, epochs: int, amp: bool, clip: Optional[float]): |
130 | | - for ep in range(1, epochs+1): |
131 | | - t0 = time.time() |
132 | | - logs = [] |
133 | | - for (xb,) in train_dl: |
134 | | - xb = xb.to(device, non_blocking=True) |
135 | | - out = timegan_supervisor_step(model, xb, opt_gs) |
136 | | - if clip is not None: |
137 | | - torch.nn.utils.clip_grad_norm_(model.supervisor.parameters(), clip) |
138 | | - logs.append(out["sup"]) |
139 | | - dt = time.time()-t0 |
140 | | - print(f"[SUP] epoch {ep}/{epochs} sup={np.mean(logs):.6f} ({dt:.1f}s)") |
| 33 | + # build and train |
| 34 | + model = TimeGAN(opt, train_data, val_data, test_data, load_weights=False) |
| 35 | + model.train_model() |
141 | 36 |
|
142 | | -def evaluate_moment(model, loader, device, z_dim: int) -> float: |
143 | | - # rough eval: moment loss on validation set (lower is better) |
144 | | - from modules import moment_loss |
145 | | - model.eval() |
146 | | - vals = [] |
147 | | - with torch.no_grad(): |
148 | | - for (xb,) in loader: |
149 | | - xb = xb.to(device) |
150 | | - z = sample_noise(xb.size(0), xb.size(1), z_dim, device) |
151 | | - # generate one batch |
152 | | - paths = model.forward_gen_paths(xb, z) |
153 | | - x_tilde = paths["X_tilde"] |
154 | | - vals.append(float(moment_loss(xb, x_tilde).cpu())) |
155 | | - return float(np.mean(vals)) if vals else math.inf |
156 | 37 |
|
157 | | -def run_joint_phase(model, train_dl, val_dl, device, opt_gs, opt_d, |
158 | | - z_dim: int, epochs: int, amp: bool, clip: Optional[float], |
159 | | - loss_weights: LossWeights, ckpt_dir: Optional[str], args=None): |
160 | | - best_val = math.inf |
161 | | - step = 0 |
162 | | - for ep in range(1, epochs+1): |
163 | | - t0 = time.time() |
164 | | - logs = {"d": [], "g_adv": [], "g_sup": [], "g_mom": [], "g_fm": [], "recon": [], "cons": [], "g_total": []} |
165 | | - for (xb,) in train_dl: |
166 | | - xb = xb.to(device, non_blocking=True) |
167 | | - z = sample_noise(xb.size(0), xb.size(1), z_dim, device) |
168 | | - out = timegan_joint_step(model, xb, z, opt_gs, opt_d, loss_weights) |
169 | | - if clip is not None: |
170 | | - torch.nn.utils.clip_grad_norm_(list(model.embedder.parameters())+ |
171 | | - list(model.recovery.parameters())+ |
172 | | - list(model.generator.parameters())+ |
173 | | - list(model.supervisor.parameters()), clip) |
174 | | - torch.nn.utils.clip_grad_norm_(model.discriminator.parameters(), clip) |
175 | | - for k, v in out.items(): logs[k].append(v) |
176 | | - step += 1 |
177 | | - |
178 | | - # validation (moment) |
179 | | - val_m = evaluate_moment(model, val_dl, device, z_dim) |
180 | | - dt = time.time()-t0 |
181 | | - log_line = " ".join([f"{k}={np.mean(v):.4f}" for k,v in logs.items()]) |
182 | | - print(f"[JOINT] epoch {ep}/{epochs} {log_line} | val_moment={val_m:.4f} ({dt:.1f}s)") |
183 | | - |
184 | | - # save best |
185 | | - if ckpt_dir: |
186 | | - if val_m < best_val: |
187 | | - best_val = val_m |
188 | | - save_ckpt(os.path.join(ckpt_dir, "best.pt"), model, opt_gs, opt_d, step, args=args, |
189 | | - extra={"val_moment": val_m}) |
190 | | - save_ckpt(os.path.join(ckpt_dir, f"step_{step}.pt"), model, opt_gs, opt_d, step, args=args, |
191 | | - extra={"val_moment": val_m}) |
192 | | - |
193 | | -# ------------------------- |
194 | | -# main |
195 | | -# ------------------------- |
196 | 38 | if __name__ == "__main__": |
197 | | - p = argparse.ArgumentParser(description="Train TimeGAN on LOBSTERData.") |
198 | | - # data sources |
199 | | - p.add_argument("--npz", type=str, help="Path to windows.npz (train/val/test). If set, ignores --data-dir.") |
200 | | - p.add_argument("--data-dir", type=str, help="Folder with message_10.csv and orderbook_10.csv") |
201 | | - p.add_argument("--message", default="message_10.csv") |
202 | | - p.add_argument("--orderbook", default="orderbook_10.csv") |
203 | | - p.add_argument("--feature-set", choices=["core","raw10"], default="core") |
204 | | - p.add_argument("--seq-len", type=int, default=128) |
205 | | - p.add_argument("--stride", type=int, default=32) |
206 | | - p.add_argument("--splits", type=float, nargs=3, default=(0.7,0.15,0.15)) |
207 | | - p.add_argument("--scaler", choices=["standard","minmax","robust","quantile","power","none"], default="robust") |
208 | | - p.add_argument("--whiten", choices=["pca","zca",None], default="pca") |
209 | | - p.add_argument("--pca-var", type=float, default=0.999) |
210 | | - p.add_argument("--headerless-message", action="store_true") |
211 | | - p.add_argument("--headerless-orderbook", action="store_true") |
212 | | - p.add_argument("--save-dir", type=str, default=None, help="If set during CSV mode, saves NPZ/meta here.") |
213 | | - |
214 | | - # model |
215 | | - p.add_argument("--x-dim", type=str, default="auto", help="'auto' infers from data; else int") |
216 | | - p.add_argument("--z-dim", type=int, default=24) |
217 | | - p.add_argument("--h-dim", type=int, default=64) |
218 | | - p.add_argument("--rnn-type", choices=["gru","lstm"], default="gru") |
219 | | - p.add_argument("--enc-layers", type=int, default=2) |
220 | | - p.add_argument("--dec-layers", type=int, default=2) |
221 | | - p.add_argument("--gen-layers", type=int, default=2) |
222 | | - p.add_argument("--sup-layers", type=int, default=1) |
223 | | - p.add_argument("--dis-layers", type=int, default=1) |
224 | | - p.add_argument("--dropout", type=float, default=0.1) |
225 | | - |
226 | | - # training |
227 | | - p.add_argument("--batch-size", type=int, default=64) |
228 | | - p.add_argument("--seed", type=int, default=1337) |
229 | | - p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
230 | | - p.add_argument("--amp", action="store_true", help="Enable mixed precision.") |
231 | | - p.add_argument("--clip", type=float, default=1.0, help="Grad clip norm; set <=0 to disable.") |
232 | | - p.add_argument("--ae-epochs", type=int, default=10) |
233 | | - p.add_argument("--sup-epochs", type=int, default=10) |
234 | | - p.add_argument("--joint-epochs", type=int, default=50) |
235 | | - p.add_argument("--lr", type=float, default=1e-3) |
236 | | - p.add_argument("--ckpt-dir", type=str, default="./ckpts") |
237 | | - |
238 | | - # augmentation passthrough when using CSV mode |
239 | | - p.add_argument("--aug-prob", type=float, default=0.0) |
240 | | - p.add_argument("--aug-jitter-std", type=float, default=0.01) |
241 | | - p.add_argument("--aug-scaling-std", type=float, default=0.05) |
242 | | - p.add_argument("--aug-timewarp-max", type=float, default=0.1) |
243 | | - |
244 | | - args = p.parse_args() |
245 | | - set_seed(args.seed) |
246 | | - device = torch.device(args.device) |
247 | | - os.makedirs(args.ckpt_dir, exist_ok=True) |
248 | | - run_dir = os.path.join(args.ckpt_dir, f"timegan_{time.strftime('%Y%m%d-%H%M%S')}") |
249 | | - os.makedirs(run_dir, exist_ok=True) |
250 | | - |
251 | | - # Data |
252 | | - if args.npz: |
253 | | - train_dl, val_dl, test_dl, T, D = build_loaders_from_npz(args.npz, args.batch_size) |
254 | | - elif args.data_dir: |
255 | | - train_dl, val_dl, test_dl, T, D = build_loaders_from_csv(args, args.batch_size) |
256 | | - else: |
257 | | - raise SystemExit("Provide either --npz or --data-dir") |
258 | | - |
259 | | - x_dim = D if args.x_dim == "auto" else int(args.x_dim) |
260 | | - |
261 | | - # Model & optims |
262 | | - model = TimeGAN( |
263 | | - x_dim=x_dim, z_dim=args.z_dim, h_dim=args.h_dim, |
264 | | - rnn_type=args.rnn_type, enc_layers=args.enc_layers, dec_layers=args.dec_layers, |
265 | | - gen_layers=args.gen_layers, sup_layers=args.sup_layers, dis_layers=args.dis_layers, |
266 | | - dropout=args.dropout |
267 | | - ).to(device) |
268 | | - |
269 | | - opt_gs = make_optim(list(model.embedder.parameters()) + |
270 | | - list(model.recovery.parameters()) + |
271 | | - list(model.generator.parameters()) + |
272 | | - list(model.supervisor.parameters()), lr=args.lr) |
273 | | - opt_d = make_optim(model.discriminator.parameters(), lr=args.lr) |
274 | | - |
275 | | - # Phase 1: autoencoder pretrain |
276 | | - if args.ae_epochs > 0: |
277 | | - run_autoencoder_phase(model, train_dl, device, opt_gs, args.ae_epochs, amp=args.amp, clip=args.clip if args.clip>0 else None) |
278 | | - save_ckpt(os.path.join(run_dir, "after_autoencoder.pt"), model, opt_gs, opt_d, step=0, args=args) |
279 | | - |
280 | | - # Phase 2: supervisor pretrain |
281 | | - if args.sup_epochs > 0: |
282 | | - run_supervisor_phase(model, train_dl, device, opt_gs, args.sup_epochs, amp=args.amp, clip=args.clip if args.clip>0 else None) |
283 | | - save_ckpt(os.path.join(run_dir, "after_supervisor.pt"), model, opt_gs, opt_d, step=0, args=args) |
284 | | - |
285 | | - # Phase 3: joint training |
286 | | - if args.joint_epochs > 0: |
287 | | - run_joint_phase( |
288 | | - model, train_dl, val_dl, device, opt_gs, opt_d, |
289 | | - z_dim=args.z_dim, epochs=args.joint_epochs, amp=args.amp, |
290 | | - clip=args.clip if args.clip>0 else None, |
291 | | - loss_weights=LossWeights(), ckpt_dir=run_dir, args=args |
292 | | - ) |
293 | | - |
294 | | - |
295 | | - # Final test moment score |
296 | | - test_m = evaluate_moment(model, test_dl, device, args.z_dim) |
297 | | - print(f"[DONE] test moment loss: {test_m:.6f}") |
298 | | - |
| 39 | + train() |
0 commit comments