1+ from __future__ import annotations
2+ from typing import Iterable , Literal , Tuple
3+
4+ import numpy as np
5+ from numpy .typing import NDArray
6+ import matplotlib .pyplot as plt
7+
8+ Metric = Literal ["spread" , "mpr" ]
9+
10+ def extract_seq_lengths (
11+ sequences : Iterable [NDArray [np .floating ]]
12+ ) -> Tuple [NDArray [np .int32 ], int ]:
13+ lengths = np .asarray ([int (s .shape [0 ]) for s in sequences ], dtype = np .int32 )
14+ return lengths , int (lengths .max (initial = 0 ))
15+
16+ def sample_noise (
17+ batch_size : int ,
18+ z_dim : int ,
19+ seq_len : int ,
20+ * ,
21+ mean : float | None = None ,
22+ std : float | None = None ,
23+ rng : np .random .Generator | None = None ,
24+ ) -> NDArray [np .float32 ]:
25+ if rng is None :
26+ rng = np .random .default_rng ()
27+
28+ if (mean is None ) ^ (std is None ):
29+ raise ValueError ("Provide both mean and std, or neither" )
30+
31+ if mean is None and std is None :
32+ out = rng .random ((batch_size , seq_len , z_dim ), dtype = np .float32 )
33+ else :
34+ interval = float (std ) * np .sqrt (12.0 )
35+ lo = float (mean ) - interval / 2.0
36+ hi = float (mean ) + interval / 2.0
37+ out = rng .uniform (lo , hi , size = (batch_size , seq_len , z_dim )).astype (np .float32 )
38+
39+ return out
40+
41+ def minmax_scale (
42+ data : NDArray [np .floating ],
43+ epsilon : float = 1e-7
44+ )-> Tuple [NDArray [np .float32 ], NDArray [np .float32 ], NDArray [np .float32 ]]:
45+ if data .ndim != 3 :
46+ raise ValueError (f"Expected data with 3 dimensions [N, T, F], got shape { data .shape } " )
47+
48+ fmin = np .min (data , axis = (0 , 1 )).astype (np .float32 )
49+ fmax = np .max (data , axis = (0 , 1 )).astype (np .float32 )
50+ denom = (fmax - fmin ).astype (np .float32 )
51+
52+ norm = (data .astype (np .float32 ) - fmin ) / (denom + epsilon )
53+ return norm , fmin , fmax
54+
55+ def minmax_inverse (
56+ norm : NDArray [np .floating ],
57+ fmin : NDArray [np .floating ],
58+ fmax : NDArray [np .floating ],
59+ ) -> NDArray [np .float32 ]:
60+ """
61+ Inverse of `minmax_scale`.
62+
63+ Args:
64+ norm: scaled data [N,T,F] or [...,F]
65+ fmin: per-feature minima [F]
66+ fmax: per-feature maxima [F]
67+
68+ Returns:
69+ original-scale data, float32
70+ """
71+ fmin = np .asarray (fmin , dtype = np .float32 )
72+ fmax = np .asarray (fmax , dtype = np .float32 )
73+ return norm .astype (np .float32 ) * (fmax - fmin ) + fmin
74+
75+ def _spread (series : NDArray [np .floating ]) -> NDArray [np .float64 ]:
76+ """
77+ Compute spread = best_ask - best_bid from a 2D array [T, F] with
78+ columns: best ask at index 0 and best bid at index 2.
79+ """
80+ if series .ndim != 2 or series .shape [1 ] < 3 :
81+ raise ValueError ("Expected shape [T, >=3]; columns 0 (ask) and 2 (bid) required." )
82+ return (series [:, 0 ] - series [:, 2 ]).astype (np .float64 )
83+
84+
85+ def _midprice_returns (series : NDArray [np .floating ]) -> NDArray [np .float64 ]:
86+ """
87+ Compute log midprice returns from a 2D array [T, F] with ask at 0 and bid at 2.
88+ """
89+ if series .ndim != 2 or series .shape [1 ] < 3 :
90+ raise ValueError ("Expected shape [T, >=3]; columns 0 (ask) and 2 (bid) required." )
91+ mid = 0.5 * (series [:, 0 ] + series [:, 2 ])
92+ # avoid log(0)
93+ mid = np .clip (mid , a_min = np .finfo (np .float64 ).tiny , a_max = None )
94+ r = np .log (mid [1 :]) - np .log (mid [:- 1 ])
95+ return r .astype (np .float64 )
96+
97+ def kl_divergence_hist (
98+ real : NDArray [np .floating ],
99+ fake : NDArray [np .floating ],
100+ metric : Literal ["spread" , "mpr" ] = "spread" ,
101+ * ,
102+ bins : int = 100 ,
103+ show_plot : bool = False ,
104+ epsilon : float = 1e-12
105+ ) -> float :
106+ if real .ndim != 2 or fake .ndim != 2 :
107+ raise ValueError ("Inputs must be 2D arrays [T, F]." )
108+
109+ if metric == "spread" :
110+ r_series = _spread (real )
111+ f_series = _spread (fake )
112+ elif metric == "mpr" :
113+ r_series = _midprice_returns (real )
114+ f_series = _midprice_returns (fake )
115+ else :
116+ raise ValueError ("metric must be 'spread' or 'mpr'." )
117+
118+ lo = float (min (r_series .min (initial = 0.0 ), f_series .min (initial = 0.0 )))
119+ hi = float (max (r_series .max (initial = 0.0 ), f_series .max (initial = 0.0 )))
120+
121+ # if degenerate, expand a hair to avoid zero-width bins
122+ if not np .isfinite (lo ) or not np .isfinite (hi ) or hi <= lo :
123+ hi = lo + 1e-6
124+
125+ r_hist , edges = np .histogram (r_series , bins = bins , range = (lo , hi ), density = False )
126+ f_hist , _ = np .histogram (f_series , bins = edges , density = False )
127+
128+ # convert to probability masses with smoothing
129+ r_p = (r_hist .astype (np .float64 ) + epsilon )
130+ f_p = (f_hist .astype (np .float64 ) + epsilon )
131+ r_p /= r_p .sum ()
132+ f_p /= f_p .sum ()
133+
134+ # KL(real || fake) = sum p * log(p/q)
135+ mask = r_p > 0 # should be true after smoothing, but keep for safety
136+ kl = np .sum (r_p [mask ] * (np .log (r_p [mask ]) - np .log (f_p [mask ])))
137+
138+ if show_plot :
139+ centers = 0.5 * (edges [:- 1 ] + edges [1 :])
140+ plt .plot (centers , r_p , label = "real" )
141+ plt .plot (centers , f_p , label = "fake" )
142+ plt .title (f"Histogram ({ metric } ); KL={ kl :.4g} " )
143+ plt .legend ()
144+ plt .show ()
145+
146+ # numerical guard: KL should be >= 0
147+ return float (max (kl , 0.0 ))
0 commit comments