Skip to content

Commit dc854da

Browse files
scherkao31Salim Cherkaouisarahmish
authored
Enable Unrestricted LLM Output with Parsing + Add One-Shot Anomaly Detection Support (#39)
* Core Changed to get normal behavior in pipeline * Transformation changed * anomalies.py changed * Hugginface.py changed : no restrictions token, and also has normal as input if 1-shot * Timeseries preprocessing.py * jsons files added for primitives * jsons files added for primitives * pipelines 0shot and 1shot added * add boolean for restrict_tokens in HF * good messages.json for prompt * Added load_normal in sigllm.data * Fixed load_normal in sigllm.data * Fixed lint format * Fixed lint format Ruff * Fixed from review Sarah * Fixed lint format after working on Sarah's reviews * Dataset prompter parameters * .jons removed from input names in 1_shot pipeline.json * .jons removed from input names in 1_shot pipeline.json * fix PR issues & add unittests * add unittests for parse_anomaly_response * remove unused functions * add new functionality tests * update ubuntu image * change normal->single * fix lint * swap normal -> single --------- Co-authored-by: Salim Cherkaoui <[email protected]> Co-authored-by: Sarah Alnegheimish <[email protected]>
1 parent 6c75dc5 commit dc854da

22 files changed

+781
-54
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ on:
1111

1212
jobs:
1313
lint:
14-
runs-on: ubuntu-20.04
14+
runs-on: ubuntu-latest
1515
steps:
1616
- uses: actions/checkout@v4
1717
- name: Set up Python 3.9

sigllm/core.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,14 @@ def __repr__(self):
100100

101101
return ('SigLLM:\n{}\nhyperparameters:\n{}\n').format(pipeline, hyperparameters)
102102

103-
def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> pd.DataFrame:
104-
"""Detect anomalies in the given data..
103+
def detect(
104+
self,
105+
data: pd.DataFrame,
106+
normal: pd.DataFrame = None,
107+
visualization: bool = False,
108+
**kwargs,
109+
) -> pd.DataFrame:
110+
"""Detect anomalies in the given data.
105111
106112
If ``visualization=True``, also return the visualization
107113
outputs from the MLPipeline object.
@@ -110,6 +116,10 @@ def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> p
110116
data (DataFrame):
111117
Input data, passed as a ``pandas.DataFrame`` containing
112118
exactly two columns: timestamp and value.
119+
normal (DataFrame, optional):
120+
Normal reference data for one-shot prompting, passed as a ``pandas.DataFrame``
121+
containing exactly two columns: timestamp and value. If None, zero-shot
122+
prompting is used. Default to None.
113123
visualization (bool):
114124
If ``True``, also capture the ``visualization`` named
115125
output from the ``MLPipeline`` and return it as a second
@@ -125,6 +135,9 @@ def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> p
125135
if not self._fitted:
126136
self._mlpipeline = self._get_mlpipeline()
127137

138+
if normal is not None:
139+
kwargs['normal'] = normal
140+
128141
result = self._detect(self._mlpipeline.fit, data, visualization, **kwargs)
129142
self._fitted = True
130143

sigllm/data.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Data Management module.
4+
5+
This module contains functions that allow downloading demo data from Amazon S3,
6+
as well as load and work with other data stored locally.
7+
"""
8+
9+
import logging
10+
import os
11+
12+
import pandas as pd
13+
from orion.data import format_csv, load_csv
14+
15+
LOGGER = logging.getLogger(__name__)
16+
17+
DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data')
18+
BUCKET = 'sintel-sigllm'
19+
S3_URL = 'https://{}.s3.amazonaws.com/{}'
20+
21+
22+
def download_normal(name, data_path=DATA_PATH):
23+
"""Load the CSV with the given name from S3.
24+
25+
If the CSV has never been loaded before, it will be downloaded
26+
from the [sintel-sigllm bucket](https://sintel-sigllm.s3.amazonaws.com) or
27+
the S3 bucket specified following the `s3://{bucket}/path/to/the.csv` format,
28+
and then cached inside the `data` folder, within the `sigllm` package
29+
directory, and then returned.
30+
31+
Otherwise, if it has been downloaded and cached before, it will be directly
32+
loaded from the `sigllm/data` folder without contacting S3.
33+
34+
Args:
35+
name (str):
36+
Name of the CSV to load.
37+
data_path (str):
38+
Path to store data.
39+
40+
Returns:
41+
pandas.DataFrame:
42+
A pandas.DataFrame is returned containing all the data.
43+
44+
Raises:
45+
FileNotFoundError: If the normal file doesn't exist locally and can't
46+
be downloaded from S3.
47+
"""
48+
try:
49+
url = None
50+
if name.startswith('s3://'):
51+
parts = name[5:].split('/', 1)
52+
bucket = parts[0]
53+
path = parts[1]
54+
url = S3_URL.format(bucket, path)
55+
filename = os.path.join(data_path, path.split('/')[-1])
56+
else:
57+
filename = os.path.join(data_path, name + '_normal.csv')
58+
data_path = os.path.join(data_path, os.path.dirname(name))
59+
60+
if os.path.exists(filename):
61+
data = pd.read_csv(filename)
62+
return data
63+
64+
url = url or S3_URL.format(BUCKET, '{}_normal.csv'.format(name))
65+
LOGGER.info('Downloading CSV %s from %s', name, url)
66+
67+
try:
68+
data = pd.read_csv(url)
69+
os.makedirs(data_path, exist_ok=True)
70+
data.to_csv(filename, index=False)
71+
return data
72+
except Exception:
73+
error_msg = (
74+
f'Could not download or find normal file for {name}. '
75+
f'Please ensure the file exists at {filename} or can be '
76+
f'downloaded from {url}'
77+
)
78+
LOGGER.error(error_msg)
79+
raise FileNotFoundError(error_msg)
80+
81+
except Exception as e:
82+
error_msg = f'Error processing normal file for {name}: {str(e)}'
83+
LOGGER.error(error_msg)
84+
raise FileNotFoundError(error_msg)
85+
86+
87+
def load_normal(name, timestamp_column=None, value_column=None, start=None, end=None):
88+
"""Load normal data from file or download if needed.
89+
90+
Args:
91+
name (str):
92+
Name or path of the normal data.
93+
timestamp_column (str or int):
94+
Column index or name for timestamp.
95+
value_column (str or int):
96+
Column index or name for values.
97+
start (int or timestamp):
98+
Optional. If specified, this will be start of the sub-sequence.
99+
end (int or timestamp):
100+
Optional. If specified, this will be end of the sub-sequence.
101+
102+
Returns:
103+
pandas.DataFrame:
104+
Loaded subsequence with `timestamp` and `value` columns.
105+
"""
106+
if os.path.isfile(name):
107+
data = load_csv(name, timestamp_column, value_column)
108+
else:
109+
data = download_normal(name)
110+
111+
data = format_csv(data)
112+
113+
# handle start or end is specified
114+
if start or end:
115+
if any(data.index.isin([start, end])):
116+
data = data.iloc[start:end]
117+
else:
118+
mask = True
119+
if start is not None:
120+
mask &= data[timestamp_column] >= start
121+
if end is not None:
122+
mask &= data[timestamp_column] <= end
123+
data = data[mask]
124+
125+
return data

sigllm/pipelines/prompter/mistral_prompter.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
},
3232
"sigllm.primitives.prompting.huggingface.HF#1": {
3333
"name": "mistralai/Mistral-7B-Instruct-v0.2",
34-
"samples": 10
34+
"samples": 10,
35+
"restrict_tokens": true
3536
},
3637
"sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": {
3738
"alpha": 0.4
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
{
2+
"primitives": [
3+
"mlstars.custom.timeseries_preprocessing.time_segments_aggregate",
4+
"sklearn.impute.SimpleImputer",
5+
"sigllm.primitives.transformation.Float2Scalar",
6+
"sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences",
7+
"sigllm.primitives.transformation.format_as_string",
8+
9+
"sigllm.primitives.prompting.huggingface.HF",
10+
"sigllm.primitives.transformation.parse_anomaly_response",
11+
"sigllm.primitives.transformation.format_as_integer",
12+
"sigllm.primitives.prompting.anomalies.val2idx",
13+
"sigllm.primitives.prompting.anomalies.find_anomalies_in_windows",
14+
"sigllm.primitives.prompting.anomalies.merge_anomalous_sequences",
15+
"sigllm.primitives.prompting.anomalies.format_anomalies"
16+
],
17+
"init_params": {
18+
"mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": {
19+
"time_column": "timestamp",
20+
"interval": 21600,
21+
"method": "mean"
22+
},
23+
"sigllm.primitives.transformation.Float2Scalar#1": {
24+
"decimal": 2,
25+
"rescale": true
26+
},
27+
"sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences#1": {
28+
"window_size": 100,
29+
"step_size": 40
30+
},
31+
"sigllm.primitives.transformation.format_as_string#1": {
32+
"space": false
33+
},
34+
"sigllm.primitives.prompting.huggingface.HF#1": {
35+
"name": "mistralai/Mistral-7B-Instruct-v0.2",
36+
"samples": 1,
37+
"temp": 0.01
38+
},
39+
"sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": {
40+
"alpha": 0.4
41+
},
42+
"sigllm.primitives.prompting.anomalies.merge_anomalous_sequences#1": {
43+
"beta": 0.5
44+
}
45+
},
46+
"input_names": {
47+
"sigllm.primitives.prompting.huggingface.HF#1": {
48+
"X": "X_str"
49+
},
50+
"sigllm.primitives.transformation.parse_anomaly_response#1": {
51+
"X": "y_hat"
52+
},
53+
"sigllm.primitives.transformation.format_as_integer#1": {
54+
"X": "y_parsed"
55+
}
56+
},
57+
"output_names": {
58+
"mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": {
59+
"index": "timestamp"
60+
},
61+
"sigllm.primitives.transformation.format_as_string#1": {
62+
"X": "X_str"
63+
},
64+
"sigllm.primitives.prompting.huggingface.HF#1": {
65+
"y": "y_hat"
66+
},
67+
"sigllm.primitives.transformation.parse_anomaly_response#1": {
68+
"X": "y_parsed"
69+
},
70+
"sigllm.primitives.transformation.format_as_integer#1": {
71+
"X": "y"
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)