Skip to content

Commit 8cd2b76

Browse files
committed
feat(viz): add sampling script to generate and save synthetic LOB data
Parses Options, loads data, restores TimeGAN from checkpoint, generates exactly len(test) rows, and saves to OUTPUT_DIR/gen_data.npy. Keeps API aligned with current dataset/modules helpers.
1 parent 1291868 commit 8cd2b76

File tree

3 files changed

+163
-234
lines changed

3 files changed

+163
-234
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/helpers/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@
2121
), (
2222
f"TRAIN_TEST_SPLIT must sum to 1.0 (got {sum(TRAIN_TEST_SPLIT):.8f})"
2323
)
24+
25+
NUM_LEVELS = 10
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
Generate LOB depth heatmaps and compute SSIM between real vs synthetic images.
3+
Refactored to be faster, cleaner, and compatible with the new modules/utils.
4+
"""
5+
from __future__ import annotations
6+
7+
from pathlib import Path
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
from numpy.typing import NDArray
12+
from skimage import img_as_float
13+
from skimage.metrics import structural_similarity as ssim
14+
15+
from args import Options
16+
from constants import NUM_LEVELS
17+
from src.dataset import load_data
18+
from src.helpers.constants import OUTPUT_DIR
19+
from src.modules import TimeGAN
20+
21+
22+
def get_ssim(img1_path: Path | str, img2_path: Path | str) -> float:
23+
"""
24+
Compute SSIM between two image files.
25+
26+
Uses `channel_axis=2` (new skimage API). Images are read via matplotlib.
27+
"""
28+
img1 = img_as_float(plt.imread(str(img1_path)))
29+
img2 = img_as_float(plt.imread(str(img2_path)))
30+
31+
# if grayscale, add channel axis
32+
if img1.ndim == 2:
33+
img1 = img1[..., None]
34+
if img2.ndim == 2:
35+
img2 = img2[..., None]
36+
return float(ssim(img1, img2, channel_axis=2, data_range=1.0))
37+
38+
39+
def plot_heatmap(
40+
data_2d: NDArray, # shape [T, F]
41+
*,
42+
title: str | None = None,
43+
save_path: Path | str | None = None,
44+
show: bool = True,
45+
dpi: int = 150,
46+
) -> None:
47+
"""
48+
Scatter-based depth heatmap.
49+
50+
Assumes features are interleaved per level: [ask_price, ask_vol, bid_price, bid_vol] x NUM_LEVELS.
51+
Colors: red=ask, blue=bid, alpha encodes relative volume in [0,1].
52+
"""
53+
T, F = data_2d.shape
54+
assert F >= 4 * NUM_LEVELS, "Expected at least 4 features per level"
55+
56+
# slice views
57+
# for each level L: price indices = 4*L + (0 for ask, 2 for bid)
58+
# vol indices = price_idx + 1
59+
prices_ask = np.stack([data_2d[:, 4 * L + 0] for L in range(NUM_LEVELS)], axis=1) # [T, L]
60+
vols_ask = np.stack([data_2d[:, 4 * L + 1] for L in range(NUM_LEVELS)], axis=1) # [T, L]
61+
prices_bid = np.stack([data_2d[:, 4 * L + 2] for L in range(NUM_LEVELS)], axis=1) # [T, L]
62+
vols_bid = np.stack([data_2d[:, 4 * L + 3] for L in range(NUM_LEVELS)], axis=1) # [T, L]
63+
64+
# Normalise volumes for alpha
65+
max_vol = float(np.max([vols_ask.max(initial=0), vols_bid.max(initial=0)])) or 1.0
66+
a_ask = (vols_ask / max_vol).astype(np.float32)
67+
a_bid = (vols_bid / max_vol).astype(np.float32)
68+
69+
# build scatter arrays
70+
# x: time indices repeated for each level
71+
t_idx = np.arange(T, dtype=np.float32)[:, None]
72+
x_ask = np.repeat(t_idx, NUM_LEVELS, axis=1).ravel()
73+
x_bid = x_ask.copy()
74+
y_ask = prices_ask.astype(np.float32).ravel()
75+
y_bid = prices_bid.astype(np.float32).ravel()
76+
77+
# colors rgba
78+
c_ask = np.stack([
79+
np.full_like(y_ask, 0.99), # r
80+
np.full_like(y_ask, 0.05), # g
81+
np.full_like(y_ask, 0.05), # b
82+
a_ask.astype(np.float32).ravel(), # A
83+
], axis=1)
84+
c_bid = np.stack([
85+
np.full_like(y_ask, 0.05), # r
86+
np.full_like(y_ask, 0.05), # g
87+
np.full_like(y_ask, 0.99), # b
88+
a_bid.astype(np.float32).ravel(), # A
89+
], axis=1)
90+
91+
# limits
92+
pmin = float(np.minimum(prices_ask.min(initial=0), prices_bid.min(initial=0)))
93+
pmax = float(np.maximum(prices_ask.max(initial=0), prices_bid.max(initial=0)))
94+
95+
# plot
96+
fig, ax = plt.subplots(figsize=(10, 6), dpi=dpi)
97+
ax.set_ylim(pmin, pmax)
98+
ax.set_xlabel("Time")
99+
ax.set_ylabel("Price")
100+
if title:
101+
ax.set_title(title)
102+
103+
ax.scatter(x_ask, y_ask, c=c_ask)
104+
ax.scatter(x_bid, y_bid, c=c_bid)
105+
106+
fig.tight_layout()
107+
if save_path is not None:
108+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
109+
fig.savefig(str(save_path), bbox_inches="tight")
110+
if show:
111+
plt.show()
112+
plt.close(fig)
113+
114+
if "__main__" == __name__:
115+
# cli
116+
opt = Options().parse()
117+
118+
# data
119+
train, val, test = load_data(opt)
120+
121+
# model (load weights)
122+
model = TimeGAN(opt, train, val, test, load_weights=True)
123+
124+
# real heatmap from test data
125+
real_path = Path(OUTPUT_DIR) / "real.png"
126+
plot_heatmap(test, title="Real LOB Depth", save_path=real_path, show=False)
127+
128+
for i in range(3):
129+
synth = model.generate(num_rows=len(test))
130+
synth_path = Path(OUTPUT_DIR) / f"synthetic_heatmap_{i}.png"
131+
plot_heatmap(synth, title=f"Synthetic LOB Depth #{i}", save_path=synth_path, show=False)
132+
score = get_ssim(real_path, synth_path)
133+
print(f"SSIM(real, synthetic_{i}) = {score:.4f}")

0 commit comments

Comments
 (0)