Skip to content

Commit 3cf8b0c

Browse files
committed
feat(metrics): add min–max scaling/inverse, noise sampler, spread/MPR KL histogram
Introduce utilities for TimeGAN-LOB: extract_seq_lengths, sample_noise (supports RNG + optional mean/std via uniform with matched σ), minmax_scale/minmax_inverse over [N,T,F], and KL(real||fake) via histograms for 'spread' and 'mpr' with smoothing + optional plot. Adds strong shape/type guards, finite-range handling, and safe midprice log-returns.
1 parent 337ff87 commit 3cf8b0c

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/helpers/arg2.py

Whitespace-only changes.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
from typing import Iterable, Literal, Tuple
3+
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
import matplotlib.pyplot as plt
7+
8+
Metric = Literal["spread", "mpr"]
9+
10+
def extract_seq_lengths(
11+
sequences: Iterable[NDArray[np.floating]]
12+
) -> Tuple[NDArray[np.int32], int]:
13+
lengths = np.asarray([int(s.shape[0]) for s in sequences], dtype=np.int32)
14+
return lengths, int(lengths.max(initial=0))
15+
16+
def sample_noise(
17+
batch_size: int,
18+
z_dim: int,
19+
seq_len: int,
20+
*,
21+
mean: float | None = None,
22+
std: float | None = None,
23+
rng: np.random.Generator | None = None,
24+
) -> NDArray[np.float32]:
25+
if rng is None:
26+
rng = np.random.default_rng()
27+
28+
if (mean is None) ^ (std is None):
29+
raise ValueError("Provide both mean and std, or neither")
30+
31+
if mean is None and std is None:
32+
out = rng.random((batch_size, seq_len, z_dim), dtype=np.float32)
33+
else:
34+
interval = float(std) * np.sqrt(12.0)
35+
lo = float(mean) - interval / 2.0
36+
hi = float(mean) + interval / 2.0
37+
out = rng.uniform(lo, hi, size=(batch_size, seq_len, z_dim)).astype(np.float32)
38+
39+
return out
40+
41+
def minmax_scale(
42+
data: NDArray[np.floating],
43+
epsilon: float = 1e-7
44+
)-> Tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:
45+
if data.ndim != 3:
46+
raise ValueError(f"Expected data with 3 dimensions [N, T, F], got shape {data.shape}")
47+
48+
fmin = np.min(data, axis=(0, 1)).astype(np.float32)
49+
fmax = np.max(data, axis=(0, 1)).astype(np.float32)
50+
denom = (fmax - fmin).astype(np.float32)
51+
52+
norm = (data.astype(np.float32) - fmin) / (denom + epsilon)
53+
return norm, fmin, fmax
54+
55+
def minmax_inverse(
56+
norm: NDArray[np.floating],
57+
fmin: NDArray[np.floating],
58+
fmax: NDArray[np.floating],
59+
) -> NDArray[np.float32]:
60+
"""
61+
Inverse of `minmax_scale`.
62+
63+
Args:
64+
norm: scaled data [N,T,F] or [...,F]
65+
fmin: per-feature minima [F]
66+
fmax: per-feature maxima [F]
67+
68+
Returns:
69+
original-scale data, float32
70+
"""
71+
fmin = np.asarray(fmin, dtype=np.float32)
72+
fmax = np.asarray(fmax, dtype=np.float32)
73+
return norm.astype(np.float32) * (fmax - fmin) + fmin
74+
75+
def _spread(series: NDArray[np.floating]) -> NDArray[np.float64]:
76+
"""
77+
Compute spread = best_ask - best_bid from a 2D array [T, F] with
78+
columns: best ask at index 0 and best bid at index 2.
79+
"""
80+
if series.ndim != 2 or series.shape[1] < 3:
81+
raise ValueError("Expected shape [T, >=3]; columns 0 (ask) and 2 (bid) required.")
82+
return (series[:, 0] - series[:, 2]).astype(np.float64)
83+
84+
85+
def _midprice_returns(series: NDArray[np.floating]) -> NDArray[np.float64]:
86+
"""
87+
Compute log midprice returns from a 2D array [T, F] with ask at 0 and bid at 2.
88+
"""
89+
if series.ndim != 2 or series.shape[1] < 3:
90+
raise ValueError("Expected shape [T, >=3]; columns 0 (ask) and 2 (bid) required.")
91+
mid = 0.5 * (series[:, 0] + series[:, 2])
92+
# avoid log(0)
93+
mid = np.clip(mid, a_min=np.finfo(np.float64).tiny, a_max=None)
94+
r = np.log(mid[1:]) - np.log(mid[:-1])
95+
return r.astype(np.float64)
96+
97+
def kl_divergence_hist(
98+
real: NDArray[np.floating],
99+
fake: NDArray[np.floating],
100+
metric: Literal["spread", "mpr"] = "spread",
101+
*,
102+
bins: int = 100,
103+
show_plot: bool = False,
104+
epsilon: float = 1e-12
105+
) -> float:
106+
if real.ndim != 2 or fake.ndim != 2:
107+
raise ValueError("Inputs must be 2D arrays [T, F].")
108+
109+
if metric == "spread":
110+
r_series = _spread(real)
111+
f_series = _spread(fake)
112+
elif metric == "mpr":
113+
r_series = _midprice_returns(real)
114+
f_series = _midprice_returns(fake)
115+
else:
116+
raise ValueError("metric must be 'spread' or 'mpr'.")
117+
118+
lo = float(min(r_series.min(initial=0.0), f_series.min(initial=0.0)))
119+
hi = float(max(r_series.max(initial=0.0), f_series.max(initial=0.0)))
120+
121+
# if degenerate, expand a hair to avoid zero-width bins
122+
if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
123+
hi = lo + 1e-6
124+
125+
r_hist, edges = np.histogram(r_series, bins=bins, range=(lo, hi), density=False)
126+
f_hist, _ = np.histogram(f_series, bins=edges, density=False)
127+
128+
# convert to probability masses with smoothing
129+
r_p = (r_hist.astype(np.float64) + epsilon)
130+
f_p = (f_hist.astype(np.float64) + epsilon)
131+
r_p /= r_p.sum()
132+
f_p /= f_p.sum()
133+
134+
# KL(real || fake) = sum p * log(p/q)
135+
mask = r_p > 0 # should be true after smoothing, but keep for safety
136+
kl = np.sum(r_p[mask] * (np.log(r_p[mask]) - np.log(f_p[mask])))
137+
138+
if show_plot:
139+
centers = 0.5 * (edges[:-1] + edges[1:])
140+
plt.plot(centers, r_p, label="real")
141+
plt.plot(centers, f_p, label="fake")
142+
plt.title(f"Histogram ({metric}); KL={kl:.4g}")
143+
plt.legend()
144+
plt.show()
145+
146+
# numerical guard: KL should be >= 0
147+
return float(max(kl, 0.0))

0 commit comments

Comments
 (0)