diff --git a/AIDojoCoordinator/global_defender.py b/AIDojoCoordinator/global_defender.py index 5937ce58..c784b11a 100644 --- a/AIDojoCoordinator/global_defender.py +++ b/AIDojoCoordinator/global_defender.py @@ -62,7 +62,7 @@ def stochastic_with_threshold(self, action: Action, episode_actions:list, tw_siz temp_episode_actions.append(action.as_dict) if len(temp_episode_actions) >= tw_size: last_n_actions = temp_episode_actions[-tw_size:] - last_n_action_types = [action['type'] for action in last_n_actions] + last_n_action_types = [action['action_type'] for action in last_n_actions] # compute ratio of action type in the TW tw_ratio = last_n_action_types.count(str(action.type))/tw_size # Count how many times this exact (parametrized) action was played in episode diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 348288f7..0ab7c851 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -6,7 +6,7 @@ #python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace -# Coordinator tesst +python3 -m pytest tests/test_global_defender.py -p no:warnings -vvvv -s --full-trace #python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace # run ruff check as well diff --git a/tests/test_global_defender.py b/tests/test_global_defender.py new file mode 100644 index 00000000..e47eb233 --- /dev/null +++ b/tests/test_global_defender.py @@ -0,0 +1,62 @@ +import pytest +from AIDojoCoordinator.game_components import ActionType, Action +from AIDojoCoordinator.global_defender import GlobalDefender +from unittest.mock import patch + +@pytest.fixture +def defender(): + return GlobalDefender() + +@pytest.fixture +def episode_actions(): + """Mock episode actions list.""" + return [ + Action(ActionType.ScanNetwork, {}).as_dict, + Action(ActionType.FindServices, {}).as_dict, + Action(ActionType.ScanNetwork, {}).as_dict, + Action(ActionType.FindServices, {}).as_dict, + ] + +def test_short_episode_does_not_detect(defender, episode_actions): + """Test when the episode action list is too short to make a decision.""" + action = Action(ActionType.ScanNetwork, {}) + assert not defender.stochastic_with_threshold(action, episode_actions[:2], tw_size=5) + +def test_below_threshold_does_not_trigger_detection(defender, episode_actions): + """Test when action thresholds are NOT exceeded (should return False).""" + action = Action(ActionType.ScanNetwork, {}) + assert not defender.stochastic_with_threshold(action, episode_actions, tw_size=5) + +def test_exceeding_threshold_triggers_stochastic(defender, episode_actions): + """Test when thresholds are exceeded and stochastic is triggered.""" + action = Action(ActionType.ScanNetwork, {}) + episode_actions += [action.as_dict] * 3 # Exceed threshold + + with patch.object(defender, "stochastic", return_value=True) as mock_stochastic: + result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5) + mock_stochastic.assert_called_once_with("ScanNetwork") # Ensure stochastic was called + assert result # Expecting True since stochastic is triggered + +def test_repeated_episode_action_threshold(defender, episode_actions): + """Test when an action exceeds the episode repeated action threshold.""" + action = Action(ActionType.FindData, {}) + episode_actions += [action.as_dict] * 3 # Exceed repeat threshold + + with patch.object(defender, "stochastic", return_value=True) as mock_stochastic: + result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5) + mock_stochastic.assert_called_once_with(ActionType.FindData) # Ensure stochastic was called + assert result # Expecting True since stochastic is triggered + +def test_other_actions_never_detected(defender, episode_actions): + """Test that actions not in any threshold lists always return False.""" + action = Action(ActionType.JoinGame, {}) + assert not defender.stochastic_with_threshold(action, episode_actions, tw_size=5) + +def test_mock_stochastic_probabilities(defender, episode_actions): + """Test stochastic function is only called when thresholds are crossed.""" + action = Action(ActionType.ScanNetwork, {}) + episode_actions += [{"action_type": str(ActionType.ScanNetwork)}] * 4 # Exceed threshold + + with patch("AIDojoCoordinator.global_defender.random", return_value=0.01): # Force detection probability + result = defender.stochastic_with_threshold(action, episode_actions, tw_size=5) + assert result # Should be True since we forced a low probability value \ No newline at end of file