Skip to content

Commit 8a36e1b

Browse files
committed
support load_normal
1 parent baaa05a commit 8a36e1b

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

sigllm/benchmark.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from orion.progress import TqdmLogger
2828

2929
from sigllm import SigLLM
30-
31-
# from sigllm.data import load_normal
30+
from sigllm.data import load_normal
3231

3332
warnings.simplefilter('ignore')
3433

@@ -56,9 +55,9 @@
5655

5756
PIPELINES = {
5857
'mistral_prompter_restricted': 'mistral_prompter',
59-
# 'mistral_prompter_0shot': 'mistral_prompter_0shot',
60-
# 'mistral_prompter_1shot': 'mistral_prompter_1shot',
61-
} # ['mistral_prompter_0shot', 'mistral_prompter_1shot']
58+
'mistral_prompter_0shot': 'mistral_prompter_0shot',
59+
'mistral_prompter_1shot': 'mistral_prompter_1shot',
60+
}
6261

6362

6463
def _get_pipeline_directory(pipeline_name):
@@ -122,8 +121,8 @@ def _evaluate_signal(
122121
truth = load_anomalies(signal)
123122

124123
normal = None
125-
# if few_shot:
126-
# normal = load_normal(signal)
124+
if few_shot:
125+
normal = load_normal(signal)
127126

128127
try:
129128
LOGGER.info(

tests/test_benchmark.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,18 @@ def test__evaluate_signal_fail(self, mock_sigllm, mock_load_signal, mock_load_an
324324
self.pipeline_name, hyperparameters=self.hyperparameters
325325
)
326326

327+
@patch('sigllm.benchmark.load_normal')
327328
@patch('sigllm.benchmark.load_anomalies')
328329
@patch('sigllm.benchmark._load_signal')
329330
@patch('sigllm.benchmark.SigLLM')
330331
def test__evaluate_signal_with_few_shot(
331-
self, mock_sigllm, mock_load_signal, mock_load_anomalies
332+
self,
333+
mock_sigllm,
334+
mock_load_signal,
335+
mock_load_anomalies,
336+
mock_load_normal,
332337
):
338+
mock_load_normal.return_value = self.test_data
333339
mock_load_signal.return_value = (None, self.test_data)
334340
mock_load_anomalies.return_value = self.truth_data
335341

@@ -349,8 +355,9 @@ def test__evaluate_signal_with_few_shot(
349355

350356
assert isinstance(result, dict)
351357
self.assertEqual(result['status'], 'OK')
352-
# TODO: fix this call to make normal a pandas dataframe
353-
mock_pipeline.detect.assert_called_once_with(self.test_data, normal=None)
358+
359+
mock_pipeline.detect.assert_called_once_with(self.test_data, normal=self.test_data)
360+
mock_load_normal.assert_called_once_with(self.signal_name)
354361
mock_load_signal.assert_called_once_with(self.signal_name, self.test_split)
355362
mock_load_anomalies.assert_called_once_with(self.signal_name)
356363
mock_sigllm.assert_called_once_with(

0 commit comments

Comments
 (0)