diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index 00ff5af1..7d478401 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.12"] + python-version: ["3.12"] steps: - uses: actions/checkout@v3 @@ -22,7 +22,7 @@ jobs: run: | python -m pip install --upgrade pip pip install ruff pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f pyproject.toml ]; then pip install .[dev] ; fi - name: Test with pytest run: | tests/run_all_tests.sh diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index 8fb51fb7..0d474d37 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -4,6 +4,7 @@ import asyncio from datetime import datetime import signal + from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig from AIDojoCoordinator.global_defender import GlobalDefender from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser @@ -22,71 +23,84 @@ def __init__(self, actions_queue, agent_response_queues, max_connections): self.current_connections = 0 self.logger = logging.getLogger("AIDojo-AgentServer") + async def handle_agent_quit(self, peername:tuple): + """ + Helper function to handle agent disconnection. + """ + # Send a quit message to the Coordinator + self.logger.info(f"\tHandling agent quit for {peername}.") + quit_message = Action(ActionType.QuitGame, parameters={}).to_json() + await self.actions_queue.put((peername, quit_message)) + async def handle_new_agent(self, reader, writer): - async def send_data_to_agent(writer, data: str) -> None: - """ - Send the world to the agent - """ - writer.write(bytes(str(data).encode())) - - # Check if the maximum number of connections has been reached - if self.current_connections >= self.max_connections: - self.logger.info( - f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}" - ) - writer.close() - return - - # Increment the count of current connections - self.current_connections += 1 - - # Handle the new agent - addr = writer.get_extra_info("peername") - self.logger.info(f"New agent connected: {addr}") - # Ensure a queue exists for this agent - if addr not in self.answers_queues: - self.answers_queues[addr] = asyncio.Queue(maxsize=2) - self.logger.info(f"Created queue for agent {addr}") - + """ + Handle a new agent connection. + """ + # get the peername of the writer + peername = writer.get_extra_info("peername") + queue_created = False try: - while True: - # Step 1: Read data from the agent - data = await reader.read(ProtocolConfig.BUFFER_SIZE) - if not data: - self.logger.info(f"Agent {addr} disconnected.") - quit_message = Action(ActionType.QuitGame, parameters={}).to_json() - await self.actions_queue.put((addr, quit_message)) - break - - raw_message = data.decode().strip() - self.logger.debug(f"Handler received from {addr}: {raw_message}") - - # Step 2: Forward the message to the Coordinator - await self.actions_queue.put((addr, raw_message)) - # await asyncio.sleep(0)w - # Step 3: Get a matching response from the answers queue - response_queue = self.answers_queues[addr] - response = await response_queue.get() - self.logger.info(f"Sending response to agent {addr}: {response}") - - # Step 4: Send the response to the agent - response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE - writer.write(response) - await writer.drain() + self.logger.info(f"New connection from {peername}") + # Check if the maximum number of connections has been reached + if self.current_connections < self.max_connections: + # increment the count of current connections + self.current_connections += 1 + self.logger.info(f"New agent connected: {peername}. Current connections: {self.current_connections}") + # Ensure a queue exists for this agent + if peername not in self.answers_queues: + self.answers_queues[peername] = asyncio.Queue(maxsize=2) + queue_created = True + self.logger.info(f"Created queue for agent {peername}") + # Handle the new agent + while True: + # Step 1: Read data from the agent + data = await reader.read(ProtocolConfig.BUFFER_SIZE) + if not data: + self.logger.info(f"Agent {peername} disconnected.") + await self.handle_agent_quit(peername) + break + + raw_message = data.decode().strip() + self.logger.debug(f"Handler received from {peername}: {raw_message}") + + # Step 2: Forward the message to the Coordinator + await self.actions_queue.put((peername, raw_message)) + + # Step 3: Get a matching response from the answers queue + response_queue = self.answers_queues[peername] + response = await response_queue.get() + self.logger.info(f"Sending response to agent {peername}: {response}") + + # Step 4: Send the response to the agent + response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE + writer.write(response) + await writer.drain() + else: + self.logger.warning(f"Queue for agent {peername} already exists. Closing connection.") + else: + self.logger.info(f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}") + except ConnectionResetError: + self.logger.warning(f"Connection reset by {peername}") + await self.handle_agent_quit(peername) except asyncio.CancelledError: - self.logger.debug("Terminating by KeyboardInterrupt") + self.logger.debug("Connection handling cancelled.") + raise # Ensure the exception propagates + except Exception as e: + self.logger.error(f"Unexpected error with client {peername}: {e}") raise finally: - # Decrement the count of current connections - self.current_connections -= 1 - if addr in self.answers_queues: - self.answers_queues.pop(addr) - self.logger.info(f"Removed queue for agent {addr}") - else: - self.logger.warning(f"Queue for agent {addr} not found during cleanup.") - writer.close() - return - + try: + if peername in self.answers_queues: + # If the queue was created, remove it + if queue_created: + self.answers_queues.pop(peername) + self.logger.info(f"Removed queue for agent {peername}") + self.current_connections = max(0, self.current_connections - 1) + writer.close() + await writer.wait_closed() + except Exception: + # swallow exceptions on close to avoid crash on cleanup + pass async def __call__(self, reader, writer): await self.handle_new_agent(reader, writer) @@ -138,10 +152,14 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._agent_states = {} # last action played by agent (Action) self._agent_last_action = {} + # False positives per agent (due to added blocks) + self._agent_false_positives = {} # agent status dict {agent_addr: int} self._agent_rewards = {} # trajectories per agent_addr self._agent_trajectories = {} + # false_positives per agent_addr + self._agent_false_positives = {} def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task: "Helper function to make sure all tasks are registered for proper termination" @@ -331,7 +349,7 @@ async def start_tasks(self): else: self._global_defender = None self._use_dynamic_ips = self.task_config.get_use_dynamic_addresses() - self._rewards = self.task_config.get_rewards(["step", "success", "fail"]) + self._rewards = self.task_config.get_rewards(["step", "success", "fail", "false_positive"]) self.logger.info(f"Rewards set to:{self._rewards}") self._min_required_players = self.task_config.get_required_num_players() self.logger.info(f"Min player requirement set to:{self._min_required_players}") @@ -365,7 +383,7 @@ async def run_game(self): # Read message from the queue agent_addr, message = await self._agent_action_queue.get() if message is not None: - self.logger.info(f"Coordinator received: {message}.") + self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.") try: # Convert message to Action action = Action.from_json(message) self.logger.debug(f"\tConverted to: {action}.") @@ -375,20 +393,23 @@ async def run_game(self): ) match action.type: # process action based on its type case ActionType.JoinGame: - self.logger.debug(f"Start processing of ActionType.JoinGame by {agent_addr}") + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}") self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}") self._spawn_task(self._process_join_game_action, agent_addr, action) case ActionType.QuitGame: - self.logger.debug(f"Start processing of ActionType.QuitGame by {agent_addr}") + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}") self._spawn_task(self._process_quit_game_action, agent_addr) case ActionType.ResetGame: - self.logger.debug(f"Start processing of ActionType.ResetGame by {agent_addr}") + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}") self._spawn_task(self._process_reset_game_action, agent_addr, action) case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: - self.logger.debug(f"Start processing of {action.type} by {agent_addr}") + self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") + self._spawn_task(self._process_game_action, agent_addr, action) + case ActionType.BlockIP: + self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") self._spawn_task(self._process_game_action, agent_addr, action) case _: - self.logger.warning(f"Unsupported action type: {action}!") + self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!") self.logger.info("\tAction processing task stopped.") async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: @@ -412,7 +433,8 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No self.agents[agent_addr] = (agent_name, agent_role) observation = self._initialize_new_player(agent_addr, new_agent_game_state) self._agent_observations[agent_addr] = observation - if len(self.agents) == self._min_required_players: + #if len(self.agents) == self._min_required_players: + if sum(1 for v in self._agent_status.values() if v == AgentStatus.PlayingWithTimeout) >= self._min_required_players: # set the event so the episde can start self._episode_start_event.set() self.logger.info("Enough players joined. Starting the episode.") @@ -467,7 +489,10 @@ async def _process_quit_game_action(self, agent_addr: tuple)->None: Outputs: None """ try: - await self.remove_agent(agent_addr, self._agent_states[agent_addr]) + if agent_addr in self._agent_states: + await self.remove_agent(agent_addr, self._agent_states[agent_addr]) + else: + self.logger.warning(f"Agent address {agent_addr} not found in _agent_states. Skipping removal.") agent_info = await self._remove_agent_from_game(agent_addr) self.logger.info(f"Agent {agent_addr} removed from the game. {agent_info}") except asyncio.CancelledError: @@ -517,6 +542,13 @@ async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Actio await self._agent_response_queues[agent_addr].put(response_msg_json) async def _process_game_action(self, agent_addr: tuple, action:Action)->None: + """ + Method for processing Action of type ActionType.GameAction + Inputs: + - agent_addr (tuple) + - action (Action) + Outputs: None + """ if self._episode_ends[agent_addr]: self.logger.warning(f"Agent {agent_addr}({self.agents[agent_addr]}) is attempting to play action {action} after the end of the episode!") # agent can't play any more actions in the game @@ -616,7 +648,9 @@ async def _assign_rewards_episode_end(self): else: self._agent_rewards[agent] += self._rewards["fail"] self._agent_status[agent] = AgentStatus.Fail - # TODO Add penalty for False positives + # dicrease the reward for false positives + self.logger.debug(f"Processing false positives for agent {agent}: {self._agent_false_positives[agent]}") + self._agent_rewards[agent] -= self._agent_false_positives[agent] * self._rewards["false_positive"] # clear the episode end event self._episode_end_event.clear() # notify all waiting agents @@ -655,6 +689,7 @@ async def _reset_game(self): self._reset_requests[agent] = False self._agent_rewards[agent] = 0 self._agent_steps[agent] = 0 + self._agent_false_positives[agent] = 0 if self.agents[agent][1].lower() == "attacker": self._agent_status[agent] = AgentStatus.PlayingWithTimeout else: @@ -678,6 +713,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role] self._agent_states[agent_addr] = agent_current_state self._agent_rewards[agent_addr] = 0 + self._agent_false_positives[agent_addr] = 0 if agent_role.lower() == "attacker": self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout else: @@ -712,6 +748,7 @@ async def _remove_agent_from_game(self, agent_addr): agent_info["state"] = self._agent_states.pop(agent_addr) agent_info["num_steps"] = self._agent_steps.pop(agent_addr) agent_info["agent_status"] = self._agent_status.pop(agent_addr) + agent_info["false_positives"] = self._agent_false_positives.pop(agent_addr) async with self._reset_lock: agent_info["reset_request"] = self._reset_requests.pop(agent_addr) # check if this agent was not preventing reset @@ -787,6 +824,17 @@ def is_timeout(self, agent:tuple)->bool: timeout_reached = True return timeout_reached + def add_false_positive(self, agent:tuple)->None: + """ + Method for adding false positive to the agent. + """ + self.logger.debug(f"Adding false positive to {agent}") + if agent in self._agent_false_positives: + self._agent_false_positives[agent] += 1 + else: + self._agent_false_positives[agent] = 1 + self.logger.debug(f"False positives for {agent}: {self._agent_false_positives[agent]}") + def _update_agent_status(self, agent:tuple)->AgentStatus: """ Update the status of an agent based on reaching the goal, timeout or detection. @@ -855,4 +903,12 @@ def _store_trajectory_to_file(self, agent_addr:tuple, location="./trajectories") filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl") with jsonlines.open(filename, "a") as writer: writer.write(self._agent_trajectories[agent_addr]) - self.logger.info(f"Trajectory of {agent_addr} strored in {filename}") \ No newline at end of file + self.logger.info(f"Trajectory of {agent_addr} strored in {filename}") + + def is_agent_benign(self, agent_addr:tuple)->bool: + """ + Check if the agent is benign (defender, normal) + """ + if agent_addr not in self.agents: + return False + return self.agents[agent_addr][1].lower() in ["defender", "benign"] \ No newline at end of file diff --git a/AIDojoCoordinator/game_components.py b/AIDojoCoordinator/game_components.py index 7f775c91..a42c923e 100755 --- a/AIDojoCoordinator/game_components.py +++ b/AIDojoCoordinator/game_components.py @@ -130,11 +130,12 @@ class Data(): """ owner: str id: str - size: int = 0 + size: int = field(compare=False, hash=False, default=0) type: str = "" - + content: str = field(compare=False, hash=False, repr=False, default_factory=str) + def __hash__(self) -> int: - return hash((self.owner, self.id, self.size, self.type)) + return hash((self.owner, self.id, self.type)) @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -379,7 +380,7 @@ def from_dict(cls, data_dict:dict): controlled_hosts = {IP(x["ip"]) for x in data_dict["controlled_hosts"]}, known_services = {IP(k):{Service(s["name"], s["type"], s["version"], s["is_local"]) for s in services} for k,services in data_dict["known_services"].items()}, - known_data = {IP(k):{Data(v["owner"], v["id"]) for v in values} for k,values in data_dict["known_data"].items()}, + known_data = {IP(k):{Data(v["owner"], v["id"], v["size"], v["type"], v["content"]) for v in values} for k,values in data_dict["known_data"].items()}, known_blocks = known_blocks ) return state @@ -396,7 +397,7 @@ def from_json(cls, json_string): controlled_hosts = {IP(x["ip"]) for x in json_data["controlled_hosts"]}, known_services = {IP(k):{Service(s["name"], s["type"], s["version"], s["is_local"]) for s in services} for k,services in json_data["known_services"].items()}, - known_data = {IP(k):{Data(v["owner"], v["id"]) for v in values} for k,values in json_data["known_data"].items()}, + known_data = {IP(k):{Data(v["owner"], v["id"], v["size"], v["type"], v["content"]) for v in values} for k,values in json_data["known_data"].items()}, known_blocks = {IP(target_host):{IP(blocked_host) for blocked_host in blocked_hosts} for target_host, blocked_hosts in json_data["known_blocks"].items()} ) return state diff --git a/AIDojoCoordinator/netsecenv_conf.yaml b/AIDojoCoordinator/netsecenv_conf.yaml index 87c87c28..04f3b0e4 100644 --- a/AIDojoCoordinator/netsecenv_conf.yaml +++ b/AIDojoCoordinator/netsecenv_conf.yaml @@ -102,10 +102,12 @@ env: use_dynamic_addresses: False use_firewall: True save_trajectories: False + required_players: 1 rewards: success: 100 step: -1 fail: -10 + false_positive: -5 actions: scan_network: prob_success: 1.0 diff --git a/AIDojoCoordinator/utils/aidojo_log_colorizer.py b/AIDojoCoordinator/utils/aidojo_log_colorizer.py new file mode 100644 index 00000000..57cb6f78 --- /dev/null +++ b/AIDojoCoordinator/utils/aidojo_log_colorizer.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +""" +A simple log processor that colorizes log entries by agent and component, +interprets JSON payloads into readable summaries, +pretty-prints key fields, and highlights actions & agents. + +Usage: + python log_processor.py path/to/logfile.log +If no file is given, reads from stdin. +""" +import sys +import re +import json +from itertools import cycle +from rich import print +from rich.console import Console +from rich.text import Text + +# Force ANSI colors even when output is piped +console = Console(force_terminal=True, color_system="truecolor") + +# Cycle of colors to assign to different agents (background colors to avoid clashes) +# We'll use text in black on these backgrounds so they're distinct from other log colors +COLOR_CYCLE = ["cyan", "magenta", "green", "yellow", "blue", "red"] +agent_colors = {} +color_picker = cycle(COLOR_CYCLE) +agent_names = {} # map ip:port -> agent name + +# Regex patterns +timestamp_re = re.compile(r"^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})") +line_re = re.compile( + r"^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}) (?P\S+) (?P\S+)\s+(?P.*)$" +) +agent_id_re = re.compile(r"\('(?P[\d\.]+)', (?P\d+)\)") +agent_reg_re = re.compile( + r"Agent (?P\S+) \(\('(?P[\d\.]+)', (?P\d+)\)\)" +) +action_re = re.compile(r"ActionType\.[A-Za-z]+") + + +def get_agent_color(agent_id: str) -> str: + """Assign or retrieve a consistent background style for an agent.""" + if agent_id not in agent_colors: + base = next(color_picker) + agent_colors[agent_id] = f"black on {base}" + return agent_colors[agent_id] + + +def summarize_json(ts, source_styled, level_styled, prefix: str, parsed: dict): + """Interpret and display key fields from a JSON payload.""" + console.print(f"[bold]{ts}[/bold] ", source_styled, level_styled, f"{prefix}:" ) + indent = " " + # Status + status = parsed.get('status') + if status is not None: + console.print(Text(f"{indent}Status: {status}", style="bold green")) + + # To agent + to_agent = parsed.get('to_agent') + if to_agent: + console.print(Text(f"{indent}To Agent: {to_agent[0]}:{to_agent[1]}", style="bold cyan")) + + # Message payload (optional) + msg_block = parsed.get('message') + if isinstance(msg_block, dict): + text = msg_block.get('message') + if text: + console.print(Text(f"{indent}Message: {text}", style="bold")) + max_steps = msg_block.get('max_steps') + if max_steps is not None: + console.print(f"{indent}Max Steps: {max_steps}") + goal = msg_block.get('goal_description') + if goal: + console.print(f"{indent}Goal: {goal}") + actions = msg_block.get('actions') + if actions: + console.print(f"{indent}Actions: {', '.join(actions)}") + conf = msg_block.get('configuration_hash') + if conf: + console.print(f"{indent}Config Hash: {conf}") + + # Observation state: networks, hosts, services, data, blocks + obs = parsed.get('observation', {}) + state = obs.get('state', {}) + if state: + nets = state.get('known_networks', []) + if nets: + items = [f"{n['ip']}/{n['mask']}" for n in nets] + console.print(f"{indent}Known Networks: {', '.join(items)}") + hosts = state.get('known_hosts', []) + if hosts: + items = [h['ip'] for h in hosts] + console.print(f"{indent}Known Hosts: {', '.join(items)}") + ctrl = state.get('controlled_hosts', []) + if ctrl: + items = [h['ip'] for h in ctrl] + console.print(f"{indent}Controlled Hosts: {', '.join(items)}") + services = state.get('known_services', {}) + if services: + for host, svcs in services.items(): + names = [s['name'] for s in svcs] + console.print(f"{indent}Services on {host}: {', '.join(names)}") + data = state.get('known_data', {}) + if data: + for host, entries in data.items(): + ids = [e.get('id') for e in entries] + console.print(f"{indent}Data on {host}: {', '.join(ids)}") + # Known blocks + blocks = state.get('known_blocks', {}) + if isinstance(blocks, dict): + ips = list(blocks.keys()) + elif isinstance(blocks, list): + ips = [b.get('ip') for b in blocks] + else: + ips = [] + if ips: + console.print(f"{indent}Blocked Hosts: {', '.join(ips)}") + else: + console.print(f"{indent}Blocked Hosts: None") + + # Reward & End: top-level or under observation + reward = parsed.get('reward', obs.get('reward')) + if reward is not None: + console.print(Text(f"{indent}Reward: {reward}", style="bold magenta")) + end = parsed.get('end', obs.get('end')) + if end is not None: + console.print(Text(f"{indent}End: {end}", style="bold red")) + + # End Reason: after End + info = parsed.get('info', {}) + obs_info = obs.get('info', {}) + end_reason = info.get('end_reason') or obs_info.get('end_reason') + if end_reason: + console.print(Text(f"{indent}End Reason: {end_reason}", style="bold yellow")) + + +def process_line(line: str): + raw = line.rstrip("\n") + m = line_re.match(raw) + if not m: + print(raw) + return + + ts = m.group('ts') + source = m.group('source') + level = m.group('level') + msg = m.group('msg').strip() + + # Capture agent registration + reg = agent_reg_re.search(msg) + if reg: + aid = f"{reg.group('ip')}:{reg.group('port')}" + agent_names[aid] = reg.group('name') + + # Style source + if 'GameCoordinator' in source: + source_styled = Text(source, style='bold red') + elif 'AgentServer' in source: + source_styled = Text(source, style='bold blue') + else: + source_styled = Text(source, style='bold white') + + # Style level + level_style = 'dim white' if level == 'INFO' else 'bold red' + level_styled = Text(level, style=level_style) + + # Highlight actions + msg = action_re.sub(lambda m: f"[bold magenta]{m.group(0)}[/bold magenta]", msg) + + # Annotate agents + def repl_agent(m): + aid = f"{m.group('ip')}:{m.group('port')}" + style = get_agent_color(aid) + name = agent_names.get(aid) + label = f"{name} {aid}" if name else aid + # Use background style for agent label + return f"[{style}]{label}[/{style}]" + msg_markup = agent_id_re.sub(repl_agent, msg) + + # Detect JSON and interpret + if '{' in msg and (msg.strip().startswith('{') or ': {' in msg): + idx = msg.find('{') + prefix = msg[:idx].rstrip(': ') + json_part = msg[idx:] + try: + parsed = json.loads(json_part) + summarize_json(ts, source_styled, level_styled, prefix, parsed) + return + except json.JSONDecodeError: + pass + + # Default print + console.print(Text(ts, style='bold'), source_styled, level_styled, Text.from_markup(msg_markup)) + + +def main(): + if len(sys.argv) > 1: + with open(sys.argv[1], 'r') as f: + for line in f: + process_line(line) + else: + for line in sys.stdin: + process_line(line) + + +if __name__ == '__main__': + main() diff --git a/AIDojoCoordinator/utils/utils.py b/AIDojoCoordinator/utils/utils.py index 6b4a92e5..b6305b9f 100644 --- a/AIDojoCoordinator/utils/utils.py +++ b/AIDojoCoordinator/utils/utils.py @@ -8,7 +8,7 @@ from AIDojoCoordinator.scenarios import smaller_scenario_configuration from AIDojoCoordinator.scenarios import tiny_scenario_configuration from AIDojoCoordinator.scenarios import three_net_scenario -from AIDojoCoordinator.game_components import IP, Data, Network, Service, GameState, Action, Observation +from AIDojoCoordinator.game_components import IP, Data, Network, Service, GameState, Action, Observation, ActionType import netaddr import logging import csv @@ -114,6 +114,22 @@ def observation_as_dict(observation:Observation)->dict: } return observation_dict +def parse_log_content(log_content:str)->list: + try: + logs = [] + data = json.loads(log_content) + for item in data: + ip = IP(item["source_host"]) + action_type = ActionType.from_string(item["action_type"]) + logs.append({"source_host":ip, "action_type":action_type}) + return logs + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}") + return None + except TypeError as e: + print(f"Error decoding JSON: {e}") + return None + class ConfigParser(): """ Class to deal with the configuration file diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py index c1f53d13..1bb6e066 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -8,6 +8,7 @@ from faker import Faker from pathlib import Path import netaddr, re +import json from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service from AIDojoCoordinator.coordinator import GameCoordinator @@ -27,6 +28,7 @@ def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attack self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs. self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects. self._fw_blocks = {} + self._agent_fw_rules = {} # All exploits in the environment self._exploits = {} # A list of all the hosts where the attacker can start in a random start @@ -97,15 +99,15 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga for controlled_host in controlled_hosts: for net in self._get_networks_from_host(controlled_host): #TODO net_obj = netaddr.IPNetwork(str(net)) - if net_obj.ip.is_ipv4_private_use(): #TODO + if net_obj.ip.is_private(): #TODO known_networks.add(net) net_obj.value += 256 - if net_obj.ip.is_ipv4_private_use(): + if net_obj.ip.is_private(): ip = Network(str(net_obj.ip), net_obj.prefixlen) self.logger.debug(f'\tAdding {ip} to agent') known_networks.add(ip) net_obj.value -= 2*256 - if net_obj.ip.is_ipv4_private_use(): + if net_obj.ip.is_private(): ip = Network(str(net_obj.ip), net_obj.prefixlen) self.logger.debug(f'\tAdding {ip} to agent') known_networks.add(ip) @@ -232,7 +234,7 @@ def process_firewall()->dict: # LOCAL NETWORKS for net, ips in self._networks.items(): # IF net is local, allow connection between all nodes in it - if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use(): + if netaddr.IPNetwork(str(net)).ip.is_private(): for src in ips: for dst in ips: firewall[src].add(dst) @@ -240,9 +242,9 @@ def process_firewall()->dict: # LOCAL TO INTERNET for net, ips in self._networks.items(): # IF net is local, allow connection between all nodes in it - if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use(): + if netaddr.IPNetwork(str(net)).ip.is_private(): for public_net, public_ips in self._networks.items(): - if not netaddr.IPNetwork(str(public_net)).ip.is_ipv4_private_use(): + if not netaddr.IPNetwork(str(public_net)).ip.is_private(): for src in ips: for dst in public_ips: firewall[src].add(dst) @@ -281,6 +283,13 @@ def process_firewall()->dict: #exploits self._exploits = exploits + + # create logfile in each nodes + for node in nodes + routers: + if node.id not in self._data: + self._data[node.id] = set() + self.logger.info(f"\tAdding logfile to node {node.id}") + self._data[node.id].add(Data(owner="system", id="logfile", type="log", size=0)) #create initial mapping self.logger.info("\tCreating initial mapping of IPs and Networks") for net in self._networks.keys(): @@ -302,7 +311,7 @@ def _create_new_network_mapping(self)->tuple: # generate mapping for networks private_nets = [] for net in self._networks.keys(): - if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use(): + if netaddr.IPNetwork(str(net)).ip.is_private(): private_nets.append(net) else: mapping_nets[net] = Network(fake.ipv4_public(), net.mask) @@ -326,7 +335,7 @@ def _create_new_network_mapping(self)->tuple: # find the new mapping new_net_addr = netaddr.IPNetwork(str(mapping_nets[private_nets_sorted[0]])).ip + diff_ip # evaluate if its still a private network - is_private_net_checks.append(new_net_addr.is_ipv4_private_use()) + is_private_net_checks.append(new_net_addr.is_private()) # store the new mapping mapping_nets[private_nets_sorted[i]] = Network(str(new_net_addr), private_nets_sorted[i].mask) if False not in is_private_net_checks: # verify that ALL new networks are still in the private ranges @@ -508,7 +517,7 @@ def _get_data_content(self, host_ip:str, data_id:str)->str: self.logger.debug("Data content not found because target IP does not exists.") return content - def _execute_action(self, current_state:GameState, action:Action)-> GameState: + def _execute_action(self, current_state:GameState, action:Action, agent_id:tuple)-> GameState: """ Execute the action and update the values in the state Before this function it was checked if the action was successful @@ -522,21 +531,37 @@ def _execute_action(self, current_state:GameState, action:Action)-> GameState: next_state = None match action.type: case ActionType.ScanNetwork: - next_state = self._execute_scan_network_action(current_state, action) + next_state = self._execute_scan_network_action(current_state, action, agent_id) case ActionType.FindServices: - next_state = self._execute_find_services_action(current_state, action) + next_state = self._execute_find_services_action(current_state, action, agent_id) case ActionType.FindData: - next_state = self._execute_find_data_action(current_state, action) + next_state = self._execute_find_data_action(current_state, action, agent_id) case ActionType.ExploitService: - next_state = self._execute_exploit_service_action(current_state, action) + next_state = self._execute_exploit_service_action(current_state, action, agent_id) case ActionType.ExfiltrateData: - next_state = self._execute_exfiltrate_data_action(current_state, action) + next_state = self._execute_exfiltrate_data_action(current_state, action, agent_id) case ActionType.BlockIP: - next_state = self._execute_block_ip_action(current_state, action) + next_state = self._execute_block_ip_action(current_state, action, agent_id) case _: raise ValueError(f"Unknown Action type or other error: '{action.type}'") return next_state + def _record_false_positive(self, src_ip:IP, dst_ip:IP, agent_id:tuple)->bool: + # only record false positive if the agent is benign + if self.is_agent_benign(agent_id): + # find agent(s) that created the rule + src_host = src_ip + dst_host = dst_ip + if (src_host, dst_host) in self._agent_fw_rules: + # check if this connection is actively blocked + for author_agent in self._agent_fw_rules[(src_host, dst_host)]: + self.logger.info(f"Adding false positive for blocking {src_host} -> {dst_host} by {author_agent}") + if author_agent not in self._agent_false_positives: + self._agent_false_positives[author_agent] = 0 + self._agent_false_positives[author_agent] += 1 + else: + self.logger.debug(f"False positive for blocking {src_host} -> {dst_host} caused by the system configuration.") + def _state_parts_deep_copy(self, current:GameState)->tuple: next_nets = copy.deepcopy(current.known_networks) next_known_h = copy.deepcopy(current.known_hosts) @@ -554,7 +579,7 @@ def _firewall_check(self, src_ip:IP, dst_ip:IP)->bool: connection_allowed = False return connection_allowed - def _execute_scan_network_action(self, current_state:GameState, action:Action)->GameState: + def _execute_scan_network_action(self, current_state:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the ScanNetwork action in the environment """ @@ -568,14 +593,16 @@ def _execute_scan_network_action(self, current_state:GameState, action:Action)-> if self._firewall_check(action.parameters["source_host"], ip): self.logger.debug(f"\t\t\tAdding {ip} to new_ips") new_ips.add(ip) + self.update_log_file(next_data,action, ip) else: + self._record_false_positive(action.parameters["source_host"], ip, agent_id) self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {ip} blocked by FW. Skipping") next_known_h = next_known_h.union(new_ips) else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_find_services_action(self, current_state:GameState, action:Action)->GameState: + def _execute_find_services_action(self, current_state:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the FindServices action in the environment """ @@ -593,13 +620,16 @@ def _execute_find_services_action(self, current_state:GameState, action:Action)- self.logger.debug(f"\t\tAdding {action.parameters['target_host']} to known_hosts") next_known_h.add(action.parameters["target_host"]) next_nets = next_nets.union({net for net, values in self._networks.items() if action.parameters["target_host"] in values}) + # update logs + self.update_log_file(next_data,action, action.parameters['target_host']) else: + self._record_false_positive(action.parameters["source_host"], action.parameters["target_host"], agent_id) self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_find_data_action(self, current:GameState, action:Action)->GameState: + def _execute_find_data_action(self, current:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the FindData action in the environment """ @@ -607,6 +637,8 @@ def _execute_find_data_action(self, current:GameState, action:Action)->GameState self.logger.debug(f"\t\tSearching for data in {action.parameters['target_host']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current.controlled_hosts: if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): + # update logs before getting the data so this action is listed there + self.update_log_file(next_data,action, action.parameters['target_host']) new_data = self._get_data_in_host(action.parameters["target_host"], current.controlled_hosts) self.logger.debug(f"\t\t\t Found {len(new_data)}: {new_data}") if len(new_data) > 0: @@ -622,12 +654,13 @@ def _execute_find_data_action(self, current:GameState, action:Action)->GameState else: next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocks) else: + self._record_false_positive(action.parameters["source_host"], action.parameters["target_host"], agent_id) self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action)->GameState: + def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the ExfiltrateData action in the environment """ @@ -643,6 +676,8 @@ def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action # Does the current state for THIS source already know about this data? if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): if action.parameters['source_host'] in current_state.known_data.keys() and action.parameters["data"] in current_state.known_data[action.parameters["source_host"]]: + # update logs + self.update_log_file(next_data,action, action.parameters['target_host']) # Does the source host have any data? if self._ip_to_hostname[action.parameters["source_host"]] in self._data.keys(): # Does the source host have this data? @@ -664,6 +699,7 @@ def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action else: self.logger.debug("\t\t\tCan not exfiltrate. Agent did not find this data yet.") else: + self._record_false_positive(action.parameters["source_host"], action.parameters["target_host"], agent_id) self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug("\t\t\tCan not exfiltrate. Source host is not controlled.") @@ -671,7 +707,7 @@ def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action self.logger.debug("\t\t\tCan not exfiltrate. Target host is not controlled.") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_exploit_service_action(self, current_state:GameState, action:Action)->GameState: + def _execute_exploit_service_action(self, current_state:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the ExploitService action in the environment """ @@ -700,7 +736,10 @@ def _execute_exploit_service_action(self, current_state:GameState, action:Action self.logger.debug("\t\t\tCan not exploit. Target host does not the service that was attempted.") else: self.logger.debug("\t\t\tCan not exploit. Target host does not have any services.") + # update logs + self.update_log_file(next_data,action, action.parameters['target_host']) else: + self._record_false_positive(action.parameters["source_host"], action.parameters["target_host"], agent_id) self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug("\t\t\tCan not exploit. Target host does not exist.") @@ -708,7 +747,7 @@ def _execute_exploit_service_action(self, current_state:GameState, action:Action self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_block_ip_action(self, current_state:GameState, action:Action)->GameState: + def _execute_block_ip_action(self, current_state:GameState, action:Action, agent_id:tuple)->GameState: """ Executes the BlockIP action - The action has BlockIP("target_host": IP object, "source_host": IP object, "blocked_host": IP object) @@ -732,6 +771,14 @@ def _execute_block_ip_action(self, current_state:GameState, action:Action)->Game if self._firewall_check(action.parameters["source_host"], action.parameters["target_host"]): if action.parameters["target_host"] != action.parameters['blocked_host']: self.logger.info(f"\t\tBlockConnection {action.parameters['target_host']} <-> {action.parameters['blocked_host']}") + # record which agent is adding the blocking rule + if (action.parameters["target_host"], action.parameters["blocked_host"]) not in self._agent_fw_rules: + self._agent_fw_rules[(action.parameters["target_host"], action.parameters["blocked_host"])] = set() + self._agent_fw_rules[(action.parameters["target_host"], action.parameters["blocked_host"])].add(agent_id) + # both directions are blocked + if (action.parameters["blocked_host"], action.parameters["target_host"]) not in self._agent_fw_rules: + self._agent_fw_rules[(action.parameters["blocked_host"], action.parameters["target_host"])] = set() + self._agent_fw_rules[(action.parameters["blocked_host"], action.parameters["target_host"])].add(agent_id) try: #remove connection target_host -> blocked_host self._firewall[action.parameters["target_host"]].discard(action.parameters["blocked_host"]) @@ -762,7 +809,10 @@ def _execute_block_ip_action(self, current_state:GameState, action:Action)->Game next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"]) else: self.logger.debug(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'") + # update logs + self.update_log_file(next_data,action, action.parameters['target_host']) else: + self._record_false_positive(action.parameters["source_host"], action.parameters["target_host"], agent_id) self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW") else: self.logger.debug(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'") @@ -773,12 +823,30 @@ def _execute_block_ip_action(self, current_state:GameState, action:Action)->Game def _get_all_local_ips(self)->set: local_ips = set() for net, ips in self._networks.items(): - if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use(): + if netaddr.IPNetwork(str(net)).ip.is_private(): for ip in ips: local_ips.add(self._ip_mapping[ip]) self.logger.info(f"\t\t\tLocal ips: {local_ips}") return local_ips - + + def update_log_file(self, known_data:set, action, target_host:IP): + hostaname = self._ip_to_hostname[target_host] + self.logger.debug(f"Updating log file in host {hostaname}") + try: + current_log_file = list(filter(lambda x: x.owner == "system" and x.type == "log",self._data[hostaname]))[0] + self._data[hostaname].discard(current_log_file) + if current_log_file.size == 0: + content = [] + else: + content = json.loads(current_log_file.content) + content.append({'source_host': str(action.parameters["source_host"]), 'action_type': str(action.type)}) + new_content = json.dumps(content) + except KeyError: + self.logger.debug(f"\t\t\tLog not found in host {hostaname}. Creating new one.") + new_content = [{'source_host': str(action.parameters["source_host"]), 'action_type': str(action.type)}] + new_content = json.dumps(new_content) + self._data[hostaname].add(Data(owner="system", id="logfile", type="log", size=len(new_content) , content= new_content)) + async def register_agent(self, agent_id, agent_role, agent_initial_view)->GameState: if len(self._networks) == 0: self._initialize() @@ -790,7 +858,7 @@ async def remove_agent(self, agent_id, agent_state)->bool: return True async def step(self, agent_id, agent_state, action)->GameState: - return self._execute_action(agent_state, action) + return self._execute_action(agent_state, action, agent_id) async def reset_agent(self, agent_id, agent_role, agent_initial_view)->GameState: game_state = self._create_state_from_view(agent_initial_view) @@ -809,8 +877,10 @@ async def reset(self)->bool: # reset self._data to orignal state self._data = copy.deepcopy(self._data_original) # reset self._data_content to orignal state - self._firewall = copy.deepcopy(self._firewall) + # reset all firewall related data structure + self._firewall = copy.deepcopy(self._firewall_original) self._fw_blocks = {} + self._agent_fw_rules = {} return True if __name__ == "__main__": diff --git a/NetSecGameAgents b/NetSecGameAgents index 5c966fd7..3a46c12d 160000 --- a/NetSecGameAgents +++ b/NetSecGameAgents @@ -1 +1 @@ -Subproject commit 5c966fd7f8560a2632dd4279392440219b4bb486 +Subproject commit 3a46c12d64334303210cdf6b373177febe642079 diff --git a/README.md b/README.md index 6ad2803c..1fcb26da 100755 --- a/README.md +++ b/README.md @@ -389,7 +389,7 @@ It is advised after every change you test if the env is running correctly by doi ```bash tests/run_all_tests.sh ``` -This will load and run the unit tests in the `tests` folder. +This will load and run the unit tests in the `tests` folder. After passing all tests, linting and formatting is checked with ruff. ## Code adaptation for new configurations The code can be adapted to new configurations of games and for new agents. See [Agent repository](https://github.com/stratosphereips/NetSecGameAgents/tree/main) for more details. diff --git a/pyproject.toml b/pyproject.toml index 56f869e4..c09d7c7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,10 +2,6 @@ requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" -[tool.setuptools.packages.find] -where = ["."] -exclude = ["tests*"] - [project] name = "AIDojoGameCoordinator" version = "0.1.0" @@ -55,4 +51,29 @@ requires-python = ">=3.12" dev = [ "pytest", "ruff", -] \ No newline at end of file + "pytest-asyncio" +] + +[project.urls] +Homepage = "https://github.com/stratosphereips/NetSecGame" +Repository = "https://github.com/stratosphereips/NetSecGame" +Documentation = "https://github.com/stratosphereips/NetSecGame" +Issues = "https://github.com/stratosphereips/NetSecGame/issues" + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["tests*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +asyncio_mode = "auto" +addopts = "-p no:warnings -vvvv -s --full-trace" + +[tool.ruff] +lint.select = ["E9", "F4", "F6", "F7", "F8", "N8"] +lint.ignore = ["F405"] +output-format = "github" +target-version = "py312" +line-length = 120 +exclude = ["NetSecGameAgents"] \ No newline at end of file diff --git a/tests/test_actions.py b/tests/OLD_test_actions.py similarity index 100% rename from tests/test_actions.py rename to tests/OLD_test_actions.py diff --git a/tests/test_components.py b/tests/components/test_action.py similarity index 58% rename from tests/test_components.py rename to tests/components/test_action.py index 140d38f4..16f98b62 100644 --- a/tests/test_components.py +++ b/tests/components/test_action.py @@ -1,192 +1,39 @@ -""" -Tests related to the game components in the Network Security Game Environment -Author: Maria Rigaki - maria.rigaki@fel.cvut.cz -""" +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz import json -from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service, GameState, AgentInfo +from AIDojoCoordinator.game_components import Action, ActionType, IP, Network, Data, Service, AgentInfo -class TestComponentsIP: +class TestComponentActionType: """ - Tests related to the IP datclass + Test cases for the ActionType enum """ - def test_ip_repr(self): - """Test the object representation""" - ip_1 = IP("192.168.1.15") - assert repr(ip_1) == "192.168.1.15" - - def test_ip_equal(self): - """Test that two IP objects with the same IP are equal""" - ip_1 = IP("192.168.1.15") - ip_2 = IP("192.168.1.15") - assert ip_1 == ip_2 - - def test_ip_not_equal(self): - """Test that two IP objects with different IPs are not equal""" - ip_1 = IP("192.168.1.15") - ip_2 = IP("192.168.2.15") - assert ip_1 != ip_2 - - def test_ip_not_str(self): - """Test that the IP object is not equal to a string""" - ip_1 = IP("192.168.1.15") - ip_2 = "192.168.2.15" - assert ip_1 != ip_2 - - def test_ip_is_private(self): - ip_1 = IP("192.168.1.15") - assert ip_1.is_private() is True - - def test_ip_is_not_private(self): - ip_1 = IP("192.143.1.15") - assert ip_1.is_private() is False - -class TestServices: - """ - Tests related to the Service dataclass - """ - def test_service_creation(self): - """ - Test that the service is created and all elements can be accessed - """ - service = Service("rdp", "passive", "1.067", True) - assert service.name == "rdp" - assert service.type == "passive" - assert service.version == "1.067" - assert service.is_local - - def test_services_equal(self): - """ - Test that two services with the same parameters are equal - """ - service_1 = Service("rdp", "passive", "1.067", True) - service_2 = Service("rdp", "passive", "1.067", True) - assert service_1 == service_2 - assert service_1 is not service_2 - - def test_services_not_equal(self): - """ - Test that two services with different parameters are not equal - """ - service_1 = Service("rdp", "passive", "1.067", True) - service_2 = Service("sql", "passive", "5.0", True) - assert service_1 != service_2 - -class TestNetwork: - """ - Test cases for the Network dataclass - """ - def test_net_creation(self): - """ - Test that the network is created and all elements can be accessed - """ - net = Network("125.36.21.3", 16) - assert net.ip == "125.36.21.3" - assert net.mask == 16 - - def test_net_str(self): - """ - Test the string representaion of the network - """ - net = Network("125.36.21.3", 16) - assert str(net) == "125.36.21.3/16" - - def test_net_repr(self): - """ - Test the repr of the Network - """ - net = Network("125.36.21.3", 16) - assert repr(net) == "125.36.21.3/16" - - def test_net_equal(self): - """ - Test that two network objects with the same paramters are equal - """ - net_1 = Network("125.36.21.3", 16) - net_2 = Network("125.36.21.3", 16) - assert net_1 == net_2 - - def test_net_not_equal(self): - """ - Test that two network objects with different paramters are not equal - """ - net_1 = Network("125.36.21.3", 16) - net_2 = Network("192.168.1.3", 16) - assert net_1 != net_2 - - def test_net_is_not_private(self): - net_1 = Network("125.36.21.3", 16) - assert net_1.is_private() is False - - def test_net_is_private(self): - net_1 = Network("192.168.1.0", 16) - assert net_1.is_private() is True - -class TestData: - """ - Test cases for the Data class - """ - def test_create_data_minimal(self): - """ - Test that the data object is created with ONLY required fields (using default for the rest) - """ - data = Data(owner="Ondra", id="Password") - assert data.owner == "Ondra" - assert data.id == "Password" - assert data.type == "" - assert data.size == 0 + def test_action_type_str(self): + """ + Test that the string representation of the ActionType enum is correct + """ + assert str(ActionType.FindData) == "ActionType.FindData" + assert str(ActionType.FindServices) == "ActionType.FindServices" + assert str(ActionType.ScanNetwork) == "ActionType.ScanNetwork" + assert str(ActionType.ExploitService) == "ActionType.ExploitService" + assert str(ActionType.ExfiltrateData) == "ActionType.ExfiltrateData" + assert str(ActionType.JoinGame) == "ActionType.JoinGame" + assert str(ActionType.ResetGame) == "ActionType.ResetGame" + assert str(ActionType.QuitGame) == "ActionType.QuitGame" - def test_create_data_all(self): + def test_action_type_hash(self): """ - Test that the data object is created with ALL fields (using default for the rest) + Test that the hash of the ActionType enum is correct """ - data = Data(owner="Ondra", id="Password",size=42, type="txt") - assert data.owner == "Ondra" - assert data.id == "Password" - assert data.type == "txt" - assert data.size == 42 + assert hash(ActionType.FindData) == hash("FindData") + assert hash(ActionType.FindServices) == hash("FindServices") + assert hash(ActionType.ScanNetwork) == hash("ScanNetwork") + assert hash(ActionType.ExploitService) == hash("ExploitService") + assert hash(ActionType.ExfiltrateData) == hash("ExfiltrateData") + assert hash(ActionType.JoinGame) == hash("JoinGame") + assert hash(ActionType.ResetGame) == hash("ResetGame") + assert hash(ActionType.QuitGame) == hash("QuitGame") - def test_data_equal(self): - """ - Test that two data objects with the same required parameters are equal - """ - data = Data("Ondra", "Password") - data2 = Data("Ondra", "Password") - # test equality with all fields used - data3 = Data(owner="Ondra", id="Password",size=42, type="txt") - data4 = Data(owner="Ondra", id="Password", size=42, type="txt") - assert data == data2 - assert data3 == data4 - - def test_data_not_equal(self): - """ - Test that two data objects with different required parameters are NOT equal - """ - data = Data("Ondra", "Password") - data2 = Data("ChuckNorris", "Password") - data3 = Data(owner="Ondra", id="Password",size=42, type="txt") - data4 = Data(owner="Ondra", id="DifferentPassword",size=41, type="rsa") - assert data != data2 - assert data3 != data4 - - def test_data_hash_equal(self): - data = Data("Ondra", "Password") - data2 = Data("Ondra", "Password") - # test equality with all fields used - data3 = Data(owner="Ondra", id="Password",size=42, type="txt") - data4 = Data(owner="Ondra", id="Password",size=42, type="txt") - assert hash(data) == hash(data2) - assert hash(data3) == hash(data4) - - def test_data_hash_not_equal(self): - data = Data("Ondra", "Password") - data2 = Data("Ondra", "NewPassword") - # test equality with all fields used - data3 = Data(owner="Ondra", id="Password",size=42, type="txt") - data4 = Data(owner="Ondra", id="Password",size=41, type="rsa") - assert hash(data) != hash(data2) - assert hash(data3) != hash(data4) - -class TestAction: +class TestComponentAction: """ Test cases for the Action class """ @@ -403,7 +250,7 @@ def test_action_to_json(self): assert "ActionType.ExfiltrateData" in data["action_type"] assert ("parameters", {"target_host": {"ip": "172.16.1.3"}, "source_host" : {"ip": "172.16.1.2"}, - "data":{"owner":"User2", "id":"PublicKey", "size":42 ,"type":"pub"}}) in data.items() + "data":{"owner":"User2", "id":"PublicKey", "size":42 ,"type":"pub", "content":""}}) in data.items() def test_action_scan_network_serialization(self): action = Action(action_type=ActionType.ScanNetwork, @@ -589,152 +436,4 @@ def test_action_to_dict_quit_game(self): new_action = Action.from_dict(action_dict) assert action == new_action assert action_dict["action_type"] == str(action.type) - assert len(action_dict["parameters"]) == 0 - -class TestGameState: - """ - Test cases related to the GameState class - """ - def test_create_game_state(self): - """ - Test the correct creation of the GameState class - """ - game_state = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - - assert isinstance(game_state.controlled_hosts, set) - assert len(game_state.known_hosts) == 2 - assert IP("192.168.1.1") in game_state.known_hosts - assert IP("192.168.1.1") in game_state.controlled_hosts - assert Data("User2", "PublicKey") in game_state.known_data - assert Network("192.168.1.0", 24) in game_state.known_networks - - def test_state_equal(self): - """ - Test that two game states with the same parameters are equal - """ - game_state = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - game_state2 = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - - assert game_state == game_state2 - - def test_state_not_equal_diff_control(self): - """ - Test that two game states with diffrent parameters are not equal. - Different controlled hosts. - """ - game_state = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - game_state2 = GameState(controlled_hosts={IP("172.16.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - - assert game_state != game_state2 - - def test_state_not_equal_diff_known(self): - """ - Test that two game states with diffrent parameters are not equal - Different known hosts. - """ - game_state = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - - game_state2 = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}, - known_networks={Network('192.168.1.0', 24)}) - - assert game_state != game_state2 - - def test_state_not_equal_diff_data(self): - """ - Test that two game states with diffrent parameters are not equal. - Different data. - """ - game_state = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User2", "PublicKey")}) - game_state2 = GameState(controlled_hosts={IP("192.168.1.1")}, - known_hosts={IP("192.168.1.1"), IP("8.8.8.8")}, - known_services=set(), - known_data={Data("User", "PublicKey")}) - assert game_state != game_state2 - - - def test_game_state_as_json(self): - game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, - known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, - known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, - known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, - IP("192.168.1.2"):{Data("McGiver", "data2", 42, "txt")}}) - game_json = game_state.as_json() - try: - data = json.loads(game_json) - except ValueError: - data = None - assert data is not None - assert {"ip": "1.1.1.1", "mask": 24} in data["known_networks"] - assert {"ip": "192.168.1.3"} in data["known_hosts"] - assert {"ip": "192.168.1.2"} in data["controlled_hosts"] - assert ("192.168.1.3", [{"name": "service1", "type": "public", "version": "1.01", "is_local": True}]) in data["known_services"].items() - assert {"owner": "ChuckNorris", "id": "data1", "size":0, "type":""} in data["known_data"]["192.168.1.3"] - assert {"owner": "ChuckNorris", "id": "data2", "size":0, "type":""} in data["known_data"]["192.168.1.3"] - assert {"owner": "McGiver", "id": "data2", "size":42, "type":"txt"} in data["known_data"]["192.168.1.2"] - - def test_game_state_json_deserialized(self): - game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, - known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, - known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, - known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, - IP("192.168.1.2"):{Data("McGiver", "data2")}}) - state_json = game_state.as_json() - deserialized_state = GameState.from_json(state_json) - assert game_state is not deserialized_state - assert game_state == deserialized_state - - def test_game_state_as_dict(self): - game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, - known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, - known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, - known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, - IP("192.168.1.2"):{Data("McGiver", "data2")}}) - game_dict = game_state.as_dict - assert game_dict is not None - assert {"ip": "1.1.1.1", "mask": 24} in game_dict["known_networks"] - assert {"ip": "192.168.1.3"} in game_dict["known_hosts"] - assert {"ip": "192.168.1.2"} in game_dict["controlled_hosts"] - assert ("192.168.1.3", [{"name": "service1", "type": "public", "version": "1.01", "is_local": True}]) in game_dict["known_services"].items() - assert {"owner": "ChuckNorris", "id": "data1", "size":0, "type":""} in game_dict["known_data"]["192.168.1.3"] - assert {"owner": "ChuckNorris", "id": "data2", "size":0, "type":""} in game_dict["known_data"]["192.168.1.3"] - - def test_game_state_from_dict(self): - game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, - known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, - known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}}, - known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")}, - IP("192.168.1.2"):{Data("McGiver", "data2")}}) - game_dict = game_state.as_dict - deserialized_state = GameState.from_dict(game_dict) - assert game_state is not deserialized_state - assert game_state == deserialized_state \ No newline at end of file + assert len(action_dict["parameters"]) == 0 \ No newline at end of file diff --git a/tests/components/test_data.py b/tests/components/test_data.py new file mode 100644 index 00000000..b3946f94 --- /dev/null +++ b/tests/components/test_data.py @@ -0,0 +1,107 @@ +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import pytest +import dataclasses +from AIDojoCoordinator.game_components import Data + +@pytest.fixture +def sample_data_minimal(): + """Fixture to provide a sample Data object with minimal fields""" + return Data(owner="User", id="Password") +@pytest.fixture +def sample_data_minimal_copy(): + """Fixture to provide a sample Data object with minimal fields same as sample_data_minimal""" + return Data(owner="User", id="Password") +@pytest.fixture +def sample_data_minimal2(): + """Fixture to provide a sample Data object with minimal fields different from sample_data_minimal""" + return Data(owner="User2", id="Password") + +@pytest.fixture +def sample_data_all(): + """Fixture to provide a sample Data object with all fields""" + return Data(owner="User", id="Password", size=42, type="txt") + +@pytest.fixture +def sample_data_all_copy(): + """Fixture to provide a sample Data object with all fields same as sample_data_all""" + return Data(owner="User", id="Password", size=42, type="txt") + +@pytest.fixture +def sample_data_all2(): + """Fixture to provide a sample Data object with all fields different from sample_data_all""" + return Data(owner="User2", id="Password", size=42, type="txt") + +def test_create_data_minimal(sample_data_minimal): + """ + Test that the data object is created with ONLY required fields (using default for the rest) + """ + data = sample_data_minimal + assert data.owner == "User" + assert data.id == "Password" + assert data.type == "" + assert data.size == 0 + +def test_create_data_all(sample_data_all): + """ + Test that the data object is created with ALL fields (using default for the rest) + """ + data = sample_data_all + assert data.owner == "User" + assert data.id == "Password" + assert data.type == "txt" + assert data.size == 42 + +def test_data_equal(sample_data_all, sample_data_all_copy, sample_data_minimal, sample_data_minimal_copy): + """ + Test that two data objects with the same required parameters are equal + """ + data = sample_data_all + data2 = sample_data_all_copy + # test equality with all fields used + data3 = sample_data_minimal + data4 = sample_data_minimal_copy + assert data == data2 + assert data3 == data4 + +def test_data_not_equal(sample_data_all, sample_data_all2, sample_data_minimal, sample_data_minimal2): + """ + Test that two data objects with different required parameters are NOT equal + """ + data = sample_data_minimal + data2 = sample_data_minimal2 + data3 = sample_data_all + data4 = sample_data_all2 + assert data != data2 + assert data3 != data4 + +def test_data_hash_equal(sample_data_all, sample_data_all_copy, sample_data_minimal, sample_data_minimal_copy): + """ + Test that the hash of two data objects with the same required parameters is equal + """ + data = sample_data_minimal + data2 = sample_data_minimal_copy + # test equality with all fields used + data3 = sample_data_all + data4 = sample_data_all_copy + assert hash(data) == hash(data2) + assert hash(data3) == hash(data4) + +def test_data_hash_not_equal(sample_data_all, sample_data_all2, sample_data_minimal, sample_data_minimal2): + data = sample_data_minimal + data2 = sample_data_minimal2 + # test equality with all fields used + data3 = sample_data_all + data4 = sample_data_all2 + assert hash(data) != hash(data2) + assert hash(data3) != hash(data4) + +def test_data_from_dict(sample_data_all): + d = dataclasses.asdict(sample_data_all) + data = Data.from_dict(d) + assert isinstance(data, Data) + assert data.owner == "User" + assert data.id == "Password" + assert data.size == 42 + assert data.type == "txt" + assert data == sample_data_all \ No newline at end of file diff --git a/tests/components/test_game_state.py b/tests/components/test_game_state.py new file mode 100644 index 00000000..770efb43 --- /dev/null +++ b/tests/components/test_game_state.py @@ -0,0 +1,303 @@ +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import json +import pytest +from AIDojoCoordinator.game_components import GameState, IP, Network, Data, Service + +# pytest fixtures for creating sample objects +@pytest.fixture +def sample_ip(): + """Fixture to provide a sample IP object""" + return IP("192.168.1.1") + +@pytest.fixture +def sample_ip2(): + """Fixture to provide a sample IP object""" + return IP("192.168.1.2") + +@pytest.fixture +def sample_network(): + """Fixture to provide a sample Network object""" + return Network("192.168.1.0", 24) + +@pytest.fixture +def sample_service(): + """Fixture to provide a sample Service object""" + return Service(name="rdp", type="passive", version="1.067", is_local=True) + +@pytest.fixture +def sample_data(): + """Fixture to provide a sample Data object""" + return Data(owner="User", id="Password", size=42, type="txt") + + +# Test cases for the GameState class +def test_create_game_state(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test the correct creation of the GameState class + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:{sample_service}}, + known_data={sample_ip:{sample_data}}, + known_networks={sample_network} + ) + + assert len(game_state.known_hosts) == 2 + assert len(game_state.known_services) == 1 + assert len(game_state.known_data) == 1 + assert len(game_state.known_networks) == 1 + assert sample_ip in game_state.controlled_hosts + assert sample_ip2 in game_state.known_hosts + assert sample_ip in game_state.known_hosts + assert sample_ip in game_state.known_services + assert sample_service in game_state.known_services[sample_ip] + assert sample_ip in game_state.known_data + assert sample_data in game_state.known_data[sample_ip] + assert sample_network in game_state.known_networks + assert isinstance(game_state.known_hosts, set) + assert isinstance(game_state.known_services, dict) + assert isinstance(game_state.known_data, dict) + assert isinstance(game_state.known_networks, set) + assert isinstance(game_state.controlled_hosts, set) + assert isinstance(game_state, GameState) + +def test_create_game_state_empty(): + """ + Test the correct creation of the GameState class with empty parameters + """ + game_state = GameState( + controlled_hosts=set(), + known_hosts=set(), + known_services=dict(), + known_data=dict(), + known_networks=set() + ) + + assert isinstance(game_state.controlled_hosts, set) + assert isinstance(game_state.known_hosts, set) + assert isinstance(game_state.known_services, dict) + assert isinstance(game_state.known_data, dict) + assert isinstance(game_state.known_networks, set) + assert len(game_state.controlled_hosts) == 0 + assert len(game_state.known_hosts) == 0 + assert len(game_state.known_services) == 0 + assert len(game_state.known_data) == 0 + assert len(game_state.known_networks) == 0 + assert isinstance(game_state, GameState) + +def test_state_equal(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test that two game states with the same parameters are equal + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_state2 = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + assert game_state == game_state2 + assert game_state is not game_state2 # Ensure they are different instances + +def test_state_not_equal_diff_control(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test that two game states with diffrent parameters are not equal. + Different controlled hosts. + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_state2 = GameState( + controlled_hosts={sample_ip2}, # Different controlled hosts + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + + assert game_state != game_state2 + +def test_state_not_equal_diff_known(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test that two game states with diffrent parameters are not equal + Different known hosts. + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_state2 = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip}, # Different known hosts + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + assert game_state != game_state2 + +def test_state_not_equal_diff_data(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test that two game states with diffrent parameters are not equal. + Different data. + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_state2 = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services=set(), + known_data={}, # Different data + known_networks={sample_network} + ) + assert game_state != game_state2 + +def test_state_not_equal_diff_service(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """ + Test that two game states with diffrent parameters are not equal. + """ + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_state2 = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip2:[sample_service]}, # Different services + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + assert game_state != game_state2 + + +def test_game_state_as_json(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + """Test the serialization of the GameState class to JSON format""" + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:[sample_service]}, + known_data={sample_ip:[sample_data]}, + known_networks={sample_network} + ) + + game_json = game_state.as_json() + try: + data = json.loads(game_json) + except ValueError: + data = None + # Check if the JSON data is correctly deserialized + assert data is not None + assert isinstance(data, dict) + # Check if the expected keys are present in the JSON data + assert "known_networks" in data + assert "known_hosts" in data + assert "controlled_hosts" in data + assert "known_services" in data + assert "known_data" in data + # Check if the types of the values are correct + assert isinstance(data["known_networks"], list) + assert isinstance(data["known_hosts"], list) + assert isinstance(data["controlled_hosts"], list) + assert isinstance(data["known_services"], dict) + assert isinstance(data["known_data"], dict) + # Check if the values in the JSON data match the original game state + assert {"ip": "192.168.1.0", "mask": 24} in data["known_networks"] + assert {"ip": "192.168.1.1"} in data["known_hosts"] + assert {"ip": "192.168.1.2"} in data["known_hosts"] + assert {"ip": "192.168.1.1"} in data["controlled_hosts"] + assert "192.168.1.1" in data["known_services"] + assert data["known_services"]["192.168.1.1"] == [{"name": "rdp", "type": "passive", "version": "1.067", "is_local": True}] + assert "192.168.1.1" in data["known_data"] + assert data["known_data"]["192.168.1.1"] == [{"owner": "User", "id": "Password", "size": 42, "type": "txt", "content": ""}] + + +def test_game_state_json_deserialized(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:{sample_service}}, + known_data={sample_ip:{sample_data}}, + known_networks={sample_network} + ) + state_json = game_state.as_json() + deserialized_state = GameState.from_json(state_json) + assert game_state is not deserialized_state + assert game_state == deserialized_state + +def test_game_state_as_dict(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:{sample_service}}, + known_data={sample_ip:{sample_data}}, + known_networks={sample_network} + ) + game_dict = game_state.as_dict + # Check if the dictionary is correctly created + assert isinstance(game_dict, dict) + # Check if the expected keys are present in the dictionary + assert "known_networks" in game_dict + assert "known_hosts" in game_dict + assert "controlled_hosts" in game_dict + assert "known_services" in game_dict + assert "known_data" in game_dict + # Check if the types of the values are correct + assert isinstance(game_dict["known_networks"], list) + assert isinstance(game_dict["known_hosts"], list) + assert isinstance(game_dict["controlled_hosts"], list) + assert isinstance(game_dict["known_services"], dict) + assert isinstance(game_dict["known_data"], dict) + # Check if the values in the dictionary match the original game state + assert {"ip": "192.168.1.0", "mask": 24} in game_dict["known_networks"] + assert {"ip": "192.168.1.1"} in game_dict["known_hosts"] + assert {"ip": "192.168.1.2"} in game_dict["known_hosts"] + assert {"ip": "192.168.1.1"} in game_dict["controlled_hosts"] + assert "192.168.1.1" in game_dict["known_services"] + assert game_dict["known_services"]["192.168.1.1"] == [{"name": "rdp", "type": "passive", "version": "1.067", "is_local": True}] + assert "192.168.1.1" in game_dict["known_data"] + assert game_dict["known_data"]["192.168.1.1"] == [{"owner": "User", "id": "Password", "size": 42, "type": "txt", "content": ""}] + + +def test_game_state_from_dict(sample_ip, sample_ip2, sample_network, sample_service, sample_data): + game_state = GameState( + controlled_hosts={sample_ip}, + known_hosts={sample_ip, sample_ip2}, + known_services={sample_ip:{sample_service}}, + known_data={sample_ip:{sample_data}}, + known_networks={sample_network} + ) + game_dict = game_state.as_dict + deserialized_state = GameState.from_dict(game_dict) + assert game_state is not deserialized_state + assert game_state == deserialized_state \ No newline at end of file diff --git a/tests/components/test_ip.py b/tests/components/test_ip.py new file mode 100644 index 00000000..58fc8870 --- /dev/null +++ b/tests/components/test_ip.py @@ -0,0 +1,88 @@ +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import pytest +import dataclasses +from AIDojoCoordinator.game_components import IP + +# Pytest fixtures for creating sample IP objects +@pytest.fixture +def sample_private_ip1(): + """Fixture to provide a sample IP object""" + ip_str = "192.168.1.15" + return IP(ip_str), ip_str + +@pytest.fixture +def sample_private_ip1_copy(): + """Fixture to provide a sample IP object""" + ip_str = "192.168.1.15" + return IP(ip_str), ip_str + +@pytest.fixture +def sample_private_ip2(): + """Fixture to provide a sample IP object different from sample_private_ip1""" + ip_str = "192.168.2.15" + return IP(ip_str), ip_str + +@pytest.fixture +def sample_public_ip(): + """Fixture to provide a sample public IP object""" + ip_str = "8.8.8.8" + return IP(ip_str), ip_str + +# Test cases for the IP class +def test_ip_repr(sample_private_ip1): + """Test the object representation""" + ip_1, ip_str = sample_private_ip1 + assert repr(ip_1) == ip_str + +def test_ip_equal(sample_private_ip1, sample_private_ip1_copy): + """Test that two IP objects with the same IP are equal""" + ip_1, _ = sample_private_ip1 + ip_2, _ = sample_private_ip1_copy + assert ip_1 == ip_2 + +def test_ip_not_equal(sample_private_ip1, sample_private_ip2): + """Test that two IP objects with different IPs are not equal""" + ip_1, _ = sample_private_ip1 + ip_2, _ = sample_private_ip2 + assert ip_1 != ip_2 + +def test_ip_not_str(sample_private_ip1): + """Test that the IP object is not equal to a string""" + ip_1, _ = sample_private_ip1 + ip_2 = "192.168.2.15" + assert ip_1 != ip_2 + +def test_ip_is_private(sample_private_ip1): + ip_1, _ = sample_private_ip1 + assert ip_1.is_private() is True + +def test_ip_is_not_private(sample_public_ip): + ip_1, _ = sample_public_ip + assert ip_1.is_private() is False + +def test_ip_from_dict(sample_private_ip1): + """Test creating an IP object from a dictionary""" + ip_1, ip1_str = sample_private_ip1 + d = dataclasses.asdict(ip_1) + assert isinstance(d, dict) + assert d["ip"] == ip1_str + # Create IP object from dictionary + ip = IP.from_dict(d) + assert isinstance(ip, IP) + assert ip.ip == ip1_str + assert ip == ip_1 + +def test_ip_from_dict_invalid(): + """Test creating an IP object from an invalid dictionary""" + d = {"ip": "invalid_ip"} + try: + _ = IP.from_dict(d) + assert False, "Expected ValueError for invalid IP" + except ValueError: + pass +def test_ip_hash(sample_private_ip1, sample_private_ip1_copy): + """Test that the hash of two IP objects with the same IP is equal""" + ip_1, _ = sample_private_ip1 + ip_2, _ = sample_private_ip1_copy + assert hash(ip_1) == hash(ip_2) \ No newline at end of file diff --git a/tests/components/test_network.py b/tests/components/test_network.py new file mode 100644 index 00000000..5042e3c8 --- /dev/null +++ b/tests/components/test_network.py @@ -0,0 +1,83 @@ +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import pytest +import dataclasses +from AIDojoCoordinator.game_components import Network + +# Pytest fixture for creating a sample Network object +@pytest.fixture +def sample_private_network1(): + """Fixture to provide a sample Network object with private IP""" + return Network("192.168.1.0", 24) +@pytest.fixture +def sample_private_network1_copy(): + """Fixture to provide a sample Network object with private IP""" + return Network("192.168.1.0", 24) + +@pytest.fixture +def sample_private_network2(): + """Fixture to provide a sample Network object with private IP different from sample_private_network1""" + return Network("192.168.2.0", 24) +@pytest.fixture +def sample_public_network(): + """Fixture to provide a sample Network object with public IP""" + return Network("8.8.8.8", 24) + +# Test cases for the Network class +def test_net_creation(sample_private_network1): + """ + Test that the network is created and all elements can be accessed + """ + net = sample_private_network1 + assert net.ip == "192.168.1.0" + assert net.mask == 24 + +def test_net_str(sample_private_network1): + """ + Test the string representaion of the network + """ + net = sample_private_network1 + assert str(net) == "192.168.1.0/24" + +def test_net_repr(sample_private_network1): + """ + Test the repr of the Network + """ + net = sample_private_network1 + assert repr(net) == "192.168.1.0/24" + +def test_net_equal(sample_private_network1, sample_private_network1_copy): + """ + Test that two network objects with the same paramters are equal + """ + net_1 = sample_private_network1 + net_2 = sample_private_network1_copy + assert net_1 == net_2 + +def test_net_not_equal(sample_private_network1, sample_public_network): + """ + Test that two network objects with different paramters are not equal + """ + net_1 = sample_private_network1 + net_2 = sample_public_network + assert net_1 != net_2 + +def test_net_is_not_private(sample_public_network): + net_1 = sample_public_network + assert net_1.is_private() is False + +def test_net_is_private(sample_private_network1): + net_1 = sample_private_network1 + assert net_1.is_private() is True + +def test_net_from_dict(sample_private_network1): + """ + Test creating a Network object from a dictionary + """ + d = dataclasses.asdict(sample_private_network1) + net = Network.from_dict(d) + assert isinstance(net, Network) + assert net.ip == "192.168.1.0" + assert net.mask == 24 + assert net == sample_private_network1 + assert net is not sample_private_network1 \ No newline at end of file diff --git a/tests/components/test_service.py b/tests/components/test_service.py new file mode 100644 index 00000000..69f87e12 --- /dev/null +++ b/tests/components/test_service.py @@ -0,0 +1,51 @@ +# Authors: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz +# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import pytest +import dataclasses +from AIDojoCoordinator.game_components import Service + +# Fixtures for Service objects +@pytest.fixture +def sample_service1(): + """Fixture to provide a sample Service object with minimal fields""" + return Service(name="rdp", type="passive", version="1.067", is_local=True) +@pytest.fixture +def sample_service1_copy(): + """Fixture to provide a sample Service object with minimal fields same as sample_service1""" + return Service(name="rdp", type="passive", version="1.067", is_local=True) +@pytest.fixture +def sample_service2(): + """Fixture to provide a sample Service object with different fields from sample_service1""" + return Service(name="sql", type="passive", version="5.0", is_local=True) + + +# Test cases for Service class +def test_service_creation(sample_service1): + """ + Test that the service is created and all elements can be accessed + """ + assert sample_service1.name == "rdp" + assert sample_service1.type == "passive" + assert sample_service1.version == "1.067" + assert sample_service1.is_local + +def test_services_equal(sample_service1, sample_service1_copy): + """ + Test that two services with the same parameters are equal + """ + assert sample_service1 == sample_service1_copy + assert sample_service1 is not sample_service1_copy + +def test_services_not_equal(sample_service1, sample_service2): + """ + Test that two services with different parameters are not equal + """ + assert sample_service1 != sample_service2 + +def test_service_from_dict(sample_service1): + """ + Test creating a Service object from a dictionary + """ + service_dict = dataclasses.asdict(sample_service1) + service_from_dict = Service.from_dict(service_dict) + assert service_from_dict == sample_service1 \ No newline at end of file diff --git a/tests/coordinator/test_agent_server.py b/tests/coordinator/test_agent_server.py new file mode 100644 index 00000000..99c2f651 --- /dev/null +++ b/tests/coordinator/test_agent_server.py @@ -0,0 +1,302 @@ +# Authors: Ondřej Lukas - ondrej.lukas@aic.fel.cvut.cz +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock +from contextlib import suppress +from AIDojoCoordinator.coordinator import AgentServer +from AIDojoCoordinator.game_components import Action, ActionType, ProtocolConfig + +# ----------------------- +# Fixtures +# ----------------------- + +@pytest.fixture +def mock_writer(): + writer = AsyncMock() + writer.get_extra_info = MagicMock(return_value=('127.0.0.1', 12345)) # ✅ Sync method + writer.write = MagicMock() # ✅ Sync method + writer.drain = AsyncMock() # ✅ Async method + writer.close = AsyncMock() # ✅ Async method + return writer + +@pytest.fixture +def mock_reader_empty(): + reader = AsyncMock() + reader.read = AsyncMock(return_value=b'') # Simulates client disconnect + return reader + +@pytest.fixture +def mock_reader_with_data(): + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[ + b'{"some":"message"}', # first message + b'' # then disconnect + ]) + return reader + +@pytest.fixture +def response_queue(): + q = asyncio.Queue() + q.put_nowait('{"response":"ok"}') + return q + +@pytest.fixture +def agent_server(): + actions_queue = asyncio.Queue() + answers_queues = {} + max_connections = 3 + return AgentServer(actions_queue, answers_queues, max_connections) + +@pytest.fixture +def make_writer_with_peer(): + def _make(ip: str, port: int): + writer = AsyncMock() + writer.get_extra_info = MagicMock(return_value=(ip, port)) # get_extra_info is sync + writer.write = MagicMock() # write is sync + writer.drain = AsyncMock() # drain is async + writer.close = AsyncMock() # close is async + return writer + return _make + +# ----------------------- +# Connection Handling Tests +# ----------------------- + +@pytest.mark.asyncio +async def test_rejects_connection_when_max_connections_reached(agent_server, mock_reader_empty, mock_writer): + agent_server.current_connections = agent_server.max_connections + await agent_server.handle_new_agent(mock_reader_empty, mock_writer) + assert ('127.0.0.1', 12345) not in agent_server.answers_queues + +@pytest.mark.asyncio +async def test_accepts_connection_under_max_connections(agent_server, mock_reader_empty, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + + await agent_server.handle_new_agent(mock_reader_empty, mock_writer) + + # incremented and decremented → back to zero + assert agent_server.current_connections == 0 + mock_writer.close.assert_called_once() + assert peername not in agent_server.answers_queues + +@pytest.mark.asyncio +async def test_accepts_multiple_connections_up_to_limit(agent_server, mock_reader_empty, make_writer_with_peer): + for i in range(agent_server.max_connections): + peername = (f'10.0.0.{i}', 1000 + i) + writer = make_writer_with_peer(*peername) + await agent_server.handle_new_agent(mock_reader_empty, writer) + writer.close.assert_called_once() + assert peername not in agent_server.answers_queues + +@pytest.mark.asyncio +async def test_prevents_simultaneous_duplicate_peername_connections(agent_server, make_writer_with_peer): + peername = ('192.168.1.10', 5555) + writer1 = make_writer_with_peer(*peername) + writer2 = make_writer_with_peer(*peername) + + writer1.get_extra_info = MagicMock(return_value=peername) + writer2.get_extra_info = MagicMock(return_value=peername) + + # First reader hangs to simulate a long-lived connection + async def never_read(_=None): + await asyncio.Event().wait() + + reader1 = AsyncMock() + reader2 = AsyncMock() + reader1.read = AsyncMock(side_effect=never_read) + reader2.read = AsyncMock(return_value=b'') + + agent_server.actions_queue = AsyncMock() + + # Start first connection + task1 = asyncio.create_task(agent_server.handle_new_agent(reader1, writer1)) + + # Wait until queue for writer1 is created + for _ in range(100): + await asyncio.sleep(0.01) + if peername in agent_server.answers_queues: + break + else: + task1.cancel() + with suppress(asyncio.CancelledError): + await task1 + pytest.fail("answers_queues was not created in time") + + # Assert queue exists before it's potentially removed + assert list(agent_server.answers_queues.keys()).count(peername) == 1 + + # Start second connection with the same peername + await agent_server.handle_new_agent(reader2, writer2) + + # Assert it was rejected (writer2 should be closed) + writer2.close.assert_called_once() + + # Clean up task1 + task1.cancel() + with suppress(asyncio.CancelledError): + await task1 + + + +# ----------------------- +# Queue Management Tests +# ----------------------- + +@pytest.mark.asyncio +async def test_does_not_create_queue_if_one_exists(agent_server, mock_reader_empty, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + preexisting = asyncio.Queue() + agent_server.answers_queues = {peername: preexisting} + + await agent_server.handle_new_agent(mock_reader_empty, mock_writer) + mock_writer.close.assert_called_once() + assert agent_server.answers_queues[peername] is preexisting + +@pytest.mark.asyncio +async def test_handles_missing_queue_on_cleanup_gracefully(agent_server, mock_reader_empty, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + agent_server.answers_queues = {} # missing + + await agent_server.handle_new_agent(mock_reader_empty, mock_writer) + mock_writer.close.assert_called_once() + assert peername not in agent_server.answers_queues + +# ----------------------- +# Data Exchange Tests +# ----------------------- + +@pytest.mark.asyncio +async def test_quit_message_is_sent_on_disconnect(agent_server, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + + reader = AsyncMock(); reader.read = AsyncMock(return_value=b'') + agent_server.actions_queue = AsyncMock() + agent_server.answers_queues = {} + + await agent_server.handle_new_agent(reader, mock_writer) + + # Queue cleaned + assert peername not in agent_server.answers_queues + agent_server.actions_queue.put.assert_awaited_once() + + (addr, msg) = agent_server.actions_queue.put.call_args[0][0] + assert addr == peername + expected = Action(ActionType.QuitGame, parameters={}).to_json() + assert msg == expected + +@pytest.mark.asyncio +async def test_agent_message_is_placed_in_queue(agent_server, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + + action = Action( + ActionType.FindServices, + parameters={"source_host": "10.0.0.1", "target_host": "10.0.0.2"} + ).to_json() + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[action.encode(), b'']) + + agent_server.actions_queue = AsyncMock() + agent_server.answers_queues = {} + + task = asyncio.create_task(agent_server.handle_new_agent(reader, mock_writer)) + + # wait for queue creation + for _ in range(100): + await asyncio.sleep(0.01) + if peername in agent_server.answers_queues: + break + else: + task.cancel() + pytest.fail("answers_queues was not created in time") + + await agent_server.answers_queues[peername].put("dummy-response") + await task + + assert agent_server.actions_queue.put.await_count >= 1 + (addr, msg) = agent_server.actions_queue.put.call_args_list[0][0][0] + assert addr == peername + assert msg == action + +@pytest.mark.asyncio +async def test_answer_queue_response_is_sent_to_agent(agent_server, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + + action = Action( + ActionType.FindServices, + parameters={"source_host": "10.0.0.1", "target_host": "10.0.0.2"} + ).to_json() + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[action.encode(), b'']) + + agent_server.actions_queue = AsyncMock() + agent_server.answers_queues = {} + mock_writer.write = MagicMock() + mock_writer.drain = AsyncMock() + + task = asyncio.create_task(agent_server.handle_new_agent(reader, mock_writer)) + + for _ in range(100): + await asyncio.sleep(0.01) + if peername in agent_server.answers_queues: + break + else: + task.cancel() + pytest.fail("answers_queues was not created in time") + + response = '{"response":"ok"}' + await agent_server.answers_queues[peername].put(response) + await task + + delimiter = getattr(ProtocolConfig, "END_OF_MESSAGE", b"\n") + expected = response.encode() + delimiter + mock_writer.write.assert_any_call(expected) + mock_writer.drain.assert_called_once() + +# ----------------------- +# Error Handling Tests +# ----------------------- + +@pytest.mark.asyncio +async def test_cancelled_error_cleanup(agent_server, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + mock_writer.close = AsyncMock() + mock_writer.wait_closed = AsyncMock() + reader = AsyncMock() + reader.read = AsyncMock(side_effect=asyncio.CancelledError()) + + agent_server.actions_queue = AsyncMock() + agent_server.answers_queues = {} + + with pytest.raises(asyncio.CancelledError): + await agent_server.handle_new_agent(reader, mock_writer) + + assert peername not in agent_server.answers_queues + mock_writer.close.assert_called_once() + mock_writer.wait_closed.assert_awaited_once() + +@pytest.mark.asyncio +async def test_unexpected_exception_cleanup(agent_server, mock_writer): + peername = ('127.0.0.1', 12345) + mock_writer.get_extra_info = MagicMock(return_value=peername) + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[b'{"some":"data"}', b'']) + + agent_server.actions_queue = AsyncMock() + agent_server.actions_queue.put.side_effect = Exception("Unexpected error") + agent_server.answers_queues = {} + + with pytest.raises(Exception, match="Unexpected error"): + await agent_server.handle_new_agent(reader, mock_writer) + + assert peername not in agent_server.answers_queues + mock_writer.close.assert_called_once() \ No newline at end of file diff --git a/tests/coordinator/test_coordinator_core.py b/tests/coordinator/test_coordinator_core.py new file mode 100644 index 00000000..7d435bb5 --- /dev/null +++ b/tests/coordinator/test_coordinator_core.py @@ -0,0 +1,376 @@ +# Authors: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import asyncio +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace + +from AIDojoCoordinator.coordinator import GameCoordinator +from AIDojoCoordinator.game_components import ActionType, Action, AgentStatus, GameState, Observation, GameStatus + +# ----------------------- +# Fixtures +# ----------------------- +@pytest.fixture +def empty_game_state(): + """Fixture of empty game state.""" + return GameState( + known_networks={}, + known_services={}, + known_hosts={}, + known_data={}, + known_blocks={}, + controlled_hosts= {}, + ) +@pytest.fixture +def empty_observation(empty_game_state): + """Fixture of empty observation.""" + return Observation( + state=empty_game_state, + reward=0, + end=False, + info={}, + ) + +@pytest.fixture +def test_config_file_path(): + # Path to your local test config file (adjust as needed) + return "tests/netsecenv-task-for-testing.yaml" + +@pytest.fixture +def gc_with_test_config(test_config_file_path): + return GameCoordinator( + game_host="localhost", + game_port=9999, + service_host=None, # force local config loading + service_port=0, + allowed_roles=["Attacker", "Defender", "Benign"], + task_config_file=test_config_file_path, + ) + + +@pytest.fixture +def initialized_coordinator(gc_with_test_config): + gc_with_test_config._starting_positions_per_role = {"Attacker": MagicMock()} + gc_with_test_config._goal_description_per_role = {"Attacker": "Achieve goal"} + gc_with_test_config._steps_limit_per_role = {"Attacker": 100} + gc_with_test_config._CONFIG_FILE_HASH = "dummyhash" + gc_with_test_config._min_required_players = 1 + gc_with_test_config._agent_status = {} + gc_with_test_config._rewards = {"step": 0, "success": 10, "failure": -10} + return gc_with_test_config + +@pytest.fixture +def mock_writer(): + writer = AsyncMock() + writer.get_extra_info.return_value = ("127.0.0.1", 12345) + return writer + +@pytest.fixture +def mock_reader_empty(): + reader = AsyncMock() + reader.read = AsyncMock(return_value=b"") # Simulate client disconnect + return reader + +@pytest.fixture +def agent_server(): + """Fixture for a mock agent server.""" + return GameCoordinator( + game_host="localhost", + game_port=9999, + service_host=None, + service_port=0, + allowed_roles=["Attacker", "Defender", "Benign"], + task_config_file=None, + ) + +@pytest.fixture +def make_writer_with_peer(): + def _make(ip: str, port: int): + writer = AsyncMock() + writer.get_extra_info.return_value = (ip, port) + return writer + return _make + +# ----------------------- +# GameCoordinator Tests (Config-related) +# ----------------------- + +@pytest.mark.asyncio +async def test_load_initialization_objects_loads_config(gc_with_test_config): + """Test that loading initialization objects sets up config and cyst objects.""" + gc_with_test_config._load_initialization_objects() + assert gc_with_test_config._cyst_objects is not None + assert hasattr(gc_with_test_config, "_CONFIG_FILE_HASH") + +def test_convert_msg_dict_to_json_success(gc_with_test_config): + """Test that convert_msg_dict_to_json correctly serializes a dictionary.""" + msg = {"foo": "bar"} + json_str = gc_with_test_config.convert_msg_dict_to_json(msg) + assert json_str == '{"foo": "bar"}' + + +def test_convert_msg_dict_to_json_failure(gc_with_test_config): + """Test that convert_msg_dict_to_json raises TypeError for unserializable objects.""" + class Unserializable: + pass + + with pytest.raises(TypeError): + gc_with_test_config.convert_msg_dict_to_json({"bad": Unserializable()}) + + +@pytest.mark.asyncio +async def test_create_agent_queue_adds_new_queue(gc_with_test_config): + """Test that create_agent_queue adds a new queue for the agent.""" + agent = ("127.0.0.1", 12345) + await gc_with_test_config.create_agent_queue(agent) + assert agent in gc_with_test_config._agent_response_queues + assert isinstance(gc_with_test_config._agent_response_queues[agent], asyncio.Queue) + + +@pytest.mark.asyncio +async def test_create_agent_queue_idempotent(gc_with_test_config): + """Test that create_agent_queue does not create a new queue if it already exists.""" + agent = ("127.0.0.1", 12345) + await gc_with_test_config.create_agent_queue(agent) + q1 = gc_with_test_config._agent_response_queues[agent] + await gc_with_test_config.create_agent_queue(agent) + q2 = gc_with_test_config._agent_response_queues[agent] + assert q1 is q2 + + +def test_load_initialization_objects(gc_with_test_config): + """Test that _load_initialization_objects initializes config and cyst objects.""" + gc_with_test_config._load_initialization_objects() + assert gc_with_test_config._cyst_objects is not None + assert hasattr(gc_with_test_config, "_CONFIG_FILE_HASH") + + +def test_get_starting_position_per_role(gc_with_test_config): + """Test that _get_starting_position_per_role returns positions for all roles.""" + gc_with_test_config._load_initialization_objects() + positions = gc_with_test_config._get_starting_position_per_role() + assert set(positions.keys()) == set(gc_with_test_config.ALLOWED_ROLES) + + +def test_get_goal_description_per_role(gc_with_test_config): + """Test that _get_goal_description_per_role returns descriptions for all roles.""" + gc_with_test_config._load_initialization_objects() + desc = gc_with_test_config._get_goal_description_per_role() + assert set(desc.keys()) == set(gc_with_test_config.ALLOWED_ROLES) + + +def test_get_win_condition_per_role(gc_with_test_config): + """Test that _get_win_condition_per_role returns win conditions for all roles.""" + gc_with_test_config._load_initialization_objects() + win = gc_with_test_config._get_win_condition_per_role() + assert set(win.keys()) == set(gc_with_test_config.ALLOWED_ROLES) + + +def test_get_max_steps_per_role(gc_with_test_config): + """Test that _get_max_steps_per_role returns max steps for all roles.""" + gc_with_test_config._load_initialization_objects() + steps = gc_with_test_config._get_max_steps_per_role() + assert isinstance(steps, dict) + # values can be int or None + assert all(isinstance(v, int) or v is None for v in steps.values()) + + +@pytest.mark.asyncio +async def test_shutdown_signal_handler_sets_flag(gc_with_test_config): + """Test that shutdown_signal_handler sets the shutdown flag.""" + assert not gc_with_test_config.shutdown_flag.is_set() + await gc_with_test_config.shutdown_signal_handler() + assert gc_with_test_config.shutdown_flag.is_set() + + +@pytest.mark.asyncio +async def test_spawn_task_registers_task(gc_with_test_config): + """Test that _spawn_task registers the task in _tasks.""" + async def dummy(): + await asyncio.sleep(0.01) + + task = gc_with_test_config._spawn_task(dummy) + assert task in gc_with_test_config._tasks + + await task # Make sure task completes + + +@pytest.mark.asyncio +@pytest.mark.parametrize("action_type", [ + ActionType.QuitGame, + ActionType.ResetGame, + ActionType.FindData, + ActionType.ExfiltrateData, + ActionType.BlockIP, + ActionType.ExploitService, +]) +async def test_run_game_spawns_expected_action_tasks(gc_with_test_config, action_type): + """Test that run_game spawns tasks for different action types.""" + dummy_action = MagicMock() + dummy_action.type = action_type + dummy_json = json.dumps({"type": action_type.value}) + + # Put real test message + gc_with_test_config._agent_action_queue.put_nowait((("127.0.0.1", 9999), dummy_json)) + + # Patch Action.from_json to return our dummy + with patch.object(Action, "from_json", return_value=dummy_action): + with patch.object(gc_with_test_config, "_spawn_task") as spawn_mock: + + async def stop_soon(): + await asyncio.sleep(0.01) + gc_with_test_config.shutdown_flag.set() + # Poison pill: unblock .get() after shutdown + gc_with_test_config._agent_action_queue.put_nowait((("0.0.0.0", 0), None)) + + stopper = asyncio.create_task(stop_soon()) + + await gc_with_test_config.run_game() + await stopper + + spawn_mock.assert_called_once() + assert spawn_mock.call_args[0][0].__name__.startswith("_process_") + +# ----------------------- +# GameCoordinator Tests (Action Processing) +# ----------------------- +@pytest.mark.asyncio +async def test_process_join_game_action_success(initialized_coordinator): + """Test that _process_join_game_action successfully processes a join game action.""" + agent = ("127.0.0.1", 5555) + await initialized_coordinator.create_agent_queue(agent) + + # Minimal working state + initialized_coordinator._starting_positions_per_role = {"Attacker": MagicMock()} + initialized_coordinator._goal_description_per_role = {"Attacker": "Goal"} + initialized_coordinator._steps_limit_per_role = {"Attacker": 10} + initialized_coordinator._CONFIG_FILE_HASH = "abc123" + initialized_coordinator._min_required_players = 1 + initialized_coordinator._agent_status = {agent: MagicMock()} + initialized_coordinator._episode_start_event.set() # Prevent wait + + action = MagicMock() + action.parameters = {"agent_info": MagicMock(name="AgentX", role="Attacker")} + observation = SimpleNamespace( + state=SimpleNamespace(as_dict={}), # empty dict works here + reward=0, + end=False, + info={} + ) + + with patch.object(initialized_coordinator, "register_agent", new_callable=AsyncMock, return_value=MagicMock()), \ + patch.object(initialized_coordinator, "_initialize_new_player", return_value=observation), \ + patch.object(initialized_coordinator.logger, "info"), \ + patch.object(initialized_coordinator.logger, "debug"): + await initialized_coordinator._process_join_game_action(agent, action) + assert agent in initialized_coordinator.agents + assert not initialized_coordinator._agent_response_queues[agent].empty() + +@pytest.mark.asyncio +async def test_process_quit_game_action_removal(initialized_coordinator, empty_game_state, empty_observation): + """Test that _process_quit_game_action removes an agent correctly.""" + agent = ("127.0.0.1", 5555) + initialized_coordinator._agent_states[agent] = empty_game_state + initialized_coordinator._agent_observations[agent] = empty_observation + + with patch.object(initialized_coordinator, "remove_agent", new_callable=AsyncMock) as remove_mock, \ + patch.object(initialized_coordinator, "_remove_agent_from_game", new_callable=AsyncMock) as remove_game_mock, \ + patch.object(initialized_coordinator.logger, "info") as log_info, \ + patch.object(initialized_coordinator.logger, "debug") as log_debug: + + await initialized_coordinator._process_quit_game_action(agent) + + remove_mock.assert_awaited_once_with(agent, initialized_coordinator._agent_states[agent]) + remove_game_mock.assert_awaited_once_with(agent) + log_info.assert_any_call(f"Agent {agent} removed from the game. {remove_game_mock.return_value}") + log_debug.assert_any_call(f"Cleaning up after QuitGame for {agent}.") + +@pytest.mark.asyncio +async def test_process_reset_game_action_sets_flag(initialized_coordinator, empty_observation): + """Test that _process_reset_game_action sets the reset flag""" + agent = ("127.0.0.1", 5555) + initialized_coordinator._reset_requests = {agent: False} + initialized_coordinator._agent_observations[agent] = empty_observation + initialized_coordinator._episode_start_event.set() + initialized_coordinator._goal_description_per_role = {"Attacker": "Goal"} + initialized_coordinator._steps_limit_per_role = {"Attacker": 10} + initialized_coordinator.agents[agent] = ("name", "Attacker") + initialized_coordinator._agent_trajectories[agent] = [1, 2, 3] + initialized_coordinator._CONFIG_FILE_HASH = "hash" + initialized_coordinator._reset_trajectory = lambda x: [] + + await initialized_coordinator.create_agent_queue(agent) + reset_action = MagicMock() + reset_action.parameters = {"request_trajectory": True} + + with patch.object(initialized_coordinator.logger, "debug"): + async def trigger_reset_done(): + await asyncio.sleep(0.01) + async with initialized_coordinator._reset_done_condition: + initialized_coordinator._reset_done_condition.notify_all() + + stopper = asyncio.create_task(trigger_reset_done()) + await initialized_coordinator._process_reset_game_action(agent, reset_action) + await stopper + assert not initialized_coordinator._agent_response_queues[agent].empty() + +@pytest.mark.asyncio +async def test_process_game_action_episode_ended(initialized_coordinator, empty_game_state): + agent = ("127.0.0.1", 5555) + action = Action(action_type = ActionType.FindData, parameters={}) # or any game action type + + # Setup state indicating episode ended for the agent + initialized_coordinator._episode_ends = {agent: True} + initialized_coordinator.agents = {agent: ("AgentName", "Attacker")} + initialized_coordinator._agent_observations = {agent: Observation(empty_game_state, reward=5, end=True, info={})} + initialized_coordinator._agent_rewards = {agent: 5} + initialized_coordinator._agent_status = {agent: AgentStatus.TimeoutReached} + await initialized_coordinator.create_agent_queue(agent) + + # Call the method + await initialized_coordinator._process_game_action(agent, action) + + # Check response queue got a message with FORBIDDEN status + msg_json = await initialized_coordinator._agent_response_queues[agent].get() + assert '"status": "' + str(GameStatus.FORBIDDEN) + '"' in msg_json + assert "Episode ended" in msg_json + + +@pytest.mark.asyncio +async def test_process_game_action_ongoing_episode(initialized_coordinator, empty_game_state): + agent = ("127.0.0.1", 5555) + action = Action(action_type = ActionType.FindData, parameters={}) # or any game action type + + # Setup state indicating episode ongoing for the agent + initialized_coordinator._episode_ends = {agent: False} + initialized_coordinator.agents = {agent: ("AgentName", "Attacker")} + initialized_coordinator._agent_states = {agent: empty_game_state} + initialized_coordinator._agent_last_action = {agent: None} + initialized_coordinator._agent_steps = {agent: 0} + initialized_coordinator._agent_status = {agent: AgentStatus.Playing} + initialized_coordinator._agent_rewards = {agent: 0} + initialized_coordinator._agent_observations = {agent: Observation(empty_game_state, reward=0, end=False, info={})} + await initialized_coordinator.create_agent_queue(agent) + + # Mocks and patches + initialized_coordinator.step = AsyncMock(return_value=empty_game_state) + initialized_coordinator._update_agent_status = MagicMock(return_value=AgentStatus.Playing) + initialized_coordinator._update_agent_episode_end = MagicMock(return_value=False) + initialized_coordinator._add_step_to_trajectory = MagicMock() + initialized_coordinator._episode_end_event.clear() + initialized_coordinator._episode_rewards_condition = asyncio.Condition() + initialized_coordinator._agents_lock = asyncio.Lock() + + # Call the method + await initialized_coordinator._process_game_action(agent, action) + + # Check that step was called with expected params + initialized_coordinator.step.assert_awaited_with(agent_id=agent, agent_state=empty_game_state, action=action) + + # Check response queue got a message with OK status + msg_json = await initialized_coordinator._agent_response_queues[agent].get() + assert '"status": "' + str(GameStatus.OK) + '"' in msg_json + assert '"reward": 0' in msg_json + assert '"end": false' in msg_json + assert '"info": {}' in msg_json \ No newline at end of file diff --git a/tests/test_global_defender.py b/tests/coordinator/test_global_defender.py similarity index 100% rename from tests/test_global_defender.py rename to tests/coordinator/test_global_defender.py diff --git a/tests/manual/three_nets/test_three_net_scenario.py b/tests/manual/three_nets/manual_test_three_net_scenario.py similarity index 100% rename from tests/manual/three_nets/test_three_net_scenario.py rename to tests/manual/three_nets/manual_test_three_net_scenario.py diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 0ab7c851..0256c908 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -3,13 +3,13 @@ # run all unit tests, -n *5 means distribute tests on 5 different process # -s to see print statements as they are executed -#python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace -python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace -python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace -python3 -m pytest tests/test_global_defender.py -p no:warnings -vvvv -s --full-trace -#python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace +# #python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace +# python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace +# python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace +# python3 -m pytest tests/test_global_defender.py -p no:warnings -vvvv -s --full-trace -# run ruff check as well -echo "Running RUFF check: in ${PWD}" -ruff check --output-format=github --select=E9,F4,F6,F7,F8,N8 --ignore=F405 --target-version=py310 --line-length=120 . +# # run ruff check as well +# echo "Running RUFF check: in ${PWD}" +# ruff check --output-format=github --select=E9,F4,F6,F7,F8,N8 --ignore=F405 --target-version=py310 --line-length=120 --exclude NetSecGameAgents . +pytest && ruff check . \ No newline at end of file diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py deleted file mode 100644 index 965f4e16..00000000 --- a/tests/test_coordinator.py +++ /dev/null @@ -1,396 +0,0 @@ -# from coordinator import Coordinator, AgentStatus -# import pytest -# import queue -# import asyncio - -# CONFIG_FILE = "tests/netsecenv-task-for-testing.yaml" -# ALLOWED_ROLES = ["Attacker", "Defender", "Benign"] - -# import sys -# from os import path - -# sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) -# from env.game_components import Action, ActionType, AgentInfo, Network, IP, GameState, Service, Data - - -# @pytest.fixture -# async def coordinator_init(): -# """Initialize Coordinator instance for tests.""" -# actions = asyncio.Queue() -# answers = {} -# world_requests = asyncio.Queue() -# world_responses = asyncio.Queue() - -# coord = Coordinator( -# actions, answers, world_requests, world_responses, CONFIG_FILE, ALLOWED_ROLES -# ) -# return coord - -# @pytest.fixture -# async def coordinator_registered_player(coordinator_init): -# """Register a player with the Coordinator.""" -# coord = coordinator_init -# registration = Action( -# ActionType.JoinGame, -# params={"agent_info": AgentInfo(name="mari", role="Attacker")}, -# ) - -# # Process join action asynchronously -# result = await coord._process_join_game_action( -# agent_addr=("192.168.1.1", "3300"), -# action=registration, -# ) -# return coord, result -# class TestCoordinator: - -# @pytest.mark.asyncio -# async def test_class_init(): -# actions = asyncio.Queue() -# answers = {} -# world_requests = asyncio.Queue() -# world_responses = asyncio.Queue() - -# coord = Coordinator(actions, answers, world_requests, world_responses, CONFIG_FILE, ALLOWED_ROLES) - -# assert coord.ALLOWED_ROLES == ALLOWED_ROLES -# assert coord.agents == {} -# assert coord._agent_steps == {} -# assert coord._reset_requests == {} -# assert coord._agent_starting_position == {} -# assert coord._agent_observations == {} -# assert coord._agent_states == {} -# assert coord._agent_rewards == {} -# assert coord._agent_statuses == {} -# assert isinstance(coord._actions_queue, asyncio.Queue) -# assert isinstance(coord._answers_queues, dict) -# assert isinstance(coord._world_action_queue, asyncio.Queue) -# assert not isinstance(coord._world_response_queue, asyncio.Queue) - -# @pytest.mark.asyncio -# async def test_initialize_new_player(self, coordinator_init): -# coord = coordinator_init -# agent_addr = ("1.1.1.1", "4242") -# agent_name = "TestAgent" -# agent_role = "Attacker" -# new_obs = coord._initialize_new_player(agent_addr, agent_name, agent_role) - -# assert agent_addr in coord.agents -# assert coord.agents[agent_addr] == (agent_name, agent_role) -# assert coord._agent_steps[agent_addr] == 0 -# assert not coord._reset_requests[agent_addr] -# assert coord._agent_statuses[agent_addr] == AgentStatus.PlayingActive - -# assert new_obs.reward == 0 -# assert new_obs.end is False -# assert new_obs.info == {} - -# def test_join(self, coordinator_init): -# coord = coordinator_init - -# registration = Action( -# ActionType.JoinGame, -# params={"agent_info": AgentInfo(name="mari", role="Attacker")}, -# ) - -# result = coord._process_join_game_action( -# agent_addr=("192.168.1.1", "3300"), -# action=registration, -# ) -# assert result["to_agent"] == ("192.168.1.1", "3300") -# assert result["status"] == "GameStatus.CREATED" -# assert "max_steps" in result["message"].keys() -# assert "goal_description" in result["message"].keys() -# assert not result["observation"]["end"] -# assert "configuration_hash" in result["message"].keys() - -# # def test_reset(self, coordinator_registered_player): -# # coord, _ = coordinator_registered_player -# # result = coord._process_reset_game_action(("192.168.1.1", "3300")) - -# # assert result["to_agent"] == ("192.168.1.1", "3300") -# # assert "Resetting" in result["message"]["message"] -# # assert "max_steps" in result["message"].keys() -# # assert "goal_description" in result["message"].keys() -# # assert result["status"] == "GameStatus.OK" - -# # assert coord._agent_steps[("192.168.1.1", "3300")] == 0 -# # assert coord._agent_goal_reached[("192.168.1.1", "3300")] is False -# # assert coord._agent_episode_ends[("192.168.1.1", "3300")] is False -# # assert coord._reset_requests[("192.168.1.1", "3300")] is False - -# def test_generic_action(self, coordinator_registered_player): -# coord, init_result = coordinator_registered_player -# action = Action( -# ActionType.ScanNetwork, -# params={ -# "source_host": IP("192.168.2.2"), -# "target_network": Network("192.168.1.0", 24), -# }, -# ) -# result = coord._process_generic_action(("192.168.1.1", "3300"), action) - -# assert result["to_agent"] == ("192.168.1.1", "3300") -# assert result["status"] == "GameStatus.OK" -# assert init_result["observation"]["state"] != result["observation"]["state"] - -# def test_check_goal_valid(self, coordinator_init): -# game_state = GameState( -# controlled_hosts=[IP("1.1.1.1"), IP("1.1.1.2")], -# known_hosts=[IP("1.1.1.1"), IP("1.1.1.2"), IP("1.1.1.3"), IP("1.1.1.4")], -# known_services={ -# IP("1.1.1.1"):[Service("test_service1", "passive", "1.01", is_local=False)] -# }, -# known_data={ -# IP("1.1.1.1"):[Data("Joe Doe", "password", 10, "txt")] -# }, -# known_networks=[Network("1.1.1.1","24")], -# known_blocks={} - -# ) -# win_conditions = { -# "known_networks":[], -# "known_hosts":[IP("1.1.1.2")], -# "controlled_hosts":[IP("1.1.1.1")], -# "known_services":{ -# IP("1.1.1.1"):[Service("test_service1", "passive", "1.01", is_local=False)], -# }, -# "known_data":{ - -# }, -# "known_blocks":{} -# } - -# assert coordinator_init._check_goal(game_state, win_conditions) is True - -# def test_check_goal_invalid(self, coordinator_init): -# game_state = GameState( -# controlled_hosts=[IP("1.1.1.1"), IP("1.1.1.2")], -# known_hosts=[IP("1.1.1.1"), IP("1.1.1.2"), IP("1.1.1.3"), IP("1.1.1.4")], -# known_services={ -# IP("1.1.1.1"):[Service("test_service1", "passive", "1.01", is_local=False)] -# }, -# known_data={ -# IP("1.1.1.1"):[Data("Joe Doe", "password", 10, "txt")] -# }, -# known_networks=[Network("1.1.1.1","24")], -# known_blocks={} -# ) -# win_conditions = { -# "known_networks":[], -# "known_hosts":[IP("1.1.1.5")], -# "controlled_hosts":[IP("1.1.1.1")], -# "known_services":{ -# IP("1.1.1.1"):[Service("test_service1", "passive", "1.01", is_local=False)], -# }, -# "known_data":{ - -# }, -# "known_blocks":{} -# } - -# assert coordinator_init._check_goal(game_state, win_conditions) is False - -# def test_check_goal_empty(self, coordinator_init): -# game_state = GameState( -# controlled_hosts=[IP("1.1.1.1"), IP("1.1.1.2")], -# known_hosts=[IP("1.1.1.1"), IP("1.1.1.2"), IP("1.1.1.3"), IP("1.1.1.4")], -# known_services={ -# IP("1.1.1.1"):[Service("test_service1", "passive", "1.01", is_local=False)] -# }, -# known_data={ -# IP("1.1.1.1"):[Data("Joe Doe", "password", 10, "txt")] -# }, -# known_networks=[Network("1.1.1.1","24")], -# known_blocks={} -# ) -# win_conditions = { -# "known_networks":[], -# "known_hosts":[], -# "controlled_hosts":[], -# "known_services":{}, -# "known_data":{}, -# "known_blocks":{} -# } -# assert coordinator_init._check_goal(game_state, win_conditions) is True - -# def test_timeout(self, coordinator_registered_player): -# coord, init_result = coordinator_registered_player -# action = Action( -# ActionType.ScanNetwork, -# params={ -# "source_host": IP("192.168.2.2"), -# "target_network": Network("192.168.1.0", 24), -# }, -# ) -# result = init_result -# for _ in range(15): -# result = coord._process_generic_action(("192.168.1.1", "3300"), action) -# assert result["to_agent"] == ("192.168.1.1", "3300") -# assert result["status"] == "GameStatus.OK" -# assert init_result["observation"]["state"] != result["observation"]["state"] -# assert coord._agent_steps[("192.168.1.1", "3300")] == 15 -# assert coord._agent_statuses[("192.168.1.1", "3300")] == "max_steps" -# assert result["observation"]["end"] -# assert result["observation"]["info"]["end_reason"] == "max_steps" - - -import pytest -from unittest.mock import AsyncMock, MagicMock -from AIDojoCoordinator.coordinator import Coordinator, AgentStatus, Action, ActionType -from AIDojoCoordinator.game_components import AgentInfo, Network, IP - -CONFIG_FILE = "tests/netsecenv-task-for-testing.yaml" -ALLOWED_ROLES = ["Attacker", "Defender", "Benign"] - - -@pytest.fixture -def coordinator_init(): - """Initialize the Coordinator instance.""" - actions_queue = MagicMock() - answers_queues = {} - coord = Coordinator( - actions_queue, - answers_queues, - CONFIG_FILE, - ALLOWED_ROLES, - ) - return coord - - -@pytest.mark.asyncio -async def test_agent_joining_game(coordinator_init): - """Test agent successfully joining the game.""" - coord = coordinator_init - - action = Action( - ActionType.JoinGame, - params={"agent_info": AgentInfo(name="TestAgent", role="Attacker")}, - ) - agent_addr = ("192.168.1.1", "3300") - - # Mock the world reset - coord._world.reset = AsyncMock(return_value=None) - coord._world.update_goal_dict = MagicMock(return_value={}) - coord._world.update_goal_descriptions = MagicMock(return_value={}) - coord._world.create_state_from_view = MagicMock(return_value={}) - - await coord._process_join_game_action(agent_addr, action) - - assert agent_addr in coord.agents - assert coord.agents[agent_addr] == ("TestAgent", "Attacker") - assert coord._agent_statuses[agent_addr] == AgentStatus.JoinRequested - -@pytest.mark.asyncio -async def test_agent_playing_scan_network_with_mocking(coordinator_init): - """Test an agent performing the ScanNetwork action with mocked queue interactions.""" - # Arrange - coord = coordinator_init - - # Mock agent details - agent_addr = ("192.168.1.1", "3300") - agent_name = "TestAgent" - agent_role = "Attacker" - coord.agents[agent_addr] = (agent_name, agent_role) - coord._agent_statuses[agent_addr] = AgentStatus.Playing - coord._agent_states[agent_addr] = MagicMock() # Mocked GameState - coord._agent_rewards[agent_addr] = None # Initialize the reward to avoid KeyError - - # Create the ScanNetwork action - action = Action( - ActionType.ScanNetwork, - params={ - "source_host": IP("192.168.2.2"), - "target_network": Network("192.168.1.0", 24), - }, - ) - - # Mock the action queue - coord._actions_queue.get = AsyncMock(return_value=(agent_addr, action.as_json())) - coord._world_action_queue.put = AsyncMock() - coord._answers_queues[agent_addr] = AsyncMock() # Mock agent's answer queue - - # Mock `_world._rewards` to provide reward values - coord._world = MagicMock() - coord._world._rewards = {"goal": 10, "detection": -5, "step": 1} - - # Act - agent_addr, message = await coord._actions_queue.get() - action = Action.from_json(message) - await coord._process_generic_action(agent_addr, action) - - # Assert - coord._world_action_queue.put.assert_called_once_with( - (agent_addr, action, coord._agent_states[agent_addr]) - ) - coord._answers_queues[agent_addr].put.assert_not_called() # No immediate response expected - assert coord._agent_statuses[agent_addr] == AgentStatus.Playing - assert coord._agent_rewards[agent_addr] is None # No end rewards assigned yet - -@pytest.mark.asyncio -async def test_agent_playing_scan_network(coordinator_init): - """Test agent performing a scan network action.""" - coord = coordinator_init - - # Set up agent in the game - agent_addr = ("192.168.1.1", "3300") - coord.agents[agent_addr] = ("TestAgent", "Attacker") - coord._agent_statuses[agent_addr] = AgentStatus.Playing - coord._agent_states[agent_addr] = MagicMock() # Mock game state - - action = Action( - ActionType.ScanNetwork, - params={ - "source_host": IP("192.168.2.2"), - "target_network": Network("192.168.1.0", 24), - }, - ) - - # Mock the world action queue - coord._world_action_queue.put = AsyncMock() - - # Call the method under test - await coord._process_generic_action(agent_addr, action) - - # Assertions - coord._world_action_queue.put.assert_called_once_with( - (agent_addr, action, coord._agent_states[agent_addr]) - ) - assert coord._agent_statuses[agent_addr] == AgentStatus.Playing - - -@pytest.mark.asyncio -async def test_agent_requesting_reset(coordinator_init): - """Test agent requesting a reset.""" - coord = coordinator_init - - # Set up agent in the game - agent_addr = ("192.168.1.1", "3300") - coord.agents[agent_addr] = ("TestAgent", "Attacker") - coord._reset_requests[agent_addr] = False - - action = Action(ActionType.ResetGame, params={}) - coord._world.reset = AsyncMock(return_value=None) - - await coord._process_generic_action(agent_addr, action) - - assert coord._reset_requests[agent_addr] is True - coord._world_action_queue.put.assert_called_with(("world", action, None)) - - -@pytest.mark.asyncio -async def test_agent_leaving_game(coordinator_init): - """Test agent leaving the game.""" - coord = coordinator_init - - # Set up agent in the game - agent_addr = ("192.168.1.1", "3300") - coord.agents[agent_addr] = ("TestAgent", "Attacker") - coord._agent_statuses[agent_addr] = AgentStatus.Playing - - action = Action(ActionType.QuitGame, params={}) - coord._world_action_queue.put = AsyncMock(return_value=None) - - await coord._process_generic_action(agent_addr, action) - - coord._world_action_queue.put.assert_called_once_with((agent_addr, action, coord._agent_states.get(agent_addr))) - assert agent_addr not in coord.agents \ No newline at end of file diff --git a/tests/test_game_coordinator.py b/tests/test_game_coordinator.py deleted file mode 100644 index 1969b714..00000000 --- a/tests/test_game_coordinator.py +++ /dev/null @@ -1,132 +0,0 @@ -import asyncio -import pytest -from unittest.mock import AsyncMock, Mock -from AIDojoCoordinator.coordinator import AgentServer, GameCoordinator - -def test_game_coordinator_initialization(): - """ - Test that the GameCoordinator is initialized correctly with the expected properties. - """ - # Test input - game_host = "localhost" - game_port = 8000 - service_host = "localhost" - service_port = 8080 - allowed_roles = ["Attacker", "Defender", "Benign"] - task_config_file = "test_config.json" - - # Create an instance of GameCoordinator - coordinator = GameCoordinator( - game_host=game_host, - game_port=game_port, - service_host=service_host, - service_port=service_port, - allowed_roles=allowed_roles, - task_config_file=task_config_file, - ) - - # Assertions for basic initialization - assert coordinator.host == game_host, "GameCoordinator host should be set correctly." - assert coordinator.port == game_port, "GameCoordinator port should be set correctly." - assert coordinator._service_host == service_host, "Service host should be set correctly." - assert coordinator._service_port == service_port, "Service port should be set correctly." - assert coordinator.ALLOWED_ROLES == allowed_roles, "Allowed roles should match the input." - assert coordinator._task_config_file == task_config_file, "Task config file should match the input." - - # Assertions for events and locks - assert isinstance(coordinator.shutdown_flag, asyncio.Event), "shutdown_flag should be an asyncio.Event." - assert isinstance(coordinator._reset_event, asyncio.Event), "reset_event should be an asyncio.Event." - assert isinstance(coordinator._episode_end_event, asyncio.Event), "episode_end_event should be an asyncio.Event." - assert isinstance(coordinator._reset_lock, asyncio.Lock), "reset_lock should be an asyncio.Lock." - assert isinstance(coordinator._agents_lock, asyncio.Lock), "agents_lock should be an asyncio.Lock." - - # Assertions for agent-related data structures - assert isinstance(coordinator._agent_action_queue, asyncio.Queue), "agent_action_queue should be an asyncio.Queue." - assert isinstance(coordinator._agent_response_queues, dict), "agent_response_queues should be a dictionary." - assert isinstance(coordinator.agents, dict), "agents should be a dictionary." - assert isinstance(coordinator._agent_steps, dict), "agent_steps should be a dictionary." - assert isinstance(coordinator._reset_requests, dict), "reset_requests should be a dictionary." - assert isinstance(coordinator._agent_status, dict), "agent_status should be a dictionary." - assert isinstance(coordinator._agent_observations, dict), "agent_observations should be a dictionary." - assert isinstance(coordinator._agent_rewards, dict), "agent_rewards should be a dictionary." - assert isinstance(coordinator._agent_trajectories, dict), "agent_trajectories should be a dictionary." - - # Assertions for tasks - assert isinstance(coordinator._tasks, set), "tasks should be a set." - - # Assertions for configuration - assert coordinator._cyst_objects is None, "cyst_objects should be None at initialization." - assert coordinator._cyst_object_string is None, "cyst_object_string should be None at initialization." - - # Assertions for logging - assert coordinator.logger.name == "AIDojo-GameCoordinator", "Logger should be initialized with the correct name." - - -def test_starting_positions(): - """Test that starting positions are correctly initialized for each role.""" - coordinator = GameCoordinator( - game_host="localhost", - game_port=8000, - service_host="localhost", - service_port=8080, - allowed_roles=["Attacker", "Defender", "Benign"], - ) - coordinator.task_config = Mock() # Mock the task_config - coordinator.task_config.get_start_position.side_effect = lambda agent_role: {"x": 0, "y": 0} - - starting_positions = coordinator._get_starting_position_per_role() - - assert starting_positions["Attacker"] == {"x": 0, "y": 0} - assert starting_positions["Defender"] == {"x": 0, "y": 0} - assert starting_positions["Benign"] == {"x": 0, "y": 0} - - - -# Agent Server -def test_agent_server_initialization(): - """ - Test that the AgentServer is initialized correctly with the expected attributes. - """ - # Test inputs - actions_queue = asyncio.Queue() - agent_response_queues = {} - max_connections = 5 - - # Create an instance of AgentServer - server = AgentServer(actions_queue, agent_response_queues, max_connections) - - # Assertions for basic attributes - assert server.actions_queue is actions_queue, "AgentServer's actions_queue should be set correctly." - assert server.answers_queues is agent_response_queues, "AgentServer's answers_queues should be set correctly." - assert server.max_connections == max_connections, "AgentServer's max_connections should match the input." - assert server.current_connections == 0, "AgentServer's current_connections should be initialized to 0." - - # Assertions for logging - assert server.logger.name == "AIDojo-AgentServer", "Logger should be initialized with the correct name." - - -@pytest.mark.asyncio -async def test_handle_new_agent_max_connections(): - """Test that a new agent connection is rejected when max_connections is reached.""" - # Test setup - actions_queue = asyncio.Queue() - agent_response_queues = {} - max_connections = 1 - server = AgentServer(actions_queue, agent_response_queues, max_connections) - - # Mock reader and writer - reader_mock = AsyncMock() - writer_mock = Mock() - writer_mock.get_extra_info.return_value = ("127.0.0.1", 12345) - - # Simulate max connections - server.current_connections = max_connections - - # Run handle_new_agent - await server.handle_new_agent(reader_mock, writer_mock) - - # Assertions - assert server.current_connections == max_connections, "Connection count should remain unchanged." - assert ("127.0.0.1", 12345) not in agent_response_queues, "Queue should not be created for rejected agent." - writer_mock.write.assert_not_called(), "No data should be sent to the rejected agent." - writer_mock.close.assert_called_once(), "Connection should be closed for the rejected agent."