1+ #!/usr/bin/env python3
12"""
23Sample synthetic sequences using a trained TimeGAN model and visualise results.
34
67(e.g., feature lines and depth heatmaps) to compare real vs. synthetic data.
78
89Typical Usage:
9- python3 -m train --data_dir <PATH> --seq_len 100 --batch_size 64 --epochs 20
10+ # Using preprocessed windows
11+ python sample_viz.py --npz ./preproc_final/windows.npz \
12+ --ckpt ./ckpts/timegan_run/best.pt --z-dim 24 --h-dim 64
13+
14+ # Preprocess on-the-fly (same flags as dataset.py)
15+ python sample_viz.py --data-dir /PATH/TO/SESSION --feature-set core \
16+ --seq-len 128 --stride 32 --scaler robust --whiten pca --pca-var 0.999 \
17+ --ckpt ./ckpts/timegan_run/best.pt --z-dim 24 --h-dim 64
1018
1119Created By: Radhesh Goel (Keys-I)
1220ID: s49088276
13-
14- References:
15- -
1621"""
17- # TODO: Implement checkpoint load, sampling, basic stats, and visualisations.
22+ from __future__ import annotations
23+ import os
24+ import argparse
25+ import numpy as np
26+ import matplotlib .pyplot as plt
27+ from typing import Tuple
28+
29+ import torch
30+
31+ # local modules
32+ from modules import TimeGAN , sample_noise
33+ from dataset import LOBSTERData
34+
35+
36+ # ---------------------------
37+ # Data loading helpers
38+ # ---------------------------
39+
40+ def load_windows_npz (npz_path : str ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
41+ d = np .load (npz_path )
42+ return d ["train" ], d ["val" ], d ["test" ]
43+
44+ def load_windows_csv (args ) -> Tuple [np .ndarray , np .ndarray , np .ndarray ]:
45+ ds = LOBSTERData (
46+ data_dir = args .data_dir ,
47+ message_file = args .message ,
48+ orderbook_file = args .orderbook ,
49+ feature_set = args .feature_set ,
50+ seq_len = args .seq_len ,
51+ stride = args .stride ,
52+ splits = tuple (args .splits ),
53+ scaler = args .scaler ,
54+ headerless_message = args .headerless_message ,
55+ headerless_orderbook = args .headerless_orderbook ,
56+ whiten = args .whiten , pca_var = args .pca_var ,
57+ aug_prob = 0.0 , # no aug for visualisation builds
58+ save_dir = None ,
59+ )
60+ return ds .load_arrays ()
61+
62+
63+ # ---------------------------
64+ # Model restore + sampling
65+ # ---------------------------
66+
67+ def build_model_from_ckpt (ckpt_path : str , x_dim : int , z_dim : int , h_dim : int , device : torch .device ) -> TimeGAN :
68+ ckpt = torch .load (ckpt_path , map_location = device )
69+ args_in_ckpt = ckpt .get ("args" , {}) or {}
70+ rnn_type = args_in_ckpt .get ("rnn_type" , "gru" )
71+ enc_layers = int (args_in_ckpt .get ("enc_layers" , 2 ))
72+ dec_layers = int (args_in_ckpt .get ("dec_layers" , 2 ))
73+ gen_layers = int (args_in_ckpt .get ("gen_layers" , 2 ))
74+ sup_layers = int (args_in_ckpt .get ("sup_layers" , 1 ))
75+ dis_layers = int (args_in_ckpt .get ("dis_layers" , 1 ))
76+ dropout = float (args_in_ckpt .get ("dropout" , 0.1 ))
77+
78+ model = TimeGAN (
79+ x_dim = x_dim , z_dim = z_dim , h_dim = h_dim ,
80+ rnn_type = rnn_type , enc_layers = enc_layers , dec_layers = dec_layers ,
81+ gen_layers = gen_layers , sup_layers = sup_layers , dis_layers = dis_layers ,
82+ dropout = dropout
83+ ).to (device )
84+
85+ model .embedder .load_state_dict (ckpt ["embedder" ])
86+ model .recovery .load_state_dict (ckpt ["recovery" ])
87+ model .generator .load_state_dict (ckpt ["generator" ])
88+ model .supervisor .load_state_dict (ckpt ["supervisor" ])
89+ model .discriminator .load_state_dict (ckpt ["discriminator" ])
90+ model .eval ()
91+ return model
92+
93+ @torch .no_grad ()
94+ def sample_synthetic (model : TimeGAN , n_seq : int , seq_len : int , z_dim : int , device : torch .device ) -> np .ndarray :
95+ z = sample_noise (n_seq , seq_len , z_dim , device )
96+ e_tilde = model .generator (z )
97+ h_tilde = model .supervisor (e_tilde )
98+ x_tilde = model .recovery (h_tilde )
99+ return x_tilde .detach ().cpu ().numpy ()
100+
101+
102+ # ---------------------------
103+ # Stats + simple similarity
104+ # ---------------------------
105+
106+ def summarize (name : str , W : np .ndarray ) -> dict :
107+ # mean/std over batch+time, per-feature
108+ mu = W .mean (axis = (0 , 1 ))
109+ sd = W .std (axis = (0 , 1 ))
110+ return {"name" : name , "mean" : mu , "std" : sd }
111+
112+ def kl_hist_avg (real : np .ndarray , synth : np .ndarray , bins : int = 64 , eps : float = 1e-9 ) -> float :
113+ """
114+ Quick histogram-based KL(real || synth) averaged over features.
115+ """
116+ from scipy .special import rel_entr
117+ F = real .shape [2 ]
118+ vals = []
119+ R = real .reshape (- 1 , F )
120+ S = synth .reshape (- 1 , F )
121+ for f in range (F ):
122+ r = R [:, f ]; s = S [:, f ]
123+ lo = np .nanpercentile (np .concatenate ([r , s ]), 0.5 )
124+ hi = np .nanpercentile (np .concatenate ([r , s ]), 99.5 )
125+ if not np .isfinite (lo ) or not np .isfinite (hi ) or hi <= lo :
126+ continue
127+ pr , _ = np .histogram (r , bins = bins , range = (lo , hi ), density = True )
128+ ps , _ = np .histogram (s , bins = bins , range = (lo , hi ), density = True )
129+ pr = pr + eps ; ps = ps + eps
130+ pr = pr / pr .sum (); ps = ps / ps .sum ()
131+ vals .append (np .sum (rel_entr (pr , ps )))
132+ return float (np .mean (vals )) if vals else float ("nan" )
133+
134+
135+ # ---------------------------
136+ # Visualisations
137+ # ---------------------------
138+
139+ def plot_feature_lines (real : np .ndarray , synth : np .ndarray , outdir : str , max_feats : int = 4 , idx : int = 0 ):
140+ """
141+ Plot a few feature time-series (same sequence index) real vs synthetic.
142+ """
143+ os .makedirs (outdir , exist_ok = True )
144+ T , F = real .shape [1 ], real .shape [2 ]
145+ feats = min (F , max_feats )
146+
147+ fig , axes = plt .subplots (feats , 1 , figsize = (10 , 2.2 * feats ), sharex = True )
148+ if feats == 1 :
149+ axes = [axes ]
150+ for i in range (feats ):
151+ axes [i ].plot (real [idx , :, i ], label = "real" , linewidth = 1.2 )
152+ axes [i ].plot (synth [idx , :, i ], label = "synthetic" , linewidth = 1.2 , linestyle = "--" )
153+ axes [i ].set_ylabel (f"feat { i } " )
154+ axes [- 1 ].set_xlabel ("time" )
155+ axes [0 ].legend (loc = "upper right" )
156+ fig .suptitle ("Feature lines: real vs synthetic" )
157+ fig .tight_layout ()
158+ fig .savefig (os .path .join (outdir , "feature_lines.png" ), dpi = 150 )
159+ plt .close (fig )
160+
161+ def plot_heatmaps (real : np .ndarray , synth : np .ndarray , outdir : str , idx : int = 0 ):
162+ """
163+ Plot depth heatmaps (time x features) for a single sequence.
164+ """
165+ os .makedirs (outdir , exist_ok = True )
166+ a = real [idx ]; b = synth [idx ]
167+ # normalize each to [0,1] for visibility
168+ def norm01 (x ):
169+ lo , hi = np .percentile (x , 1 ), np .percentile (x , 99 )
170+ return np .clip ((x - lo ) / (hi - lo + 1e-9 ), 0 , 1 )
171+
172+ a = norm01 (a ); b = norm01 (b )
173+
174+ fig , axes = plt .subplots (1 , 2 , figsize = (12 , 4 ))
175+ im0 = axes [0 ].imshow (a , aspect = "auto" , origin = "lower" )
176+ axes [0 ].set_title ("Real (heatmap)" )
177+ axes [0 ].set_xlabel ("feature" ); axes [0 ].set_ylabel ("time" )
178+ fig .colorbar (im0 , ax = axes [0 ], fraction = 0.046 , pad = 0.04 )
179+
180+ im1 = axes [1 ].imshow (b , aspect = "auto" , origin = "lower" )
181+ axes [1 ].set_title ("Synthetic (heatmap)" )
182+ axes [1 ].set_xlabel ("feature" ); axes [1 ].set_ylabel ("time" )
183+ fig .colorbar (im1 , ax = axes [1 ], fraction = 0.046 , pad = 0.04 )
184+
185+ fig .tight_layout ()
186+ fig .savefig (os .path .join (outdir , "heatmaps.png" ), dpi = 150 )
187+ plt .close (fig )
188+
189+
190+ # ---------------------------
191+ # Main
192+ # ---------------------------
193+
194+ if __name__ == "__main__" :
195+ ap = argparse .ArgumentParser (description = "Sample & visualise TimeGAN outputs vs real." )
196+ # data
197+ ap .add_argument ("--npz" , type = str , help = "Path to windows.npz (train/val/test). If set, ignores --data-dir." )
198+ ap .add_argument ("--data-dir" , type = str , help = "Folder with message_10.csv and orderbook_10.csv" )
199+ ap .add_argument ("--message" , default = "message_10.csv" )
200+ ap .add_argument ("--orderbook" , default = "orderbook_10.csv" )
201+ ap .add_argument ("--feature-set" , choices = ["core" ,"raw10" ], default = "core" )
202+ ap .add_argument ("--seq-len" , type = int , default = 128 )
203+ ap .add_argument ("--stride" , type = int , default = 32 )
204+ ap .add_argument ("--splits" , type = float , nargs = 3 , default = (0.7 ,0.15 ,0.15 ))
205+ ap .add_argument ("--scaler" , choices = ["standard" ,"minmax" ,"robust" ,"quantile" ,"power" ,"none" ], default = "robust" )
206+ ap .add_argument ("--whiten" , choices = ["pca" ,"zca" ,None ], default = "pca" )
207+ ap .add_argument ("--pca-var" , type = float , default = 0.999 )
208+ ap .add_argument ("--headerless-message" , action = "store_true" )
209+ ap .add_argument ("--headerless-orderbook" , action = "store_true" )
210+
211+ # model restore
212+ ap .add_argument ("--ckpt" , type = str , required = True )
213+ ap .add_argument ("--z-dim" , type = int , required = True )
214+ ap .add_argument ("--h-dim" , type = int , required = True )
215+ ap .add_argument ("--device" , type = str , default = "cuda" if torch .cuda .is_available () else "cpu" )
216+
217+ # viz
218+ ap .add_argument ("--n-synth" , type = int , default = 128 , help = "How many synthetic windows to sample." )
219+ ap .add_argument ("--seq-index" , type = int , default = 0 , help = "Which sequence index to plot." )
220+ ap .add_argument ("--max-feats" , type = int , default = 4 , help = "Max features to show in line plot." )
221+ ap .add_argument ("--outdir" , type = str , default = "./viz_out" )
222+
223+ args = ap .parse_args ()
224+ os .makedirs (args .outdir , exist_ok = True )
225+ device = torch .device (args .device )
226+
227+ # Load real windows
228+ if args .npz :
229+ Wtr , Wval , Wte = load_windows_npz (args .npz )
230+ elif args .data_dir :
231+ Wtr , Wval , Wte = load_windows_csv (args )
232+ else :
233+ raise SystemExit ("Provide either --npz or --data-dir" )
234+
235+ # Pick a real reference set (test split)
236+ real = Wte
237+ _ , T , D = real .shape
238+
239+ # Build model & restore
240+ model = build_model_from_ckpt (args .ckpt , x_dim = D , z_dim = args .z_dim , h_dim = args .h_dim , device = device )
241+ model .eval ()
242+
243+ # Sample synthetic
244+ n_synth = min (args .n_synth , len (real ))
245+ synth = sample_synthetic (model , n_synth , T , args .z_dim , device )
246+
247+ # Basic stats
248+ s_real = summarize ("real(test)" , real )
249+ s_synth = summarize ("synthetic" , synth )
250+ print ("=== Summary (per-feature mean/std) ===" )
251+ print (f"{ s_real ['name' ]} : mean[0:5]={ s_real ['mean' ][:5 ]} , std[0:5]={ s_real ['std' ][:5 ]} " )
252+ print (f"{ s_synth ['name' ]} : mean[0:5]={ s_synth ['mean' ][:5 ]} , std[0:5]={ s_synth ['std' ][:5 ]} " )
253+
254+ # Quick KL(hist) similarity
255+ try :
256+ kl = kl_hist_avg (real [:n_synth ], synth )
257+ print (f"KL(real || synth) ~ { kl :.4f} (lower is better)" )
258+ except Exception as e :
259+ print (f"KL computation skipped: { e } " )
260+
261+ # Visualisations
262+ idx = max (0 , min (args .seq_index , n_synth - 1 ))
263+ plot_feature_lines (real , synth , args .outdir , max_feats = args .max_feats , idx = idx )
264+ plot_heatmaps (real , synth , args .outdir , idx = idx )
265+
266+ print (f"Saved plots to: { args .outdir } " )
0 commit comments