Skip to content

Commit 08644a2

Browse files
authored
Add a Quick Example (#12)
* fix primitive setup * fix lint * add helper functions * reorganize folder * fix lint * correct import order * Revert "correct import order" This reverts commit 6389c55. * fix import order * pause ubuntu/windows lint test * add error catch + format json * fix minor issues with dimension * add simple example * update notebook * update core
1 parent 72532a8 commit 08644a2

File tree

8 files changed

+376
-117
lines changed

8 files changed

+376
-117
lines changed

sigllm/pipelines/detector/mistral_detector.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"sklearn.impute.SimpleImputer",
55
"sigllm.primitives.transformation.Float2Scalar",
66
"sigllm.primitives.forecasting.custom.rolling_window_sequences",
7-
"sigllm.primitives.transformation.format_as_string",
7+
"sigllm.primitives.transformation.format_as_string",
88
"sigllm.primitives.forecasting.huggingface.HF",
99
"sigllm.primitives.transformation.format_as_integer",
1010
"sigllm.primitives.transformation.Scalar2Float",
@@ -36,7 +36,8 @@
3636
"steps": 5
3737
},
3838
"sigllm.primitives.transformation.format_as_integer#1": {
39-
"trunc": 1
39+
"trunc": 1,
40+
"errors": "coerce"
4041
},
4142
"sigllm.primitives.postprocessing.aggregate_rolling_window#1": {
4243
"agg": "median"
@@ -60,9 +61,9 @@
6061
"sigllm.primitives.postprocessing.aggregate_rolling_window#1": {
6162
"y": "y_hat"
6263
},
63-
"numpy.reshape#1": {
64-
"X": "y_hat"
65-
},
64+
"numpy.reshape#1": {
65+
"X": "y_hat"
66+
},
6667
"orion.primitives.timeseries_anomalies.find_anomalies#1": {
6768
"index": "target_index"
6869
}

sigllm/primitives/forecasting/huggingface.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ def __init__(self, name=DEFAULT_MODEL, sep=',', steps=1, temp=1, top_p=1,
5353
self.raw = raw
5454
self.samples = samples
5555
self.padding = padding
56-
self.max_tokens = None
57-
self.input_length = None
5856

5957
self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_fast=False)
6058

@@ -105,21 +103,19 @@ def forecast(self, X, **kwargs):
105103
"""
106104
all_responses, all_probs = [], []
107105
for text in tqdm(X):
108-
x = text.flatten().tolist()
109106
tokenized_input = self.tokenizer(
110-
x,
107+
[text],
111108
return_tensors="pt"
112109
).to("cuda")
113110

114-
if self.max_tokens is None or self.input_length is None:
115-
self.input_length = tokenized_input['input_ids'].shape[1]
116-
average_length = self.input_length / len(x[0].split(','))
117-
self.max_tokens = (average_length + self.padding) * self.steps
111+
input_length = tokenized_input['input_ids'].shape[1]
112+
average_length = input_length / len(text.split(','))
113+
max_tokens = (average_length + self.padding) * self.steps
118114

119115
generate_ids = self.model.generate(
120116
**tokenized_input,
121117
do_sample=True,
122-
max_new_tokens=self.max_tokens,
118+
max_new_tokens=max_tokens,
123119
temperature=self.temp,
124120
top_p=self.top_p,
125121
bad_words_ids=self.invalid_tokens,
@@ -128,7 +124,7 @@ def forecast(self, X, **kwargs):
128124
)
129125

130126
responses = self.tokenizer.batch_decode(
131-
generate_ids[:, self.input_length:],
127+
generate_ids[:, input_length:],
132128
skip_special_tokens=True,
133129
clean_up_tokenization_spaces=False
134130
)

sigllm/primitives/jsons/sigllm.primitives.transformation.Float2Scalar.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
{
3737
"name": "minimum",
3838
"type": "float"
39+
},
40+
{
41+
"name": "decimal",
42+
"type": "int"
3943
}
4044
]
4145
},

sigllm/primitives/jsons/sigllm.primitives.transformation.Scalar2Float.json

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
"name": "minimum",
2323
"type": "float",
2424
"default": 0
25+
},
26+
{
27+
"name": "decimal",
28+
"type": "int",
29+
"default": 2
2530
}
2631
],
2732
"output": [
@@ -30,13 +35,5 @@
3035
"type": "ndarray"
3136
}
3237
]
33-
},
34-
"hyperparameters": {
35-
"fixed": {
36-
"decimal": {
37-
"type": "int",
38-
"default": 2
39-
}
40-
}
4138
}
4239
}

sigllm/primitives/transformation.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@ def format_as_string(X, sep=',', space=False):
2727
A list of string representation of each row.
2828
"""
2929
def _as_string(x):
30-
text = sep.join(list(map(str, x)))
30+
text = sep.join(list(map(str, x.flatten())))
3131

3232
if space:
3333
text = ' '.join(text)
3434

3535
return text
3636

37-
return np.apply_along_axis(_as_string, axis=1, arr=X)
37+
results = list(map(_as_string, X))
38+
39+
return np.array(results)
3840

3941

4042
def _from_string_to_integer(text, sep=',', trunc=None, errors='ignore'):
@@ -147,7 +149,7 @@ def transform(self, X):
147149

148150
values = sign * (values * 10**self.decimal).astype(int)
149151

150-
return values, self.minimum
152+
return values, self.minimum, self.decimal
151153

152154

153155
class Scalar2Float:
@@ -160,14 +162,13 @@ class Scalar2Float:
160162
105, 200, 310, 483, 500, 0 -> 1.05, 2., 3.1, 4.8342, 5, 0
161163
162164
Args:
165+
minimum (float):
166+
Bias to shift the data. Captured from Float2Scalar.
163167
decimal (int):
164168
Number of decimal points to keep from the float representation. Default to `2`.
165169
"""
166170

167-
def __init__(self, decimal=2):
168-
self.decimal = decimal
169-
170-
def transform(self, X, minimum=0):
171-
values = X * 10**(-self.decimal)
171+
def transform(self, X, minimum=0, decimal=2):
172+
values = X * 10**(-decimal)
172173

173174
return values + minimum

tests/primitives/test_transformation.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_transform_default(self):
230230
print(converter)
231231

232232
converter.fit(data)
233-
output, minimum = converter.transform(data)
233+
output, minimum, decimal = converter.transform(data)
234234

235235
assert converter.decimal == 2
236236
assert converter.rescale is True
@@ -249,7 +249,7 @@ def test_transform_decimal_zero(self):
249249
])
250250

251251
converter.fit(data)
252-
output, minimum = converter.transform(data)
252+
output, minimum, decimal = converter.transform(data)
253253

254254
assert converter.decimal == 0
255255
assert converter.rescale is True
@@ -268,7 +268,7 @@ def test_transform_minimum_not_zero(self):
268268
])
269269

270270
converter.fit(data)
271-
output, minimum = converter.transform(data)
271+
output, minimum, decimal = converter.transform(data)
272272

273273
assert converter.decimal == 2
274274
assert converter.rescale is True
@@ -287,7 +287,7 @@ def test_transform_rescale_false(self):
287287
])
288288

289289
converter.fit(data)
290-
output, minimum = converter.transform(data)
290+
output, minimum, decimal = converter.transform(data)
291291

292292
assert converter.decimal == 2
293293
assert converter.rescale is False
@@ -306,7 +306,7 @@ def test_transform_negative(self):
306306
])
307307

308308
converter.fit(data)
309-
output, minimum = converter.transform(data)
309+
output, minimum, decimal = converter.transform(data)
310310

311311
assert converter.decimal == 2
312312
assert converter.rescale is True
@@ -325,7 +325,7 @@ def test_transform_fit_different(self):
325325
])
326326

327327
converter.fit([7, 3, 0.5])
328-
output, minimum = converter.transform(data)
328+
output, minimum, decimal = converter.transform(data)
329329

330330
assert converter.decimal == 2
331331
assert converter.rescale is True
@@ -348,12 +348,10 @@ def test_transform_default(self):
348348

349349
output = converter.transform(data)
350350

351-
assert converter.decimal == 2
352-
353351
np.testing.assert_array_equal(output, expected)
354352

355353
def test_transform_decimal_zero(self):
356-
converter = Scalar2Float(decimal=0)
354+
converter = Scalar2Float()
357355

358356
data = np.array([
359357
1, 2, 3, 4, 5, 0
@@ -362,9 +360,7 @@ def test_transform_decimal_zero(self):
362360
1., 2., 3., 4., 5., 0.
363361
])
364362

365-
output = converter.transform(data)
366-
367-
assert converter.decimal == 0
363+
output = converter.transform(data, decimal=0)
368364

369365
np.testing.assert_array_equal(output, expected)
370366

@@ -380,8 +376,6 @@ def test_transform_minimum_not_zero(self):
380376

381377
output = converter.transform(data, minimum=-1)
382378

383-
assert converter.decimal == 2
384-
385379
np.testing.assert_allclose(output, expected)
386380

387381

@@ -400,10 +394,10 @@ def test_float2scalar_scalar2float_integration():
400394
])
401395

402396
float2scalar.fit(data)
403-
transformed, minimum = float2scalar.transform(data)
397+
transformed, minimum, decimal = float2scalar.transform(data)
404398

405-
scalar2float = Scalar2Float(decimal)
399+
scalar2float = Scalar2Float()
406400

407-
output = scalar2float.transform(transformed, minimum)
401+
output = scalar2float.transform(transformed, minimum, decimal)
408402

409403
np.testing.assert_allclose(output, expected, rtol=1e-2)

tutorials/Simple Time Series Example.ipynb

Lines changed: 279 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)