diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6ee5c65..6f7940e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,7 +11,7 @@ on: jobs: lint: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python 3.9 diff --git a/sigllm/core.py b/sigllm/core.py index 3e55407..0008002 100644 --- a/sigllm/core.py +++ b/sigllm/core.py @@ -100,8 +100,14 @@ def __repr__(self): return ('SigLLM:\n{}\nhyperparameters:\n{}\n').format(pipeline, hyperparameters) - def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> pd.DataFrame: - """Detect anomalies in the given data.. + def detect( + self, + data: pd.DataFrame, + normal: pd.DataFrame = None, + visualization: bool = False, + **kwargs, + ) -> pd.DataFrame: + """Detect anomalies in the given data. If ``visualization=True``, also return the visualization outputs from the MLPipeline object. @@ -110,6 +116,10 @@ def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> p data (DataFrame): Input data, passed as a ``pandas.DataFrame`` containing exactly two columns: timestamp and value. + normal (DataFrame, optional): + Normal reference data for one-shot prompting, passed as a ``pandas.DataFrame`` + containing exactly two columns: timestamp and value. If None, zero-shot + prompting is used. Default to None. visualization (bool): If ``True``, also capture the ``visualization`` named output from the ``MLPipeline`` and return it as a second @@ -125,6 +135,9 @@ def detect(self, data: pd.DataFrame, visualization: bool = False, **kwargs) -> p if not self._fitted: self._mlpipeline = self._get_mlpipeline() + if normal is not None: + kwargs['normal'] = normal + result = self._detect(self._mlpipeline.fit, data, visualization, **kwargs) self._fitted = True diff --git a/sigllm/data.py b/sigllm/data.py new file mode 100644 index 0000000..ac89b25 --- /dev/null +++ b/sigllm/data.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +"""Data Management module. + +This module contains functions that allow downloading demo data from Amazon S3, +as well as load and work with other data stored locally. +""" + +import logging +import os + +import pandas as pd +from orion.data import format_csv, load_csv + +LOGGER = logging.getLogger(__name__) + +DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data') +BUCKET = 'sintel-sigllm' +S3_URL = 'https://{}.s3.amazonaws.com/{}' + + +def download_normal(name, data_path=DATA_PATH): + """Load the CSV with the given name from S3. + + If the CSV has never been loaded before, it will be downloaded + from the [sintel-sigllm bucket](https://sintel-sigllm.s3.amazonaws.com) or + the S3 bucket specified following the `s3://{bucket}/path/to/the.csv` format, + and then cached inside the `data` folder, within the `sigllm` package + directory, and then returned. + + Otherwise, if it has been downloaded and cached before, it will be directly + loaded from the `sigllm/data` folder without contacting S3. + + Args: + name (str): + Name of the CSV to load. + data_path (str): + Path to store data. + + Returns: + pandas.DataFrame: + A pandas.DataFrame is returned containing all the data. + + Raises: + FileNotFoundError: If the normal file doesn't exist locally and can't + be downloaded from S3. + """ + try: + url = None + if name.startswith('s3://'): + parts = name[5:].split('/', 1) + bucket = parts[0] + path = parts[1] + url = S3_URL.format(bucket, path) + filename = os.path.join(data_path, path.split('/')[-1]) + else: + filename = os.path.join(data_path, name + '_normal.csv') + data_path = os.path.join(data_path, os.path.dirname(name)) + + if os.path.exists(filename): + data = pd.read_csv(filename) + return data + + url = url or S3_URL.format(BUCKET, '{}_normal.csv'.format(name)) + LOGGER.info('Downloading CSV %s from %s', name, url) + + try: + data = pd.read_csv(url) + os.makedirs(data_path, exist_ok=True) + data.to_csv(filename, index=False) + return data + except Exception: + error_msg = ( + f'Could not download or find normal file for {name}. ' + f'Please ensure the file exists at {filename} or can be ' + f'downloaded from {url}' + ) + LOGGER.error(error_msg) + raise FileNotFoundError(error_msg) + + except Exception as e: + error_msg = f'Error processing normal file for {name}: {str(e)}' + LOGGER.error(error_msg) + raise FileNotFoundError(error_msg) + + +def load_normal(name, timestamp_column=None, value_column=None, start=None, end=None): + """Load normal data from file or download if needed. + + Args: + name (str): + Name or path of the normal data. + timestamp_column (str or int): + Column index or name for timestamp. + value_column (str or int): + Column index or name for values. + start (int or timestamp): + Optional. If specified, this will be start of the sub-sequence. + end (int or timestamp): + Optional. If specified, this will be end of the sub-sequence. + + Returns: + pandas.DataFrame: + Loaded subsequence with `timestamp` and `value` columns. + """ + if os.path.isfile(name): + data = load_csv(name, timestamp_column, value_column) + else: + data = download_normal(name) + + data = format_csv(data) + + # handle start or end is specified + if start or end: + if any(data.index.isin([start, end])): + data = data.iloc[start:end] + else: + mask = True + if start is not None: + mask &= data[timestamp_column] >= start + if end is not None: + mask &= data[timestamp_column] <= end + data = data[mask] + + return data diff --git a/sigllm/pipelines/prompter/mistral_prompter.json b/sigllm/pipelines/prompter/mistral_prompter.json index 0bc3e10..a1a5bb7 100644 --- a/sigllm/pipelines/prompter/mistral_prompter.json +++ b/sigllm/pipelines/prompter/mistral_prompter.json @@ -31,7 +31,8 @@ }, "sigllm.primitives.prompting.huggingface.HF#1": { "name": "mistralai/Mistral-7B-Instruct-v0.2", - "samples": 10 + "samples": 10, + "restrict_tokens": true }, "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": { "alpha": 0.4 diff --git a/sigllm/pipelines/prompter/mistral_prompter_0shot.json b/sigllm/pipelines/prompter/mistral_prompter_0shot.json new file mode 100644 index 0000000..40188e0 --- /dev/null +++ b/sigllm/pipelines/prompter/mistral_prompter_0shot.json @@ -0,0 +1,74 @@ +{ + "primitives": [ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate", + "sklearn.impute.SimpleImputer", + "sigllm.primitives.transformation.Float2Scalar", + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences", + "sigllm.primitives.transformation.format_as_string", + + "sigllm.primitives.prompting.huggingface.HF", + "sigllm.primitives.transformation.parse_anomaly_response", + "sigllm.primitives.transformation.format_as_integer", + "sigllm.primitives.prompting.anomalies.val2idx", + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows", + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences", + "sigllm.primitives.prompting.anomalies.format_anomalies" + ], + "init_params": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 21600, + "method": "mean" + }, + "sigllm.primitives.transformation.Float2Scalar#1": { + "decimal": 2, + "rescale": true + }, + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences#1": { + "window_size": 100, + "step_size": 40 + }, + "sigllm.primitives.transformation.format_as_string#1": { + "space": false + }, + "sigllm.primitives.prompting.huggingface.HF#1": { + "name": "mistralai/Mistral-7B-Instruct-v0.2", + "samples": 1, + "temp": 0.01 + }, + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": { + "alpha": 0.4 + }, + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences#1": { + "beta": 0.5 + } + }, + "input_names": { + "sigllm.primitives.prompting.huggingface.HF#1": { + "X": "X_str" + }, + "sigllm.primitives.transformation.parse_anomaly_response#1": { + "X": "y_hat" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y_parsed" + } + }, + "output_names": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "index": "timestamp" + }, + "sigllm.primitives.transformation.format_as_string#1": { + "X": "X_str" + }, + "sigllm.primitives.prompting.huggingface.HF#1": { + "y": "y_hat" + }, + "sigllm.primitives.transformation.parse_anomaly_response#1": { + "X": "y_parsed" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y" + } + } +} \ No newline at end of file diff --git a/sigllm/pipelines/prompter/mistral_prompter_1shot.json b/sigllm/pipelines/prompter/mistral_prompter_1shot.json new file mode 100644 index 0000000..62dc8ce --- /dev/null +++ b/sigllm/pipelines/prompter/mistral_prompter_1shot.json @@ -0,0 +1,174 @@ +{ + "primitives": [ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate", + "sklearn.impute.SimpleImputer", + "sigllm.primitives.transformation.Float2Scalar", + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences", + "sigllm.primitives.transformation.format_as_string", + + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate", + "sklearn.impute.SimpleImputer", + "sigllm.primitives.transformation.Float2Scalar", + "sigllm.primitives.transformation.format_as_string", + + "sigllm.primitives.prompting.huggingface.HF", + "sigllm.primitives.transformation.parse_anomaly_response", + "sigllm.primitives.transformation.format_as_integer", + "sigllm.primitives.prompting.anomalies.val2idx", + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows", + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences", + "sigllm.primitives.prompting.anomalies.format_anomalies" + ], + "init_params": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 21600, + "method": "mean" + }, + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#2": { + "time_column": "normal_timestamp", + "interval": 21600, + "method": "mean" + }, + "sigllm.primitives.transformation.Float2Scalar#1": { + "decimal": 2, + "rescale": true + }, + "sigllm.primitives.transformation.Float2Scalar#2": { + "decimal": 2, + "rescale": true + }, + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences#1": { + "window_size": 200, + "step_size": 40 + }, + "sigllm.primitives.transformation.format_as_string#1": { + "space": false + }, + "sigllm.primitives.transformation.format_as_string#2": { + "space": false, + "single": true + }, + "sigllm.primitives.prompting.huggingface.HF#1": { + "name": "mistralai/Mistral-7B-Instruct-v0.2", + "anomalous_percent": 0.5, + "samples": 1, + "temp": 0.01 + }, + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": { + "alpha": 0.4 + }, + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences#1": { + "beta": 0.5 + } + }, + "input_names": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "X": "X", + "timestamp": "timestamp" + }, + "sklearn.impute.SimpleImputer#1": { + "X": "X_processed" + }, + "sigllm.primitives.transformation.Float2Scalar#1": { + "X": "X_imputed" + }, + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences#1": { + "X": "X_scalar" + }, + "sigllm.primitives.transformation.format_as_string#1": { + "X": "X_sequences" + }, + + + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#2": { + "X": "normal", + "timestamp": "normal_timestamp" + }, + "sklearn.impute.SimpleImputer#2": { + "X": "normal_processed" + }, + "sigllm.primitives.transformation.Float2Scalar#2": { + "X": "normal_imputed" + }, + "sigllm.primitives.transformation.format_as_string#2": { + "X": "normal_scalar" + }, + + + "sigllm.primitives.prompting.huggingface.HF#1": { + "X": "X_str", + "normal": "normal_str" + }, + "sigllm.primitives.transformation.parse_anomaly_response#1": { + "X": "y_hat" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y_parsed" + }, + "sigllm.primitives.prompting.anomalies.val2idx#1": { + "y": "y_intermediate", + "X": "X_sequences" + }, + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": { + "y": "y_idx" + }, + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences#1": { + "y": "y_windows" + }, + "sigllm.primitives.prompting.anomalies.format_anomalies#1": { + "y": "y_merged" + } + }, + "output_names": { + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "X": "X_processed", + "index": "timestamp" + }, + "sklearn.impute.SimpleImputer#1": { + "X": "X_imputed" + }, + "sigllm.primitives.transformation.Float2Scalar#1": { + "X": "X_scalar" + }, + "sigllm.primitives.prompting.timeseries_preprocessing.rolling_window_sequences#1": { + "X": "X_sequences" + }, + "sigllm.primitives.transformation.format_as_string#1": { + "X": "X_str" + }, + + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#2": { + "X": "normal_processed", + "index": "normal_timestamp" + }, + "sklearn.impute.SimpleImputer#2": { + "X": "normal_imputed" + }, + "sigllm.primitives.transformation.Float2Scalar#2": { + "X": "normal_scalar" + }, + "sigllm.primitives.transformation.format_as_string#2": { + "X": "normal_str" + }, + + "sigllm.primitives.prompting.huggingface.HF#1": { + "y": "y_hat" + }, + "sigllm.primitives.transformation.parse_anomaly_response#1": { + "X": "y_parsed" + }, + "sigllm.primitives.transformation.format_as_integer#1": { + "X": "y_intermediate" + }, + "sigllm.primitives.prompting.anomalies.val2idx#1": { + "y": "y_idx" + }, + "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows#1": { + "y": "y_windows" + }, + "sigllm.primitives.prompting.anomalies.merge_anomalous_sequences#1": { + "y": "y_merged" + } + } +} diff --git a/sigllm/pipelines/prompter/prompter_artificialwithanomaly.json b/sigllm/pipelines/prompter/prompter_artificialwithanomaly.json new file mode 100644 index 0000000..eebcc81 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_artificialwithanomaly.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 600 + } +} diff --git a/sigllm/pipelines/prompter/prompter_realadexchange.json b/sigllm/pipelines/prompter/prompter_realadexchange.json new file mode 100644 index 0000000..6b8aac0 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_realadexchange.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 3600 + } +} diff --git a/sigllm/pipelines/prompter/prompter_realawscloudwatch.json b/sigllm/pipelines/prompter/prompter_realawscloudwatch.json new file mode 100644 index 0000000..eebcc81 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_realawscloudwatch.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 600 + } +} diff --git a/sigllm/pipelines/prompter/prompter_realtraffic.json b/sigllm/pipelines/prompter/prompter_realtraffic.json new file mode 100644 index 0000000..eebcc81 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_realtraffic.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 600 + } +} diff --git a/sigllm/pipelines/prompter/prompter_realtweets.json b/sigllm/pipelines/prompter/prompter_realtweets.json new file mode 100644 index 0000000..eebcc81 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_realtweets.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 600 + } +} diff --git a/sigllm/pipelines/prompter/prompter_smap.json b/sigllm/pipelines/prompter/prompter_smap.json new file mode 100644 index 0000000..e4fe0c1 --- /dev/null +++ b/sigllm/pipelines/prompter/prompter_smap.json @@ -0,0 +1,6 @@ +{ + "mlstars.custom.timeseries_preprocessing.time_segments_aggregate#1": { + "time_column": "timestamp", + "interval": 21600 + } +} diff --git a/sigllm/primitives/jsons/sigllm.primitives.prompting.huggingface.HF.json b/sigllm/primitives/jsons/sigllm.primitives.prompting.huggingface.HF.json index 91d0530..b78afc8 100644 --- a/sigllm/primitives/jsons/sigllm.primitives.prompting.huggingface.HF.json +++ b/sigllm/primitives/jsons/sigllm.primitives.prompting.huggingface.HF.json @@ -16,6 +16,11 @@ { "name": "X", "type": "ndarray" + }, + { + "name": "normal", + "type": "ndarray", + "default": null } ], "output": [ @@ -63,6 +68,10 @@ "padding": { "type": "int", "default": 0 + }, + "restrict_tokens": { + "type": "bool", + "default": false } } } diff --git a/sigllm/primitives/jsons/sigllm.primitives.transformation.format_as_string.json b/sigllm/primitives/jsons/sigllm.primitives.transformation.format_as_string.json index 89d18f5..faa32fa 100644 --- a/sigllm/primitives/jsons/sigllm.primitives.transformation.format_as_string.json +++ b/sigllm/primitives/jsons/sigllm.primitives.transformation.format_as_string.json @@ -4,7 +4,7 @@ "Sarah Alnegheimish ", "Linh Nguyen " ], - "description": "Transform an ndarray of scalar values to an ndarray of string.", + "description": "Format X to string(s). Handles both normal time series (single string) and multiple windows (list of strings).", "classifiers": { "type": "preprocessor", "subtype": "tranformer" @@ -34,6 +34,10 @@ "space": { "type": "bool", "default": false + }, + "single": { + "type": "bool", + "default": false } } } diff --git a/sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json b/sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json new file mode 100644 index 0000000..a7ff470 --- /dev/null +++ b/sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json @@ -0,0 +1,25 @@ +{ + "name": "sigllm.primitives.transformation.parse_anomaly_response", + "contributors": ["Salim Cherkaoui"], + "description": "Parse LLM responses to extract anomaly values from text format.", + "classifiers": { + "type": "transformer", + "subtype": "parser" + }, + "modalities": ["text"], + "primitive": "sigllm.primitives.transformation.parse_anomaly_response", + "produce": { + "args": [ + { + "name": "X", + "type": "ndarray" + } + ], + "output": [ + { + "name": "X", + "type": "ndarray" + } + ] + } +} \ No newline at end of file diff --git a/sigllm/primitives/prompting/anomalies.py b/sigllm/primitives/prompting/anomalies.py index d70164d..82c462f 100644 --- a/sigllm/primitives/prompting/anomalies.py +++ b/sigllm/primitives/prompting/anomalies.py @@ -35,6 +35,7 @@ def val2idx(y, X): idx_win_list.append(indices) idx_list.append(idx_win_list) idx_list = np.array(idx_list, dtype=object) + return idx_list @@ -57,7 +58,6 @@ def find_anomalies_in_windows(y, alpha=0.5): idx_list = [] for samples in y: min_vote = np.ceil(alpha * len(samples)) - # print(type(samples.tolist())) flattened_res = np.concatenate(samples.tolist()) @@ -67,6 +67,7 @@ def find_anomalies_in_windows(y, alpha=0.5): idx_list.append(final_list) idx_list = np.array(idx_list, dtype=object) + return idx_list @@ -112,7 +113,7 @@ def format_anomalies(y, timestamp, padding_size=50): Args: y (ndarray): - A 1-dimensional array of indices. + A 1-dimensional array of indices. Can be empty if no anomalies are found. timestamp (ndarray): List of full timestamp of the signal. padding_size (int): @@ -120,8 +121,12 @@ def format_anomalies(y, timestamp, padding_size=50): Returns: List[Tuple]: - List of intervals (start, end, score). + List of intervals (start, end, score). Empty list if no anomalies are found. """ + # Handle empty array case + if len(y) == 0: + return [] + y = timestamp[y] # Convert list of indices into list of timestamps start, end = timestamp[0], timestamp[-1] interval = timestamp[1] - timestamp[0] @@ -151,4 +156,5 @@ def format_anomalies(y, timestamp, padding_size=50): merged_intervals.append(current_interval) # Append the current interval if no overlap merged_intervals = [(interval[0], interval[1], 0) for interval in merged_intervals] + return merged_intervals diff --git a/sigllm/primitives/prompting/huggingface.py b/sigllm/primitives/prompting/huggingface.py index ac33874..301253e 100644 --- a/sigllm/primitives/prompting/huggingface.py +++ b/sigllm/primitives/prompting/huggingface.py @@ -47,6 +47,8 @@ class HF: padding (int): Additional padding token to forecast to reduce short horizon predictions. Default to `0`. + restrict_tokens (bool): + Whether to restrict tokens or not. Default to `False`. """ def __init__( @@ -59,6 +61,7 @@ def __init__( raw=False, samples=10, padding=0, + restrict_tokens=False, ): self.name = name self.sep = sep @@ -68,6 +71,7 @@ def __init__( self.raw = raw self.samples = samples self.padding = padding + self.restrict_tokens = restrict_tokens self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_fast=False) @@ -85,16 +89,19 @@ def __init__( self.tokenizer.add_special_tokens(special_tokens_dict) self.tokenizer.pad_token = self.tokenizer.eos_token # indicate the end of the time series - # invalid tokens - valid_tokens = [] - for number in VALID_NUMBERS: - token = self.tokenizer.convert_tokens_to_ids(number) - valid_tokens.append(token) + # Only set up invalid tokens if restriction is enabled + if self.restrict_tokens: + valid_tokens = [] + for number in VALID_NUMBERS: + token = self.tokenizer.convert_tokens_to_ids(number) + valid_tokens.append(token) - valid_tokens.append(self.tokenizer.convert_tokens_to_ids(self.sep)) - self.invalid_tokens = [ - [i] for i in range(len(self.tokenizer) - 1) if i not in valid_tokens - ] + valid_tokens.append(self.tokenizer.convert_tokens_to_ids(self.sep)) + self.invalid_tokens = [ + [i] for i in range(len(self.tokenizer) - 1) if i not in valid_tokens + ] + else: + self.invalid_tokens = None self.model = AutoModelForCausalLM.from_pretrained( self.name, @@ -104,12 +111,15 @@ def __init__( self.model.eval() - def detect(self, X, **kwargs): + def detect(self, X, normal=None, **kwargs): """Use HF to detect anomalies of a signal. Args: X (ndarray): Input sequences of strings containing signal values + normal (str, optional): + A normal reference sequence for one-shot prompting. If None, + zero-shot prompting is used. Default to None. Returns: list, list: @@ -120,31 +130,67 @@ def detect(self, X, **kwargs): max_tokens = input_length * float(self.anomalous_percent) all_responses, all_generate_ids = [], [] + # Prepare the one-shot example if provided + one_shot_message = '' + if normal is not None: + one_shot_message = PROMPTS['one_shot_prefix'] + normal + '\n\n' + for text in tqdm(X): system_message = PROMPTS['system_message'] - user_message = PROMPTS['user_message'] - message = ' '.join([system_message, user_message, text, '[RESPONSE]']) - - input_length = len(self.tokenizer.encode(message[0])) + if self.restrict_tokens: + user_message = PROMPTS['user_message'] + else: + user_message = PROMPTS['user_message_2'] + + # Combine messages with one-shot example if provided + message = ' '.join([ + system_message, + one_shot_message, + user_message, + text, + '[RESPONSE]', + ]) + + input_length = len(self.tokenizer.encode(message)) tokenized_input = self.tokenizer(message, return_tensors='pt').to('cuda') - generate_ids = self.model.generate( - **tokenized_input, - do_sample=True, - max_new_tokens=max_tokens, - temperature=self.temp, - top_p=self.top_p, - bad_words_ids=self.invalid_tokens, - renormalize_logits=True, - num_return_sequences=self.samples, - ) - - responses = self.tokenizer.batch_decode( - generate_ids[:, input_length:], - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - ) + generate_kwargs = { + 'do_sample': True, + 'max_new_tokens': max_tokens, + 'temperature': self.temp, + 'top_p': self.top_p, + 'renormalize_logits': True, + 'num_return_sequences': self.samples, + } + + # Only add bad_words_ids if token restriction is enabled + if self.restrict_tokens: + generate_kwargs['bad_words_ids'] = self.invalid_tokens + + generate_ids = self.model.generate(**tokenized_input, **generate_kwargs) + + if self.restrict_tokens: + responses = self.tokenizer.batch_decode( + generate_ids[:, input_length:], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + else: # Extract only the part after [RESPONSE] + # Get the full generated text + full_responses = self.tokenizer.batch_decode( + generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + responses = [] + for full_response in full_responses: + try: + response = full_response.split('[RESPONSE]')[1].strip() + responses.append(response) + except IndexError: + responses.append('') # If no [RESPONSE] found, return empty string + all_responses.append(responses) all_generate_ids.append(generate_ids) diff --git a/sigllm/primitives/prompting/huggingface_messages.json b/sigllm/primitives/prompting/huggingface_messages.json index 3ad1dad..e329949 100644 --- a/sigllm/primitives/prompting/huggingface_messages.json +++ b/sigllm/primitives/prompting/huggingface_messages.json @@ -1,4 +1,6 @@ { - "system_message": "You are an exceptionally intelligent assistant that detect anomalies in time series data by listing all the anomalies.", - "user_message": "Below is a [SEQUENCE], please return the anomalies in that sequence in [RESPONSE]. Only return the numbers. [SEQUENCE]" + "system_message": "You are an expert in time series analysis. Your task is to detect anomalies in time series data.", + "user_message": "Below is a [SEQUENCE], please return the anomalies in that sequence in [RESPONSE]. Only return the numbers. [SEQUENCE]", + "user_message_2": "Below is a [SEQUENCE], analyze the following time series and identify any anomalies. If you find anomalies, provide their values in the format [first_anomaly, ..., last_anomaly]. If no anomalies are found, respond with 'no anomalies'. Be concise, do not write code, do not perform any calculations, just give your answers as told.: [SEQUENCE]", + "one_shot_prefix": "Here is a normal reference of the time series: [NORMAL]" } \ No newline at end of file diff --git a/sigllm/primitives/prompting/timeseries_preprocessing.py b/sigllm/primitives/prompting/timeseries_preprocessing.py index e5d3644..fee3de9 100644 --- a/sigllm/primitives/prompting/timeseries_preprocessing.py +++ b/sigllm/primitives/prompting/timeseries_preprocessing.py @@ -37,5 +37,4 @@ def rolling_window_sequences(X, window_size=500, step_size=100): out_X.append(X[start:end]) X_index.append(index[start]) start = start + step_size - return np.asarray(out_X), np.asarray(X_index), window_size, step_size diff --git a/sigllm/primitives/transformation.py b/sigllm/primitives/transformation.py index 41a98fc..b8ee151 100644 --- a/sigllm/primitives/transformation.py +++ b/sigllm/primitives/transformation.py @@ -6,34 +6,43 @@ import numpy as np -def format_as_string(X, sep=',', space=False): +def format_as_string(X, sep=',', space=False, single=False): """Format X to a list of string. - Transform a 2-D array of integers to a list of strings, - seperated by the indicated seperator and space. + Transform an array of integers to string(s), separated by the + indicated separator and space. Handles two cases: + - If single=True, treats X as a single time series (window_size, 1) + - If single=False, treats X as multiple windows (num_windows, window_size, 1) Args: sep (str): String to separate each element in X. Default to `','`. space (bool): Whether to add space between each digit in the result. Default to `False`. + single (bool): + Whether to treat X as a single time series. If True, expects (window_size, 1) + and returns a single string. If False, expects (num_windows, window_size, 1) + and returns a list of strings. Default to `False`. Returns: - ndarray: - A list of string representation of each row. + ndarray or str: + If single=True, returns one string representation. If single=False, + returns a list of string representations for each window. """ def _as_string(x): text = sep.join(list(map(str, x.flatten()))) - if space: text = ' '.join(text) - return text - results = list(map(_as_string, X)) - - return np.array(results) + if single: + # single time series (window_size, 1) + return _as_string(X) + else: + # multiple windows (num_windows, window_size, 1) + results = list(map(_as_string, X)) + return np.array(results) def _from_string_to_integer(text, sep=',', trunc=None, errors='ignore'): @@ -74,6 +83,7 @@ def format_as_integer(X, sep=',', trunc=None, errors='ignore'): Transforms a list of list of string input as 3-D array of integers, seperated by the indicated seperator and truncated based on `trunc`. + Handles empty strings by returning empty arrays. Args: sep (str): @@ -91,7 +101,7 @@ def format_as_integer(X, sep=',', trunc=None, errors='ignore'): Returns: ndarray: - An array of digits values. + An array of digits values. Empty arrays for empty strings. """ result = list() for string_list in X: @@ -100,8 +110,11 @@ def format_as_integer(X, sep=',', trunc=None, errors='ignore'): raise ValueError('Input is not a list of lists.') for text in string_list: - scalar = _from_string_to_integer(text, sep, trunc, errors) - sample.append(scalar) + if not text: # empty string + sample.append(np.array([], dtype=float)) + else: + scalar = _from_string_to_integer(text, sep, trunc, errors) + sample.append(scalar) result.append(sample) @@ -171,3 +184,44 @@ def transform(self, X, minimum=0, decimal=2): values = X * 10 ** (-decimal) return values + minimum + + +def parse_anomaly_response(X): + """Parse a list of lists of LLM responses to extract anomaly values and format them as strings. + + Args: + X (List[List[str]]): + List of lists of response texts from the LLM in the format + "Answer: no anomalies" or "Answer: [val1, val2, ..., valN]." + values must be within brackets. + + Returns: + List[List[str]]: + List of lists of parsed responses where each element is either + "val1,val2,...,valN" if anomalies are found, or empty string if + no anomalies are present. + """ + + def _parse_single_response(text): + text = text.strip().lower() + + if 'no anomalies' in text or 'no anomaly' in text: + return '' + + # match anything that consists of digits and commas + pattern = r'\[([\d\s,]+)\]' + match = re.search(pattern, text) + + if match: + values = match.group(1) + values = [val.strip() for val in values.split(',') if val.strip()] + return ','.join(values) + + return '' + + result = [] + for response_list in X: + parsed_list = [_parse_single_response(response) for response in response_list] + result.append(parsed_list) + + return result diff --git a/tests/primitives/test_transformation.py b/tests/primitives/test_transformation.py index eb759d9..538ceef 100644 --- a/tests/primitives/test_transformation.py +++ b/tests/primitives/test_transformation.py @@ -9,6 +9,7 @@ _from_string_to_integer, format_as_integer, format_as_string, + parse_anomaly_response, ) @@ -45,6 +46,14 @@ def test_format_as_string_decimal(self): assert output == expected + def test_format_as_string_single(self): + data = np.array([1, 2, 3, 4, 5]) + expected = '1,2,3,4,5' + + output = format_as_string(data, single=True) + + np.testing.assert_array_equal(output, expected) + class FromStringToIntegerTest(unittest.TestCase): def test__from_string_to_integer_default(self): @@ -131,6 +140,16 @@ def test_format_as_integer_list(): np.testing.assert_equal(output, expected) +def test_format_as_integer_empty(): + data = [['']] + + expected = np.array([[np.array([], dtype=float)]]) + + output = format_as_integer(data) + + np.testing.assert_equal(output, expected) + + def test_format_as_integer_2d_shape_mismatch(): data = [['1,2,3,4,5'], ['1, 294., 3 , j34,5'], ['!232, 23,3,4,5']] @@ -146,6 +165,18 @@ def test_format_as_integer_2d_shape_mismatch(): np.testing.assert_equal(o, e) +def test_format_as_integer_mixed(): + data = [[''], ['1,2,3']] + + expected = np.array([[np.array([], dtype=float)], [np.array([1.0, 2.0, 3.0])]], dtype=object) + + output = format_as_integer(data) + + for out, exp in list(zip(output, expected)): + for o, e in list(zip(out, exp)): + np.testing.assert_equal(o, e) + + def test_format_as_integer_2d_trunc(): data = [['1,2,3,4,5'], ['1,294.,3,j34,5'], ['!232, 23,3,4,5']] @@ -311,3 +342,57 @@ def test_float2scalar_scalar2float_integration(): output = scalar2float.transform(transformed, minimum, decimal) np.testing.assert_allclose(output, expected, rtol=1e-2) + + +class ParseAnomalyResponseTest(unittest.TestCase): + def test_no_anomalies(self): + data = [['Answer: no anomalies'], ['Answer: no anomaly'], ['no anomaly, with extra']] + expected = [[''], [''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_single_anomaly(self): + data = [['Answer: [123]'], ['Answer: [456]', 'answer: [789]']] + expected = [['123'], ['456', '789']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_multiple_anomalies(self): + data = [['Answer: [123, 456, 789]'], ['Answer: [111, 222, 333]']] + expected = [['123,456,789'], ['111,222,333']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_mixed_responses(self): + data = [['Answer: no anomalies', 'Answer: [123, 456]'], ['Answer: [789]', 'no anomaly']] + expected = [['', '123,456'], ['789', '']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_different_formats(self): + data = [ + ['Answer: [123, 456]', 'Answer: [ 789 , 101 ]'], + ['Answer: [1,2,3]', 'Answer: [ 4 , 5 , 6 ]'], + ] + expected = [['123,456', '789,101'], ['1,2,3', '4,5,6']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_empty_responses(self): + data = [[''], ['Answer: no anomalies'], ['answer'], ['no anomly']] + expected = [[''], [''], [''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_invalid_format(self): + data = [['Answer: invalid format'], ['Answer: [123, abc]']] + expected = [[''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..8efe4d0 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""Tests for `sigllm.data` module.""" + +from datetime import datetime +from unittest.mock import patch + +import pandas as pd +import pytest + +from sigllm.data import load_normal + + +@pytest.fixture +def sample_data(): + return pd.DataFrame({ + 'timestamp': pd.date_range(start='2023-01-01', periods=10, freq='D'), + 'value': range(10), + }) + + +@patch('sigllm.data.download_normal') +@patch('sigllm.data.format_csv') +def test_load_normal_without_start_end(mock_format_csv, mock_download, sample_data): + mock_format_csv.return_value = sample_data + mock_download.return_value = sample_data + + result = load_normal('test.csv') + mock_download.assert_called_once() + pd.testing.assert_frame_equal(result, sample_data) + + +@patch('sigllm.data.download_normal') +@patch('sigllm.data.format_csv') +def test_load_normal_with_index_based_start_end(mock_format_csv, mock_download, sample_data): + mock_format_csv.return_value = sample_data + mock_download.return_value = sample_data + + result = load_normal('test.csv', start=2, end=5) + expected = sample_data.iloc[2:5] + pd.testing.assert_frame_equal(result, expected) + + result = load_normal('test.csv', start=2) + expected = sample_data.iloc[2:] + pd.testing.assert_frame_equal(result, expected) + + result = load_normal('test.csv', end=5) + expected = sample_data.iloc[:5] + pd.testing.assert_frame_equal(result, expected) + + +@patch('sigllm.data.download_normal') +@patch('sigllm.data.format_csv') +def test_load_normal_with_timestamp_based_start_end(mock_format_csv, mock_download, sample_data): + mock_format_csv.return_value = sample_data + mock_download.return_value = sample_data + + start_date = datetime(2023, 1, 3) + end_date = datetime(2023, 1, 6) + result = load_normal('test.csv', timestamp_column='timestamp', start=start_date, end=end_date) + + expected = sample_data[ + (sample_data['timestamp'] >= start_date) & (sample_data['timestamp'] <= end_date) + ] + pd.testing.assert_frame_equal(result, expected) + + result = load_normal('test.csv', timestamp_column='timestamp', start=start_date) + expected = sample_data[sample_data['timestamp'] >= start_date] + pd.testing.assert_frame_equal(result, expected) + + result = load_normal('test.csv', timestamp_column='timestamp', end=end_date) + expected = sample_data[sample_data['timestamp'] <= end_date] + pd.testing.assert_frame_equal(result, expected)