Skip to content

Commit f4ea7f1

Browse files
author
AllenBaranov
committed
Tutorial Notebook + trunc behavior
1 parent 8159622 commit f4ea7f1

File tree

4 files changed

+669
-175
lines changed

4 files changed

+669
-175
lines changed

sigllm/primitives/formatting/json_format.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,28 @@ def _extract_d0_values(self, sample):
7676

7777
def _format_as_integer_legacy(self, X, trunc=None):
7878
"""
79-
Legacy format_as_integer behavior.
79+
Extract d0 values from parsed output.
8080
81-
- If trunc is None: returns all values (full round-trip for validation)
82-
- If trunc is set: extracts only d0 values and truncates (for pipeline)
81+
- trunc=None: return all d0 values (num_windows, num_samples, num_d0_values)
82+
- trunc=int: return 3D array (num_windows, num_samples, trunc)
8383
"""
84-
batch_rows = []
85-
for window in X:
86-
samples = []
87-
for sample in window:
88-
if trunc is None:
89-
tokens = re.findall(r'd\d+:(\d+)', sample)
90-
values = [int(v) for v in tokens]
91-
else:
92-
values = self._extract_d0_values(sample)[:trunc]
93-
samples.append(values)
94-
batch_rows.append(samples)
95-
return np.array(batch_rows, dtype=object)
84+
if trunc is None:
85+
batch_rows = []
86+
for window in X:
87+
samples = []
88+
for sample in window:
89+
samples.append(self._extract_d0_values(sample))
90+
batch_rows.append(samples)
91+
return np.array(batch_rows, dtype=object)
92+
93+
num_windows = len(X)
94+
num_samples = len(X[0]) if num_windows > 0 else 0
95+
result = np.zeros((num_windows, num_samples, trunc), dtype=int)
96+
97+
for i, window in enumerate(X):
98+
for j, sample in enumerate(window):
99+
d0_values = self._extract_d0_values(sample)
100+
for k in range(min(trunc, len(d0_values))):
101+
result[i, j, k] = d0_values[k]
102+
103+
return result

sigllm/primitives/formatting/multivariate_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self, method_name: str, verbose: bool = False, **kwargs):
1010
self.metadata = {}
1111
self.verbose = verbose
1212

13-
if self.method_name != "persistence_control":
13+
if self.method_name != "persistence_control" and self.config.get('trunc', None) == None:
1414
test_multivariate_formatting_validity(self, verbose=verbose)
1515

1616

sigllm/primitives/formatting/utils.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,30 @@ def test_multivariate_formatting_validity(method, verbose=False):
2020
if verbose:
2121
print("Testing multivariate formatting method validity")
2222

23-
#raw_data = create_test_data()[:, 1:]
2423
raw_data = create_test_data().to_numpy()[:, 1:]
2524
windowed_data = np.array([raw_data[i:i+15,:] for i in range(0, len(raw_data)-15, 1)])
2625
data = (1000 * windowed_data).astype(int)
2726
if verbose:
2827
print(data.shape)
2928

30-
# Temporarily disable trunc for validation (we need full round-trip)
31-
original_trunc = method.config.get('trunc')
32-
had_trunc = 'trunc' in method.config
33-
method.config['trunc'] = None
34-
35-
try:
36-
string_data = method.format_as_string(data, **method.config)
37-
LLM_mock_output = np.array(string_data).reshape(-1, 1)
38-
if verbose:
39-
print(LLM_mock_output)
40-
integer_data = method.format_as_integer(LLM_mock_output, **method.config)
41-
if verbose:
42-
print(f"Format as string output: {string_data}")
43-
44-
assert isinstance(string_data, list)
45-
assert isinstance(string_data[0], str)
46-
assert isinstance(integer_data, np.ndarray)
47-
48-
if method.method_name == "univariate_control":
49-
assert np.all(integer_data.flatten() == data[:, :, 0].flatten())
50-
else:
51-
assert np.all(integer_data.flatten() == data.flatten())
52-
53-
print("Validation suite passed")
54-
finally:
55-
# Restore original trunc value
56-
if had_trunc:
57-
method.config['trunc'] = original_trunc
58-
elif 'trunc' in method.config:
59-
del method.config['trunc']
29+
string_data = method.format_as_string(data, **method.config)
30+
LLM_mock_output = np.array(string_data).reshape(-1, 1)
31+
if verbose:
32+
print(f"LLM mock output: {LLM_mock_output}")
33+
integer_data = method.format_as_integer(LLM_mock_output, **method.config)
34+
if verbose:
35+
print(f"Format as string output: {string_data}")
36+
37+
assert isinstance(string_data, list)
38+
assert isinstance(string_data[0], str)
39+
assert isinstance(integer_data, np.ndarray)
40+
41+
if len(integer_data.flatten()) == len(data.flatten()):
42+
assert np.all(integer_data.flatten() == data.flatten())
43+
elif len(integer_data.flatten()) == len(data[:, :, 0].flatten()):
44+
assert np.all(integer_data.flatten() == data[:, :, 0].flatten())
45+
else:
46+
raise ValueError(f"Validation suite failed: Dimensions do not match")
6047

6148

6249

tutorials/pipelines/multivariate-detector-pipeline.ipynb

Lines changed: 627 additions & 128 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)