Skip to content

Commit c775263

Browse files
authored
Refactor Core Module (#14)
* remove fit func (wip) * refactor core and test example * fix lint * edit markdown text
1 parent 08644a2 commit c775263

File tree

3 files changed

+167
-107
lines changed

3 files changed

+167
-107
lines changed

sigllm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import os
1010

11+
from sigllm.core import SigLLM # noqa: F401
12+
1113
_BASE_PATH = os.path.abspath(os.path.dirname(__file__))
1214
MLBLOCKS_PRIMITIVES = os.path.join(_BASE_PATH, 'primitives', 'jsons')
1315
MLBLOCKS_PIPELINES = tuple([

sigllm/core.py

Lines changed: 86 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
"""
44
Main module.
55
6-
This is an extension to Orion's core module
6+
SigLLM is an extension to Orion's core module
77
"""
8+
import logging
89
from typing import Union
910

11+
import pandas as pd
1012
from mlblocks import MLPipeline
1113
from orion import Orion
1214

13-
from sigllm.primitives.prompting.anomalies import get_anomaly_list_within_seq, str2idx
14-
from sigllm.primitives.prompting.data import sig2str
15+
LOGGER = logging.getLogger(__name__)
16+
17+
INTERVAL_PRIMITIVE = "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1"
18+
DECIMAL_PRIMITIVE = "sigllm.primitives.transformation.Float2Scalar#1"
19+
WINDOW_SIZE_PRIMITIVE = "sigllm.primitives.forecasting.custom.rolling_window_sequences#1"
1520

1621

1722
class SigLLM(Orion):
@@ -28,49 +33,95 @@ class SigLLM(Orion):
2833
* An ``str`` with the name of a registered pipeline.
2934
* An ``MLPipeline`` instance.
3035
* A ``dict`` with an ``MLPipeline`` specification.
36+
interval (int):
37+
Number of time points between one sample and another.
38+
decimal (int):
39+
Number of decimal points to keep from the float representation.
3140
window_size (int):
3241
Size of the input window.
33-
steps (int):
34-
Number of steps ahead to forecast.
35-
3642
hyperparameters (dict):
3743
Additional hyperparameters to set to the Pipeline.
3844
"""
45+
DEFAULT_PIPELINE = 'mistral_detector'
46+
47+
def _augment_hyperparameters(self, primitive, key, value):
48+
if self._hyperparameters is None:
49+
self._hyperparameters = {
50+
primitive: {}
51+
}
52+
else:
53+
if primitive not in self._hyperparameters:
54+
self._hyperparameters[primitive] = {}
55+
56+
if value:
57+
self._hyperparameters[primitive][key] = value
3958

40-
def __init__(self, pipeline: Union[str, dict, MLPipeline] = None,
41-
hyperparameters: dict = None):
59+
def __init__(self, pipeline: Union[str, dict, MLPipeline] = None, interval: int = None,
60+
decimal: int = None, window_size: int = None, hyperparameters: dict = None):
4261
self._pipeline = pipeline or self.DEFAULT_PIPELINE
4362
self._hyperparameters = hyperparameters
4463
self._mlpipeline = self._get_mlpipeline()
4564
self._fitted = False
4665

66+
self.interval = interval
67+
self.decimal = decimal
68+
self.window_size = window_size
4769

48-
def get_anomalies(seq, msg_func, model_func, num_iters=1, alpha=0.5):
49-
"""Get LLM anomaly detection results.
70+
self._augment_hyperparameters(INTERVAL_PRIMITIVE, 'interval', interval)
71+
self._augment_hyperparameters(DECIMAL_PRIMITIVE, 'decimal', decimal)
72+
self._augment_hyperparameters(WINDOW_SIZE_PRIMITIVE, 'window_size', window_size)
5073

51-
The function get the LLM's anomaly detection and converts them into an 1D array
74+
def __repr__(self):
75+
if isinstance(self._pipeline, MLPipeline):
76+
pipeline = '\n'.join(
77+
' {}'.format(primitive) for primitive in self._pipeline.to_dict()['primitives'])
5278

53-
Args:
54-
seq (ndarray):
55-
The sequence to detect anomalies.
56-
msg_func (func):
57-
Function to create message prompt.
58-
model_func (func):
59-
Function to get LLM answer.
60-
num_iters (int):
61-
Number of times to run the same query.
62-
alpha (float):
63-
Percentage of total number of votes that an index needs to have to be
64-
considered anomalous. Default: 0.5
65-
66-
Returns:
67-
ndarray:
68-
1D array containing anomalous indices of the sequence.
69-
"""
70-
message = msg_func(sig2str(seq, space=True))
71-
res_list = []
72-
for i in range(num_iters):
73-
res = model_func(message)
74-
ano_ind = str2idx(res, len(seq))
75-
res_list.append(ano_ind)
76-
return get_anomaly_list_within_seq(res_list, alpha=alpha)
79+
elif isinstance(self._pipeline, dict):
80+
pipeline = '\n'.join(
81+
' {}'.format(primitive) for primitive in self._pipeline['primitives'])
82+
83+
else:
84+
pipeline = ' {}'.format(self._pipeline)
85+
86+
hyperparameters = None
87+
if self._hyperparameters is not None:
88+
hyperparameters = '\n'.join(
89+
' {}: {}'.format(step, value) for step, value in self._hyperparameters.items())
90+
91+
return (
92+
'SigLLM:\n{}\n'
93+
'hyperparameters:\n{}\n'
94+
).format(
95+
pipeline,
96+
hyperparameters
97+
)
98+
99+
def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> pd.DataFrame:
100+
"""Detect anomalies in the given data..
101+
102+
If ``visualization=True``, also return the visualization
103+
outputs from the MLPipeline object.
104+
105+
Args:
106+
data (DataFrame):
107+
Input data, passed as a ``pandas.DataFrame`` containing
108+
exactly two columns: timestamp and value.
109+
visualization (bool):
110+
If ``True``, also capture the ``visualization`` named
111+
output from the ``MLPipeline`` and return it as a second
112+
output.
113+
114+
Returns:
115+
DataFrame or tuple:
116+
If visualization is ``False``, it returns the events
117+
DataFrame. If visualization is ``True``, it returns a
118+
tuple containing the events DataFrame followed by the
119+
visualization outputs dict.
120+
"""
121+
if not self._fitted:
122+
self._mlpipeline = self._get_mlpipeline()
123+
124+
result = self._detect(self._mlpipeline.fit, data, visualization, **kwargs)
125+
self._fitted = True
126+
127+
return result

tutorials/Simple Time Series Example.ipynb

Lines changed: 79 additions & 72 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)