Skip to content

Commit e2f1b74

Browse files
committed
feat(predict): add TimeGAN sampling & visualisation script (lines + heatmaps + stats)
Loads windows from NPZ or CSV via LOBSTERData, restores trained checkpoint, samples synthetic sequences, prints per-feature mean/std and quick KL, and saves feature-line plots + depth heatmaps to --outdir.
1 parent 53ee1ea commit e2f1b74

File tree

1 file changed

+254
-5
lines changed
  • recognition/TimeLOB_TimeGAN_49088276/src

1 file changed

+254
-5
lines changed
Lines changed: 254 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python3
12
"""
23
Sample synthetic sequences using a trained TimeGAN model and visualise results.
34
@@ -6,12 +7,260 @@
67
(e.g., feature lines and depth heatmaps) to compare real vs. synthetic data.
78
89
Typical Usage:
9-
python3 -m train --data_dir <PATH> --seq_len 100 --batch_size 64 --epochs 20
10+
# Using preprocessed windows
11+
python sample_viz.py --npz ./preproc_final/windows.npz \
12+
--ckpt ./ckpts/timegan_run/best.pt --z-dim 24 --h-dim 64
13+
14+
# Preprocess on-the-fly (same flags as dataset.py)
15+
python sample_viz.py --data-dir /PATH/TO/SESSION --feature-set core \
16+
--seq-len 128 --stride 32 --scaler robust --whiten pca --pca-var 0.999 \
17+
--ckpt ./ckpts/timegan_run/best.pt --z-dim 24 --h-dim 64
1018
1119
Created By: Radhesh Goel (Keys-I)
1220
ID: s49088276
13-
14-
References:
15-
-
1621
"""
17-
# TODO: Implement checkpoint load, sampling, basic stats, and visualisations.
22+
from __future__ import annotations
23+
import os
24+
import argparse
25+
import numpy as np
26+
import matplotlib.pyplot as plt
27+
from typing import Tuple
28+
29+
import torch
30+
31+
# local modules
32+
from modules import TimeGAN, sample_noise
33+
from dataset import LOBSTERData
34+
35+
36+
# ---------------------------
37+
# Data loading helpers
38+
# ---------------------------
39+
40+
def load_windows_npz(npz_path: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
41+
d = np.load(npz_path)
42+
return d["train"], d["val"], d["test"]
43+
44+
def load_windows_csv(args) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
45+
ds = LOBSTERData(
46+
data_dir=args.data_dir,
47+
message_file=args.message,
48+
orderbook_file=args.orderbook,
49+
feature_set=args.feature_set,
50+
seq_len=args.seq_len,
51+
stride=args.stride,
52+
splits=tuple(args.splits),
53+
scaler=args.scaler,
54+
headerless_message=args.headerless_message,
55+
headerless_orderbook=args.headerless_orderbook,
56+
whiten=args.whiten, pca_var=args.pca_var,
57+
aug_prob=0.0, # no aug for visualisation builds
58+
save_dir=None,
59+
)
60+
return ds.load_arrays()
61+
62+
63+
# ---------------------------
64+
# Model restore + sampling
65+
# ---------------------------
66+
67+
def build_model_from_ckpt(ckpt_path: str, x_dim: int, z_dim: int, h_dim: int, device: torch.device) -> TimeGAN:
68+
ckpt = torch.load(ckpt_path, map_location=device)
69+
args_in_ckpt = ckpt.get("args", {}) or {}
70+
rnn_type = args_in_ckpt.get("rnn_type", "gru")
71+
enc_layers = int(args_in_ckpt.get("enc_layers", 2))
72+
dec_layers = int(args_in_ckpt.get("dec_layers", 2))
73+
gen_layers = int(args_in_ckpt.get("gen_layers", 2))
74+
sup_layers = int(args_in_ckpt.get("sup_layers", 1))
75+
dis_layers = int(args_in_ckpt.get("dis_layers", 1))
76+
dropout = float(args_in_ckpt.get("dropout", 0.1))
77+
78+
model = TimeGAN(
79+
x_dim=x_dim, z_dim=z_dim, h_dim=h_dim,
80+
rnn_type=rnn_type, enc_layers=enc_layers, dec_layers=dec_layers,
81+
gen_layers=gen_layers, sup_layers=sup_layers, dis_layers=dis_layers,
82+
dropout=dropout
83+
).to(device)
84+
85+
model.embedder.load_state_dict(ckpt["embedder"])
86+
model.recovery.load_state_dict(ckpt["recovery"])
87+
model.generator.load_state_dict(ckpt["generator"])
88+
model.supervisor.load_state_dict(ckpt["supervisor"])
89+
model.discriminator.load_state_dict(ckpt["discriminator"])
90+
model.eval()
91+
return model
92+
93+
@torch.no_grad()
94+
def sample_synthetic(model: TimeGAN, n_seq: int, seq_len: int, z_dim: int, device: torch.device) -> np.ndarray:
95+
z = sample_noise(n_seq, seq_len, z_dim, device)
96+
e_tilde = model.generator(z)
97+
h_tilde = model.supervisor(e_tilde)
98+
x_tilde = model.recovery(h_tilde)
99+
return x_tilde.detach().cpu().numpy()
100+
101+
102+
# ---------------------------
103+
# Stats + simple similarity
104+
# ---------------------------
105+
106+
def summarize(name: str, W: np.ndarray) -> dict:
107+
# mean/std over batch+time, per-feature
108+
mu = W.mean(axis=(0, 1))
109+
sd = W.std(axis=(0, 1))
110+
return {"name": name, "mean": mu, "std": sd}
111+
112+
def kl_hist_avg(real: np.ndarray, synth: np.ndarray, bins: int = 64, eps: float = 1e-9) -> float:
113+
"""
114+
Quick histogram-based KL(real || synth) averaged over features.
115+
"""
116+
from scipy.special import rel_entr
117+
F = real.shape[2]
118+
vals = []
119+
R = real.reshape(-1, F)
120+
S = synth.reshape(-1, F)
121+
for f in range(F):
122+
r = R[:, f]; s = S[:, f]
123+
lo = np.nanpercentile(np.concatenate([r, s]), 0.5)
124+
hi = np.nanpercentile(np.concatenate([r, s]), 99.5)
125+
if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
126+
continue
127+
pr, _ = np.histogram(r, bins=bins, range=(lo, hi), density=True)
128+
ps, _ = np.histogram(s, bins=bins, range=(lo, hi), density=True)
129+
pr = pr + eps; ps = ps + eps
130+
pr = pr / pr.sum(); ps = ps / ps.sum()
131+
vals.append(np.sum(rel_entr(pr, ps)))
132+
return float(np.mean(vals)) if vals else float("nan")
133+
134+
135+
# ---------------------------
136+
# Visualisations
137+
# ---------------------------
138+
139+
def plot_feature_lines(real: np.ndarray, synth: np.ndarray, outdir: str, max_feats: int = 4, idx: int = 0):
140+
"""
141+
Plot a few feature time-series (same sequence index) real vs synthetic.
142+
"""
143+
os.makedirs(outdir, exist_ok=True)
144+
T, F = real.shape[1], real.shape[2]
145+
feats = min(F, max_feats)
146+
147+
fig, axes = plt.subplots(feats, 1, figsize=(10, 2.2 * feats), sharex=True)
148+
if feats == 1:
149+
axes = [axes]
150+
for i in range(feats):
151+
axes[i].plot(real[idx, :, i], label="real", linewidth=1.2)
152+
axes[i].plot(synth[idx, :, i], label="synthetic", linewidth=1.2, linestyle="--")
153+
axes[i].set_ylabel(f"feat {i}")
154+
axes[-1].set_xlabel("time")
155+
axes[0].legend(loc="upper right")
156+
fig.suptitle("Feature lines: real vs synthetic")
157+
fig.tight_layout()
158+
fig.savefig(os.path.join(outdir, "feature_lines.png"), dpi=150)
159+
plt.close(fig)
160+
161+
def plot_heatmaps(real: np.ndarray, synth: np.ndarray, outdir: str, idx: int = 0):
162+
"""
163+
Plot depth heatmaps (time x features) for a single sequence.
164+
"""
165+
os.makedirs(outdir, exist_ok=True)
166+
a = real[idx]; b = synth[idx]
167+
# normalize each to [0,1] for visibility
168+
def norm01(x):
169+
lo, hi = np.percentile(x, 1), np.percentile(x, 99)
170+
return np.clip((x - lo) / (hi - lo + 1e-9), 0, 1)
171+
172+
a = norm01(a); b = norm01(b)
173+
174+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
175+
im0 = axes[0].imshow(a, aspect="auto", origin="lower")
176+
axes[0].set_title("Real (heatmap)")
177+
axes[0].set_xlabel("feature"); axes[0].set_ylabel("time")
178+
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
179+
180+
im1 = axes[1].imshow(b, aspect="auto", origin="lower")
181+
axes[1].set_title("Synthetic (heatmap)")
182+
axes[1].set_xlabel("feature"); axes[1].set_ylabel("time")
183+
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
184+
185+
fig.tight_layout()
186+
fig.savefig(os.path.join(outdir, "heatmaps.png"), dpi=150)
187+
plt.close(fig)
188+
189+
190+
# ---------------------------
191+
# Main
192+
# ---------------------------
193+
194+
if __name__ == "__main__":
195+
ap = argparse.ArgumentParser(description="Sample & visualise TimeGAN outputs vs real.")
196+
# data
197+
ap.add_argument("--npz", type=str, help="Path to windows.npz (train/val/test). If set, ignores --data-dir.")
198+
ap.add_argument("--data-dir", type=str, help="Folder with message_10.csv and orderbook_10.csv")
199+
ap.add_argument("--message", default="message_10.csv")
200+
ap.add_argument("--orderbook", default="orderbook_10.csv")
201+
ap.add_argument("--feature-set", choices=["core","raw10"], default="core")
202+
ap.add_argument("--seq-len", type=int, default=128)
203+
ap.add_argument("--stride", type=int, default=32)
204+
ap.add_argument("--splits", type=float, nargs=3, default=(0.7,0.15,0.15))
205+
ap.add_argument("--scaler", choices=["standard","minmax","robust","quantile","power","none"], default="robust")
206+
ap.add_argument("--whiten", choices=["pca","zca",None], default="pca")
207+
ap.add_argument("--pca-var", type=float, default=0.999)
208+
ap.add_argument("--headerless-message", action="store_true")
209+
ap.add_argument("--headerless-orderbook", action="store_true")
210+
211+
# model restore
212+
ap.add_argument("--ckpt", type=str, required=True)
213+
ap.add_argument("--z-dim", type=int, required=True)
214+
ap.add_argument("--h-dim", type=int, required=True)
215+
ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
216+
217+
# viz
218+
ap.add_argument("--n-synth", type=int, default=128, help="How many synthetic windows to sample.")
219+
ap.add_argument("--seq-index", type=int, default=0, help="Which sequence index to plot.")
220+
ap.add_argument("--max-feats", type=int, default=4, help="Max features to show in line plot.")
221+
ap.add_argument("--outdir", type=str, default="./viz_out")
222+
223+
args = ap.parse_args()
224+
os.makedirs(args.outdir, exist_ok=True)
225+
device = torch.device(args.device)
226+
227+
# Load real windows
228+
if args.npz:
229+
Wtr, Wval, Wte = load_windows_npz(args.npz)
230+
elif args.data_dir:
231+
Wtr, Wval, Wte = load_windows_csv(args)
232+
else:
233+
raise SystemExit("Provide either --npz or --data-dir")
234+
235+
# Pick a real reference set (test split)
236+
real = Wte
237+
_, T, D = real.shape
238+
239+
# Build model & restore
240+
model = build_model_from_ckpt(args.ckpt, x_dim=D, z_dim=args.z_dim, h_dim=args.h_dim, device=device)
241+
model.eval()
242+
243+
# Sample synthetic
244+
n_synth = min(args.n_synth, len(real))
245+
synth = sample_synthetic(model, n_synth, T, args.z_dim, device)
246+
247+
# Basic stats
248+
s_real = summarize("real(test)", real)
249+
s_synth = summarize("synthetic", synth)
250+
print("=== Summary (per-feature mean/std) ===")
251+
print(f"{s_real['name']}: mean[0:5]={s_real['mean'][:5]}, std[0:5]={s_real['std'][:5]}")
252+
print(f"{s_synth['name']}: mean[0:5]={s_synth['mean'][:5]}, std[0:5]={s_synth['std'][:5]}")
253+
254+
# Quick KL(hist) similarity
255+
try:
256+
kl = kl_hist_avg(real[:n_synth], synth)
257+
print(f"KL(real || synth) ~ {kl:.4f} (lower is better)")
258+
except Exception as e:
259+
print(f"KL computation skipped: {e}")
260+
261+
# Visualisations
262+
idx = max(0, min(args.seq_index, n_synth - 1))
263+
plot_feature_lines(real, synth, args.outdir, max_feats=args.max_feats, idx=idx)
264+
plot_heatmaps(real, synth, args.outdir, idx=idx)
265+
266+
print(f"Saved plots to: {args.outdir}")

0 commit comments

Comments
 (0)