diff --git a/sigllm/pipelines/prompter/mistral_prompter_0shot.json b/sigllm/pipelines/prompter/mistral_prompter_0shot.json index 40188e0..48b02c7 100644 --- a/sigllm/pipelines/prompter/mistral_prompter_0shot.json +++ b/sigllm/pipelines/prompter/mistral_prompter_0shot.json @@ -7,7 +7,7 @@ "sigllm.primitives.transformation.format_as_string", "sigllm.primitives.prompting.huggingface.HF", - "sigllm.primitives.transformation.parse_anomaly_response", + "sigllm.primitives.prompting.anomalies.parse_anomaly_response", "sigllm.primitives.transformation.format_as_integer", "sigllm.primitives.prompting.anomalies.val2idx", "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows", diff --git a/sigllm/pipelines/prompter/mistral_prompter_1shot.json b/sigllm/pipelines/prompter/mistral_prompter_1shot.json index 62dc8ce..eb729a3 100644 --- a/sigllm/pipelines/prompter/mistral_prompter_1shot.json +++ b/sigllm/pipelines/prompter/mistral_prompter_1shot.json @@ -12,7 +12,7 @@ "sigllm.primitives.transformation.format_as_string", "sigllm.primitives.prompting.huggingface.HF", - "sigllm.primitives.transformation.parse_anomaly_response", + "sigllm.primitives.prompting.anomalies.parse_anomaly_response", "sigllm.primitives.transformation.format_as_integer", "sigllm.primitives.prompting.anomalies.val2idx", "sigllm.primitives.prompting.anomalies.find_anomalies_in_windows", diff --git a/sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json b/sigllm/primitives/jsons/sigllm.primitives.prompting.anomalies.parse_anomaly_response.json similarity index 63% rename from sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json rename to sigllm/primitives/jsons/sigllm.primitives.prompting.anomalies.parse_anomaly_response.json index a7ff470..13b327d 100644 --- a/sigllm/primitives/jsons/sigllm.primitives.transformation.parse_anomaly_response.json +++ b/sigllm/primitives/jsons/sigllm.primitives.prompting.anomalies.parse_anomaly_response.json @@ -1,5 +1,5 @@ { - "name": "sigllm.primitives.transformation.parse_anomaly_response", + "name": "sigllm.primitives.prompting.anomalies.parse_anomaly_response", "contributors": ["Salim Cherkaoui"], "description": "Parse LLM responses to extract anomaly values from text format.", "classifiers": { @@ -7,7 +7,7 @@ "subtype": "parser" }, "modalities": ["text"], - "primitive": "sigllm.primitives.transformation.parse_anomaly_response", + "primitive": "sigllm.primitives.prompting.anomalies.parse_anomaly_response", "produce": { "args": [ { @@ -21,5 +21,13 @@ "type": "ndarray" } ] + }, + "hyperparameters": { + "fixed": { + "interval": { + "type": "bool", + "default": false + } + } } } \ No newline at end of file diff --git a/sigllm/primitives/prompting/anomalies.py b/sigllm/primitives/prompting/anomalies.py index 82c462f..298cb76 100644 --- a/sigllm/primitives/prompting/anomalies.py +++ b/sigllm/primitives/prompting/anomalies.py @@ -5,8 +5,84 @@ This module contains functions that help filter LLMs results to get the final anomalies. """ +import ast +import re + import numpy as np +PATTERN = r'\[([\d\s,]+)\]' + + +def _clean_response(text): + text = text.strip().lower() + text = re.sub(r',+', ',', text) + + if 'no anomalies' in text or 'no anomaly' in text: + return '' + + return text + + +def _parse_list_response(text): + clean = _clean_response(text) + + # match anything that consists of digits and commas + match = re.search(PATTERN, clean) + + if match: + values = match.group(1) + values = [val.strip() for val in values.split(',') if val.strip()] + return ','.join(values) + + return '' + + +def _parse_interval_response(text): + clean = _clean_response(text) + match = re.finditer(PATTERN, clean) + + if match: + values = list() + for m in match: + interval = ast.literal_eval(m.group()) + if len(interval) == 2: + start, end = ast.literal_eval(m.group()) + values.extend(list(range(start, end + 1))) + + return values + + return [] + + +def parse_anomaly_response(X, interval=False): + """Parse a list of lists of LLM responses to extract anomaly values and format them as strings. + + Args: + X (List[List[str]]): + List of lists of response texts from the LLM in the format + "Answer: no anomalies" or "Answer: [val1, val2, ..., valN]." + values must be within brackets. + interval (bool): + Whether to parse the response as a list "Answer: [val1, val2, ..., valN]." + or list of intervals "Answer: [[s1, e1], [s2, e2], ..., [sn, en]]." + + Returns: + List[List[str]]: + List of lists of parsed responses where each element is either + "val1,val2,...,valN" if anomalies are found, or empty string if + no anomalies are present. + """ + method = _parse_list_response + if interval: + method = _parse_interval_response + + result = [] + for response_list in X: + parsed_list = [method(response) for response in response_list] + result.append(parsed_list) + + return result + def val2idx(y, X): """Convert detected anomalies values into indices. diff --git a/sigllm/primitives/transformation.py b/sigllm/primitives/transformation.py index b8ee151..5965861 100644 --- a/sigllm/primitives/transformation.py +++ b/sigllm/primitives/transformation.py @@ -184,44 +184,3 @@ def transform(self, X, minimum=0, decimal=2): values = X * 10 ** (-decimal) return values + minimum - - -def parse_anomaly_response(X): - """Parse a list of lists of LLM responses to extract anomaly values and format them as strings. - - Args: - X (List[List[str]]): - List of lists of response texts from the LLM in the format - "Answer: no anomalies" or "Answer: [val1, val2, ..., valN]." - values must be within brackets. - - Returns: - List[List[str]]: - List of lists of parsed responses where each element is either - "val1,val2,...,valN" if anomalies are found, or empty string if - no anomalies are present. - """ - - def _parse_single_response(text): - text = text.strip().lower() - - if 'no anomalies' in text or 'no anomaly' in text: - return '' - - # match anything that consists of digits and commas - pattern = r'\[([\d\s,]+)\]' - match = re.search(pattern, text) - - if match: - values = match.group(1) - values = [val.strip() for val in values.split(',') if val.strip()] - return ','.join(values) - - return '' - - result = [] - for response_list in X: - parsed_list = [_parse_single_response(response) for response in response_list] - result.append(parsed_list) - - return result diff --git a/tests/primitives/prompting/test_anomalies.py b/tests/primitives/prompting/test_anomalies.py index 4722123..e494680 100644 --- a/tests/primitives/prompting/test_anomalies.py +++ b/tests/primitives/prompting/test_anomalies.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- +import unittest import numpy as np from pytest import fixture from sigllm.primitives.prompting.anomalies import ( + _clean_response, + _parse_interval_response, + _parse_list_response, find_anomalies_in_windows, format_anomalies, merge_anomalous_sequences, + parse_anomaly_response, val2idx, ) @@ -100,10 +105,160 @@ def test_val2idx(anomalous_val, windows): # timestamp2interval - - def test_format_anomalies(idx_list, timestamp): expected = [(1000, 1820, 0), (5950, 6950, 0), (7390, 8390, 0), (11530, 12840, 0)] result = format_anomalies(idx_list, timestamp) assert expected == result + + +def test_clean_response_no_anomalies(): + test_cases = [ + 'no anomalies', + 'NO ANOMALIES', + ' no anomalies ', + 'There are no anomalies in this data', + 'No anomaly detected', + ' No anomaly ', + ] + for text in test_cases: + assert _clean_response(text) == '' + + +def test_clean_response_with_anomalies(): + test_cases = [ + ('[1, 2, 3]', '[1, 2, 3]'), + (' [1, 2, 3] ', '[1, 2, 3]'), + ('Anomalies found at [1, 2, 3]', 'anomalies found at [1, 2, 3]'), + ('ANOMALIES AT [1, 2, 3]', 'anomalies at [1, 2, 3]'), + ] + for input_text, expected in test_cases: + assert _clean_response(input_text) == expected + + +def test_parse_list_response_valid_cases(): + test_cases = [ + ('[1, 2, 3]', '1,2,3'), + (' [1, 2, 3] ', '1,2,3'), + ('Anomalies found at [1, 2, 3]', '1,2,3'), + ('[1,2,3]', '1,2,3'), + ('[1, 2, 3, 4, 5]', '1,2,3,4,5'), + ] + for input_text, expected in test_cases: + assert _parse_list_response(input_text) == expected + + +def test_parse_list_response_invalid_cases(): + test_cases = [ + 'no anomalies', + '[]', + '[ ]', + 'text with [no numbers]', + 'text with [letters, and, symbols]', + ' ', + ] + for text in test_cases: + assert _parse_list_response(text) == '' + + +def test_parse_list_response_edge_cases(): + test_cases = [ + ('[1,2,3,]', '1,2,3'), # trailing comma + ('[1,,2,3]', '1,2,3'), # double comma + ('[1, 2, 3], [5]', '1,2,3'), # two lists + ] + for input_text, expected in test_cases: + assert _parse_list_response(input_text) == expected + + +def test_parse_interval_response_valid_cases(): + test_cases = [ + ('[[1, 3]]', [1, 2, 3]), + (' [[1, 3]] ', [1, 2, 3]), + ('Anomalies found at [[1, 3]]', [1, 2, 3]), + ('[[1, 3], [5, 7]]', [1, 2, 3, 5, 6, 7]), + ('[[1, 3], [5, 7], [8, 9]]', [1, 2, 3, 5, 6, 7, 8, 9]), + ('[[1, 3], [4, 6],]', [1, 2, 3, 4, 5, 6]), + ('[[1, 2], [3]]', [1, 2]), + ('[[1,,3]]', [1, 2, 3]), + ('[[0, 10]]', list(range(11))), + ] + for input_text, expected in test_cases: + assert _parse_interval_response(input_text) == expected + + +def test_parse_interval_response_invalid_cases(): + test_cases = [ + '[]', + '[[]]', + 'text with [no numbers]', + '[[1]]', # single number instead of pair + '[[1, 2, 3]]', # triple instead of pair + ] + for text in test_cases: + assert _parse_interval_response(text) == [] + + +def test_parse_interval_response_multiple_matches(): + test_cases = [ + ('Found [[1, 3]] and [[5, 7]]', [1, 2, 3, 5, 6, 7]), + ('[[1, 2]] in first part and [[3, 4]] in second', [1, 2, 3, 4]), + ('Multiple intervals: [[1, 3]], [[4, 6]], [[7, 9]]', [1, 2, 3, 4, 5, 6, 7, 8, 9]), + ('[[1, 2]] and [[1, 2]] and [[1, 2]]', [1, 2, 1, 2, 1, 2]), + ] + for input_text, expected in test_cases: + assert _parse_interval_response(input_text) == expected + + +class ParseAnomalyResponseTest(unittest.TestCase): + def test_no_anomalies(self): + data = [['Answer: no anomalies'], ['Answer: no anomaly'], ['no anomaly, with extra']] + expected = [[''], [''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_single_anomaly(self): + data = [['Answer: [123]'], ['Answer: [456]', 'answer: [789]']] + expected = [['123'], ['456', '789']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_multiple_anomalies(self): + data = [['Answer: [123, 456, 789]'], ['Answer: [111, 222, 333]']] + expected = [['123,456,789'], ['111,222,333']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_mixed_responses(self): + data = [['Answer: no anomalies', 'Answer: [123, 456]'], ['Answer: [789]', 'no anomaly']] + expected = [['', '123,456'], ['789', '']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_different_formats(self): + data = [ + ['Answer: [123, 456]', 'Answer: [ 789 , 101 ]'], + ['Answer: [1,2,3]', 'Answer: [ 4 , 5 , 6 ]'], + ] + expected = [['123,456', '789,101'], ['1,2,3', '4,5,6']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_empty_responses(self): + data = [[''], ['Answer: no anomalies'], ['answer'], ['no anomly']] + expected = [[''], [''], [''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) + + def test_invalid_format(self): + data = [['Answer: invalid format'], ['Answer: [123, abc]']] + expected = [[''], ['']] + + output = parse_anomaly_response(data) + self.assertEqual(output, expected) diff --git a/tests/primitives/test_transformation.py b/tests/primitives/test_transformation.py index 538ceef..29052fd 100644 --- a/tests/primitives/test_transformation.py +++ b/tests/primitives/test_transformation.py @@ -9,7 +9,6 @@ _from_string_to_integer, format_as_integer, format_as_string, - parse_anomaly_response, ) @@ -342,57 +341,3 @@ def test_float2scalar_scalar2float_integration(): output = scalar2float.transform(transformed, minimum, decimal) np.testing.assert_allclose(output, expected, rtol=1e-2) - - -class ParseAnomalyResponseTest(unittest.TestCase): - def test_no_anomalies(self): - data = [['Answer: no anomalies'], ['Answer: no anomaly'], ['no anomaly, with extra']] - expected = [[''], [''], ['']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_single_anomaly(self): - data = [['Answer: [123]'], ['Answer: [456]', 'answer: [789]']] - expected = [['123'], ['456', '789']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_multiple_anomalies(self): - data = [['Answer: [123, 456, 789]'], ['Answer: [111, 222, 333]']] - expected = [['123,456,789'], ['111,222,333']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_mixed_responses(self): - data = [['Answer: no anomalies', 'Answer: [123, 456]'], ['Answer: [789]', 'no anomaly']] - expected = [['', '123,456'], ['789', '']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_different_formats(self): - data = [ - ['Answer: [123, 456]', 'Answer: [ 789 , 101 ]'], - ['Answer: [1,2,3]', 'Answer: [ 4 , 5 , 6 ]'], - ] - expected = [['123,456', '789,101'], ['1,2,3', '4,5,6']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_empty_responses(self): - data = [[''], ['Answer: no anomalies'], ['answer'], ['no anomly']] - expected = [[''], [''], [''], ['']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected) - - def test_invalid_format(self): - data = [['Answer: invalid format'], ['Answer: [123, abc]']] - expected = [[''], ['']] - - output = parse_anomaly_response(data) - self.assertEqual(output, expected)