Skip to content

Commit fe608da

Browse files
authored
Merge pull request #4 from trustyai-explainability/hf_detector_tests
Further tests for the HF detector class
2 parents 92217c7 + e91f909 commit fe608da

7 files changed

+408
-3
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# third-party imports
2+
import pytest
3+
import torch
4+
from unittest.mock import Mock
5+
from transformers import PreTrainedTokenizer
6+
7+
# relative imports
8+
from detectors.huggingface.detector import Detector
9+
10+
11+
class TestGetProbabilities:
12+
@pytest.fixture
13+
def detector(self):
14+
detector = Detector.__new__(Detector)
15+
detector.tokenizer = Mock(spec=PreTrainedTokenizer)
16+
return detector
17+
18+
def test_normal_case(self, detector):
19+
# Setup
20+
logprobs = [
21+
Mock(values=torch.tensor([[0.0, -1.0]]), indices=torch.tensor([[1, 2]]))
22+
]
23+
detector.tokenizer.convert_ids_to_tokens.side_effect = lambda x: (
24+
"safe" if x == 1 else "unsafe"
25+
)
26+
result = detector.get_probabilities(logprobs, "safe", "unsafe")
27+
assert isinstance(result, torch.Tensor)
28+
assert len(result) == 2
29+
assert torch.allclose(result.sum(), torch.tensor(1.0))
30+
assert result[0] > result[1] # Safe token has higher probability
31+
32+
def test_empty_logprobs(self, detector):
33+
result = detector.get_probabilities([], "safe", "unsafe")
34+
assert torch.allclose(result, torch.tensor([0.5, 0.5]))
35+
36+
def test_very_small_probabilities(self, detector):
37+
logprobs = [
38+
Mock(values=torch.tensor([[-50.0, -50.0]]), indices=torch.tensor([[1, 2]]))
39+
]
40+
detector.tokenizer.convert_ids_to_tokens.side_effect = lambda x: (
41+
"safe" if x == 1 else "unsafe"
42+
)
43+
result = detector.get_probabilities(logprobs, "safe", "unsafe")
44+
assert torch.allclose(result.sum(), torch.tensor(1.0))
45+
assert torch.allclose(result[0], result[1]) # Should be equal probabilities
46+
47+
def test_case_sensitivity(self, detector):
48+
logprobs = [Mock(values=torch.tensor([[0.0]]), indices=torch.tensor([[1]]))]
49+
detector.tokenizer.convert_ids_to_tokens.return_value = "SAFE"
50+
result = detector.get_probabilities(logprobs, "safe", "unsafe")
51+
assert result[0] > result[1]
52+
53+
def test_invalid_tokens(self, detector):
54+
logprobs = [Mock(values=torch.tensor([[0.0]]), indices=torch.tensor([[1]]))]
55+
detector.tokenizer.convert_ids_to_tokens.return_value = "invalid"
56+
result = detector.get_probabilities(logprobs, "safe", "unsafe")
57+
assert torch.allclose(result, torch.tensor([0.5, 0.5]))

tests/test_detector_initalisation.py renamed to tests/detectors/huggingface/test_method_initialize_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# global imports
1+
# third-party imports
22
import os
33
import pytest
44

@@ -12,7 +12,9 @@ def setup_environment():
1212
"""
1313
Setup the required environment variable for the model directory.
1414
"""
15-
os.environ["MODEL_DIR"] = os.path.join(os.path.dirname(__file__), "dummy_models")
15+
current_dir = os.path.dirname(__file__)
16+
parent_dir = os.path.dirname(os.path.dirname(current_dir))
17+
os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models")
1618

1719

1820
# tests to check the model initialization
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# third-party imports
2+
import numpy as np
3+
import pytest
4+
import torch
5+
from unittest.mock import Mock, patch
6+
from transformers import PreTrainedTokenizer
7+
8+
# relative imports
9+
from detectors.huggingface.detector import Detector, ContentAnalysisResponse
10+
11+
12+
class TestDetectorParseOutput:
13+
@pytest.fixture
14+
def detector(self):
15+
detector = Detector.__new__(Detector)
16+
detector.tokenizer = Mock(spec=PreTrainedTokenizer)
17+
detector.get_probabilities = Mock()
18+
return detector
19+
20+
@pytest.fixture
21+
def mock_output(self):
22+
output = Mock()
23+
output.sequences = torch.tensor([[1, 2, 3, 4]])
24+
output.scores = [torch.randn(1, 5)]
25+
return output
26+
27+
@pytest.fixture
28+
def default_params(self):
29+
return {
30+
"input_len": 2,
31+
"nlogprobs": 5,
32+
"safe_token": "Safe",
33+
"unsafe_token": "Unsafe",
34+
}
35+
36+
def test_parse_output_safe_classification(
37+
self, detector, mock_output, default_params
38+
):
39+
"""Test safe token classification with probabilities"""
40+
detector.tokenizer.decode.return_value = "safe"
41+
detector.get_probabilities.return_value = torch.tensor([0.7, 0.3])
42+
43+
label, prob = detector.parse_output(output=mock_output, **default_params)
44+
45+
assert label == "Safe"
46+
assert isinstance(prob, float)
47+
np.testing.assert_almost_equal(prob, 0.3, decimal=5)
48+
49+
def test_parse_output_unsafe_classification(
50+
self, detector, mock_output, default_params
51+
):
52+
"""Test unsafe token classification with probabilities"""
53+
detector.tokenizer.decode.return_value = "unsafe"
54+
detector.get_probabilities.return_value = torch.tensor([0.3, 0.7])
55+
56+
label, prob = detector.parse_output(output=mock_output, **default_params)
57+
58+
assert label == "Unsafe"
59+
assert isinstance(prob, float)
60+
np.testing.assert_almost_equal(prob, 0.7, decimal=5)
61+
62+
def test_parse_output_failed_classification(
63+
self, detector, mock_output, default_params
64+
):
65+
"""Test when decoded token doesn't match safe/unsafe"""
66+
detector.tokenizer.decode.return_value = "invalid"
67+
detector.get_probabilities.return_value = torch.tensor([0.5, 0.5])
68+
69+
label, prob = detector.parse_output(output=mock_output, **default_params)
70+
71+
assert label == "failed"
72+
assert prob == 0.5
73+
74+
def test_parse_output_empty_sequence(self, detector, default_params):
75+
"""Test with empty sequence"""
76+
mock_output = Mock()
77+
mock_output.sequences = torch.tensor([[]])
78+
detector.tokenizer.decode.return_value = ""
79+
80+
label, prob = detector.parse_output(
81+
output=mock_output,
82+
input_len=0,
83+
nlogprobs=0,
84+
safe_token="Safe",
85+
unsafe_token="Unsafe",
86+
)
87+
88+
assert label == "failed"
89+
assert prob is None
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# third-party imports
2+
import os
3+
import pytest
4+
import torch
5+
from unittest.mock import Mock, patch
6+
7+
# relative imports
8+
from detectors.huggingface.detector import Detector, ContentAnalysisResponse
9+
10+
11+
class MockGraniteOutput:
12+
def __init__(self):
13+
self.sequences = torch.tensor([[1, 2, 3, 4]])
14+
self.scores = [torch.randn(1, 5)]
15+
16+
17+
@pytest.fixture
18+
def setup_environment():
19+
"""Setup the required environment variable for the model directory."""
20+
current_dir = os.path.dirname(__file__)
21+
parent_dir = os.path.dirname(os.path.dirname(current_dir))
22+
os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models")
23+
24+
25+
class TestDetector:
26+
@pytest.fixture(autouse=True)
27+
def setup(self, setup_environment):
28+
pass
29+
30+
@pytest.fixture
31+
def detector_instance(self):
32+
with patch.dict("os.environ", {"MODEL_DIR": "/dummy/path"}):
33+
detector = Detector.__new__(Detector)
34+
35+
detector.tokenizer = Mock()
36+
detector.tokenizer.apply_chat_template = Mock(
37+
return_value=torch.tensor([[1, 2, 3]])
38+
)
39+
detector.tokenizer.decode = Mock(return_value="Yes")
40+
41+
detector.model = Mock()
42+
detector.model.device = torch.device("cpu")
43+
detector.model.generate = Mock(return_value=MockGraniteOutput())
44+
45+
detector.model_name = "causal_lm"
46+
detector.is_causal_lm = True
47+
detector.cuda_device = None
48+
detector.risk_names = ["harm", "bias"]
49+
50+
return detector
51+
52+
def validate_results(self, results, input_text, detector):
53+
"""Helper method to validate the classification results"""
54+
assert len(results) == len(detector.risk_names)
55+
56+
for result in results:
57+
expected_fields = [
58+
"start",
59+
"end",
60+
"detection",
61+
"detection_type",
62+
"score",
63+
"sequence_classification",
64+
"sequence_probability",
65+
"token_classifications",
66+
"token_probabilities",
67+
"text",
68+
"evidences",
69+
]
70+
71+
for field in expected_fields:
72+
assert hasattr(
73+
result, field
74+
), f"Missing '{field}' in ContentAnalysisResponse"
75+
76+
assert isinstance(result, ContentAnalysisResponse)
77+
assert isinstance(result.start, int)
78+
assert isinstance(result.end, int)
79+
assert isinstance(result.detection, str)
80+
assert isinstance(result.detection_type, str)
81+
assert isinstance(result.score, float)
82+
assert isinstance(result.sequence_classification, str)
83+
assert isinstance(result.sequence_probability, float)
84+
assert isinstance(result.text, str)
85+
assert isinstance(result.evidences, list)
86+
87+
assert 0 <= result.start <= len(input_text)
88+
assert 0 <= result.end <= len(input_text)
89+
assert 0.0 <= result.score <= 1.0
90+
assert 0.0 <= result.sequence_probability <= 1.0
91+
assert result.sequence_classification in detector.risk_names
92+
93+
def test_process_causal_lm_single_short_input(self, detector_instance):
94+
text = "This is a test."
95+
results = detector_instance.process_causal_lm(text)
96+
self.validate_results(results, text, detector_instance)
97+
98+
def test_process_causal_lm_single_long_input(self, detector_instance):
99+
text = "This is a test." * 1_000
100+
results = detector_instance.process_causal_lm(text)
101+
self.validate_results(results, text, detector_instance)
102+
103+
def test_process_causal_lm_single_empty_input(self, detector_instance):
104+
text = ""
105+
results = detector_instance.process_causal_lm(text)
106+
self.validate_results(results, text, detector_instance)

tests/test_detector_process.py renamed to tests/detectors/huggingface/test_method_process_sequence_classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ def setup_environment():
1111
"""
1212
Setup the required environment variable for the model directory.
1313
"""
14-
os.environ["MODEL_DIR"] = os.path.join(os.path.dirname(__file__), "dummy_models")
14+
current_dir = os.path.dirname(__file__)
15+
parent_dir = os.path.dirname(os.path.dirname(current_dir))
16+
os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models")
1517

1618

1719
# tests to check the detector output

0 commit comments

Comments
 (0)