Skip to content

Commit eb60d72

Browse files
committed
feat(dataset): add DataOptions CLI; robust split handling; logging; fix batch_generator
Introduce DataOptions wrapper with flags (--seq_len, --data_dir, --orderbook_filename, --no_shuffle, --keep_zero_rows, --splits, --log_level). Support ORDERBOOK_DEFAULT/SPLITS_DEFAULT fallbacks; accept proportions or cumulative cutoffs; replace prints with logging; add CLI entrypoint. Fix batch_generator index sampling and time=None handling; return constant T_mb; return windowed splits from load_data.
1 parent bc932cc commit eb60d72

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

recognition/TimeLOB_TimeGAN_49088276/src/dataset.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ class DatasetConfig:
6464
"""
6565
seq_len: int
6666
data_dir: Path = field(default_factory=lambda: Path(DATA_DIR))
67-
filename: str = ORDERBOOK_FILENAME
67+
orderbook_filename: str = ORDERBOOK_FILENAME
6868
splits: Tuple[float, float, float] = TRAIN_TEST_SPLIT
69-
shuffle: bool = True
69+
shuffle_windows: bool = True
7070
dtype: type = np.float32
7171
filter_zero_rows: bool = True
7272

@@ -75,8 +75,8 @@ def from_namespace(cls, arg: Namespace) -> "DatasetConfig":
7575
return cls(
7676
seq_len=getattr(arg, "seq_len", 128),
7777
data_dir=Path(getattr(arg, "data_dir", DATA_DIR)),
78-
filename=getattr(arg, "filename", ORDERBOOK_FILENAME),
79-
shuffle=getattr(arg, "shuffle", True),
78+
orderbook_filename=getattr(arg, "orderbook_filename", ORDERBOOK_FILENAME),
79+
shuffle_windows=getattr(arg, "shuffle_windows", True),
8080
dtype=getattr(arg, "dtype", np.float32),
8181
filter_zero_rows=getattr(arg, "filter_zero_rows", True),
8282
)
@@ -119,7 +119,7 @@ def make_windows(
119119
Window the selected split into shape (num_windows, seq_len, num_features).
120120
"""
121121
data = self._select_split(split)
122-
return self._windowize(data, self.cfg.seq_len, self.cfg.shuffle)
122+
return self._windowize(data, self.cfg.seq_len, self.cfg.shuffle_windows)
123123

124124
def dataset_windowed(
125125
self
@@ -133,7 +133,7 @@ def dataset_windowed(
133133
return train_w, val_w, test_w
134134

135135
def _read_raw(self) -> NDArray[np.int64]:
136-
path = Path(self.cfg.data_dir, self.cfg.filename)
136+
path = Path(self.cfg.data_dir, self.cfg.orderbook_filename)
137137
if not path.exists():
138138
msg = (
139139
f"{path} not found.\n"
@@ -166,6 +166,7 @@ def _split_chronological(self) -> None:
166166
self._train = self._filtered[:t_cutoff]
167167
self._val = self._filtered[t_cutoff:v_cutoff]
168168
self._test = self._filtered[v_cutoff:]
169+
169170
assert all(
170171
len(d) > 5 for d in (self._train, self._val, self._test)
171172
), "Each split must have at least 5 windows."
@@ -186,7 +187,7 @@ def _windowize(
186187
self,
187188
data: NDArray[np.float32],
188189
seq_len: int,
189-
shuffle: bool
190+
shuffle_windows: bool
190191
) -> NDArray[np.float32]:
191192
n_samples, n_features = data.shape
192193
n_windows = n_samples - seq_len + 1
@@ -196,7 +197,7 @@ def _windowize(
196197
out = np.empty((n_windows, seq_len, n_features), dtype=self.cfg.dtype)
197198
for i in range(n_windows):
198199
out[i] = data[i: i + seq_len]
199-
if shuffle:
200+
if shuffle_windows:
200201
np.random.shuffle(out)
201202
return out
202203

@@ -217,13 +218,13 @@ def batch_generator(
217218
if `time` is None, uses a constant length equal to data.shape[1] (seq_len).
218219
"""
219220
n = len(data)
220-
idx = np.random.randint(n)[:batch_size]
221+
idx = np.random.choice(n, size=batch_size, replace=True)
221222
data_mb = data[idx].astype(np.float32)
222223
if time is not None:
223-
T_mb = np.full((batch_size,), data_mb.shape[1], dtype=np.int32)
224+
t_mb = np.full((batch_size,), data_mb.shape[1], dtype=np.int32)
224225
else:
225-
T_mb = time[idx].astype(np.int32)
226-
return data_mb, T_mb
226+
t_mb = time[idx].astype(np.int32)
227+
return data_mb, t_mb
227228

228229

229230
def load_data(arg: Namespace) -> tuple[NDArray[np.float32], NDArray[np.float32], NDArray[np.float32]]:

recognition/TimeLOB_TimeGAN_49088276/src/helpers/arg2.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Options for the entire model
3+
"""
4+
from __future__ import annotations
5+
6+
from argparse import ArgumentParser, Namespace
7+
from typing import Optional
8+
9+
import numpy as np
10+
11+
from src.helpers.constants import DATA_DIR, TRAIN_TEST_SPLIT, ORDERBOOK_FILENAME
12+
13+
try:
14+
# tolerate alternates if present in your helpers
15+
from src.helpers.constants import ORDERBOOK_FILENAME as _OB_ALT
16+
ORDERBOOK_DEFAULT = _OB_ALT
17+
except Exception:
18+
ORDERBOOK_DEFAULT = ORDERBOOK_FILENAME
19+
20+
class DataOptions:
21+
"""
22+
Thin wrapper around argparse that produces a Namespace suitable for DatasetConfig.
23+
Usage:
24+
opts = DataOptions().parse()
25+
train_w, val_w, test_w = load_data(opts)
26+
"""
27+
28+
def __init__(self) -> None:
29+
parser = ArgumentParser(
30+
prog="timeganlob_dataset",
31+
description="Lightweight LOBSTER preprocessing + MinMax scaling",
32+
)
33+
parser.add_argument("--seq-len", type=int, default=128)
34+
parser.add_argument("--data_dir", type=str, default=str(DATA_DIR))
35+
parser.add_argument("--orderbook_filename", type=str, default=ORDERBOOK_FILENAME)
36+
parser.add_argument(
37+
"--no-shuffle",
38+
action="store_true",
39+
help="Disable shuffling of windowed sequences"
40+
)
41+
parser.add_argument(
42+
"--keep_zero_rows",
43+
action="store_true",
44+
help="Do NOT filter rows containing zeros."
45+
)
46+
parser.add_argument(
47+
"--splits",
48+
type=float,
49+
nargs=3,
50+
metavar=("TRAIN", "VAL", "TEST"),
51+
help="Either proportions that sum to ~1.0 or cumulative cutoffs (e.g., 0.6 0.8 1.0).",
52+
default=None,
53+
)
54+
self._parser = parser
55+
56+
def parse(self, argv: Optional[list | str]) -> Namespace:
57+
args = self._parser.parse_args(argv)
58+
59+
ns = Namespace(
60+
seq_len=args.seq_len,
61+
data_dir=args.data_dir,
62+
orderbook_filename=args.orderbook_filename,
63+
splits=tuple(args.splits) if args.splits is not None else TRAIN_TEST_SPLIT,
64+
shuffle_windows=not args.no_shuffle,
65+
dtype=np.float32,
66+
keep_zero_rows=not args.keep_zero_rows,
67+
)
68+
69+
return ns

0 commit comments

Comments
 (0)