Skip to content

Commit 1291868

Browse files
committed
feat(train): add CLI entrypoint to run TimeGAN end-to-end
Parses Options, loads datasets via load_data, constructs TimeGAN, and executes the full three-phase schedule with checkpoints. Keeps modules/dataset imports minimal to match current package layout.
1 parent b4fbbc1 commit 1291868

File tree

1 file changed

+17
-276
lines changed
  • recognition/TimeLOB_TimeGAN_49088276/src

1 file changed

+17
-276
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/train.py

Lines changed: 17 additions & 276 deletions
Original file line numberDiff line numberDiff line change
@@ -7,292 +7,33 @@
77
and saves model checkpoints and plots. The model is imported from ``modules.py``
88
and data loaders from ``dataset.py``.
99
10-
Typical Usage:
11-
python3 -m predict --ckpt checkpoints/best.pt --n 8 --seq_len 120 --out outputs/predictions
12-
1310
Created By: Radhesh Goel (Keys-I)
1411
ID: s49088276
1512
1613
References:
1714
-
1815
"""
19-
from __future__ import annotations
20-
import os, json, math, time, argparse, random
21-
from dataclasses import asdict
22-
from typing import Tuple, Optional
23-
24-
import numpy as np
25-
import torch
26-
from torch.utils.data import TensorDataset, DataLoader
27-
28-
# local imports
29-
from dataset import LOBSTERData
30-
from modules import (
31-
TimeGAN, sample_noise, make_optim,
32-
timegan_autoencoder_step, timegan_supervisor_step, timegan_joint_step,
33-
LossWeights
34-
)
35-
36-
# -------------------------
37-
# utils
38-
# -------------------------
39-
def set_seed(seed: int = 1337):
40-
random.seed(seed); np.random.seed(seed)
41-
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
42-
43-
def shape_from_npz(npz_path: str) -> Tuple[int,int,int]:
44-
d = np.load(npz_path)
45-
w = d["train"]
46-
return tuple(w.shape) # num_seq, seq_len, x_dim
47-
48-
def build_loaders_from_npz(npz_path: str, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader, int, int]:
49-
d = np.load(npz_path)
50-
W_train = torch.from_numpy(d["train"]).float()
51-
W_val = torch.from_numpy(d["val"]).float()
52-
W_test = torch.from_numpy(d["test"]).float()
53-
T = W_train.size(1); D = W_train.size(2)
54-
train_dl = DataLoader(TensorDataset(W_train), batch_size=batch_size, shuffle=True, drop_last=True)
55-
val_dl = DataLoader(TensorDataset(W_val), batch_size=batch_size, shuffle=False)
56-
test_dl = DataLoader(TensorDataset(W_test), batch_size=batch_size, shuffle=False)
57-
return train_dl, val_dl, test_dl, T, D
16+
from dataset import load_data
17+
from modules import TimeGAN
18+
from src.helpers.args import Options
5819

59-
def build_loaders_from_csv(args, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader, int, int]:
60-
ds = LOBSTERData(
61-
data_dir=args.data_dir,
62-
message_file=args.message,
63-
orderbook_file=args.orderbook,
64-
feature_set=args.feature_set,
65-
seq_len=args.seq_len,
66-
stride=args.stride,
67-
splits=tuple(args.splits),
68-
scaler=args.scaler,
69-
headerless_message=args.headerless_message,
70-
headerless_orderbook=args.headerless_orderbook,
71-
# optional whitening & aug flags if you want them in training too:
72-
whiten=args.whiten, pca_var=args.pca_var,
73-
aug_prob=args.aug_prob, aug_jitter_std=args.aug_jitter_std,
74-
aug_scaling_std=args.aug_scaling_std, aug_timewarp_max=args.aug_timewarp_max,
75-
save_dir=args.save_dir,
76-
)
77-
W_train, W_val, W_test = ds.load_arrays()
78-
T = W_train.shape[1]; D = W_train.shape[2]
79-
train_dl = DataLoader(TensorDataset(torch.from_numpy(W_train).float()), batch_size=batch_size, shuffle=True, drop_last=True)
80-
val_dl = DataLoader(TensorDataset(torch.from_numpy(W_val).float()), batch_size=batch_size, shuffle=False)
81-
test_dl = DataLoader(TensorDataset(torch.from_numpy(W_test).float()), batch_size=batch_size, shuffle=False)
82-
# Persist meta if saving:
83-
if args.save_dir:
84-
meta = ds.get_meta()
85-
with open(os.path.join(args.save_dir, "meta.train.json"), "w") as f:
86-
json.dump(meta, f, indent=2)
87-
return train_dl, val_dl, test_dl, T, D
8820

89-
def save_ckpt(path: str, model: TimeGAN, opt_gs, opt_d, step: int, args, extra=None):
90-
os.makedirs(os.path.dirname(path), exist_ok=True)
91-
payload = {
92-
"step": step,
93-
"args": vars(args),
94-
"embedder": model.embedder.state_dict(),
95-
"recovery": model.recovery.state_dict(),
96-
"generator": model.generator.state_dict(),
97-
"supervisor": model.supervisor.state_dict(),
98-
"discriminator": model.discriminator.state_dict(),
99-
"opt_gs": opt_gs.state_dict(),
100-
"opt_d": opt_d.state_dict(),
101-
"extra": extra or {},
102-
}
103-
torch.save(payload, path)
21+
def train() -> None:
22+
# parse cli args as before
23+
opt = Options().parse()
10424

105-
# -------------------------
106-
# train loops
107-
# -------------------------
108-
def run_autoencoder_phase(model, train_dl, device, opt_gs, epochs: int, amp: bool, clip: Optional[float]):
109-
scaler = torch.amp.GradScaler('cuda', enabled=amp)
110-
for ep in range(1, epochs+1):
111-
t0 = time.time()
112-
logs = []
113-
for (xb,) in train_dl:
114-
xb = xb.to(device, non_blocking=True)
115-
opt_gs.zero_grad(set_to_none=True)
116-
if amp:
117-
with torch.amp.autocast('cuda'):
118-
out = timegan_autoencoder_step(model, xb, opt_gs)
119-
else:
120-
out = timegan_autoencoder_step(model, xb, opt_gs)
121-
# timegan_autoencoder_step already steps opt; clip if needed
122-
if clip is not None:
123-
torch.nn.utils.clip_grad_norm_(model.embedder.parameters(), clip)
124-
torch.nn.utils.clip_grad_norm_(model.recovery.parameters(), clip)
125-
logs.append(out["recon"])
126-
dt = time.time()-t0
127-
print(f"[AE] epoch {ep}/{epochs} recon={np.mean(logs):.6f} ({dt:.1f}s)")
25+
# train_data: [N, T, F]; val/test should be 2D [T, F] for quick metrics
26+
train_data, val_data, test_data = load_data(opt)
27+
# if val/test come windowed [N, T, F], flatten to [T', F]
28+
if getattr(val_data, "ndim", None) == 3:
29+
val_data = val_data.reshape(-1, val_data.shape[-1])
30+
if getattr(test_data, "ndim", None) == 3:
31+
test_data = test_data.reshape(-1, test_data.shape[-1])
12832

129-
def run_supervisor_phase(model, train_dl, device, opt_gs, epochs: int, amp: bool, clip: Optional[float]):
130-
for ep in range(1, epochs+1):
131-
t0 = time.time()
132-
logs = []
133-
for (xb,) in train_dl:
134-
xb = xb.to(device, non_blocking=True)
135-
out = timegan_supervisor_step(model, xb, opt_gs)
136-
if clip is not None:
137-
torch.nn.utils.clip_grad_norm_(model.supervisor.parameters(), clip)
138-
logs.append(out["sup"])
139-
dt = time.time()-t0
140-
print(f"[SUP] epoch {ep}/{epochs} sup={np.mean(logs):.6f} ({dt:.1f}s)")
33+
# build and train
34+
model = TimeGAN(opt, train_data, val_data, test_data, load_weights=False)
35+
model.train_model()
14136

142-
def evaluate_moment(model, loader, device, z_dim: int) -> float:
143-
# rough eval: moment loss on validation set (lower is better)
144-
from modules import moment_loss
145-
model.eval()
146-
vals = []
147-
with torch.no_grad():
148-
for (xb,) in loader:
149-
xb = xb.to(device)
150-
z = sample_noise(xb.size(0), xb.size(1), z_dim, device)
151-
# generate one batch
152-
paths = model.forward_gen_paths(xb, z)
153-
x_tilde = paths["X_tilde"]
154-
vals.append(float(moment_loss(xb, x_tilde).cpu()))
155-
return float(np.mean(vals)) if vals else math.inf
15637

157-
def run_joint_phase(model, train_dl, val_dl, device, opt_gs, opt_d,
158-
z_dim: int, epochs: int, amp: bool, clip: Optional[float],
159-
loss_weights: LossWeights, ckpt_dir: Optional[str], args=None):
160-
best_val = math.inf
161-
step = 0
162-
for ep in range(1, epochs+1):
163-
t0 = time.time()
164-
logs = {"d": [], "g_adv": [], "g_sup": [], "g_mom": [], "g_fm": [], "recon": [], "cons": [], "g_total": []}
165-
for (xb,) in train_dl:
166-
xb = xb.to(device, non_blocking=True)
167-
z = sample_noise(xb.size(0), xb.size(1), z_dim, device)
168-
out = timegan_joint_step(model, xb, z, opt_gs, opt_d, loss_weights)
169-
if clip is not None:
170-
torch.nn.utils.clip_grad_norm_(list(model.embedder.parameters())+
171-
list(model.recovery.parameters())+
172-
list(model.generator.parameters())+
173-
list(model.supervisor.parameters()), clip)
174-
torch.nn.utils.clip_grad_norm_(model.discriminator.parameters(), clip)
175-
for k, v in out.items(): logs[k].append(v)
176-
step += 1
177-
178-
# validation (moment)
179-
val_m = evaluate_moment(model, val_dl, device, z_dim)
180-
dt = time.time()-t0
181-
log_line = " ".join([f"{k}={np.mean(v):.4f}" for k,v in logs.items()])
182-
print(f"[JOINT] epoch {ep}/{epochs} {log_line} | val_moment={val_m:.4f} ({dt:.1f}s)")
183-
184-
# save best
185-
if ckpt_dir:
186-
if val_m < best_val:
187-
best_val = val_m
188-
save_ckpt(os.path.join(ckpt_dir, "best.pt"), model, opt_gs, opt_d, step, args=args,
189-
extra={"val_moment": val_m})
190-
save_ckpt(os.path.join(ckpt_dir, f"step_{step}.pt"), model, opt_gs, opt_d, step, args=args,
191-
extra={"val_moment": val_m})
192-
193-
# -------------------------
194-
# main
195-
# -------------------------
19638
if __name__ == "__main__":
197-
p = argparse.ArgumentParser(description="Train TimeGAN on LOBSTERData.")
198-
# data sources
199-
p.add_argument("--npz", type=str, help="Path to windows.npz (train/val/test). If set, ignores --data-dir.")
200-
p.add_argument("--data-dir", type=str, help="Folder with message_10.csv and orderbook_10.csv")
201-
p.add_argument("--message", default="message_10.csv")
202-
p.add_argument("--orderbook", default="orderbook_10.csv")
203-
p.add_argument("--feature-set", choices=["core","raw10"], default="core")
204-
p.add_argument("--seq-len", type=int, default=128)
205-
p.add_argument("--stride", type=int, default=32)
206-
p.add_argument("--splits", type=float, nargs=3, default=(0.7,0.15,0.15))
207-
p.add_argument("--scaler", choices=["standard","minmax","robust","quantile","power","none"], default="robust")
208-
p.add_argument("--whiten", choices=["pca","zca",None], default="pca")
209-
p.add_argument("--pca-var", type=float, default=0.999)
210-
p.add_argument("--headerless-message", action="store_true")
211-
p.add_argument("--headerless-orderbook", action="store_true")
212-
p.add_argument("--save-dir", type=str, default=None, help="If set during CSV mode, saves NPZ/meta here.")
213-
214-
# model
215-
p.add_argument("--x-dim", type=str, default="auto", help="'auto' infers from data; else int")
216-
p.add_argument("--z-dim", type=int, default=24)
217-
p.add_argument("--h-dim", type=int, default=64)
218-
p.add_argument("--rnn-type", choices=["gru","lstm"], default="gru")
219-
p.add_argument("--enc-layers", type=int, default=2)
220-
p.add_argument("--dec-layers", type=int, default=2)
221-
p.add_argument("--gen-layers", type=int, default=2)
222-
p.add_argument("--sup-layers", type=int, default=1)
223-
p.add_argument("--dis-layers", type=int, default=1)
224-
p.add_argument("--dropout", type=float, default=0.1)
225-
226-
# training
227-
p.add_argument("--batch-size", type=int, default=64)
228-
p.add_argument("--seed", type=int, default=1337)
229-
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
230-
p.add_argument("--amp", action="store_true", help="Enable mixed precision.")
231-
p.add_argument("--clip", type=float, default=1.0, help="Grad clip norm; set <=0 to disable.")
232-
p.add_argument("--ae-epochs", type=int, default=10)
233-
p.add_argument("--sup-epochs", type=int, default=10)
234-
p.add_argument("--joint-epochs", type=int, default=50)
235-
p.add_argument("--lr", type=float, default=1e-3)
236-
p.add_argument("--ckpt-dir", type=str, default="./ckpts")
237-
238-
# augmentation passthrough when using CSV mode
239-
p.add_argument("--aug-prob", type=float, default=0.0)
240-
p.add_argument("--aug-jitter-std", type=float, default=0.01)
241-
p.add_argument("--aug-scaling-std", type=float, default=0.05)
242-
p.add_argument("--aug-timewarp-max", type=float, default=0.1)
243-
244-
args = p.parse_args()
245-
set_seed(args.seed)
246-
device = torch.device(args.device)
247-
os.makedirs(args.ckpt_dir, exist_ok=True)
248-
run_dir = os.path.join(args.ckpt_dir, f"timegan_{time.strftime('%Y%m%d-%H%M%S')}")
249-
os.makedirs(run_dir, exist_ok=True)
250-
251-
# Data
252-
if args.npz:
253-
train_dl, val_dl, test_dl, T, D = build_loaders_from_npz(args.npz, args.batch_size)
254-
elif args.data_dir:
255-
train_dl, val_dl, test_dl, T, D = build_loaders_from_csv(args, args.batch_size)
256-
else:
257-
raise SystemExit("Provide either --npz or --data-dir")
258-
259-
x_dim = D if args.x_dim == "auto" else int(args.x_dim)
260-
261-
# Model & optims
262-
model = TimeGAN(
263-
x_dim=x_dim, z_dim=args.z_dim, h_dim=args.h_dim,
264-
rnn_type=args.rnn_type, enc_layers=args.enc_layers, dec_layers=args.dec_layers,
265-
gen_layers=args.gen_layers, sup_layers=args.sup_layers, dis_layers=args.dis_layers,
266-
dropout=args.dropout
267-
).to(device)
268-
269-
opt_gs = make_optim(list(model.embedder.parameters()) +
270-
list(model.recovery.parameters()) +
271-
list(model.generator.parameters()) +
272-
list(model.supervisor.parameters()), lr=args.lr)
273-
opt_d = make_optim(model.discriminator.parameters(), lr=args.lr)
274-
275-
# Phase 1: autoencoder pretrain
276-
if args.ae_epochs > 0:
277-
run_autoencoder_phase(model, train_dl, device, opt_gs, args.ae_epochs, amp=args.amp, clip=args.clip if args.clip>0 else None)
278-
save_ckpt(os.path.join(run_dir, "after_autoencoder.pt"), model, opt_gs, opt_d, step=0, args=args)
279-
280-
# Phase 2: supervisor pretrain
281-
if args.sup_epochs > 0:
282-
run_supervisor_phase(model, train_dl, device, opt_gs, args.sup_epochs, amp=args.amp, clip=args.clip if args.clip>0 else None)
283-
save_ckpt(os.path.join(run_dir, "after_supervisor.pt"), model, opt_gs, opt_d, step=0, args=args)
284-
285-
# Phase 3: joint training
286-
if args.joint_epochs > 0:
287-
run_joint_phase(
288-
model, train_dl, val_dl, device, opt_gs, opt_d,
289-
z_dim=args.z_dim, epochs=args.joint_epochs, amp=args.amp,
290-
clip=args.clip if args.clip>0 else None,
291-
loss_weights=LossWeights(), ckpt_dir=run_dir, args=args
292-
)
293-
294-
295-
# Final test moment score
296-
test_m = evaluate_moment(model, test_dl, device, args.z_dim)
297-
print(f"[DONE] test moment loss: {test_m:.6f}")
298-
39+
train()

0 commit comments

Comments
 (0)