diff --git a/env/__init__.py b/AIDojoCoordinator/__init__.py similarity index 100% rename from env/__init__.py rename to AIDojoCoordinator/__init__.py diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py new file mode 100644 index 00000000..657292ce --- /dev/null +++ b/AIDojoCoordinator/coordinator.py @@ -0,0 +1,844 @@ +import jsonlines +import logging +import json +import asyncio +from datetime import datetime +import signal +from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus +from AIDojoCoordinator.global_defender import GlobalDefender +from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser +import os + +from aiohttp import ClientSession +from cyst.api.environment.environment import Environment + +class AgentServer(asyncio.Protocol): + """ + Class used for serving the agents when conneting to the game run by th GameCoordinator. + """ + def __init__(self, actions_queue, agent_response_queues, max_connections): + self.actions_queue = actions_queue + self.answers_queues = agent_response_queues + self.max_connections = max_connections + self.current_connections = 0 + self.logger = logging.getLogger("AIDojo-AgentServer") + + async def handle_new_agent(self, reader, writer): + async def send_data_to_agent(writer, data: str) -> None: + """ + Send the world to the agent + """ + writer.write(bytes(str(data).encode())) + + # Check if the maximum number of connections has been reached + if self.current_connections >= self.max_connections: + self.logger.info( + f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}" + ) + writer.close() + return + + # Increment the count of current connections + self.current_connections += 1 + + # Handle the new agent + addr = writer.get_extra_info("peername") + self.logger.info(f"New agent connected: {addr}") + # Ensure a queue exists for this agent + if addr not in self.answers_queues: + self.answers_queues[addr] = asyncio.Queue(maxsize=2) + self.logger.info(f"Created queue for agent {addr}") + + try: + while True: + # Step 1: Read data from the agent + data = await reader.read(500) + if not data: + self.logger.info(f"Agent {addr} disconnected.") + quit_message = Action(ActionType.QuitGame, parameters={}).to_json() + await self.actions_queue.put((addr, quit_message)) + break + + raw_message = data.decode().strip() + self.logger.debug(f"Handler received from {addr}: {raw_message}") + + # Step 2: Forward the message to the Coordinator + await self.actions_queue.put((addr, raw_message)) + # await asyncio.sleep(0) + # Step 3: Get a matching response from the answers queue + response_queue = self.answers_queues[addr] + response = await response_queue.get() + self.logger.info(f"Sending response to agent {addr}: {response}") + + # Step 4: Send the response to the agent + writer.write(bytes(str(response).encode())) + await writer.drain() + except asyncio.CancelledError: + self.logger.debug("Terminating by KeyboardInterrupt") + raise + finally: + # Decrement the count of current connections + self.current_connections -= 1 + if addr in self.answers_queues: + self.answers_queues.pop(addr) + self.logger.info(f"Removed queue for agent {addr}") + else: + self.logger.warning(f"Queue for agent {addr} not found during cleanup.") + writer.close() + return + + async def __call__(self, reader, writer): + await self.handle_new_agent(reader, writer) + +class GameCoordinator: + """ + Class for creation, and management of agent interactions in AI Dojo. + """ + def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"], task_config_file:str=None) -> None: + self.host = game_host + self.port = game_port + self.logger = logging.getLogger("AIDojo-GameCoordinator") + + self._tasks = set() + self.shutdown_flag = asyncio.Event() + self._reset_event = asyncio.Event() + self._episode_end_event = asyncio.Event() + self._episode_rewards_condition = asyncio.Condition() + self._reset_done_condition = asyncio.Condition() + self._reset_lock = asyncio.Lock() + self._agents_lock = asyncio.Lock() + self._semaphore = asyncio.Semaphore(2) + + # for accessing configuration remotely + self._service_host = service_host + self._service_port = service_port + # for reading configuration locally + self._task_config_file = task_config_file + self.logger = logging.getLogger("AIDojo-GameCoordinator") + self.ALLOWED_ROLES = allowed_roles + self._cyst_objects = None + self._cyst_object_string = None + + # prepare agent communication + self._agent_action_queue = asyncio.Queue() + self._agent_response_queues = {} + + # agent information + self.agents = {} + # step counter per agent_addr (int) + self._agent_steps = {} + # reset request per agent_addr (bool) + self._reset_requests = {} + self._agent_status = {} + self._episode_ends = {} + self._agent_observations = {} + # starting per agent_addr (dict) + self._agent_starting_position = {} + # current state per agent_addr (GameState) + self._agent_states = {} + # last action played by agent (Action) + self._agent_last_action = {} + # agent status dict {agent_addr: int} + self._agent_rewards = {} + # trajectories per agent_addr + self._agent_trajectories = {} + + def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task: + "Helper function to make sure all tasks are registered for proper termination" + task = asyncio.create_task(coroutine(*args, **kwargs)) + self._tasks.add(task) + def remove_task(t): + self._tasks.discard(t) + task.add_done_callback(remove_task) # Remove task when done + return task + + async def shutdown_signal_handler(self): + """Handle shutdown signals.""" + self.logger.info("Shutdown signal received. Setting shutdown flag.") + self.shutdown_flag.set() + + async def create_agent_queue(self, agent_addr:tuple)->None: + """ + Creates a queue for the given agent address if it doesn't already exist. + """ + if agent_addr not in self._agent_response_queues: + self._agent_response_queues[agent_addr] = asyncio.Queue() + self.logger.info(f"Created queue for agent {agent_addr}. {len(self._agent_response_queues)} queues in total.") + + def convert_msg_dict_to_json(self, msg_dict:dict)->str: + """ + Helper function to create text-base messge from a dictionary. Used in the Agent-Game communication. + """ + try: + # Convert message into string representation + output_message = json.dumps(msg_dict) + except Exception as e: + self.logger.error(f"Error when converting msg to JSON:{e}") + raise e + # Send to anwer_queue + return output_message + + def run(self)->None: + """ + Wrapper for ayncio run function. Starts all tasks in AIDojo + """ + try: + asyncio.run(self.start_tasks()) + except Exception as e: + self.logger.error(f"Unexpected error: {e}") + finally: + self.logger.info(f"{__class__.__name__} has exited.") + + async def _fetch_initialization_objects(self): + """Send a REST request to MAIN and fetch initialization objects of CYST simulator.""" + async with ClientSession() as session: + try: + async with session.get(f"http://{self._service_host}:{self._service_port}/cyst_init_objects") as response: + if response.status == 200: + response = await response.json() + self.logger.debug(response) + env = Environment.create() + self._CONFIG_FILE_HASH = get_str_hash(response) + self._cyst_objects = env.configuration.general.load_configuration(response) + self.logger.debug(f"Initialization objects received:{self._cyst_objects}") + #self.task_config = ConfigParser(config_dict=response["task_configuration"]) + else: + self.logger.error(f"Failed to fetch initialization objects. Status: {response.status}") + except Exception as e: + self.logger.error(f"Error fetching initialization objects: {e}") + + def _load_initialization_objects(self)->None: + """ + Loads task configuration from a local file. + """ + self.task_config = ConfigParser(self._task_config_file) + self._cyst_objects = self.task_config.get_scenario() + self._CONFIG_FILE_HASH = get_str_hash(str(self._cyst_objects)) + + def _get_starting_position_per_role(self)->dict: + """ + Method for finding starting position for each agent role in the game. + """ + starting_positions = {} + for agent_role in self.ALLOWED_ROLES: + try: + starting_positions[agent_role] = self.task_config.get_start_position(agent_role=agent_role) + self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}") + except KeyError: + starting_positions[agent_role] = {} + return starting_positions + + def _get_win_condition_per_role(self)-> dict: + """ + Method for finding wininng conditions for each agent role in the game. + """ + win_conditions = {} + for agent_role in self.ALLOWED_ROLES: + try: + win_conditions[agent_role] = self.task_config.get_win_conditions(agent_role=agent_role) + except KeyError: + win_conditions[agent_role] = {} + self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}") + return win_conditions + + def _get_goal_description_per_role(self)->dict: + """ + Method for finding goal description for each agent role in the game. + """ + goal_descriptions ={} + for agent_role in self.ALLOWED_ROLES: + try: + goal_descriptions[agent_role] = self.task_config.get_goal_description(agent_role=agent_role) + except KeyError: + goal_descriptions[agent_role] = "" + self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}") + return goal_descriptions + + def _get_max_steps_per_role(self)->dict: + """ + Method for finding max amount of steps in 1 episode for each agent role in the game. + """ + max_steps = {role:self.task_config.get_max_steps(role) for role in self.ALLOWED_ROLES} + return max_steps + + async def start_tcp_server(self): + """ + Starts TPC sever for the agent communication. + """ + try: + self.logger.info("Starting the server listening for agents") + server = await asyncio.start_server( + AgentServer( + self._agent_action_queue, + self._agent_response_queues, + max_connections=2 + ), + self.host, + self.port + ) + addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets) + self.logger.info(f"\tServing on {addrs}") + while not self.shutdown_flag.is_set(): + await asyncio.sleep(1) + except asyncio.CancelledError: + self.logger.debug("\tStopping TCP server task.") + except Exception as e: + self.logger.error(f"TCP server failed: {e}") + finally: + server.close() + await server.wait_closed() + self.logger.info("\tTCP server task stopped") + + async def start_tasks(self): + """ + High level funciton to start all the other asynchronous tasks. + - Reads the conf of the coordinator + - Creates queues + - Start the main part of the coordinator + - Start a server that listens for agents + """ + loop = asyncio.get_running_loop() + + # Set up signal handlers for graceful shutdown + loop.add_signal_handler( + signal.SIGINT, lambda: asyncio.create_task(self.shutdown_signal_handler()) + ) + loop.add_signal_handler( + signal.SIGTERM, lambda: asyncio.create_task(self.shutdown_signal_handler()) + ) + + + # initialize the game objects + if self._service_host: #get the task config using REST API + self.logger.info(f"Fetching task configuration from {self._service_host}:{self._service_port}") + await self._fetch_initialization_objects() + elif self._task_config_file: # load task config locally from a file + self.logger.info(f"Loading task configuration from file: {self._task_config_file}") + self._load_initialization_objects() + else: + raise ValueError("Task configuration not specified") + + + # Read configuration + self._starting_positions_per_role = self._get_starting_position_per_role() + self._win_conditions_per_role = self._get_win_condition_per_role() + self._goal_description_per_role = self._get_goal_description_per_role() + self._steps_limit_per_role = self._get_max_steps_per_role() + self.logger.debug(f"Timeouts set to:{self._steps_limit_per_role}") + if self.task_config.get_use_global_defender(): + self._global_defender = GlobalDefender() + else: + self._global_defender = None + self._use_dynamic_ips = self.task_config.get_use_dynamic_addresses() + self._rewards = self.task_config.get_rewards(["step", "sucess", "fail"]) + self.logger.debug(f"Rewards set to:{self._rewards}") + + # start server for agent communication + self._spawn_task(self.start_tcp_server) + + # start episode rewards task + self._spawn_task(self._assign_rewards_episode_end) + + # start episode rewards task + self._spawn_task(self._reset_game) + + # start action processing task + self._spawn_task(self.run_game) + + while not self.shutdown_flag.is_set(): + # just wait until user terminates + await asyncio.sleep(1) + self.logger.debug("Final cleanup started") + # make sure there are no running tasks left + for task in self._tasks: + task.cancel() # Cancel each active task + await asyncio.gather(*self._tasks, return_exceptions=True) # Wait for all tasks to finish + self.logger.info("All tasks shut down.") + + async def run_game(self): + """ + Task responsible for reading messages from the agent queue and processing them based on the ActionType. + """ + while not self.shutdown_flag.is_set(): + # Read message from the queue + agent_addr, message = await self._agent_action_queue.get() + if message is not None: + self.logger.info(f"Coordinator received: {message}.") + try: # Convert message to Action + action = Action.from_json(message) + self.logger.debug(f"\tConverted to: {action}.") + except Exception as e: + self.logger.error( + f"Error when converting msg to Action using Action.from_json():{e}, {message}" + ) + match action.type: # process action based on its type + case ActionType.JoinGame: + self.logger.debug(f"Start processing of ActionType.JoinGame by {agent_addr}") + self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}") + self._spawn_task(self._process_join_game_action, agent_addr, action) + case ActionType.QuitGame: + self.logger.debug(f"Start processing of ActionType.QuitGame by {agent_addr}") + self._spawn_task(self._process_quit_game_action, agent_addr) + case ActionType.ResetGame: + self.logger.debug(f"Start processing of ActionType.ResetGame by {agent_addr}") + self._spawn_task(self._process_reset_game_action, agent_addr) + case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: + self.logger.debug(f"Start processing of {action.type} by {agent_addr}") + self._spawn_task(self._process_game_action, agent_addr, action) + case _: + self.logger.warning(f"Unsupported action type: {action}!") + self.logger.info("\tAction processing task stopped.") + + async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: + """ + Method for processing Action of type ActionType.JoinGame + Inputs: + - agent_addr (tuple) + - JoinGame Action + Outputs: None (Method stores reposnse in the agent's response queue) + """ + try: + async with self._semaphore: + self.logger.info(f"New Join request by {agent_addr}.") + if agent_addr not in self.agents: + agent_name = action.parameters["agent_info"].name + agent_role = action.parameters["agent_info"].role + if agent_role in self.ALLOWED_ROLES: + # add agent to the world + new_agent_game_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role]) + if new_agent_game_state: # successful registration + async with self._agents_lock: + self.agents[agent_addr] = (agent_name, agent_role) + observation = self._initialize_new_player(agent_addr, new_agent_game_state) + self._agent_observations[agent_addr] = observation + output_message_dict = { + "to_agent": agent_addr, + "status": str(GameStatus.CREATED), + "observation": observation_as_dict(observation), + "message": { + "message": f"Welcome {agent_name}, registred as {agent_role}", + "max_steps": self._steps_limit_per_role[agent_role], + "goal_description": self._goal_description_per_role[agent_role], + "actions": [str(a) for a in ActionType], + "configuration_hash": self._CONFIG_FILE_HASH + }, + } + await self._agent_response_queues[agent_addr].put(self.convert_msg_dict_to_json(output_message_dict)) + else: + self.logger.info( + f"\tError in registration, unknown agent role: {agent_role}!" + ) + output_message_dict = { + "to_agent": agent_addr, + "status": str(GameStatus.BAD_REQUEST), + "message": f"Incorrect agent_role {agent_role}", + } + response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + await self._agent_response_queues[agent_addr].put(response_msg_json) + else: + self.logger.info("\tError in registration, agent already exists!") + output_message_dict = { + "to_agent": agent_addr, + "status": str(GameStatus.BAD_REQUEST), + "message": "Agent already exists.", + } + response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + await self._agent_response_queues[agent_addr].put(response_msg_json) + except asyncio.CancelledError: + self.logger.debug(f"Proccessing JoinAction of agent {agent_addr} interrupted") + raise # Ensure the exception propagates + finally: + self.logger.debug(f"Cleaning up after JoinGame for {agent_addr}.") + + async def _process_quit_game_action(self, agent_addr: tuple)->None: + """ + Method for processing Action of type ActionType.QuitGame + Inputs: + - agent_addr (tuple) + Outputs: None + """ + try: + await self.remove_agent(agent_addr, self._agent_states[agent_addr]) + agent_info = await self._remove_agent_from_game(agent_addr) + self.logger.info(f"Agent {agent_addr} removed from the game. {agent_info}") + except asyncio.CancelledError: + self.logger.debug(f"Proccessing QuitAction of agent {agent_addr} interrupted") + raise # Ensure the exception propagates + finally: + self.logger.debug(f"Cleaning up after QuitGame for {agent_addr}.") + + async def _process_reset_game_action(self, agent_addr: tuple)->None: + """ + Method for processing Action of type ActionType.ResetGame + Inputs: + - agent_addr (tuple) + Outputs: None + """ + self.logger.debug("Beginning the _process_reset_game_action.") + async with self._reset_lock: + # add reset request for this agent + self._reset_requests[agent_addr] = True + if all(self._reset_requests.values()): + # all agents want reset - reset the world + self.logger.debug(f"All agents requested reset, setting the event") + self._reset_event.set() + + # wait until reset is done + async with self._reset_done_condition: + await self._reset_done_condition.wait() + async with self._agents_lock: + output_message_dict = { + "to_agent": agent_addr, + "status": str(GameStatus.RESET_DONE), + "observation": observation_as_dict(self._agent_observations[agent_addr]), + "message": { + "message": "Resetting Game and starting again.", + "max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]], + "goal_description": self._goal_description_per_role[self.agents[agent_addr][1]], + "configuration_hash": self._CONFIG_FILE_HASH + }, + } + response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + await self._agent_response_queues[agent_addr].put(response_msg_json) + + async def _process_game_action(self, agent_addr: tuple, action:Action)->None: + if self._episode_ends[agent_addr]: + self.logger.warning(f"Agent {agent_addr}({self.agents[agent_addr]}) is attempting to play action {action} after the end of the episode!") + # agent can't play any more actions in the game + current_observation = self._agent_observations[agent_addr] + reward = self._agent_rewards[agent_addr] + end_reason = str(self._agent_status[agent_addr]) + new_observation = Observation( + current_observation.state, + reward=reward, + end=True, + info={'end_reason': end_reason, "info":"Episode ended. Request reset for starting new episode."}) + output_message_dict = { + "to_agent": agent_addr, + "observation": observation_as_dict(new_observation), + "status": str(GameStatus.FORBIDDEN), + } + else: + async with self._agents_lock: + self._agent_last_action[agent_addr] = action + self._agent_steps[agent_addr] += 1 + # wait for the new state from the world + new_state = await self.step(agent_id=agent_addr, agent_state=self._agent_states[agent_addr], action=action) + + # update agent's values + async with self._agents_lock: + # store new state of the agent + self._agent_states[agent_addr] = new_state + + # store new state of the agent using the new state + self._agent_status[agent_addr] = self._update_agent_status(agent_addr) + + # add reward for step (other rewards are added at the end of the episode) + self._agent_rewards[agent_addr] = self._rewards["step"] + + # check if the episode ends for this agent + self._episode_ends[agent_addr] = self._update_agent_episode_end(agent_addr) + + # check if the episode ends + if all(self._episode_ends.values()): + self._episode_end_event.set() + if self._episode_ends[agent_addr]: + # episode ended for this agent - wait for the others to finish + async with self._episode_rewards_condition: + await self._episode_rewards_condition.wait() + # append step to the trajectory if needed + if self.task_config.get_store_trajectories() or self._global_defender: + async with self._agents_lock: + self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None) + # add information to 'info' field if needed + info = {} + if self._agent_status[agent_addr] not in [AgentStatus.Playing, AgentStatus.PlayingWithTimeout]: + info["end_reason"] = str(self._agent_status[agent_addr]) + new_observation = Observation(self._agent_states[agent_addr], self._agent_rewards[agent_addr], self._episode_ends[agent_addr], info=info) + self._agent_observations[agent_addr] = new_observation + output_message_dict = { + "to_agent": agent_addr, + "observation": observation_as_dict(new_observation), + "status": str(GameStatus.OK), + } + response_msg_json = self.convert_msg_dict_to_json(output_message_dict) + await self._agent_response_queues[agent_addr].put(response_msg_json) + + async def _assign_rewards_episode_end(self): + """Task that waits for all agents to finish and assigns rewards.""" + self.logger.debug("Starting task for episode end reward assigning.") + while not self.shutdown_flag.is_set(): + # wait until episode is finished by all agents + done, pending = await asyncio.wait( + [asyncio.create_task(self._episode_end_event.wait()), + asyncio.create_task(self.shutdown_flag.wait())], + return_when=asyncio.FIRST_COMPLETED, + ) + # Check if shutdown_flag was set + if self.shutdown_flag.is_set(): + self.logger.debug("\tExiting reward assignment task.") + break + self.logger.info("Episode finished. Assigning final rewards to agents.") + async with self._agents_lock: + attackers = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "attacker"] + defenders = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "defender"] + successful_attack = False + # award attackers + for agent in attackers: + self.logger.debug(f"Processing reward for agent {agent}") + if self._agent_status[agent] is AgentStatus.Success: + self._agent_rewards[agent] += self._rewards["sucess"] + successful_attack = True + else: + self._agent_rewards[agent] += self._rewards["fail"] + + # award defenders + for agent in defenders: + self.logger.debug(f"Processing reward for agent {agent}") + if not successful_attack: + self._agent_rewards[agent] += self._rewards["sucess"] + self._agent_status[agent] = AgentStatus.Success + else: + self._agent_rewards[agent] += self._rewards["fail"] + self._agent_status[agent] = AgentStatus.Fail + # TODO Add penalty for False positives + # clear the episode end event + self._episode_end_event.clear() + # notify all waiting agents + async with self._episode_rewards_condition: + self._episode_rewards_condition.notify_all() + self.logger.info("\tReward assignment task stopped.") + + async def _reset_game(self): + """Task that waits for all agents to request resets""" + self.logger.debug("Starting task for game reset handelling.") + while not self.shutdown_flag.is_set(): + # wait until episode is finished by all agents + done, pending = await asyncio.wait( + [asyncio.create_task(self._reset_event.wait()), + asyncio.create_task(self.shutdown_flag.wait())], + return_when=asyncio.FIRST_COMPLETED, + ) + # Check if shutdown_flag was set + if self.shutdown_flag.is_set(): + self.logger.debug("\tExiting reset_game task.") + break + # wait until episode is finished by all agents + self.logger.info("Resetting game to initial state.") + await self.reset() + for agent in self.agents: + if self.task_config.get_store_trajectories() or self._global_defender: + async with self._agents_lock: + self._store_trajectory_to_file(agent) + self.logger.debug(f"Resetting agent {agent}") + new_state = await self.reset_agent(agent, self.agents[agent][1], self._agent_starting_position[agent]) + new_observation = Observation(self._agent_states[agent], 0, False, {}) + async with self._agents_lock: + self._agent_states[agent] = new_state + self._agent_observations[agent] = new_observation + self._episode_ends[agent] = False + self._reset_requests[agent] = False + self._agent_rewards[agent] = 0 + self._agent_steps[agent] = 0 + if self.agents[agent][1].lower() == "attacker": + self._agent_status[agent] = AgentStatus.PlayingWithTimeout + else: + self._agent_status[agent] = AgentStatus.Playing + if self.task_config.get_store_trajectories() or self._global_defender: + self._agent_trajectories[agent] = self._reset_trajectory(agent) + self._reset_event.clear() + # notify all waiting agents + async with self._reset_done_condition: + self._reset_done_condition.notify_all() + self.logger.info("\tReset game task stopped.") + + def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation: + """ + Method to initialize new player upon joining the game. + Returns initial observation for the agent based on the agent's role + """ + self.logger.info(f"\tInitializing new player{agent_addr}") + agent_name, agent_role = self.agents[agent_addr] + self._agent_steps[agent_addr] = 0 + self._reset_requests[agent_addr] = False + self._episode_ends[agent_addr] = False + self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role] + self._agent_states[agent_addr] = agent_current_state + self._agent_rewards[agent_addr] = 0 + if agent_role.lower() == "attacker": + self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout + else: + self._agent_status[agent_addr] = AgentStatus.Playing + if self.task_config.get_store_trajectories() or self._global_defender: + self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) + self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}") + return Observation(self._agent_states[agent_addr], 0, False, {}) + + async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + """ + Domain specific method of the environment. Creates the initial state of the agent. + """ + raise NotImplementedError + + async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool: + """ + Domain specific method of the environment. Creates the initial state of the agent. + """ + raise NotImplementedError + + async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + raise NotImplementedError + + async def _remove_agent_from_game(self, agent_addr): + """ + Removes player from the game. Should be called AFTER QuitGame action was processed by the world. + """ + self.logger.info(f"Removing player {agent_addr} from the GameCoordinator") + agent_info = {} + async with self._agents_lock: + if agent_addr in self.agents: + agent_info["state"] = self._agent_states.pop(agent_addr) + agent_info["num_steps"] = self._agent_steps.pop(agent_addr) + agent_info["agent_status"] = self._agent_status.pop(agent_addr) + async with self._reset_lock: + agent_info["reset_request"] = self._reset_requests.pop(agent_addr) + # check if this agent was not preventing reset + if any(self._reset_requests.values()): + self._reset_event.set() + agent_info["episode_end"] = self._episode_ends.pop(agent_addr) + #check if this agent was not preventing episode end + if all(self._episode_ends.values()): + if len(self.agents) > 0: + self._episode_end_event.set() + agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None) + agent_info["agent_info"] = self.agents.pop(agent_addr) + self.logger.debug(f"\t{agent_info}") + else: + self.logger.info(f"\t Player {agent_addr} not present in the game!") + return agent_info + + async def step(self, agent_id:tuple, agent_state:GameState, action:Action): + raise NotImplementedError + + async def reset(self): + return NotImplemented + + def goal_check(self, agent_addr:tuple)->bool: + """ + Check if the goal conditons were satisfied in a given game state + """ + def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool: + """ + Helper function for checking if a goal dictionary condition is satisfied + """ + # check if we have all IPs that should have some values (are keys in goal_dict) + if goal_dict.keys() <= known_dict.keys(): + try: + # Check if values (sets) for EACH key (host) in goal_dict are subsets of known_dict, keep matching_keys + matching_keys = [host for host in goal_dict.keys() if goal_dict[host]<= known_dict[host]] + # Check we have the amount of mathing keys as in the goal_dict + if len(matching_keys) == len(goal_dict.keys()): + return True + except KeyError: + # some keys are missing in the known_dict + return False + return False + self.logger.debug(f"Checking goal for agent {agent_addr}.") + goal_conditions = self._win_conditions_per_role[self.agents[agent_addr][1]] + state = self._agent_states[agent_addr] + # For each part of the state of the game, check if the conditions are met + goal_reached = {} + goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks) + goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts) + goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts) + goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services) + goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data) + goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks) + self.logger.debug(f"\t{goal_reached}") + return all(goal_reached.values()) + + def is_detected(self, agent:tuple)->bool: + if self._global_defender: + detection = self._global_defender.stochastic_with_threshold(self._agent_last_action[agent], self._agent_trajectories[agent]["trajectory"]["actions"]) + self.logger.debug(f"Global Detection result: {detection}") + return detection + else: + # No global defender + return False + + def is_timeout(self, agent:tuple)->bool: + timeout_reached = False + if self._steps_limit_per_role[self.agents[agent][1]]: + if self._agent_steps[agent] >= self._steps_limit_per_role[self.agents[agent][1]]: + timeout_reached = True + return timeout_reached + + def _update_agent_status(self, agent:tuple)->AgentStatus: + """ + Update the status of an agent based on reaching the goal, timeout or detection. + """ + # read current status of the agent + next_status = self._agent_status[agent] + if self.goal_check(agent): + # Goal has been reached + self.logger.info(f"Agent {agent}{self.agents[agent]} reached the goal!") + next_status = AgentStatus.Success + elif self.is_detected(agent): + # Detection by Global Defender + self.logger.info(f"Agent {agent}{self.agents[agent]} detected by GlobalDefender!") + next_status = AgentStatus.Fail + elif self.is_timeout(agent): + # Timout Reached + self.logger.info(f"Agent {agent}{self.agents[agent]} reached timeout ({self._agent_steps[agent]} steps).") + next_status = AgentStatus.TimeoutReached + return next_status + + def _update_agent_episode_end(self, agent:tuple)->bool: + episode_end = False + if self._agent_status[agent] in [AgentStatus.Success, AgentStatus.Fail, AgentStatus.TimeoutReached]: + # agent reached goal, timeout or was detected + episode_end = True + # check if there are any agents playing with timeout + elif all( + status != AgentStatus.PlayingWithTimeout + for status in self._agent_status.values() + ): + # all attackers have finised - terminate episode + self.logger.info(f"Stopping episode for {agent} because the is no ACTIVE agent playing.") + episode_end = True + return episode_end + + def _reset_trajectory(self, agent_addr:tuple)->dict: + agent_name, agent_role = self.agents[agent_addr] + self.logger.debug(f"Resetting trajectory of {agent_addr}") + return { + "trajectory":{ + "states":[self._agent_states[agent_addr].as_dict], + "actions":[], + "rewards":[], + }, + "end_reason":None, + "agent_role":agent_role, + "agent_name":agent_name + } + + def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str=None)-> None: + """ + Method for adding one step to the agent trajectory. + """ + if agent_addr in self._agent_trajectories: + self.logger.debug(f"Adding step to trajectory of {agent_addr}") + self._agent_trajectories[agent_addr]["trajectory"]["actions"].append(action.as_dict) + self._agent_trajectories[agent_addr]["trajectory"]["rewards"].append(reward) + self._agent_trajectories[agent_addr]["trajectory"]["states"].append(next_state.as_dict) + if end_reason: + self._agent_trajectories[agent_addr]["end_reason"] = end_reason + + def _store_trajectory_to_file(self, agent_addr:tuple, location="./trajectories")-> None: + self.logger.debug(f"Storing Trajectory of {agent_addr}in file") + if agent_addr in self._agent_trajectories: + agent_name, agent_role = self.agents[agent_addr] + filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl") + with jsonlines.open(filename, "a") as writer: + writer.write(self._agent_trajectories[agent_addr]) + self.logger.info(f"Trajectory of {agent_addr} strored in {filename}") \ No newline at end of file diff --git a/docs/Architecture.md b/AIDojoCoordinator/docs/Architecture.md similarity index 100% rename from docs/Architecture.md rename to AIDojoCoordinator/docs/Architecture.md diff --git a/docs/Components.md b/AIDojoCoordinator/docs/Components.md similarity index 100% rename from docs/Components.md rename to AIDojoCoordinator/docs/Components.md diff --git a/docs/Coordinator.md b/AIDojoCoordinator/docs/Coordinator.md similarity index 83% rename from docs/Coordinator.md rename to AIDojoCoordinator/docs/Coordinator.md index 6ac5aadc..bd7af93b 100644 --- a/docs/Coordinator.md +++ b/AIDojoCoordinator/docs/Coordinator.md @@ -12,11 +12,8 @@ Coordinator is the centerpiece of the game orchestration. It provides an interfa ## Connction to other game components Coordinator, having the role of the middle man in all communication between the agent and the world uses several queues for massing passing and handelling. -1. `Actions queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs. -2. `Answer queue` is a separeate queue **per agent** in which the results of the actions are send to the agent. -3. `World action queue` is a queue used for sending the acions from coordinator to the AI Dojo world -4. `World response queue` is a channel used for wolrd -> coordinator communicaiton (responses to the agents' action) -Message passing overview +1. `Action queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs. +2. `Answer queues` is a separeate queue **per agent** in which the results of the actions are send to the agent. ## Main components of the coordinator @@ -45,3 +42,7 @@ Coordinator, having the role of the middle man in all communication between the `self._agent_statuses`: status of each agent. One of AgentStatus `self._agent_rewards`: dictionary of final reward of each agent in the current episod. Only agent's which can't participate in the ongoing episode are listed. `self._agent_trajectories`: complete trajectories for each agent in the ongoing episode + + +## Episode +The episode starts with sufficient amount of agents registering in the game. Each agent role has a maximum allowed number of steps defined in the task configuration. An episode ends if all agents reach the goal \ No newline at end of file diff --git a/docs/Trajectory_analysis.md b/AIDojoCoordinator/docs/Trajectory_analysis.md similarity index 100% rename from docs/Trajectory_analysis.md rename to AIDojoCoordinator/docs/Trajectory_analysis.md diff --git a/docs/figures/architecture_diagram.jpg b/AIDojoCoordinator/docs/figures/architecture_diagram.jpg similarity index 100% rename from docs/figures/architecture_diagram.jpg rename to AIDojoCoordinator/docs/figures/architecture_diagram.jpg diff --git a/docs/figures/message_passing_coordinator.jpg b/AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg similarity index 100% rename from docs/figures/message_passing_coordinator.jpg rename to AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg diff --git a/docs/figures/scenarios/scenario 1_small.png b/AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png similarity index 100% rename from docs/figures/scenarios/scenario 1_small.png rename to AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png diff --git a/docs/figures/scenarios/scenario_1.png b/AIDojoCoordinator/docs/figures/scenarios/scenario_1.png similarity index 100% rename from docs/figures/scenarios/scenario_1.png rename to AIDojoCoordinator/docs/figures/scenarios/scenario_1.png diff --git a/docs/figures/scenarios/scenario_1_tiny.png b/AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png similarity index 100% rename from docs/figures/scenarios/scenario_1_tiny.png rename to AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png diff --git a/docs/figures/scenarios/three_nets.png b/AIDojoCoordinator/docs/figures/scenarios/three_nets.png similarity index 100% rename from docs/figures/scenarios/three_nets.png rename to AIDojoCoordinator/docs/figures/scenarios/three_nets.png diff --git a/env/game_components.py b/AIDojoCoordinator/game_components.py similarity index 60% rename from env/game_components.py rename to AIDojoCoordinator/game_components.py index 1c1a8c05..b82f94aa 100755 --- a/env/game_components.py +++ b/AIDojoCoordinator/game_components.py @@ -1,6 +1,7 @@ # Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz # Library of helpful functions and objects to play the net sec game -from dataclasses import dataclass, field +from dataclasses import dataclass, field, asdict +from typing import Dict, Any import dataclasses from collections import namedtuple import json @@ -16,9 +17,13 @@ class Service(): Service represents the service object in the NetSecGame """ name: str - type: str - version: str - is_local: bool + type: str = "unknown" + version: str = "unknown" + is_local: bool = True + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) """ IP represents the ip address object in the NetSecGame @@ -42,6 +47,11 @@ def __post_init__(self): def __repr__(self): return self.ip + def __eq__(self, other): + if not isinstance(other, IP): + return NotImplemented + return self.ip == other.ip + def is_private(self): """ Return if the IP is private or not @@ -54,6 +64,12 @@ def is_private(self): if self.ip != 'external': return True return False + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + def __hash__(self): + return hash(self.ip) @dataclass(frozen=True, eq=True) class Network(): @@ -96,6 +112,10 @@ def is_private(self): except ipaddress.AddressValueError: # If we are dealing with strings, assume they are local networks return True + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) """ Data represents the data object in the NetSecGame @@ -114,66 +134,48 @@ class Data(): def __hash__(self) -> int: return hash((self.owner, self.id, self.size, self.type)) + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + @enum.unique class ActionType(enum.Enum): - """ - ActionType represents generic action for attacker in the game. Each transition has a default probability - of success and probability of detection (if the defender is present). - Currently 5 action types are implemented: - - ScanNetwork - - FindServices - - FindData - - ExploitService - - ExfiltrateData - - BlockIP - - JoinGame - - QuitGame - """ - - #override the __new__ method to enable multiple parameters - def __new__(cls, *args, **kwargs): - value = len(cls.__members__) + 1 - obj = object.__new__(cls) - obj._value_ = value - return obj + ScanNetwork = "ScanNetwork" + FindServices = "FindServices" + FindData = "FindData" + ExploitService = "ExploitService" + ExfiltrateData = "ExfiltrateData" + BlockIP = "BlockIP" + JoinGame = "JoinGame" + QuitGame = "QuitGame" + ResetGame = "ResetGame" + + def to_string(self): + """Convert enum to string.""" + return self.value + + def __eq__(self, other): + # Compare with another ActionType + if isinstance(other, ActionType): + return self.value == other.value + # Compare with a string + elif isinstance(other, str): + return self.value == other + return False - def __init__(self, default_success_p: float): - self.default_success_p = default_success_p + def __hash__(self): + # Use the hash of the value for consistent behavior + return hash(self.value) @classmethod - def from_string(cls, string:str): - match string: - case "ActionType.ExploitService": - return ActionType.ExploitService - case "ActionType.ScanNetwork": - return ActionType.ScanNetwork - case "ActionType.FindServices": - return ActionType.FindServices - case "ActionType.FindData": - return ActionType.FindData - case "ActionType.ExfiltrateData": - return ActionType.ExfiltrateData - case "ActionType.BlockIP": - return ActionType.BlockIP - case "ActionType.JoinGame": - return ActionType.JoinGame - case "ActionType.ResetGame": - return ActionType.ResetGame - case "ActionType.QuitGame": - return ActionType.QuitGame - case _: - raise ValueError("Uknown Action Type") - - #ActionTypes - ScanNetwork = 0.9 - FindServices = 0.9 - FindData = 0.8 - ExploitService = 0.7 - ExfiltrateData = 0.8 - BlockIP = 1 - JoinGame = 1 - QuitGame = 1 - ResetGame = 1 + def from_string(cls, name): + """Convert string to enum, stripping 'ActionType.' if present.""" + if name.startswith("ActionType."): + name = name.split("ActionType.")[1] + try: + return cls[name] + except KeyError: + raise ValueError(f"Invalid ActionType: {name}") @dataclass(frozen=True, eq=True, order=True) class AgentInfo(): @@ -186,132 +188,83 @@ class AgentInfo(): def __repr__(self): return f"{self.name}({self.role})" -#Actions -class Action(): + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + +@dataclass(frozen=True) +class Action: """ - Actions are composed of the action type (see ActionTupe) and additional parameters listed in dictionary - - ScanNetwork {"target_network": Network object, "source_host": IP object} - - FindServices {"target_host": IP object, "source_host": IP object,} - - FindData {"target_host": IP object, "source_host": IP object} - - ExploitService {"target_host": IP object, "target_service": Service object, "source_host": IP object} - - ExfiltrateData {"target_host": IP object, "source_host": IP object, "data": Data object} - - BlockIP("target_host": IP object, "source_host": IP object, "blocked_host": IP object) + Immutable dataclass representing an Action. """ - def __init__(self, action_type: ActionType, params: dict={}) -> None: - self._type = action_type - self._parameters = params + action_type: ActionType + parameters: Dict[str, Any] = field(default_factory=dict) @property - def type(self) -> ActionType: - return self._type - @property - def parameters(self)->dict: - return self._parameters - - @property - def as_dict(self)->dict: + def as_dict(self) -> Dict[str, Any]: + """Return a dictionary representation of the Action.""" params = {} - for k,v in self.parameters.items(): - if isinstance(v, Service): - params[k] = vars(v) - elif isinstance(v, Data): - params[k] = vars(v) - elif isinstance(v, AgentInfo): - params[k] = vars(v) + for k, v in self.parameters.items(): + if hasattr(v, '__dict__'): # Handle custom objects like Service, Data, AgentInfo + params[k] = asdict(v) else: params[k] = str(v) - return {"type": str(self.type), "params": params} + return {"action_type": str(self.action_type), "parameters": params} + @property + def type(self): + return self.action_type + + def to_json(self) -> str: + """Serialize the Action to a JSON string.""" + return json.dumps(self.as_dict) + @classmethod - def from_dict(cls, data_dict:dict): - action_type = ActionType.from_string(data_dict["type"]) + def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": + """Create an Action from a dictionary.""" + action_type = ActionType.from_string(data_dict["action_type"]) params = {} - for k,v in data_dict["params"].items(): + for k, v in data_dict["parameters"].items(): match k: - case "source_host": - params[k] = IP(v) - case "target_host": - params[k] = IP(v) - case "blocked_host": - params[k] = IP(v) + case "source_host" | "target_host" | "blocked_host": + params[k] = IP.from_dict(v) case "target_network": - net,mask = v.split("/") - params[k] = Network(net ,int(mask)) + params[k] = Network.from_dict(v) case "target_service": - params[k] = Service(**v) + params[k] = Service.from_dict(v) case "data": - params[k] = Data(**v) + params[k] = Data.from_dict(v) case "agent_info": - params[k] = AgentInfo(**v) + params[k] = AgentInfo.from_dict(v) case _: - raise ValueError(f"Unsupported Value in {k}:{v}") - action = Action(action_type=action_type, params=params) - return action - - def __repr__(self) -> str: - return f"Action <{self._type}|{self._parameters}>" + raise ValueError(f"Unsupported value in {k}: {v}") + return cls(action_type=action_type, parameters=params) - def __str__(self) -> str: - return f"Action <{self._type}|{self._parameters}>" + @classmethod + def from_json(cls, json_string: str) -> "Action": + """Create an Action from a JSON string.""" + data_dict = json.loads(json_string) + return cls.from_dict(data_dict) - def __eq__(self, __o: object) -> bool: - if isinstance(__o, Action): - return self._type == __o.type and self.parameters == __o.parameters - return False + def __repr__(self) -> str: + return f"Action <{self.action_type}|{self.parameters}>" - def __hash__(self) -> int: - sorted_params = sorted(self._parameters.items(), key= lambda x: x[0]) - sorted_params = [f"{x}{str(y)}" for x,y in sorted_params] - return hash(self._type) + hash("".join(sorted_params)) - - def as_json(self)->str: - ret_dict = {"action_type":str(self.type)} - ret_dict["parameters"] = {k:dataclasses.asdict(v) for k,v in self.parameters.items()} - return json.dumps(ret_dict) + def __str__(self) -> str: + return f"Action <{self.action_type}|{self.parameters}>" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Action): + return NotImplemented + return ( + self.action_type == other.action_type and + self.parameters == other.parameters + ) - - - @classmethod - def from_json(cls, json_string:str): - """ - Classmethod to ccreate Action object from json string representation - """ - parameters_dict = json.loads(json_string) - action_type = ActionType.from_string(parameters_dict["action_type"]) - parameters = {} - parameters_dict = parameters_dict["parameters"] - match action_type: - case ActionType.ScanNetwork: - parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]),"target_network": Network(parameters_dict["target_network"]["ip"], parameters_dict["target_network"]["mask"])} - case ActionType.FindServices: - parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]), "target_host": IP(parameters_dict["target_host"]["ip"])} - case ActionType.FindData: - parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]), "target_host": IP(parameters_dict["target_host"]["ip"])} - case ActionType.ExploitService: - parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]), - "target_service": Service(parameters_dict["target_service"]["name"], - parameters_dict["target_service"]["type"], - parameters_dict["target_service"]["version"], - parameters_dict["target_service"]["is_local"]), - "source_host": IP(parameters_dict["source_host"]["ip"])} - case ActionType.ExfiltrateData: - parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]), - "source_host": IP(parameters_dict["source_host"]["ip"]), - "data": Data(parameters_dict["data"]["owner"],parameters_dict["data"]["id"])} - case ActionType.BlockIP: - parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]), - "source_host": IP(parameters_dict["source_host"]["ip"]), - "blocked_host": IP(parameters_dict["blocked_host"]["ip"])} - case ActionType.JoinGame: - parameters = {"agent_info":AgentInfo(parameters_dict["agent_info"]["name"], parameters_dict["agent_info"]["role"])} - case ActionType.QuitGame: - parameters = {} - case ActionType.ResetGame: - parameters = {} - case _: - raise ValueError(f"Unknown Action type:{action_type}") - action = Action(action_type=action_type, params=parameters) - return action + def __hash__(self) -> int: + # Convert parameters to a sorted tuple of key-value pairs for consistency + sorted_params = tuple(sorted((k, hash(v)) for k, v in self.parameters.items())) + return hash((self.action_type, sorted_params)) @dataclass(frozen=True) class GameState(): @@ -481,14 +434,40 @@ def from_string(cls, string:str): return GameStatus.RESET_DONE def __repr__(self) -> str: return str(self) -if __name__ == "__main__": - pass - #data1 = Data(owner="test", id="test_data", content="content", type="db") - #data2 = Data(owner="test", id="test_data", content="content", type="db") - # print(data) - # print(data.size) - - # s = set() - # s.add(data) - # s.add( Data("test", "test_data", content="new_content", type="db")) - # print(s) \ No newline at end of file + + +@enum.unique +class AgentStatus(enum.Enum): + Playing = "Playing" + PlayingWithTimeout = "PlayingWithTimeout" + TimeoutReached = "TimeoutReached" + ResetRequested = "ResetRequested" + Success = "Success" + Fail = "Fail" + + def to_string(self): + """Convert enum to string.""" + return self.value + + def __eq__(self, other): + # Compare with another ActionType + if isinstance(other, AgentStatus): + return self.value == other.value + # Compare with a string + elif isinstance(other, str): + return self.value == other + return False + + def __hash__(self): + # Use the hash of the value for consistent behavior + return hash(self.value) + + @classmethod + def from_string(cls, name): + """Convert string to enum, stripping 'AgentStatus.' if present.""" + if name.startswith("AgentStatus."): + name = name.split("AgentStatus.")[1] + try: + return cls[name] + except KeyError: + raise ValueError(f"Invalid AgentStatus: {name}") \ No newline at end of file diff --git a/AIDojoCoordinator/global_defender.py b/AIDojoCoordinator/global_defender.py new file mode 100644 index 00000000..5937ce58 --- /dev/null +++ b/AIDojoCoordinator/global_defender.py @@ -0,0 +1,89 @@ +# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +from itertools import groupby +from AIDojoCoordinator.game_components import ActionType, Action +from random import random + + +class GlobalDefender: + + def __init__(self): + + # The probability of detecting an action is defined by the following dictionary + self._DEFAULT_DETECTION_PROBS = { + ActionType.ScanNetwork: 0.05, + ActionType.FindServices: 0.075, + ActionType.ExploitService: 0.1, + ActionType.FindData: 0.025, + ActionType.ExfiltrateData: 0.025, + ActionType.BlockIP: 0.01 + } + + + # Ratios of action types in the time window (TW) for each action type. The ratio should be higher than the defined value to trigger a detection check + self._TW_TYPE_RATIOS_THRESHOLD = { + ActionType.ScanNetwork: 0.25, + ActionType.FindServices: 0.3, + ActionType.ExploitService: 0.25, + ActionType.FindData: 0.5, + ActionType.ExfiltrateData: 0.25, + ActionType.BlockIP: 1 + } + + # Thresholds for consecutive actions of the same type in the TW. Only if the threshold is crossed, the detection check is triggered + self._TW_CONSECUTIVE_TYPE_THRESHOLD = { + ActionType.ScanNetwork: 2, + ActionType.FindServices: 3, + ActionType.ExfiltrateData: 2 + } + + # Thresholds for repeated actions in the episode. Only if the threshold is crossed, the detection check is triggered + self._EPISODE_REPEATED_ACTION_THRESHOLD = { + ActionType.ExploitService: 2, + ActionType.FindData: 2, + } + + def stochastic(self, action_type:ActionType)->bool: + """ + Simple random detection based on predefied probability and ActionType + """ + roll = random() + if roll < self._DEFAULT_DETECTION_PROBS[action_type]: + return True + else: + return False + + def stochastic_with_threshold(self, action: Action, episode_actions:list, tw_size:int=5)-> bool: + """ + Only detect based on set probabilities if pre-defined thresholds are crossed. + """ + # extend the episode with the latest action + # We need to copy the list before the copying, so we avoid modifying it when it is returned. Modifycation of passed list is the default behavior in Python + temp_episode_actions = episode_actions.copy() + temp_episode_actions.append(action.as_dict) + if len(temp_episode_actions) >= tw_size: + last_n_actions = temp_episode_actions[-tw_size:] + last_n_action_types = [action['type'] for action in last_n_actions] + # compute ratio of action type in the TW + tw_ratio = last_n_action_types.count(str(action.type))/tw_size + # Count how many times this exact (parametrized) action was played in episode + repeats_in_episode = temp_episode_actions.count(action.as_dict) + # compute the highest consecutive number of action type in TW + max_consecutive_action_type = max(sum(1 for item in grouped if item == str(action.type)) + for _, grouped in groupby(last_n_action_types)) + + if action.type in self._TW_CONSECUTIVE_TYPE_THRESHOLD.keys(): + # ScanNetwork, FindServices, ExfiltrateData + if tw_ratio < self._TW_TYPE_RATIOS_THRESHOLD[action.type] and max_consecutive_action_type < self._TW_CONSECUTIVE_TYPE_THRESHOLD[action.type]: + return False + else: + return self.stochastic(action.type) + elif action.type in self._EPISODE_REPEATED_ACTION_THRESHOLD.keys(): + # FindData, Exploit service + if tw_ratio < self._TW_TYPE_RATIOS_THRESHOLD[action.type] and repeats_in_episode < self._EPISODE_REPEATED_ACTION_THRESHOLD[action.type]: + return False + else: + return self.stochastic(action.type) + else: #Other actions - Do not detect + return False + else: + return False diff --git a/env/netsecenv_conf.yaml b/AIDojoCoordinator/netsecenv_conf.yaml similarity index 98% rename from env/netsecenv_conf.yaml rename to AIDojoCoordinator/netsecenv_conf.yaml index e24a5424..ba93ae9c 100644 --- a/env/netsecenv_conf.yaml +++ b/AIDojoCoordinator/netsecenv_conf.yaml @@ -4,7 +4,7 @@ coordinator: # Define the conf for all the 'attacker' agents. agents: Attacker: - max_steps: 50 + max_steps: 20 goal: description: "Exfiltrate data to '213.47.23.195'" is_any_part_of_goal_random: True @@ -102,9 +102,10 @@ env: use_dynamic_addresses: False use_firewall: True save_trajectories: False - goal_reward: 100 - detection_reward: -5 - step_reward: -1 + rewards: + win: 100 + step: -1 + loss: -10 actions: scan_network: prob_success: 1.0 diff --git a/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml b/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml new file mode 100644 index 00000000..92217768 --- /dev/null +++ b/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml @@ -0,0 +1,120 @@ +# Configuration file for the NetSecGame environment + +coordinator: + # Define the conf for all the 'attacker' agents. + agents: + Attacker: + max_steps: 10 + goal: + description: "Take control of the host '192.168.0.3'" + is_any_part_of_goal_random: True + known_networks: [] + #known_networks: [192.168.1.0/24, 192.168.3.0/24] + known_hosts: [] + #known_hosts: [192.168.1.1, 192.168.1.2] + controlled_hosts: ['192.168.0.3'] + #controlled_hosts: [213.47.23.195, 192.168.1.3] + # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local' + known_services: {} + #known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]} + # In data, put the target host that must have the data and which data in format user,data + # Example to fix the data in one host + known_data: {} + # Example to fix two data in one host + #known_data: {213.47.23.195: [[User1,DataFromServer1], [User5,DataFromServer5]]} + # Example to fix the data in two host + #known_data: {213.47.23.195: [User1,DataFromServer1], 192.168.3.1: [User3,Data3FromServer3]} + # Example to ask a random data in a specific server. Putting 'random' in the data, forces the env to randomly choose where the goal data is + # known_data: {213.47.23.195: [random]} + known_blocks: {} + # Example of known blocks. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3 + # known_blocks: {192.168.2.2: {192.168.1.3}} + start_position: + known_networks: [] + known_hosts: [] + # The attacker must always at least control the CC if the goal is to exfiltrate there + # Example of fixing the starting point of the agent in a local host + controlled_hosts: [random] + # Example of asking a random position to start the agent + # controlled_hosts: [213.47.23.195, random] + # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local' + known_services: {} + # known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]} + # Same format as before + known_data: {} + known_blocks: {} + # Example of known blocks to start with. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3 + # known_blocks: {192.168.2.2: {192.168.1.3}} + + Defender: + goal: + description: "Block all attackers" + is_any_part_of_goal_random: False + known_networks: [] + # Example + #known_networks: [192.168.1.0/24, 192.168.3.0/24] + known_hosts: [] + # Example + #known_hosts: [192.168.1.1, 192.168.1.2] + controlled_hosts: [] + # Example + #controlled_hosts: [213.47.23.195, 192.168.1.3] + # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local' + known_services: {} + # Example + #known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]} + # In data, put the target host that must have the data and which data in format user,data + # Example to fix the data in one host + known_data: {} + # Example to fix two data in one host + #known_data: {213.47.23.195: [[User1,DataFromServer1], [User5,DataFromServer5]]} + # Example to fix the data in two host + #known_data: {213.47.23.195: [User1,DataFromServer1], 192.168.3.1: [User3,Data3FromServer3]} + # Example to ask a random data in a specific server. Putting 'random' in the data, forces the env to randomly choose where the goal data is + # known_data: {213.47.23.195: [random]} + known_blocks: {192.168.0.3: 'all_attackers'} + # Example of known blocks. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3 + # known_blocks: {192.168.2.2: {192.168.1.3}} + # You can also use the wildcard string 'all_routers', and 'all_attackers', to mean that all the controlled hosts of all the attackers should be in this list in order to win + + start_position: + # should be empty for defender - will be extracted from controlled hosts + known_networks: [] + # should be empty for defender - will be extracted from controlled hosts + known_hosts: [] + # list of controlled hosts, wildard "all_local" can be used to include all local IPs + controlled_hosts: [all_local] + known_services: {} + known_data: {} + # Blocked IPs + blocked_ips: {} + known_blocks: {} + # Example of known blocks to start with. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3 + # known_blocks: {192.168.2.2: {192.168.1.3}} + +env: + # random means to choose the seed in a random way, so it is not fixed + random_seed: 'random' + # Or you can fix the seed + # random_seed: 42 + scenario: 'scenario1_tiny' + use_global_defender: False + use_dynamic_addresses: False + use_firewall: True + save_trajectories: False + goal_reward: 100 + detection_reward: -5 + step_reward: -1 + actions: + scan_network: + prob_success: 1.0 + find_services: + prob_success: 1.0 + exploit_service: + prob_success: 1.0 + find_data: + prob_success: 1.0 + exfiltrate_data: + prob_success: 1.0 + block_ip: + prob_success: 1.0 \ No newline at end of file diff --git a/env/scenarios/__init__.py b/AIDojoCoordinator/scenarios/__init__.py similarity index 100% rename from env/scenarios/__init__.py rename to AIDojoCoordinator/scenarios/__init__.py diff --git a/env/scenarios/scenario_configuration.py b/AIDojoCoordinator/scenarios/scenario_configuration.py similarity index 96% rename from env/scenarios/scenario_configuration.py rename to AIDojoCoordinator/scenarios/scenario_configuration.py index e2dbdcc6..d8fedb80 100644 --- a/env/scenarios/scenario_configuration.py +++ b/AIDojoCoordinator/scenarios/scenario_configuration.py @@ -1,9 +1,9 @@ # This file defines the hosts and their characteristics, the services they run, the users they have and their security levels, the data they have, and in the router/FW all the rules of which host can access what import cyst.api.configuration as cyst_cfg -#from cyst.api.configuration import * from cyst.api.configuration.network.elements import RouteConfig from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity - +from cyst.api.configuration import ExploitConfig, VulnerableServiceConfig +from cyst.api.logic.exploit import ExploitLocality, ExploitCategory ''' -------------------------------------------------------------------------------------------------------------------- A template for local password authentication. ''' @@ -28,7 +28,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="microsoft-ds", + name="microsoft-ds", owner="Local system", version="10.0.19041", local=False, @@ -66,7 +66,7 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -94,7 +94,7 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="windows login", + name="windows login", owner="Administrator", version="10.0.19041", local=True, @@ -102,7 +102,7 @@ authentication_providers=[local_password_auth("windows login")] ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, @@ -130,7 +130,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -156,7 +156,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="postgresql", + name="postgresql", owner="postgresql", version="14.3.0", private_data=[ @@ -168,7 +168,7 @@ access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -195,7 +195,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="http", + name="http", owner="lighttpd", version="1.4.54", local=False, @@ -211,7 +211,7 @@ access_schemes=[] ), cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -237,7 +237,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -261,7 +261,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -282,7 +282,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -306,7 +306,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -327,7 +327,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -362,7 +362,7 @@ ], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -386,14 +386,14 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -419,7 +419,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -443,14 +443,14 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -476,7 +476,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -498,14 +498,14 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -531,7 +531,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -553,14 +553,14 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -586,7 +586,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -604,7 +604,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -703,14 +703,14 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="listener", + name="listener", owner="attacker", version="1.0.0", local=False, @@ -746,16 +746,16 @@ - There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed... ''' exploits = [ - cyst_cfg.ExploitConfig( + ExploitConfig( services=[ - cyst_cfg.VulnerableServiceConfig( - name="microsoft-ds", - min_version="10.0.19041", + VulnerableServiceConfig( + service="microsoft-ds", + min_version="10.0. 19041", max_version="10.0.19041" ) ], - locality=cyst_cfg.ExploitLocality.REMOTE, - category=cyst_cfg.ExploitCategory.DATA_MANIPULATION, + locality=ExploitLocality.REMOTE, + category=ExploitCategory.DATA_MANIPULATION, id="smb_exploit" ) ] diff --git a/env/scenarios/smaller_scenario_configuration.py b/AIDojoCoordinator/scenarios/smaller_scenario_configuration.py similarity index 95% rename from env/scenarios/smaller_scenario_configuration.py rename to AIDojoCoordinator/scenarios/smaller_scenario_configuration.py index a4f6d725..215b4aa7 100644 --- a/env/scenarios/smaller_scenario_configuration.py +++ b/AIDojoCoordinator/scenarios/smaller_scenario_configuration.py @@ -1,8 +1,10 @@ # This file defines the hosts and their characteristics, the services they run, the users they have and their security levels, the data they have, and in the router/FW all the rules of which host can access what import cyst.api.configuration as cyst_cfg +#from cyst.api.configuration import * from cyst.api.configuration.network.elements import RouteConfig from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity - +from cyst.api.configuration import ExploitConfig, VulnerableServiceConfig +from cyst.api.logic.exploit import ExploitLocality, ExploitCategory ''' -------------------------------------------------------------------------------------------------------------------- A template for local password authentication. ''' @@ -21,12 +23,13 @@ - the only windows server. It does not connect to the AD - access schemes for remote desktop and file sharing are kept separate, but can be integrated into one if needed +- Service types should be derived from nmap services https://svn.nmap.org/nmap/nmap-services ''' smb_server = cyst_cfg.NodeConfig( active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="lanman server", + name="microsoft-ds", owner="Local system", version="10.0.19041", local=False, @@ -64,7 +67,7 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="remote desktop service", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -92,7 +95,7 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="windows login", + name="windows login", owner="Administrator", version="10.0.19041", local=True, @@ -100,7 +103,7 @@ authentication_providers=[local_password_auth("windows login")] ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, @@ -128,7 +131,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="openssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -154,7 +157,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="postgresql", + name="postgresql", owner="postgresql", version="14.3.0", private_data=[ @@ -166,7 +169,7 @@ access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -193,7 +196,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="lighttpd", + name="http", owner="lighttpd", version="1.4.54", local=False, @@ -209,7 +212,7 @@ access_schemes=[] ), cyst_cfg.PassiveServiceConfig( - type="openssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -235,7 +238,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -259,7 +262,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="openssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -280,7 +283,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -304,7 +307,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="openssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -325,7 +328,7 @@ )] ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -360,7 +363,7 @@ ], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="remote desktop service", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -384,14 +387,14 @@ ] ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -490,14 +493,14 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED ), cyst_cfg.PassiveServiceConfig( - type="listener", + name="listener", owner="attacker", version="1.0.0", local=False, @@ -505,7 +508,7 @@ ) ], traffic_processors=[], - interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.195/26"))], + interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.192/26"))], shell="bash", id="outside_node" ) @@ -520,7 +523,7 @@ cyst_cfg.ConnectionConfig("other_server_1", 0, "router1", 3), cyst_cfg.ConnectionConfig("other_server_2", 0, "router1", 4), cyst_cfg.ConnectionConfig("client_1", 0, "router1", 5), - cyst_cfg.ConnectionConfig("internet", 0, "router1", 6), + cyst_cfg.ConnectionConfig("internet", 0, "router1", 10), cyst_cfg.ConnectionConfig("internet", 1, "outside_node", 0) ] @@ -529,16 +532,16 @@ - There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed... ''' exploits = [ - cyst_cfg.ExploitConfig( + ExploitConfig( services=[ - cyst_cfg.VulnerableServiceConfig( - name="lanman server", - min_version="10.0.19041", + VulnerableServiceConfig( + service="microsoft-ds", + min_version="10.0. 19041", max_version="10.0.19041" ) ], - locality=cyst_cfg.ExploitLocality.REMOTE, - category=cyst_cfg.ExploitCategory.DATA_MANIPULATION, + locality=ExploitLocality.REMOTE, + category=ExploitCategory.DATA_MANIPULATION, id="smb_exploit" ) ] diff --git a/env/scenarios/test_scenario_configuration.py b/AIDojoCoordinator/scenarios/test_scenario_configuration.py similarity index 100% rename from env/scenarios/test_scenario_configuration.py rename to AIDojoCoordinator/scenarios/test_scenario_configuration.py diff --git a/env/scenarios/three_net_scenario.py b/AIDojoCoordinator/scenarios/three_net_scenario.py similarity index 97% rename from env/scenarios/three_net_scenario.py rename to AIDojoCoordinator/scenarios/three_net_scenario.py index d4a14f00..b4bd6f95 100644 --- a/env/scenarios/three_net_scenario.py +++ b/AIDojoCoordinator/scenarios/three_net_scenario.py @@ -34,7 +34,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="microsoft-ds", + name="microsoft-ds", owner="Local system", version="10.0.19041", local=False, @@ -75,7 +75,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -118,7 +118,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="windows login", + name="windows login", owner="Administrator", version="10.0.19041", local=True, @@ -126,7 +126,7 @@ authentication_providers=[local_password_auth("windows login")], ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, @@ -157,7 +157,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="http", + name="http", owner="lighttpd", version="1.4.54", local=False, @@ -170,7 +170,7 @@ access_schemes=[], ), cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -213,7 +213,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -241,7 +241,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -271,7 +271,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -305,7 +305,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="active-directory", + name="active-directory", owner="Local system", version="10.0.19041", local=False, @@ -329,7 +329,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="windows login", + name="windows login", owner="Administrator", version="10.0.19041", local=True, @@ -363,7 +363,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -406,7 +406,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="postgresql", + name="postgresql", owner="postgresql", version="14.3.0", private_data=[ @@ -416,7 +416,7 @@ access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -455,7 +455,7 @@ ], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -486,14 +486,14 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -523,7 +523,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ms-wbt-server", + name="ms-wbt-server", owner="Local system", version="10.0.19041", local=False, @@ -554,14 +554,14 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="powershell", + name="powershell", owner="Local system", version="10.0.19041", local=True, access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -591,7 +591,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -622,14 +622,14 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -659,7 +659,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="ssh", + name="ssh", owner="openssh", version="8.1.0", local=False, @@ -690,14 +690,14 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -726,7 +726,7 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, @@ -750,7 +750,7 @@ ], ), cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", + name="can_attack_start_here", owner="Local system", version="1", local=True, @@ -1122,14 +1122,14 @@ active_services=[], passive_services=[ cyst_cfg.PassiveServiceConfig( - type="bash", + name="bash", owner="root", version="5.0.0", local=True, access_level=cyst_cfg.AccessLevel.LIMITED, ), cyst_cfg.PassiveServiceConfig( - type="listener", + name="listener", owner="attacker", version="1.0.0", local=False, @@ -1173,7 +1173,7 @@ cyst_cfg.ExploitConfig( services=[ cyst_cfg.VulnerableServiceConfig( - name="microsoft-ds", min_version="10.0.19041", max_version="10.0.19041" + service="microsoft-ds", min_version="10.0.19041", max_version="10.0.19041" ) ], locality=cyst_cfg.ExploitLocality.REMOTE, diff --git a/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py b/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py new file mode 100644 index 00000000..afc3af09 --- /dev/null +++ b/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py @@ -0,0 +1,122 @@ +import cyst.api.configuration as cyst_cfg + + +target = cyst_cfg.NodeConfig( + active_services=[], + passive_services=[ + cyst_cfg.PassiveServiceConfig( + name="bash", + owner="root", + version="8.1.0", + access_level=cyst_cfg.AccessLevel.LIMITED, + local=True, + ), + cyst_cfg.PassiveServiceConfig( + name="lighttpd", + owner="www", + version="1.4.62", + access_level=cyst_cfg.AccessLevel.LIMITED, + local=False, + ) + ], + shell="bash", + traffic_processors=[], + interfaces=[], + name="target" +) + +attacker_service = cyst_cfg.ActiveServiceConfig( + type="netsecenv_agent", + name="attacker", + owner="attacker", + access_level=cyst_cfg.AccessLevel.LIMITED, + ref="attacker_service" +) + +attacker = cyst_cfg.NodeConfig( + active_services=[attacker_service()], + passive_services=[], + interfaces=[], + shell="", + traffic_processors=[], + name="attacker_node" +) + +attacker2 = cyst_cfg.NodeConfig( + active_services=[attacker_service()], + passive_services=[], + interfaces=[], + shell="", + traffic_processors=[], + name="attacker_node_2" +) + +router = cyst_cfg.RouterConfig( + interfaces=[ + cyst_cfg.InterfaceConfig( + ip=cyst_cfg.IPAddress("192.168.0.1"), + net=cyst_cfg.IPNetwork("192.168.0.1/24"), + index=0 + ), + cyst_cfg.InterfaceConfig( + ip=cyst_cfg.IPAddress("192.168.0.1"), + net=cyst_cfg.IPNetwork("192.168.0.1/24"), + index=1 + ), + cyst_cfg.InterfaceConfig( + ip=cyst_cfg.IPAddress("192.168.0.1"), + net=cyst_cfg.IPNetwork("192.168.0.1/24"), + index=2 + ) + ], + traffic_processors=[ + cyst_cfg.FirewallConfig( + default_policy=cyst_cfg.FirewallPolicy.ALLOW, + chains=[ + cyst_cfg.FirewallChainConfig( + type=cyst_cfg.FirewallChainType.FORWARD, + policy=cyst_cfg.FirewallPolicy.ALLOW, + rules=[] + ) + ] + ) + ], + id="router" +) + +exploit1 = cyst_cfg.ExploitConfig( + services=[ + cyst_cfg.VulnerableServiceConfig( + service="lighttpd", + min_version="1.4.62", + max_version="1.4.62" + ) + ], + locality=cyst_cfg.ExploitLocality.REMOTE, + category=cyst_cfg.ExploitCategory.CODE_EXECUTION, + id="http_exploit" +) + +connection1 = cyst_cfg.ConnectionConfig( + src_ref=target, + src_port=-1, + dst_ref=router, + dst_port=0 +) + +connection2 = cyst_cfg.ConnectionConfig( + src_ref=attacker, + src_port=-1, + dst_ref=router, + dst_port=1 +) + +connection3 = cyst_cfg.ConnectionConfig( + src_ref=attacker2, + src_port=-1, + dst_ref=router, + dst_port=2 +) + + +configuration_objects = [target, attacker_service, attacker, attacker2, router, exploit1, connection2, connection1, connection3] diff --git a/env/worlds/__init__.py b/AIDojoCoordinator/utils/__init__.py similarity index 100% rename from env/worlds/__init__.py rename to AIDojoCoordinator/utils/__init__.py diff --git a/utils/action_plots.r b/AIDojoCoordinator/utils/action_plots.r similarity index 100% rename from utils/action_plots.r rename to AIDojoCoordinator/utils/action_plots.r diff --git a/utils/action_plots_readme.md b/AIDojoCoordinator/utils/action_plots_readme.md similarity index 100% rename from utils/action_plots_readme.md rename to AIDojoCoordinator/utils/action_plots_readme.md diff --git a/utils/actions_parser.py b/AIDojoCoordinator/utils/actions_parser.py similarity index 100% rename from utils/actions_parser.py rename to AIDojoCoordinator/utils/actions_parser.py diff --git a/utils/gamaplay_graphs.py b/AIDojoCoordinator/utils/gamaplay_graphs.py similarity index 60% rename from utils/gamaplay_graphs.py rename to AIDojoCoordinator/utils/gamaplay_graphs.py index c42e7947..a2f6f80e 100644 --- a/utils/gamaplay_graphs.py +++ b/AIDojoCoordinator/utils/gamaplay_graphs.py @@ -1,17 +1,16 @@ from trajectory_analysis import read_json import numpy as np -import sys import os import utils import argparse import matplotlib.pyplot as plt -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) ))) -from env.game_components import GameState, Action +from AIDojoCoordinator.game_components import GameState, Action class TrajectoryGraph: def __init__(self)->None: self._checkpoints = {} + self._checkpoint_size = {} self._checkpoint_edges = {} self._checkpoint_simple_edges = {} self._wins_per_checkpoint = {} @@ -54,6 +53,7 @@ def add_checkpoint(self, trajectories:list, end_reason=None)->None: wins = [] edges = {} simple_edges = {} + self._checkpoint_size[self.num_checkpoints] = len(trajectories) for play in trajectories: if end_reason and play["end_reason"] not in end_reason: continue @@ -63,11 +63,13 @@ def add_checkpoint(self, trajectories:list, end_reason=None)->None: wins.append(1) else: wins.append(0) + # get the id of the first state state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0])) - #print(f'Trajectory len: {len(play["trajectory"]["actions"])}') - for i in range(1, len(play["trajectory"]["actions"])): + # iterate over the trajectory + assert len(play["trajectory"]["states"]) == len(play["trajectory"]["actions"]) +1 + for i in range(1, len(play["trajectory"]["states"])): next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i])) - action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i]))) + action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i-1]))) # fullgraph if (state_id, next_state_id, action_id) not in edges: edges[state_id, next_state_id, action_id] = 0 @@ -155,19 +157,95 @@ def get_graph_structure_progress(self)->dict: return super_graph def get_graph_structure_probabilistic_progress(self)->dict: - + # collect all edeges from all checkpoints all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values())) + # prepare data straucture for the probabiliites per edge super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges} - for i, edge_list in self._checkpoint_edges.items(): - total_out_edges_use = {} - for (src, _, _), frequency in edge_list.items(): - if src not in total_out_edges_use: - total_out_edges_use[src] = 0 - total_out_edges_use[src] += frequency - for (src,dst,edge), value in edge_list.items(): - super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src] + for i, edges in self._checkpoint_edges.items(): + total_edge_count_cp = sum(edges.values()) + # print(f"Processing timestamp {i}") + # # total_out_edges_use = {} + # # for (src, _, _), frequency in edge_list.items(): + # # if src not in total_out_edges_use: + # # total_out_edges_use[src] = 0 + # # total_out_edges_use[src] += frequency + # # for (src,dst,edge), value in edge_list.items(): + # # super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src] + # src_nodes = set([x for (x, _, _ ) in edge_list.keys()]) + # print(f"\t{len(src_nodes)} source nodes") + # num_outgoing_edges = {} + # for node in src_nodes: + # num_outgoing_edges[node] = sum([v for k,v in edge_list.items() if k[0] == node]) + # print(f"\t{num_outgoing_edges[node]}") + # for (src,dst,action), occurence in edge_list.items(): + # print(f"\tedge:{(src,dst,self._id_to_action[action])}, occurence={occurence}, prob={occurence/num_outgoing_edges[src]}") + # super_graph[(src,dst,action)][i] = occurence/num_outgoing_edges[src] + for edge, occurence in edges.items(): + super_graph[edge][i] = occurence/total_edge_count_cp return super_graph + def calculate_source_node_likelihoods(self) -> dict: + """ + Calculates the likelihood of each edge originating from its source node + in each checkpoint. + + Returns: + source_likelihoods (dict): A nested dictionary where the outer keys are checkpoint numbers, + and the inner dictionaries map edges to their likelihoods. + """ + all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values())) + # prepare data straucture for the probabiliites per edge + source_likelihoods = {key:np.zeros(self.num_checkpoints) for key in all_edges} + + for checkpoint, edges in self._checkpoint_edges.items(): + # Map to store total occurrences of edges for each source node + source_totals = {} + + # Calculate total occurrences for each source node + for (source, destination, action), count in edges.items(): + if source not in source_totals: + source_totals[source] = 0 + source_totals[source] += count + + # Calculate likelihoods for each edge + + for (source, destination, action), count in edges.items(): + if source_totals[source] > 0: + source_likelihoods[(source, destination, action)][checkpoint] = count / source_totals[source] + return source_likelihoods + + def calculate_edge_play_likelihoods(self) -> dict: + """ + Calculates the likelihood of each edge being present in a play for each checkpoint. + Returns: + play_likelihoods (dict): A nested dictionary where the outer keys are checkpoint numbers, + and the inner dictionaries map edges to their likelihoods. + """ + all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values())) + # prepare data straucture for the probabiliites per edge + play_likelihoods = {key:np.zeros(self.num_checkpoints) for key in all_edges} + for checkpoint, edges in self._checkpoint_edges.items(): + # Get the total number of trajectories (plays) in this checkpoint + total_plays = self._checkpoint_size.get(checkpoint, 0) + if total_plays == 0: + continue # Skip checkpoints with no trajectories + + # Calculate likelihood for each edge + for edge, count in edges.items(): + play_likelihoods[edge][checkpoint] = min(count / total_plays, 1) + + return play_likelihoods + def get_graph_entropies(self)->dict: + def compute_entropy(probs, epsilon_value=1e-10): + probs = np.array(probs) + normalized_probs = probs / np.sum(probs) + entropy = -np.sum(normalized_probs * np.log(normalized_probs +epsilon_value)) # Avoid log(0) + return entropy + probabilistic_graph = self.get_graph_structure_probabilistic_progress + edge_entropy = {} + for e, probs in probabilistic_graph: + edge_entropy[e] = compute_entropy(probs) + def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: edges = {} nodes_timestamps = {} @@ -255,55 +333,27 @@ def get_graph_modificiation(edge_list1, edge_list2): # parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True) # parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True) parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False) - parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False) + parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=2000, required=False) args = parser.parse_args() - # trajectories1 = read_json(args.t1, max_lines=args.n_trajectories) - # trajectories2 = read_json(args.t2, max_lines=args.n_trajectories) - # states = {} - # actions = {} - - # graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason) - # graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason) - - # state_to_id = {v:k for k,v in states.items()} - # action_to_id = {v:k for k,v in states.items()} - - # print(f"Trajectory 1: {args.t1}") - # print(f"WR={t1_wr_mean}±{t1_wr_std}") - # get_graph_stats(graph_t1, state_to_id, action_to_id) - # print(f"Trajectory 2: {args.t2}") - # print(f"WR={t2_wr_mean}±{t2_wr_std}") - # get_graph_stats(graph_t2, state_to_id, action_to_id) - - # a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2) - # print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}") - # # print("positions of same states:") - # # for node in node_set(graph_t1).intersection(node_set(graph_t2)): - # # print(g1_timestaps[node], g2_timestaps[node]) - # # print("-----------------------") - # tg_no_blocks = TrajectoryGraph() - - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories)) - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories)) - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories)) - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories)) - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories)) - # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories)) - - # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories)) - # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories)) - # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories)) - # tg_no_blocks.plot_graph_stats_progress() - + tg_blocks = TrajectoryGraph() - tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories)) - tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories)) - tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories)) - tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-2000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-4000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-6000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-8000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-10000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-12000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-14000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-16000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-18000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-20000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-22000.jsonl",max_lines=args.n_trajectories)) + tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-24000.jsonl",max_lines=args.n_trajectories)) tg_blocks.plot_graph_stats_progress() - super_graph = tg_blocks.get_graph_structure_probabilistic_progress() - print(len(super_graph)) - edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0] - print(len(edges_present_everycheckpoint)) \ No newline at end of file + #edge_probabilies_per_cp = tg_blocks.get_graph_structure_probabilistic_progress() + #edge_probabilies_per_cp = tg_blocks.calculate_source_node_likelihoods() + edge_play_likelihoods = tg_blocks.calculate_edge_play_likelihoods() + for k,v in edge_play_likelihoods.items(): + print(f"{k}, {','.join(map(str, v.tolist()))}") \ No newline at end of file diff --git a/utils/log_parser.py b/AIDojoCoordinator/utils/log_parser.py similarity index 100% rename from utils/log_parser.py rename to AIDojoCoordinator/utils/log_parser.py diff --git a/utils/trajectory_analysis.py b/AIDojoCoordinator/utils/trajectory_analysis.py similarity index 99% rename from utils/trajectory_analysis.py rename to AIDojoCoordinator/utils/trajectory_analysis.py index f6034f42..1bed8d3d 100644 --- a/utils/trajectory_analysis.py +++ b/AIDojoCoordinator/utils/trajectory_analysis.py @@ -1,6 +1,5 @@ import jsonlines import numpy as np -import sys import os import utils import matplotlib.pyplot as plt @@ -10,10 +9,7 @@ import plotly.graph_objects as go from sklearn.preprocessing import StandardScaler - - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) ))) -from env.game_components import GameState, Action, ActionType +from AIDojoCoordinator.game_components import GameState, Action, ActionType diff --git a/utils/utils.py b/AIDojoCoordinator/utils/utils.py similarity index 86% rename from utils/utils.py rename to AIDojoCoordinator/utils/utils.py index 7a53d276..b43a9509 100644 --- a/utils/utils.py +++ b/AIDojoCoordinator/utils/utils.py @@ -1,23 +1,21 @@ # Utility functions for then env and for the agents # Author: Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz # Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz -#import configparser + import yaml -import sys -from os import path # This is used so the agent can see the environment and game components -sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) -from env.scenarios import scenario_configuration -from env.scenarios import smaller_scenario_configuration -from env.scenarios import tiny_scenario_configuration -from env.scenarios import three_net_scenario -from env.game_components import IP, Data, Network, Service, GameState, Action, Observation +from AIDojoCoordinator.scenarios import scenario_configuration +from AIDojoCoordinator.scenarios import smaller_scenario_configuration +from AIDojoCoordinator.scenarios import tiny_scenario_configuration +from AIDojoCoordinator.scenarios import three_net_scenario +from AIDojoCoordinator.game_components import IP, Data, Network, Service, GameState, Action, Observation import netaddr import logging import csv from random import randint import json import hashlib +from cyst.api.configuration.network.node import NodeConfig def get_file_hash(filepath, hash_func='sha256', chunk_size=4096): """ @@ -31,6 +29,14 @@ def get_file_hash(filepath, hash_func='sha256', chunk_size=4096): chunk = file.read(chunk_size) return hash_algorithm.hexdigest() +def get_str_hash(string, hash_func='sha256', chunk_size=4096): + """ + Computes hash of a given file. + """ + hash_algorithm = hashlib.new(hash_func) + hash_algorithm.update(string.encode('utf-8')) + return hash_algorithm.hexdigest() + def read_replay_buffer_from_csv(csvfile:str)->list: """ Function to read steps from a CSV file @@ -112,14 +118,19 @@ class ConfigParser(): """ Class to deal with the configuration file """ - def __init__(self, task_config_file): + def __init__(self, task_config_file:str=None, config_dict:dict=None): """ - Init the class + Initializes the configuration parser. Required either path to a confgiuration file or a dict with configuraitons. """ self.logger = logging.getLogger('configparser') - self.read_config_file(task_config_file) + if task_config_file: + self.read_config_file(task_config_file) + elif config_dict: + self.config = config_dict + else: + self.logger.error("You must provide either the configuration file or a dictionary with the configuration!") - def read_config_file(self, conf_file_name): + def read_config_file(self, conf_file_name:str): """ reads configuration file """ @@ -182,18 +193,6 @@ def read_agents_known_blocks(self, type_agent: str, type_data: str) -> dict: known_blocks[target_host] = block_list else: raise ValueError(f"Unsupported value in 'known_blocks': {known_blocks_conf}") - # try: - # # Check the host is a good ip - # _ = netaddr.IPAddress(target_host) - # target_host_ip = IP(target_host) - # for known_blocked_host in dict_blocked_hosts.values(): - # known_blocked_host_ip = IP(known_blocked_host) - # known_blocks[target_host_ip].append(known_blocked_host_ip) - # except (ValueError, netaddr.AddrFormatError): - # if target_host.lower() == "all_routers": - # known_blocks["all_routers"] = dict_blocked_hosts - # except (ValueError): - # known_blocks = {} return known_blocks def read_agents_known_services(self, type_agent: str, type_data: str) -> dict: @@ -276,7 +275,7 @@ def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> dict: self.logger.error(f'Configuration problem with the controlled hosts: {e}') return controlled_hosts - def get_player_win_conditions(self, type_of_player): + def get_player_win_conditions(self, type_of_player:str): """ Get the goal of the player type_of_player: Can be 'attackers' or 'defenders' @@ -312,7 +311,7 @@ def get_player_win_conditions(self, type_of_player): return player_goal - def get_player_start_position(self, type_of_player): + def get_player_start_position(self, type_of_player:str): """ Generate the starting position of an attacking agent type_of_player: Can be 'attackers' or 'defenders' @@ -341,7 +340,7 @@ def get_player_start_position(self, type_of_player): return player_start_position - def get_start_position(self, agent_role): + def get_start_position(self, agent_role:str): match agent_role: case "Attacker": return self.get_player_start_position(agent_role) @@ -391,7 +390,6 @@ def get_max_steps(self, role=str)->int: self.logger.warning(f"Unsupported value in 'coordinator.agents.{role}.max_steps': {e}. Setting value to default=None (no step limit)") return max_steps - def get_goal_description(self, agent_role)->dict: """ Get goal description per role @@ -412,46 +410,17 @@ def get_goal_description(self, agent_role)->dict: case _: raise ValueError(f"Unsupported agent role: {agent_role}") return description - - def get_goal_reward(self)->float: - """ - Reads what is the reward for reaching the goal. - default: 100 - """ - try: - goal_reward = self.config['env']['goal_reward'] - return float(goal_reward) - except KeyError: - return 100 - except ValueError: - return 100 - - def get_detection_reward(self)->float: - """ - Reads what is the reward for detection. - default: -50 - """ - try: - detection_reward = self.config['env']['detection_reward'] - return float(detection_reward) - except KeyError: - return -50 - except ValueError: - return -50 - - def get_step_reward(self)->float: - """ - Reads what is the reward for detection. - default: -1 - """ - try: - step_reward = self.config['env']['step_reward'] - return float(step_reward) - except KeyError: - return -1 - except ValueError: - return -1 + def get_rewards(self, reward_names:list, default_value=0)->dict: + "Reads configuration for rewards for cases listed in 'rewards_names'" + rewards = {} + for name in reward_names: + try: + rewards[name] = self.config['env']["rewards"][name] + except KeyError: + self.logger.warning(f"No reward value found for '{name}'. Usinng default reward({name})={default_value}") + rewards[name] = default_value + return rewards def get_use_dynamic_addresses(self)->bool: """ @@ -525,10 +494,10 @@ def get_use_firewall(self)->bool: def get_use_global_defender(self)->bool: try: - use_firewall = self.config['env']['use_global_defender'] + use_global_defender = self.config['env']['use_global_defender'] except KeyError: - use_firewall = False - return use_firewall + use_global_defender = False + return use_global_defender def get_logging_level(debug_level): """ @@ -545,6 +514,23 @@ def get_logging_level(debug_level): level = log_levels.get(debug_level.upper(), logging.ERROR) return level +def get_starting_position_from_cyst_config(cyst_objects): + starting_positions = {} + for obj in cyst_objects: + if isinstance(obj, NodeConfig): + for active_service in obj.active_services: + if active_service.type == "netsecenv_agent": + print(f"startig processing {obj.id}.{active_service.name}") + hosts = set() + networks = set() + for interface in obj.interfaces: + hosts.add(IP(str(interface.ip))) + net_ip, net_mask = str(interface.net).split("/") + networks.add(Network(net_ip,int(net_mask))) + starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} + return starting_positions + + if __name__ == "__main__": state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)}, known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")}, diff --git a/AIDojoCoordinator/worlds/CYSTCoordinator.py b/AIDojoCoordinator/worlds/CYSTCoordinator.py new file mode 100644 index 00000000..f9d65c74 --- /dev/null +++ b/AIDojoCoordinator/worlds/CYSTCoordinator.py @@ -0,0 +1,266 @@ +# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz + +import os +import requests +import json +import copy +import ast +import logging +import argparse +from pathlib import Path +from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Service +from AIDojoCoordinator.coordinator import GameCoordinator + +from AIDojoCoordinator.utils.utils import get_starting_position_from_cyst_config, get_logging_level + +class CYSTCoordinator(GameCoordinator): + + def __init__(self, game_host:str, game_port:int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"]): + super().__init__(game_host, game_port, service_host, service_port, allowed_roles) + self._id_to_cystid = {} + self._cystid_to_id = {} + self._known_agent_roles = {} + self._last_state_per_agent = {} + self._last_action_per_agent = {} + self._last_msg_per_agent = {} + self._starting_positions = None + self._availabe_cyst_agents = None + + def get_cyst_id(self, agent_role:str): + """ + Returns ID of the CYST agent based on the agent's role. + """ + try: + cyst_id = self._availabe_cyst_agents[agent_role].pop() + except KeyError: + cyst_id = None + return cyst_id + + async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + self.logger.debug(f"Registering agent {agent_id} in the world.") + agent_role = "Attacker" + if not self._starting_positions: + self._starting_positions = get_starting_position_from_cyst_config(self._cyst_objects) + self._availabe_cyst_agents = {"Attacker":set([k for k in self._starting_positions.keys()])} + async with self._agents_lock: + cyst_id = self.get_cyst_id(agent_role) + if cyst_id: + self._cystid_to_id[cyst_id] = agent_id + self._id_to_cystid[agent_id] = cyst_id + self._known_agent_roles[agent_id] = agent_role + kh = self._starting_positions[cyst_id]["known_hosts"] + kn = self._starting_positions[cyst_id]["known_networks"] + return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + else: + return None + + async def remove_agent(self, agent_id, agent_state:GameState)->bool: + print(f"Removing agent {agent_id} from the CYST World") + async with self._agents_lock: + try: + agent_role = self._known_agent_roles[agent_id] + cyst_id = self._id_to_cystid[agent_id] + # remove agent's IDs + self._id_to_cystid.pop(agent_id) + self._cystid_to_id.pop(cyst_id) + # make cyst_agent avaiable again + self._availabe_cyst_agents[agent_role].add(cyst_id) + return True + except KeyError: + self.logger.error(f"Unknown agent ID: {agent_id}!") + return False + + async def step(self, agent_id:tuple, agent_state:GameState, action:Action)->GameState: + self.logger.info(f"Processing {action} from {agent_id}({self._id_to_cystid[agent_id]})") + next_state = None + match action.type: + case ActionType.ScanNetwork: + next_state = await self._execute_scan_network_action(agent_id, agent_state, action) + case ActionType.FindServices: + next_state = await self._execute_find_services_action(agent_id, agent_state, action) + case ActionType.FindData: + next_state = await self._execute_find_data_action(agent_id, agent_state, action) + case ActionType.ExploitService: + next_state = await self._execute_exploit_service_action(agent_id, agent_state, action) + case ActionType.ExfiltrateData: + next_state = await self._execute_exfiltrate_data_action(agent_id, agent_state, action) + case ActionType.BlockIP: + next_state = await self._execute_block_ip_action(agent_id, agent_state, action) + case _: + raise ValueError(f"Unknown Action type or other error: '{action.type}'") + return next_state + + async def _cyst_request(self, cyst_id:str, msg:dict)->tuple: + url = f"http://localhost:8282/execute/{cyst_id}/" # Replace with your server's URL + data = msg # The JSON data you want to send + self.logger.info(f"Sedning request {msg} to {url}") + try: + # Send the POST request with JSON data + response = requests.post(url, json=data) + + # Print the response from the server + self.logger.debug(f'Status code:{response.status_code}') + self.logger.debug(f'Response body:{response.text}') + return int(response.status_code), json.loads(response.text) + except requests.exceptions.RequestException as e: + print(f'An error occurred: {e}') + + async def _execute_scan_network_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + action_dict = { + "action":"dojo:scan_network", + "params": + { + "dst_ip":str(action.parameters["source_host"]), + "dst_service":"", + "to_network":str(action.parameters["target_network"]) + } + } + cyst_rsp_status, cyst_rsp_data = await self._cyst_request(self._id_to_cystid[agent_id], action_dict) + extended_kh = copy.deepcopy(agent_state.known_hosts) + extended_kn = copy.deepcopy(agent_state.known_networks) + extended_ch = copy.deepcopy(agent_state.controlled_hosts) + extended_ks = copy.deepcopy(agent_state.known_services) + extended_kd = copy.deepcopy(agent_state.known_data) + extended_kb = copy.deepcopy(agent_state.known_blocks) + + if cyst_rsp_status == 200: + self.logger.debug("Valid response from CYST") + data = ast.literal_eval(cyst_rsp_data["result"][1]["content"]) + for ip_str in data: + ip = IP(ip_str) + self.logger.debug(f"Adding {ip} to known_hosts") + extended_kh.add(ip) + return GameState(extended_ch, extended_kh, extended_ks, extended_kd, extended_kn, extended_kb) + + async def _execute_find_services_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + action_dict = { + "action":"dojo:find_services", + "params": + { + "dst_ip":str(action.parameters["target_host"]), + "dst_service":"" + } + } + cyst_rsp_status, cyst_rsp_data = await self._cyst_request(self._id_to_cystid[agent_id], action_dict) + extended_kh = copy.deepcopy(agent_state.known_hosts) + extended_kn = copy.deepcopy(agent_state.known_networks) + extended_ch = copy.deepcopy(agent_state.controlled_hosts) + extended_ks = copy.deepcopy(agent_state.known_services) + extended_kd = copy.deepcopy(agent_state.known_data) + extended_kb = copy.deepcopy(agent_state.known_blocks) + + if cyst_rsp_status == 200: + self.logger.debug("Valid response from CYST") + data = ast.literal_eval(cyst_rsp_data["result"][1]["content"]) + self.logger.warning(data) + for item in data: + ip = IP(item["ip"]) + # Add IP in case it was discovered by the scan + extended_kh.add(ip) + if len(item["services"]) > 0: + if ip not in extended_ks.keys(): + extended_ks[ip] = set() + for service_dict in item["services"]: + service = Service.from_dict(service_dict) + extended_ks[ip].add(service) + return GameState(extended_ch, extended_kh, extended_ks, extended_kd, extended_kn, extended_kb) + + async def _execute_find_data_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + raise NotImplementedError + + async def _execute_exploit_service_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + raise NotImplementedError + + async def _execute_exfiltrate_data_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + raise NotImplementedError + + async def _execute_block_ip_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState: + raise NotImplementedError + + async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState: + cyst_id = self._id_to_cystid[agent_id] + kh = self._starting_positions[cyst_id]["known_hosts"] + kn = self._starting_positions[cyst_id]["known_networks"] + return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn) + + async def reset(self)->bool: + return True + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="CYST-NetSecGame Coordinator Server Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz", + usage="%(prog)s [options]", + ) + + parser.add_argument( + "-l", + "--debug_level", + help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL", + action="store", + required=False, + type=str, + default="DEBUG", + ) + + parser.add_argument( + "-gh", + "--game_host", + help="host where to run the game server", + action="store", + required=False, + type=str, + default="127.0.0.1", + ) + + parser.add_argument( + "-gp", + "--game_port", + help="Port where to run the game server", + action="store", + required=False, + type=int, + default="9000", + ) + + parser.add_argument( + "-sh", + "--service_host", + help="Host where to run the config server", + action="store", + required=False, + type=str, + default="127.0.0.1", + ) + + parser.add_argument( + "-sp", + "--service_port", + help="Port where to listen for cyst config", + action="store", + required=False, + type=int, + default="9009", + ) + + + args = parser.parse_args() + print(args) + # Set the logging + log_filename = Path("CYST_coordinator.log") + if not log_filename.parent.exists(): + os.makedirs(log_filename.parent) + + # Convert the logging level in the args to the level to use + pass_level = get_logging_level(args.debug_level) + + logging.basicConfig( + filename=log_filename, + filemode="w", + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=pass_level, + ) + + game_server = CYSTCoordinator(args.game_host, args.game_port, args.service_host , args.service_port) + # Run it! + game_server.run() \ No newline at end of file diff --git a/env/worlds/network_security_game.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py old mode 100755 new mode 100644 similarity index 76% rename from env/worlds/network_security_game.py rename to AIDojoCoordinator/worlds/NSEGameCoordinator.py index 1e39b232..f5ab0c5a --- a/env/worlds/network_security_game.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -1,94 +1,126 @@ -#Authors -# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz -# Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz - -import netaddr -import env.game_components as gc +# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz +import os +import logging +import argparse import random -import copy -from cyst.api.configuration import NodeConfig, RouterConfig, ConnectionConfig, ExploitConfig, FirewallPolicy import numpy as np +import copy from faker import Faker -from env.worlds.aidojo_world import AIDojoWorld - -class NetworkSecurityEnvironment(AIDojoWorld): - """ - Class to manage the whole network security game - It uses some Cyst libraries for the network topology - It presents a env environment to play - """ - def __init__(self, task_config_file, action_queue, response_queue, world_name="NetSecEnv") -> None: - super().__init__(task_config_file, action_queue, response_queue, world_name) - self.logger.info("Initializing NetSetGame environment") - # Prepare data structures for all environment components (to be filled in self._process_cyst_config()) +from pathlib import Path +import netaddr + +from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service +from AIDojoCoordinator.coordinator import GameCoordinator +from cyst.api.configuration import NodeConfig, RouterConfig, ConnectionConfig, ExploitConfig, FirewallPolicy + +from AIDojoCoordinator.utils.utils import get_logging_level + +class NSGCoordinator(GameCoordinator): + + def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=42): + super().__init__(game_host, game_port, service_host=None, service_port=None, allowed_roles=allowed_roles, task_config_file=task_config) + + # Internal data structure of the NSG self._ip_to_hostname = {} # Mapping of `IP`:`host_name`(str) of all nodes in the environment self._networks = {} # A `dict` of the networks present in the environment. Keys: `Network` objects, values `set` of `IP` objects. self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs. self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs. self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects. self._fw_blocks = {} - self._data_content = {} #content of each datapoint from self._data # All exploits in the environment self._exploits = {} # A list of all the hosts where the attacker can start in a random start self.hosts_to_start = [] self._network_mapping = {} self._ip_mapping = {} - # Load CYST configuration - self._process_cyst_config(self.task_config.get_scenario()) - # Set the seed - seed = self.task_config.get_seed('env') + np.random.seed(seed) random.seed(seed) self._seed = seed self.logger.info(f'Setting env seed to {seed}') - - # Set rewards for goal/detection/step - self._rewards = { - "goal": self.task_config.get_goal_reward(), - "detection": self.task_config.get_detection_reward(), - "step": self.task_config.get_step_reward() - } - self.logger.info(f"\tSetting rewards - {self._rewards}") - - # Set the default parameters of all actionss - # if the values of the actions were updated in the configuration file - gc.ActionType.ScanNetwork.default_success_p = self.task_config.read_env_action_data('scan_network') - gc.ActionType.FindServices.default_success_p = self.task_config.read_env_action_data('find_services') - gc.ActionType.ExploitService.default_success_p = self.task_config.read_env_action_data('exploit_service') - gc.ActionType.FindData.default_success_p = self.task_config.read_env_action_data('find_data') - gc.ActionType.ExfiltrateData.default_success_p = self.task_config.read_env_action_data('exfiltrate_data') - gc.ActionType.BlockIP.default_success_p = self.task_config.read_env_action_data('block_ip') - - # At this point all 'random' values should be assigned to something - # Check if dynamic network and ip adddresses are required - if self.task_config.get_use_dynamic_addresses(): + def _initialize(self)->None: + # Load CYST configuration + self._process_cyst_config(self._cyst_objects) + # Check if dynamic network and ip adddresses are required + if self._use_dynamic_ips: self.logger.info("Dynamic change of the IP and network addresses enabled") self._faker_object = Faker() - Faker.seed(seed) - self._episode_replay_buffer = None - - # Make a copy of data placements so it is possible to reset to it when episode ends + Faker.seed(self._seed) + # store initial values for parts which are modified during the game self._data_original = copy.deepcopy(self._data) - self._data_content_original = copy.deepcopy(self._data_content) self._firewall_original = copy.deepcopy(self._firewall) - - self._actions_played = [] self.logger.info("Environment initialization finished") - @property - def seed(self)->int: + def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState: """ - Can be used by agents to use the same random seed as the environment + Builds a GameState from given view. + If there is a keyword 'random' used, it is replaced by a valid option at random. + + Currently, we artificially extend the knonw_networks with +- 1 in the third octet. """ - return self._seed - - @property - def num_actions(self)->int: - return len(self.get_all_actions()) - + self.logger.info(f'Generating state from view:{view}') + # re-map all networks based on current mapping in self._network_mapping + known_networks = set([self._network_mapping[net] for net in view["known_networks"]]) + + controlled_hosts = set() + # controlled_hosts + for host in view['controlled_hosts']: + if isinstance(host, IP): + controlled_hosts.add(self._ip_mapping[host]) + self.logger.debug(f'\tThe attacker has control of host {self._ip_mapping[host]}.') + elif host == 'random': + # Random start + self.logger.debug('\tAdding random starting position of agent') + self.logger.debug(f'\t\tChoosing from {self.hosts_to_start}') + selected = random.choice(self.hosts_to_start) + controlled_hosts.add(selected) + self.logger.debug(f'\t\tMaking agent start in {selected}') + elif host == "all_local": + # all local ips + self.logger.debug('\t\tAdding all local hosts to agent') + controlled_hosts = controlled_hosts.union(self._get_all_local_ips()) + else: + self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}") + # re-map all known based on current mapping in self._ip_mapping + known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]]) + # Add all controlled hosts to known_hosts + known_hosts = known_hosts.union(controlled_hosts) + + if add_neighboring_nets: + # Extend the known networks with the neighbouring networks + # This is to solve in the env (and not in the agent) the problem + # of not knowing other networks appart from the one the agent is in + # This is wrong and should be done by the agent, not here + # TODO remove this! + for controlled_host in controlled_hosts: + for net in self._get_networks_from_host(controlled_host): #TODO + net_obj = netaddr.IPNetwork(str(net)) + if net_obj.ip.is_ipv4_private_use(): #TODO + known_networks.add(net) + net_obj.value += 256 + if net_obj.ip.is_ipv4_private_use(): + ip = Network(str(net_obj.ip), net_obj.prefixlen) + self.logger.debug(f'\tAdding {ip} to agent') + known_networks.add(ip) + net_obj.value -= 2*256 + if net_obj.ip.is_ipv4_private_use(): + ip = Network(str(net_obj.ip), net_obj.prefixlen) + self.logger.debug(f'\tAdding {ip} to agent') + known_networks.add(ip) + #return value back to the original + net_obj.value += 256 + known_services ={} + for ip, service_list in view["known_services"]: + known_services[self._ip_mapping[ip]] = service_list + known_data = {} + for ip, data_list in view["known_data"]: + known_data[self._ip_mapping[ip]] = data_list + game_state = GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks) + self.logger.info(f"Generated GameState:{game_state}") + return game_state + def _process_cyst_config(self, configuration_objects:list)-> None: """ Process the cyst configuration file @@ -122,8 +154,11 @@ def process_node_config(node_obj:NodeConfig) -> None: self.logger.info(f"\t\tProcessing interfaces in node '{node_obj.id}'") for interface in node_obj.interfaces: net_ip, net_mask = str(interface.net).split("/") - net = gc.Network(net_ip,int(net_mask)) - ip = gc.IP(str(interface.ip)) + net = Network(net_ip,int(net_mask)) + ip = IP(str(interface.ip)) + if len(node_obj.active_services)>0: + self.logger.info(f"\tAdding as potential start point") + self.hosts_to_start.append(ip) self._ip_to_hostname[ip] = node_obj.id if net not in self._networks: self._networks[net] = [] @@ -136,27 +171,26 @@ def process_node_config(node_obj:NodeConfig) -> None: for service in node_obj.passive_services: # Check if it is a candidate for random start # Becareful, it will add all the IPs for this node - if service.type == "can_attack_start_here": - self.hosts_to_start.append(gc.IP(str(interface.ip))) + if service.name == "can_attack_start_here": + self.hosts_to_start.append(IP(str(interface.ip))) continue if node_obj.id not in self._services: self._services[node_obj.id] = [] - self._services[node_obj.id].append(gc.Service(service.type, "passive", service.version, service.local)) + self._services[node_obj.id].append(Service(service.name, "passive", service.version, service.local)) #data - self.logger.info(f"\t\t\tProcessing data in node '{node_obj.id}':'{service.type}' service") + self.logger.info(f"\t\t\tProcessing data in node '{node_obj.id}':'{service.name}' service") try: for data in service.private_data: if node_obj.id not in self._data: self._data[node_obj.id] = set() - datapoint = gc.Data(data.owner, data.description) + datapoint = Data(data.owner, data.description) self._data[node_obj.id].add(datapoint) # add content self._data_content[node_obj.id, datapoint.id] = f"Content of {datapoint.id}" except AttributeError: pass #service does not contain any data - def process_router_config(router_obj:RouterConfig)->None: self.logger.info(f"\tProcessing config of router '{router_obj.id}'") # Process a router @@ -173,8 +207,8 @@ def process_router_config(router_obj:RouterConfig)->None: self.logger.info(f"\t\tProcessing interfaces in router '{router_obj.id}'") for interface in r.interfaces: net_ip, net_mask = str(interface.net).split("/") - net = gc.Network(net_ip,int(net_mask)) - ip = gc.IP(str(interface.ip)) + net = Network(net_ip,int(net_mask)) + ip = IP(str(interface.ip)) self._ip_to_hostname[ip] = router_obj.id if net not in self._networks: self._networks[net] = [] @@ -271,7 +305,7 @@ def _create_new_network_mapping(self)->tuple: if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use(): private_nets.append(net) else: - mapping_nets[net] = gc.Network(fake.ipv4_public(), net.mask) + mapping_nets[net] = Network(fake.ipv4_public(), net.mask) # for private networks, we want to keep the distances among them private_nets_sorted = sorted(private_nets) @@ -282,7 +316,7 @@ def _create_new_network_mapping(self)->tuple: # find the new lowest networks new_base = netaddr.IPNetwork(f"{fake.ipv4_private()}/{private_nets_sorted[0].mask}") # store its new mapping - mapping_nets[private_nets[0]] = gc.Network(str(new_base.network), private_nets_sorted[0].mask) + mapping_nets[private_nets[0]] = Network(str(new_base.network), private_nets_sorted[0].mask) base = netaddr.IPNetwork(str(private_nets_sorted[0])) is_private_net_checks = [] for i in range(1,len(private_nets_sorted)): @@ -294,7 +328,7 @@ def _create_new_network_mapping(self)->tuple: # evaluate if its still a private network is_private_net_checks.append(new_net_addr.is_ipv4_private_use()) # store the new mapping - mapping_nets[private_nets_sorted[i]] = gc.Network(str(new_net_addr), private_nets_sorted[i].mask) + mapping_nets[private_nets_sorted[i]] = Network(str(new_net_addr), private_nets_sorted[i].mask) if False not in is_private_net_checks: # verify that ALL new networks are still in the private ranges valid_valid_network_mapping = True except IndexError as e: @@ -312,7 +346,7 @@ def _create_new_network_mapping(self)->tuple: # remove broadcast and network ip from the list random.shuffle(ip_list) for i,ip in enumerate(ips): - mapping_ips[ip] = gc.IP(str(ip_list[i])) + mapping_ips[ip] = IP(str(ip_list[i])) # Always add random, in case random is selected for ips mapping_ips['random'] = 'random' self.logger.info(f"Mapping IPs done:{mapping_ips}") @@ -326,22 +360,6 @@ def _create_new_network_mapping(self)->tuple: new_self_networks[mapping_nets[net]].add(mapping_ips[ip]) self._networks = new_self_networks - # Harpo says that here there is a problem that firewall.items() do not return an ip that can be used in the mapping - # His solution is: (check) - """ - new_self_firewall = {} - for ip, dst_ips in self._firewall.items(): - if ip not in mapping_ips: - self.logger.debug(f"IP {ip} not found in mapping_ips") - continue # Skip this IP if it's not found in the mapping - - new_self_firewall[mapping_ips[ip]] = set() - - for dst_ip in dst_ips: - new_self_firewall[mapping_ips[ip]].add(mapping_ips[dst_ip]) - self._firewall = new_self_firewall - """ - #self._firewall new_self_firewall = {} for ip, dst_ips in self._firewall.items(): @@ -431,12 +449,12 @@ def _get_data_content(self, host_ip:str, data_id:str)->str: if (hostname, data_id) in self._data_content: content = self._data_content[hostname,data_id] else: - self.logger.info(f"\tData '{data_id}' not found in host '{hostname}'({host_ip})") + self.logger.debug(f"\tData '{data_id}' not found in host '{hostname}'({host_ip})") else: self.logger.debug("Data content not found because target IP does not exists.") return content - def _execute_action(self, current_state:gc.GameState, action:gc.Action, agent_id)-> gc.GameState: + def _execute_action(self, current_state:GameState, action:Action)-> GameState: """ Execute the action and update the values in the state Before this function it was checked if the action was successful @@ -449,23 +467,23 @@ def _execute_action(self, current_state:gc.GameState, action:gc.Action, agent_id """ next_state = None match action.type: - case gc.ActionType.ScanNetwork: + case ActionType.ScanNetwork: next_state = self._execute_scan_network_action(current_state, action) - case gc.ActionType.FindServices: + case ActionType.FindServices: next_state = self._execute_find_services_action(current_state, action) - case gc.ActionType.FindData: + case ActionType.FindData: next_state = self._execute_find_data_action(current_state, action) - case gc.ActionType.ExploitService: + case ActionType.ExploitService: next_state = self._execute_exploit_service_action(current_state, action) - case gc.ActionType.ExfiltrateData: + case ActionType.ExfiltrateData: next_state = self._execute_exfiltrate_data_action(current_state, action) - case gc.ActionType.BlockIP: + case ActionType.BlockIP: next_state = self._execute_block_ip_action(current_state, action) case _: raise ValueError(f"Unknown Action type or other error: '{action.type}'") return next_state - def _state_parts_deep_copy(self, current:gc.GameState)->tuple: + def _state_parts_deep_copy(self, current:GameState)->tuple: next_nets = copy.deepcopy(current.known_networks) next_known_h = copy.deepcopy(current.known_hosts) next_controlled_h = copy.deepcopy(current.controlled_hosts) @@ -474,7 +492,7 @@ def _state_parts_deep_copy(self, current:gc.GameState)->tuple: next_blocked = copy.deepcopy(current.known_blocks) return next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked - def _firewall_check(self, src_ip:gc.IP, dst_ip:gc.IP)->bool: + def _firewall_check(self, src_ip:IP, dst_ip:IP)->bool: """Checks if firewall allows connection from 'src_ip to ''dst_ip'""" try: connection_allowed = dst_ip in self._firewall[src_ip] @@ -482,12 +500,12 @@ def _firewall_check(self, src_ip:gc.IP, dst_ip:gc.IP)->bool: connection_allowed = False return connection_allowed - def _execute_scan_network_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_scan_network_action(self, current_state:GameState, action:Action)->GameState: """ Executes the ScanNetwork action in the environment """ next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) - self.logger.info(f"\t\tScanning {action.parameters['target_network']}") + self.logger.debug(f"\t\tScanning {action.parameters['target_network']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: new_ips = set() for ip in self._ip_to_hostname.keys(): #check if IP exists @@ -500,15 +518,15 @@ def _execute_scan_network_action(self, current_state:gc.GameState, action:gc.Act self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {ip} blocked by FW. Skipping") next_known_h = next_known_h.union(new_ips) else: - self.logger.info(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_find_services_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_find_services_action(self, current_state:GameState, action:Action)->GameState: """ Executes the FindServices action in the environment """ next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) - self.logger.info(f"\t\tSearching for services in {action.parameters['target_host']}") + self.logger.debug(f"\t\tSearching for services in {action.parameters['target_host']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): found_services = self._get_services_from_host(action.parameters["target_host"], current_state.controlled_hosts) @@ -525,14 +543,14 @@ def _execute_find_services_action(self, current_state:gc.GameState, action:gc.Ac self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_find_data_action(self, current:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_find_data_action(self, current:GameState, action:Action)->GameState: """ Executes the FindData action in the environment """ next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current) - self.logger.info(f"\t\tSearching for data in {action.parameters['target_host']}") + self.logger.debug(f"\t\tSearching for data in {action.parameters['target_host']}") if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current.controlled_hosts: if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']): new_data = self._get_data_in_host(action.parameters["target_host"], current.controlled_hosts) @@ -553,9 +571,9 @@ def _execute_find_data_action(self, current:gc.GameState, action:gc.Action)->gc. self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_exfiltrate_data_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action)->GameState: """ Executes the ExfiltrateData action in the environment """ @@ -597,9 +615,9 @@ def _execute_exfiltrate_data_action(self, current_state:gc.GameState, action:gc. self.logger.debug("\t\t\tCan not exfiltrate. Source host is not controlled.") else: self.logger.debug("\t\t\tCan not exfiltrate. Target host is not controlled.") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_exploit_service_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_exploit_service_action(self, current_state:GameState, action:Action)->GameState: """ Executes the ExploitService action in the environment """ @@ -634,9 +652,9 @@ def _execute_exploit_service_action(self, current_state:gc.GameState, action:gc. self.logger.debug("\t\t\tCan not exploit. Target host does not exist.") else: self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_block_ip_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState: + def _execute_block_ip_action(self, current_state:GameState, action:Action)->GameState: """ Executes the BlockIP action - The action has BlockIP("target_host": IP object, "source_host": IP object, "blocked_host": IP object) @@ -689,14 +707,14 @@ def _execute_block_ip_action(self, current_state:gc.GameState, action:gc.Action) next_blocked[action.parameters["target_host"]].add(action.parameters["blocked_host"]) next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"]) else: - self.logger.info(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'") + self.logger.debug(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'") else: self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW") else: - self.logger.info(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'") + self.logger.debug(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'") else: - self.logger.info(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") - return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'") + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) def _get_all_local_ips(self)->set: local_ips = set() @@ -706,179 +724,105 @@ def _get_all_local_ips(self)->set: local_ips.add(self._ip_mapping[ip]) self.logger.info(f"\t\t\tLocal ips: {local_ips}") return local_ips - - def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->gc.GameState: - """ - Builds a GameState from given view. - If there is a keyword 'random' used, it is replaced by a valid option at random. - - Currently, we artificially extend the knonw_networks with +- 1 in the third octet. - """ - self.logger.info(f'Generating state from view:{view}') - # re-map all networks based on current mapping in self._network_mapping - known_networks = set([self._network_mapping[net] for net in view["known_networks"]]) - - - controlled_hosts = set() - # controlled_hosts - for host in view['controlled_hosts']: - if isinstance(host, gc.IP): - controlled_hosts.add(self._ip_mapping[host]) - self.logger.info(f'\tThe attacker has control of host {self._ip_mapping[host]}.') - elif host == 'random': - # Random start - self.logger.info('\tAdding random starting position of agent') - self.logger.info(f'\t\tChoosing from {self.hosts_to_start}') - selected = random.choice(self.hosts_to_start) - controlled_hosts.add(selected) - self.logger.info(f'\t\tMaking agent start in {selected}') - elif host == "all_local": - # all local ips - self.logger.info('\t\tAdding all local hosts to agent') - controlled_hosts = controlled_hosts.union(self._get_all_local_ips()) - else: - self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}") - # re-map all known based on current mapping in self._ip_mapping - known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]]) - # Add all controlled hosts to known_hosts - known_hosts = known_hosts.union(controlled_hosts) - - if add_neighboring_nets: - # Extend the known networks with the neighbouring networks - # This is to solve in the env (and not in the agent) the problem - # of not knowing other networks appart from the one the agent is in - # This is wrong and should be done by the agent, not here - # TODO remove this! - for controlled_host in controlled_hosts: - for net in self._get_networks_from_host(controlled_host): #TODO - net_obj = netaddr.IPNetwork(str(net)) - if net_obj.ip.is_ipv4_private_use(): #TODO - known_networks.add(net) - net_obj.value += 256 - if net_obj.ip.is_ipv4_private_use(): - ip = gc.Network(str(net_obj.ip), net_obj.prefixlen) - self.logger.info(f'\tAdding {ip} to agent') - known_networks.add(ip) - net_obj.value -= 2*256 - if net_obj.ip.is_ipv4_private_use(): - ip = gc.Network(str(net_obj.ip), net_obj.prefixlen) - self.logger.info(f'\tAdding {ip} to agent') - known_networks.add(ip) - #return value back to the original - net_obj.value += 256 - known_services ={} - for ip, service_list in view["known_services"]: - known_services[self._ip_mapping[ip]] = service_list - known_data = {} - for ip, data_list in view["known_data"]: - known_data[self._ip_mapping[ip]] = data_list - game_state = gc.GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks) - self.logger.info(f"Generated GameState:{game_state}") + + async def register_agent(self, agent_id, agent_role, agent_initial_view)->GameState: + if len(self._networks) == 0: + self._initialize() + game_state = self._create_state_from_view(agent_initial_view) return game_state + + async def remove_agent(self, agent_id, agent_state)->bool: + # No action is required + return True + + async def step(self, agent_id, agent_state, action)->GameState: + return self._execute_action(agent_state, action) + + async def reset_agent(self, agent_id, agent_role, agent_initial_view)->GameState: + game_state = self._create_state_from_view(agent_initial_view) + return game_state - def update_goal_dict(self, goal_dict:dict)->dict: - """ - Updates goal dict based on the current values - in self._network_mapping and self._ip_mapping. - """ - new_dict = { - "known_networks":set(), - "known_hosts":set(), - "controlled_hosts":set(), - "known_services": {}, - "known_data": {}, - "known_blocks": {} - } - for net in goal_dict["known_networks"]: - if net in self._network_mapping: - new_dict["known_networks"].add(self._network_mapping[net]) - else: - # Unknown net, do not map - new_dict["known_networks"].add(net) - for host in goal_dict["known_hosts"]: - if host in self._ip_mapping: - new_dict["known_hosts"].add(self._ip_mapping[host]) - else: - # Unknown IP, do not map - new_dict["known_hosts"].add(host) - for host in goal_dict["controlled_hosts"]: - if host in self._ip_mapping: - new_dict["controlled_hosts"].add(self._ip_mapping[host]) - else: - # Unknown IP, do not map - new_dict["controlled_hosts"].add(host) - for host, items in goal_dict["known_services"].items(): - if host in self._ip_mapping: - new_dict["known_services"][self._ip_mapping[host]] = items - else: - # Unknown IP, do not map - new_dict["known_services"][host] = items - for host, items in goal_dict["known_data"].items(): - if host in self._ip_mapping: - new_dict["known_data"][self._ip_mapping[host]] = items - else: - # Unknown IP, do not map - new_dict["known_data"][host] = items - for host, items in goal_dict["known_blocks"].items(): - if host in self._ip_mapping: - new_dict["known_blocks"][self._ip_mapping[host]] = items - else: - # Unknown IP, do not map - new_dict["known_blocks"][host] = items - return new_dict - - def update_goal_descriptions(self, goal_description:str)->str: - new_description = goal_description - for ip in self._ip_mapping: - new_description = new_description.replace(str(ip), str(self._ip_mapping[ip])) - return new_description - - def reset(self)->None: + async def reset(self)->bool: """ Function to reset the state of the game and prepare for a new episode """ # write all steps in the episode replay buffer in the file - self.logger.info('--- Reseting env to its initial state ---') + self.logger.info('--- Reseting NSG Environment to its initial state ---') # change IPs if needed if self.task_config.get_use_dynamic_addresses(): self._create_new_network_mapping() # reset self._data to orignal state self._data = copy.deepcopy(self._data_original) # reset self._data_content to orignal state - self._data_content_original = copy.deepcopy(self._data_content_original) self._firewall = copy.deepcopy(self._firewall_original) self._fw_blocks = {} - - - self._actions_played = [] - - def step(self, state:gc.GameState, action:gc.Action, agent_id:tuple)-> gc.GameState: - """ - Take a step in the environment given an action - in: action - out: observation of the state of the env - """ - self.logger.info(f"Agent {agent_id}. Action: {action}") - # Reward for taking an action - reward = self._rewards["step"] - - # 1. Perform the action - self._actions_played.append(action) - if random.random() <= action.type.default_success_p: - next_state = self._execute_action(state, action, agent_id) - else: - self.logger.info("\tAction NOT sucessful") - next_state = state - - - # Make the state we just got into, our current state - current_state = state - self.logger.info(f'New state: {next_state} ') - - - # Save the transition to the episode replay buffer if there is any - if self._episode_replay_buffer is not None: - self._episode_replay_buffer.append((current_state, action, reward, next_state)) - # Return an observation - return next_state \ No newline at end of file + return True + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NetSecGame Coordinator Server Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz", + usage="%(prog)s [options]", + ) + + parser.add_argument( + "-l", + "--debug_level", + help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL", + action="store", + required=False, + type=str, + default="INFO", + ) + + parser.add_argument( + "-gh", + "--game_host", + help="host where to run the game server", + action="store", + required=False, + type=str, + default="127.0.0.1", + ) + + parser.add_argument( + "-gp", + "--game_port", + help="Port where to run the game server", + action="store", + required=False, + type=int, + default="9000", + ) + + parser.add_argument( + "-c", + "--task_config", + help="File with the task configuration", + action="store", + required=True, + type=str, + default="netsecenv_conf.yaml", + ) + + args = parser.parse_args() + print(args) + # Set the logging + log_filename = Path("NSG_coordinator.log") + if not log_filename.parent.exists(): + os.makedirs(log_filename.parent) + + # Convert the logging level in the args to the level to use + pass_level = get_logging_level(args.debug_level) + + logging.basicConfig( + filename=log_filename, + filemode="w", + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=pass_level, + ) + + game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config) + # Run it! + game_server.run() \ No newline at end of file diff --git a/env/worlds/network_security_game_real_world.py b/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py similarity index 62% rename from env/worlds/network_security_game_real_world.py rename to AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py index d7276510..763326ab 100644 --- a/env/worlds/network_security_game_real_world.py +++ b/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py @@ -2,22 +2,49 @@ # Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz # Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz -import env.game_components as components -from env.worlds.network_security_game import NetworkSecurityEnvironment import subprocess import xml.etree.ElementTree as ElementTree +import logging +import argparse +import os +from pathlib import Path +from AIDojoCoordinator.utils.utils import get_logging_level +from AIDojoCoordinator.game_components import GameState, Action, ActionType, Service,IP +from AIDojoCoordinator.worlds.NSEGameCoordinator import NSGCoordinator -class NetworkSecurityEnvironmentRealWorld(NetworkSecurityEnvironment): - """ - Class to manage the whole network security game in the real world (current network) - It uses some Cyst libraries for the network topology - It presents a env environment to play - """ - def __init__(self, task_config_file, world_name="NetSecEnvRealWorld") -> None: - super().__init__(task_config_file, world_name) +class NSERealWorldGameCoordinator(NSGCoordinator): + + def _execute_action(self, current_state:GameState, action:Action)-> GameState: + """ + Execute the action and update the values in the state + Before this function it was checked if the action was successful + So in here all actions were already successful. - def _execute_scan_network_action_real_world(self, current_state:components.GameState, action:components.Action)->components.GameState: + - actions_type: Define if the action is simulated in netsecenv or in the real world + - agent_id: is the name or type of agent that requested the action + + Returns: A new GameState + """ + next_state = None + match action.type: + case ActionType.ScanNetwork: + next_state = self._execute_scan_network_action_real_world(current_state, action) + case ActionType.FindServices: + next_state = self._execute_find_services_action_real_world(current_state, action) + case ActionType.FindData: + next_state = self._execute_find_data_action(current_state, action) + case ActionType.ExploitService: + next_state = self._execute_exploit_service_action(current_state, action) + case ActionType.ExfiltrateData: + next_state = self._execute_exfiltrate_data_action(current_state, action) + case ActionType.BlockIP: + next_state = self._execute_block_ip_action(current_state, action) + case _: + raise ValueError(f"Unknown Action type or other error: '{action.type}'") + return next_state + + def _execute_scan_network_action_real_world(self, current_state:GameState, action:Action)->GameState: """ Executes the ScanNetwork action in the the real world """ @@ -38,7 +65,7 @@ def _execute_scan_network_action_real_world(self, current_state:components.GameS status = "" ip_elem = host.find('./address[@addrtype="ipv4"]') if ip_elem is not None: - ip = components.IP(str(ip_elem.get('addr'))) + ip = IP(str(ip_elem.get('addr'))) else: ip = "" @@ -53,9 +80,9 @@ def _execute_scan_network_action_real_world(self, current_state:components.GameS self.logger.debug(f"\t\t\tAdding {ip} to new_ips. {status}, {mac_address}, {vendor}") new_ips.add(ip) next_known_h = next_known_h.union(new_ips) - return components.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_find_services_action_real_world(self, current_state:components.GameState, action:components.Action)->components.GameState: + def _execute_find_services_action_real_world(self, current_state:GameState, action:Action)->GameState: """ Executes the FindServices action in the real world """ @@ -85,7 +112,7 @@ def _execute_find_services_action_real_world(self, current_state:components.Game service_elem = port.find('./service[@name]') service_name = service_elem.get('name') if service_elem is not None else "Unknown" service_fullname = f'{port_id}/{protocol}/{service_name}' - service = components.Service(name=service_fullname, type=service_name, version='', is_local=False) + service = Service(name=service_fullname, type=service_name, version='', is_local=False) found_services.add(service) next_services[action.parameters["target_host"]] = found_services @@ -95,66 +122,72 @@ def _execute_find_services_action_real_world(self, current_state:components.Game self.logger.info(f"\t\tAdding {action.parameters['target_host']} to known_hosts") next_known_h.add(action.parameters["target_host"]) next_nets = next_nets.union({net for net, values in self._networks.items() if action.parameters["target_host"] in values}) - return components.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) + return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked) - def _execute_action(self, current_state:components.GameState, action:components.Action, agent_id)-> components.GameState: - """ - Execute the action and update the values in the state - Before this function it was checked if the action was successful - So in here all actions were already successful. - - - actions_type: Define if the action is simulated in netsecenv or in the real world - - agent_id: is the name or type of agent that requested the action - - Returns: A new GameState - """ - next_state = None - match action.type: - case components.ActionType.ScanNetwork: - next_state = self._execute_scan_network_action_real_world(current_state, action) - case components.ActionType.FindServices: - next_state = self._execute_find_services_action_real_world(current_state, action) - case components.ActionType.FindData: - # This Action type is not implemente in real world - use the simualtion - next_state = self._execute_find_data_action(current_state, action) - case components.ActionType.ExploitService: - # This Action type is not implemente in real world - use the simualtion - next_state = self._execute_exploit_service_action(current_state, action) - case components.ActionType.ExfiltrateData: - # This Action type is not implemente in real world - use the simualtion - next_state = self._execute_exfiltrate_data_action(current_state, action) - case components.ActionType.BlockIP: - # This Action type is not implemente in real world - use the simualtion - next_state = self._execute_block_ip_action(current_state, action) - case _: - raise ValueError(f"Unknown Action type or other error: '{action.type}'") - return next_state - - def step(self, state:components.GameState, action:components.Action, agent_id:tuple)-> components.GameState: - """ - Take a step in the environment given an action - in: action - out: observation of the state of the env - """ - self.logger.info(f"Agent {agent_id}. Action: {action}") - # Reward for taking an action - reward = self._rewards["step"] +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="NetSecGame Coordinator Server (Real World) Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz; sebastian.garcia@agents.fel.cvut.cz", + usage="%(prog)s [options]", + ) + + parser.add_argument( + "-l", + "--debug_level", + help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL", + action="store", + required=False, + type=str, + default="INFO", + ) + + parser.add_argument( + "-gh", + "--game_host", + help="host where to run the game server", + action="store", + required=False, + type=str, + default="127.0.0.1", + ) + + parser.add_argument( + "-gp", + "--game_port", + help="Port where to run the game server", + action="store", + required=False, + type=int, + default="9000", + ) - # 1. Perform the action - self._actions_played.append(action) - - # No randomness in action success - we are playing in real world - next_state = self._execute_action(state, action, agent_id) - + parser.add_argument( + "-c", + "--task_config", + help="File with the task configuration", + action="store", + required=True, + type=str, + default="netsecenv_conf.yaml", + ) - - # Make the state we just got into, our current state - current_state = state - self.logger.info(f'New state: {next_state} ') + args = parser.parse_args() + print(args) + # Set the logging + log_filename = Path("NSG_real_world_coordinator.log") + if not log_filename.parent.exists(): + os.makedirs(log_filename.parent) + # Convert the logging level in the args to the level to use + pass_level = get_logging_level(args.debug_level) - # Save the transition to the episode replay buffer if there is any - if self._episode_replay_buffer is not None: - self._episode_replay_buffer.append((current_state, action, reward, next_state)) - # Return an observation - return next_state \ No newline at end of file + logging.basicConfig( + filename=log_filename, + filemode="w", + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=pass_level, + ) + + game_server = NSERealWorldGameCoordinator(args.game_host, args.game_port, args.task_config) + # Run it! + game_server.run() \ No newline at end of file diff --git a/utils/__init__.py b/AIDojoCoordinator/worlds/__init__.py similarity index 100% rename from utils/__init__.py rename to AIDojoCoordinator/worlds/__init__.py diff --git a/NetSecGameAgents b/NetSecGameAgents index d15cf5c6..5c966fd7 160000 --- a/NetSecGameAgents +++ b/NetSecGameAgents @@ -1 +1 @@ -Subproject commit d15cf5c663c76479a7ce9845e56e46c28016c80d +Subproject commit 5c966fd7f8560a2632dd4279392440219b4bb486 diff --git a/README.md b/README.md index 619ddd48..c944e5c7 100755 --- a/README.md +++ b/README.md @@ -19,17 +19,17 @@ python -m venv ai-dojo-venv- source ai-dojo-venv/bin/activate ``` -- Install the requirements with +- Install using pip by running following in the **root** directory ```bash -python3 -m pip install -r requirements.txt +pip install -e . ``` - If you use conda use ```bash conda create --name aidojo python==3.10 conda activate aidojo -python3 -m pip install -r requirements.txt +pip install -e . ``` ## Architecture @@ -149,8 +149,13 @@ This approach ensures that only repeated or excessive behavior is flagged, reduc ### Starting the game -The environment should be created before starting the agents. The properties of the environment can be defined in a YAML file. The game server can be started by running: -```python3 coordinator.py``` +The environment should be created before starting the agents. The properties of the game, the task and the topology can be either read from a local file or via REST request to the GameDashboard. + +#### To start the game with local configuration file +```python3 -m AIDojoCoordinator.worlds.NSEGameCoordinator --task_config=``` + +#### To start the game with remotely defined configuration +```python3 -m AIDojoCoordinator.worlds.CYSTCoordinator --service_host= --service_port= ``` When created, the environment: 1. reads the configuration file diff --git a/coordinator.conf b/coordinator.conf deleted file mode 100644 index 4758d9b6..00000000 --- a/coordinator.conf +++ /dev/null @@ -1,7 +0,0 @@ -{ -"host": "127.0.0.1", -"port": 9000, -"start_reward": 0, -"max_steps": 1500, -"world_type": "netsecenv" -} diff --git a/coordinator.py b/coordinator.py deleted file mode 100644 index 31dc475b..00000000 --- a/coordinator.py +++ /dev/null @@ -1,950 +0,0 @@ -#!/usr/bin/env python -# Server for the Aidojo project, coordinator -# Author: sebastian garcia, sebastian.garcia@agents.fel.cvut.cz -# Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz -import jsonlines -import argparse -import logging -import json -import asyncio -import enum -from datetime import datetime -from env.worlds.network_security_game import NetworkSecurityEnvironment -from env.worlds.network_security_game_real_world import NetworkSecurityEnvironmentRealWorld -from env.worlds.aidojo_world import AIDojoWorld -from env.game_components import Action, Observation, ActionType, GameStatus, GameState -from utils.utils import observation_as_dict, get_logging_level, get_file_hash -from pathlib import Path -import os -import signal -from env.global_defender import stochastic_with_threshold -from utils.utils import ConfigParser - -@enum.unique -class AgentStatus(enum.Enum): - """ - Class representing the current status for each agent connected to the coordinator - """ - JoinRequested = 0 - Ready = 1 - Playing = 2 - PlayingActive = 3 - FinishedMaxSteps = 4 - FinishedBlocked = 5 - FinishedGoalReached = 6 - FinishedGameLost = 7 - ResetRequested = 8 - Quitting = 9 - - -class AIDojo: - def __init__(self, host: str, port: int, net_sec_config: str, world_type) -> None: - self.host = host - self.port = port - self.logger = logging.getLogger("AIDojo-main") - self._agent_action_queue = asyncio.Queue() - self._agent_response_queues = {} - self._coordinator = Coordinator( - self._agent_action_queue, - self._agent_response_queues, - net_sec_config, - allowed_roles=["Attacker", "Defender", "Benign"], - world_type = world_type, - ) - - async def create_agent_queue(self, addr): - """ - Create a queue for the given agent address if it doesn't already exist. - """ - if addr not in self._agent_response_queues: - self._agent_response_queues[addr] = asyncio.Queue() - self.logger.info(f"Created queue for agent {addr}. {len(self._agent_response_queues)} queues in total.") - - def run(self)->None: - """ - Wrapper for ayncio run function. Starts all tasks in AIDojo - """ - asyncio.run(self.start_tasks()) - - async def start_tasks(self): - """ - High level funciton to start all the other asynchronous tasks. - - Reads the conf of the coordinator - - Creates queues - - Start the main part of the coordinator - - Start a server that listens for agents - """ - self.logger.info("Starting all tasks") - loop = asyncio.get_running_loop() - - self.logger.info("Starting Coordinator taks") - coordinator_task = asyncio.create_task(self._coordinator.run()) - - self.logger.info("Starting the server listening for agents") - running_server = await asyncio.start_server( - ConnectionLimitProtocol( - self._agent_action_queue, - self._agent_response_queues, - max_connections=2 - ), - self.host, - self.port - ) - addrs = ", ".join(str(sock.getsockname()) for sock in running_server.sockets) - self.logger.info(f"\tServing on {addrs}") - - # prepare the stopping event for keyboard interrupt - stop = loop.create_future() - - # register the signal handler to the stopping event - loop.add_signal_handler(signal.SIGINT, stop.set_result, None) - - await stop # Event that triggers stopping the AIDojo - # Stop the server - self.logger.info("Initializing server shutdown") - running_server.close() - await running_server.wait_closed() - self.logger.info("\tServer stopped") - # Stop coordinator taks - self.logger.info("Initializing coordinator shutdown") - coordinator_task.cancel() - await asyncio.gather(coordinator_task, return_exceptions=True) - self.logger.info("\tCoordinator stopped") - # Everything stopped correctly, terminate - self.logger.info("AIDojo terminating") - -class ConnectionLimitProtocol(asyncio.Protocol): - def __init__(self, actions_queue, agent_response_queues, max_connections): - self.actions_queue = actions_queue - self.answers_queues = agent_response_queues - self.max_connections = max_connections - self.current_connections = 0 - self.logger = logging.getLogger("AIDojo-Server") - self._stop = False - - def close(self)->None: - self.logger.info( - "Stopping server" - ) - self._stop = True - - async def handle_new_agent(self, reader, writer): - async def send_data_to_agent(writer, data: str) -> None: - """ - Send the world to the agent - """ - writer.write(bytes(str(data).encode())) - - # Check if the maximum number of connections has been reached - if self.current_connections >= self.max_connections: - self.logger.info( - f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}" - ) - writer.close() - return - - # Increment the count of current connections - self.current_connections += 1 - - # Handle the new agent - addr = writer.get_extra_info("peername") - self.logger.info(f"New agent connected: {addr}") - # Ensure a queue exists for this agent - if addr not in self.answers_queues: - self.answers_queues[addr] = asyncio.Queue(maxsize=2) - self.logger.info(f"Created queue for agent {addr}") - - try: - while not self._stop: - # Step 1: Read data from the agent - data = await reader.read(500) - if not data: - self.logger.info(f"Agent {addr} disconnected.") - quit_message = Action(ActionType.QuitGame, params={}).as_json() - await self.actions_queue.put((addr, quit_message)) - break - - raw_message = data.decode().strip() - self.logger.debug(f"Handler received from {addr}: {raw_message}") - - # Step 2: Forward the message to the Coordinator - await self.actions_queue.put((addr, raw_message)) - await asyncio.sleep(0) - # Step 3: Get a matching response from the answers queue - response_queue = self.answers_queues[addr] - response = await response_queue.get() - self.logger.info(f"Sending response to agent {addr}: {response}") - - # Step 4: Send the response to the agent - writer.write(bytes(str(response).encode())) - await writer.drain() - except KeyboardInterrupt: - self.logger.debug("Terminating by KeyboardInterrupt") - raise SystemExit - except Exception as e: - self.logger.error(f"Exception in handle_new_agent(): {e}") - finally: - # Decrement the count of current connections - self.current_connections -= 1 - if addr in self.answers_queues: - self.answers_queues.pop(addr) - self.logger.info(f"Removed queue for agent {addr}") - else: - self.logger.warning(f"Queue for agent {addr} not found during cleanup.") - writer.close() - return - - async def __call__(self, reader, writer): - await self.handle_new_agent(reader, writer) - -class Coordinator: - def __init__(self, actions_queue, answers_queues, net_sec_config, allowed_roles, world_type="netsecenv"): - # communication channels for asyncio - # agents -> coordinator - self._actions_queue = actions_queue - # coordinator -> agent (separate queue per agent) - self._answers_queues = answers_queues - # coordinator -> world - self._world_action_queue = asyncio.Queue() - # world -> coordinator - self._world_response_queue = asyncio.Queue() - - self.task_config = ConfigParser(net_sec_config) - self.ALLOWED_ROLES = allowed_roles - self.logger = logging.getLogger("AIDojo-Coordinator") - - # world definition - match world_type: - case "netsecenv": - self._world = NetworkSecurityEnvironment(net_sec_config,self._world_action_queue, self._world_response_queue) - case "netsecenv-real-world": - self._world = NetworkSecurityEnvironmentRealWorld(net_sec_config, self._world_action_queue, self._world_response_queue) - case _: - self._world = AIDojoWorld(net_sec_config, self._world_action_queue, self._world_response_queue) - self.world_type = world_type - self._CONFIG_FILE_HASH = get_file_hash(net_sec_config) - self._starting_positions_per_role = self._get_starting_position_per_role() - self._win_conditions_per_role = self._get_win_condition_per_role() - self._goal_description_per_role = self._get_goal_description_per_role() - self._steps_limit_per_role = self._get_max_steps_per_role() - self._use_global_defender = self.task_config.get_use_global_defender() - - # player information - self.agents = {} - # step counter per agent_addr (int) - self._agent_steps = {} - # reset request per agent_addr (bool) - self._reset_requests = {} - self._agent_observations = {} - # starting per agent_addr (dict) - self._agent_starting_position = {} - # current state per agent_addr (GameState) - self._agent_states = {} - # last action played by agent (Action) - self._agent_last_action = {} - # agent status dict {agent_addr: AgentStatus} - self._agent_statuses = {} - # agent status dict {agent_addr: int} - self._agent_rewards = {} - # trajectories per agent_addr - self._agent_trajectories = {} - - @property - def episode_end(self)->bool: - # Episode ends ONLY IF all agents with defined max_steps reached the end fo the episode - exists_active_player = any(status is AgentStatus.PlayingActive for status in self._agent_statuses.values()) - self.logger.debug(f"End evaluation: {self._agent_statuses.items()} - Episode end:{not exists_active_player}") - return not exists_active_player - - @property - def config_file_hash(self): - return self._CONFIG_FILE_HASH - - def convert_msg_dict_to_json(self, msg_dict)->str: - try: - # Convert message into string representation - output_message = json.dumps(msg_dict) - except Exception as e: - self.logger.error(f"Error when converting msg to Json:{e}") - raise e - # Send to anwer_queue - return output_message - - async def run(self): - """ - Main method to be run for coordinating the agent's interaction with the game engine. - - Reads messages from action queue - - processes actions based on their type - - Forwards actions in the game engine - - Evaluates states and checks for goal - - Forwards responses the agents - """ - try: - self.logger.info("Main coordinator started.") - # Start World Response handler task - world_response_task = asyncio.create_task(self._handle_world_responses()) - world_processing_task = asyncio.create_task(self._world.handle_incoming_action()) - while True: - # Read message from the queue - agent_addr, message = await self._actions_queue.get() - if message is not None: - self.logger.debug(f"Coordinator received: {message}.") - try: # Convert message to Action - action = Action.from_json(message) - except Exception as e: - self.logger.error( - f"Error when converting msg to Action using Action.from_json():{e}, {message}" - ) - match action.type: # process action based on its type - case ActionType.JoinGame: - self.logger.debug(f"Start processing of ActionType.JoinGame by {agent_addr}") - await self._process_join_game_action(agent_addr, action) - case ActionType.QuitGame: - self.logger.info(f"Coordinator received from QUIT message from agent {agent_addr}") - # update agent status - self._agent_statuses[agent_addr] = AgentStatus.Quitting - # forward the message to the world - await self._world_action_queue.put((agent_addr, action, self._agent_states[agent_addr])) - case ActionType.ResetGame: - self._reset_requests[agent_addr] = True - self._agent_statuses[agent_addr] = AgentStatus.ResetRequested - self.logger.info(f"Coordinator received from RESET request from agent {agent_addr} ({self.agents[agent_addr]})") - if all(self._reset_requests.values()): - # should we discard the queue here? - self.logger.info("All active agents requested reset") - # send WORLD reset request to the world - await self._world_action_queue.put(("world", Action(ActionType.ResetGame, params={}), None)) - # send request for each of the agents (to get new initial state) - for agent in self._reset_requests: - await self._world_action_queue.put((agent, Action(ActionType.ResetGame, params={}), self._agent_starting_position[agent])) - else: - self.logger.info("\t Waiting for other agents to request reset") - case _: - self.logger.debug(f"Agent {self.agents[agent_addr]}, played Action {action}.") - # actions in the game - await self._process_generic_action(agent_addr, action) - await asyncio.sleep(0) - except asyncio.CancelledError: - world_response_task.cancel() - world_processing_task.cancel() - asyncio.gather(world_processing_task, world_response_task, return_exceptions=True) - self.logger.info("\tTerminating by CancelledError") - except Exception as e: - self.logger.error(f"Exception in Class coordinator(): {e}") - raise e - - def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation: - """ - Method to initialize new player upon joining the game. - Returns initial observation for the agent based on the agent's role - """ - self.logger.info(f"\tInitializing new player{agent_addr}") - agent_name, agent_role = self.agents[agent_addr] - self._agent_steps[agent_addr] = 0 - self._reset_requests[agent_addr] = False - self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role] - self._agent_statuses[agent_addr] = AgentStatus.PlayingActive if agent_role == "Attacker" else AgentStatus.Playing - self._agent_states[agent_addr] = agent_current_state - - - #self._agent_states[agent_addr] = self._world.create_state_from_view(self._agent_starting_position[agent_addr]) - - # if self._steps_limit_per_role[agent_role]: - # # This agent can force episode end (has timeout and goal defined) - # self._agent_statuses[agent_addr] = AgentStatus.PlayingActive - # else: - # # This agent can NOT force episode end (does NOT timeout or goal defined) - # self._agent_statuses[agent_addr] = AgentStatus.Playing - - if self.task_config.get_store_trajectories() or self._use_global_defender: - self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) - self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}") - # Initializeing the player, also sets the end_episode to false for the Observation of the agent - end_episode = False - return Observation(self._agent_states[agent_addr], 0, end_episode, {}) - - def _remove_player(self, agent_addr:tuple)->dict: - """ - Removes player from the game. Should be called AFTER QuitGame action was processed by the world. - """ - self.logger.info(f"Removing player {agent_addr} from the Coordinator") - agent_info = {} - if agent_addr in self.agents: - agent_info["state"] = self._agent_states.pop(agent_addr) - agent_info["status"] = self._agent_statuses.pop(agent_addr) - agent_info["num_steps"] = self._agent_steps.pop(agent_addr) - agent_info["reset_request"] = self._reset_requests.pop(agent_addr) - agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None) - agent_info["agent_info"] = self.agents.pop(agent_addr) - self.logger.debug(f"\t{agent_info}") - else: - self.logger.info(f"\t Player {agent_addr} not present in the game!") - return agent_info - - def _get_starting_position_per_role(self)->dict: - """ - Method for finding starting position for each agent role in the game. - """ - starting_positions = {} - for agent_role in self.ALLOWED_ROLES: - try: - starting_positions[agent_role] = self.task_config.get_start_position(agent_role=agent_role) - self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}") - except KeyError: - starting_positions[agent_role] = {} - return starting_positions - - def _get_win_condition_per_role(self)-> dict: - """ - Method for finding wininng conditions for each agent role in the game. - """ - win_conditions = {} - for agent_role in self.ALLOWED_ROLES: - try: - win_conditions[agent_role] = self._world.update_goal_dict( - self.task_config.get_win_conditions(agent_role=agent_role) - ) - except KeyError: - win_conditions[agent_role] = {} - self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}") - return win_conditions - - def _get_goal_description_per_role(self)->dict: - """ - Method for finding goal description for each agent role in the game. - """ - goal_descriptions ={} - for agent_role in self.ALLOWED_ROLES: - try: - goal_descriptions[agent_role] = self._world.update_goal_descriptions( - self.task_config.get_goal_description(agent_role=agent_role) - ) - except KeyError: - goal_descriptions[agent_role] = "" - self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}") - return goal_descriptions - - def _get_max_steps_per_role(self)->dict: - """ - Method for finding max amount of steps in 1 episode for each agent role in the game. - """ - max_steps = {} - for agent_role in self.ALLOWED_ROLES: - try: - max_steps[agent_role] = self.task_config.get_max_steps(agent_role) - except KeyError: - max_steps[agent_role] = None - self.logger.info(f"Max steps in episode for '{agent_role}': {max_steps[agent_role]}") - return max_steps - - async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: - """ " - Method for processing Action of type ActionType.JoinGame - """ - self.logger.info(f"New Join request by {agent_addr}.") - if agent_addr not in self.agents: - agent_name = action.parameters["agent_info"].name - agent_role = action.parameters["agent_info"].role - if agent_role in self.ALLOWED_ROLES: - self.agents[agent_addr] = (agent_name, agent_role) - self._agent_statuses[agent_addr] = AgentStatus.JoinRequested - self.logger.debug(f"Sending JoinRequest by {agent_addr} to the world_action_queue") - await self._world_action_queue.put((agent_addr, action, self._starting_positions_per_role[agent_role])) - else: - self.logger.info( - f"\tError in registration, unknown agent role: {agent_role}!" - ) - output_message_dict = { - "to_agent": agent_addr, - "status": str(GameStatus.BAD_REQUEST), - "message": f"Incorrect agent_role {agent_role}", - } - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) - await self._answers_queues[agent_addr].put(response_msg_json) - else: - self.logger.info("\tError in registration, agent already exists!") - output_message_dict = { - "to_agent": agent_addr, - "status": str(GameStatus.BAD_REQUEST), - "message": "Agent already exists.", - } - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) - await self._answers_queues[agent_addr].put(response_msg_json) - - def _create_response_to_reset_game_action(self, agent_addr: tuple) -> dict: - """ " - Method for generatating answers to Action of type ActionType.ResetGame after all agents requested reset - """ - self.logger.info( - f"Coordinator responding to RESET request from agent {agent_addr}" - ) - # store trajectory in file if needed - if self.task_config.get_store_trajectories(): - self._store_trajectory_to_file(agent_addr) - new_observation = Observation(self._agent_states[agent_addr], 0, self.episode_end, {}) - # reset trajectory - self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) - output_message_dict = { - "to_agent": agent_addr, - "status": str(GameStatus.RESET_DONE), - "observation": observation_as_dict(new_observation), - "message": { - "message": "Resetting Game and starting again.", - "max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]], - "goal_description": self._goal_description_per_role[self.agents[agent_addr][1]], - "configuration_hash": self._CONFIG_FILE_HASH - }, - } - return output_message_dict - - def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str)-> None: - """ - Method for adding one step to the agent trajectory. - """ - if agent_addr in self._agent_trajectories: - self.logger.debug(f"Adding step to trajectory of {agent_addr}") - self._agent_trajectories[agent_addr]["trajectory"]["actions"].append(action.as_dict) - self._agent_trajectories[agent_addr]["trajectory"]["rewards"].append(reward) - self._agent_trajectories[agent_addr]["trajectory"]["states"].append(next_state.as_dict) - if end_reason: - self._agent_trajectories[agent_addr]["end_reason"] = end_reason - - def _store_trajectory_to_file(self, agent_addr:tuple, location="./trajectories")-> None: - self.logger.debug(f"Storing Trajectory of {agent_addr}in file") - if agent_addr in self._agent_trajectories: - agent_name, agent_role = self.agents[agent_addr] - filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl") - with jsonlines.open(filename, "a") as writer: - writer.write(self._agent_trajectories[agent_addr]) - self.logger.info(f"Trajectory of {agent_addr} strored in {filename}") - - def _reset_trajectory(self,agent_addr)->dict: - agent_name, agent_role = self.agents[agent_addr] - self.logger.debug(f"Resetting trajectory of {agent_addr}") - return { - "trajectory":{ - "states":[self._agent_states[agent_addr].as_dict], - "actions":[], - "rewards":[], - }, - "end_reason":None, - "agent_role":agent_role, - "agent_name":agent_name - } - - async def _process_generic_action(self, agent_addr: tuple, action: Action) ->None: - """ - Method processing the Actions relevant to the environment - """ - self.logger.info(f"Processing {action} from {agent_addr}") - if not self.episode_end: - self._agent_last_action[agent_addr] = action - await self._world_action_queue.put((agent_addr, action, self._agent_states[agent_addr])) - else: - # Episode finished, just send back the rewards and final episode info - self._assign_end_rewards() - self.logger.info(f"{self.episode_end}, {self._agent_statuses[agent_addr]}") - output_message_dict = self._generate_episode_end_message(agent_addr) - response_msg_json = self.convert_msg_dict_to_json(output_message_dict) - await self._answers_queues[agent_addr].put(response_msg_json) - - def _generate_episode_end_message(self, agent_addr:tuple)->dict: - """ - Method for generating response when agent attemps to make a step after episode ended. - """ - # There is a case when a defender agent connects first and is alone that it doesnt receive an observation because the game may not have started. - current_observation = self._agent_observations.get(agent_addr, Observation(self._agent_states[agent_addr], 0, True, {})) - reward = self._agent_rewards[agent_addr] - end_reason = str(self._agent_statuses[agent_addr]) - new_observation = Observation( - current_observation.state, - reward=reward, - end=True, - info={'end_reason': end_reason, "info":"Episode ended. Request reset for starting new episode."}) - output_message_dict = { - "to_agent": agent_addr, - "observation": observation_as_dict(new_observation), - "status": str(GameStatus.FORBIDDEN), - } - return output_message_dict - - def _goal_reached(self, agent_addr:tuple)->bool: - """ - Determines if and agent reached a goal state - """ - self.logger.info(f"Goal check for {agent_addr}({self.agents[agent_addr][1]})") - agents_state = self._agent_states[agent_addr] - agent_role = self.agents[agent_addr][1] - win_condition = self._world.update_goal_dict(self._win_conditions_per_role[agent_role]) - goal_check = self._check_goal(agents_state, win_condition) - if goal_check: - self.logger.info("\tGoal reached!") - else: - self.logger.info("\tGoal not reached!") - return goal_check - - def _check_goal(self, state:GameState, goal_conditions:dict)->bool: - """ - Check if the goal conditons were satisfied in a given game state - """ - def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool: - """ - Helper function for checking if a goal dictionary condition is satisfied - """ - # check if we have all IPs that should have some values (are keys in goal_dict) - if goal_dict.keys() <= known_dict.keys(): - try: - # Check if values (sets) for EACH key (host) in goal_dict are subsets of known_dict, keep matching_keys - matching_keys = [host for host in goal_dict.keys() if goal_dict[host]<= known_dict[host]] - # Check we have the amount of mathing keys as in the goal_dict - if len(matching_keys) == len(goal_dict.keys()): - return True - except KeyError: - # some keys are missing in the known_dict - return False - return False - - # For each part of the state of the game, check if the conditions are met - goal_reached = {} - goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks) - goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts) - goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts) - goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services) - goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data) - goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks) - self.logger.debug(f"\t{goal_reached}") - return all(goal_reached.values()) - - def _check_detection(self, agent_addr:tuple, last_action:Action)->bool: - self.logger.info(f"Detection check for {agent_addr}({self.agents[agent_addr][1]})") - detection = False - if last_action: - if self._use_global_defender: - self.logger.warning("Global defender - ONLY use for backward compatibility!") - episode_actions = None - if (agent_addr in self._agent_trajectories and "trajectory" in self._agent_trajectories[agent_addr] and "actions" in self._agent_trajectories[agent_addr]["trajectory"]): - episode_actions = self._agent_trajectories[agent_addr]["trajectory"]["actions"] - detection = stochastic_with_threshold(last_action, episode_actions) - if detection: - self.logger.info("\tDetected!") - else: - self.logger.info("\tNot detected!") - return detection - - def _max_steps_reached(self, agent_addr:tuple) ->bool: - """ - Checks if the agent reached the max allowed steps. Only applies to role 'Attacker' - """ - self.logger.debug(f"Checking timout for {self.agents[agent_addr]}") - agent_role = self.agents[agent_addr][1] - if self._steps_limit_per_role[agent_role]: - if self._agent_steps[agent_addr] >= self._steps_limit_per_role[agent_role]: - self.logger.info(f"Timeout reached by {self.agents[agent_addr]}!") - return True - else: - self.logger.debug(f"No max steps defined for role {agent_role}") - return False - - def _assign_end_rewards(self)->None: - """ - Method which assings rewards to each agent which has finished playing - """ - self.logger.debug("Assigning rewards") - is_episode_over = self.episode_end - for agent, status in self._agent_statuses.items(): - if agent not in self._agent_rewards.keys(): # reward has not been assigned yet - agent_name, agent_role = self.agents[agent] - if agent_role == "Attacker": - match status: - case AgentStatus.FinishedGoalReached: - self._agent_rewards[agent] = self._world._rewards["goal"] - case AgentStatus.FinishedMaxSteps: - self._agent_rewards[agent] = 0 - case AgentStatus.FinishedBlocked: - self._agent_rewards[agent] = self._world._rewards["detection"] - self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}") - elif agent_role == "Defender": - if self._agent_statuses[agent] is AgentStatus.FinishedMaxSteps: #defender was responsible for the end - raise NotImplementedError - self._agent_rewards[agent] = 0 - else: - if is_episode_over: #only assign defender's reward when episode ends - sucessful_attacks = list(self._agent_statuses.values()).count("goal_reached") - if sucessful_attacks > 0: - self._agent_rewards[agent] = sucessful_attacks*self._world._rewards["detection"] - self._agent_statuses[agent] = "game_lost" - else: #no successful attacker - self._agent_rewards[agent] = self._world._rewards["goal"] - self._agent_statuses[agent] = "goal_reached" - self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}") - else: - if is_episode_over: - self._agent_rewards[agent] = 0 - self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}") - - async def _handle_world_responses(self)-> None: - """ - Continuously processes responses from the AIDojo World, evaluates them and sends messages to agents - """ - try: - self.logger.info("\tStarting task to handle AIDojo World responses") - while True: - try: - # Get a response from the World Response Queue - agent_id, response = await self._world_response_queue.get() - self.logger.info(f"Received response for agent {agent_id}: {response}") - - # Processing of the response - response_msg_json = self._process_world_response(agent_id, response) - # Notify the agent if there is message - if len(response_msg_json) > 2: # we have NON EMPTY JSON (len('{}') = 2) - self.logger.info(f"Generated response for agent {agent_id} ({self.agents[agent_id]}): {response_msg_json}") - await self._answers_queues[agent_id].put(response_msg_json) - self.logger.info(f"Placed response in answers queue for agent {agent_id}") - else: - self.logger.info(f"Empty response for agent {agent_id}: {response_msg_json}. Skipping") - await asyncio.sleep(0) - except Exception as e: - self.logger.error(f"Error handling world response: {e}") - except asyncio.CancelledError: - self.logger.info("\tTerminating by CancelledError") - - def _process_world_response(self, agent_addr:tuple, response:tuple)-> str: - """ - Method for generation of messages to the agent based on the world response - """ - agent_new_state, game_status = response - output_message_dict = {} - try: - agent_status = self._agent_statuses[agent_addr] - if agent_status is AgentStatus.JoinRequested: - output_message_dict = self._process_world_response_created(agent_addr, game_status, agent_new_state) - elif agent_status is AgentStatus.ResetRequested: - output_message_dict = self._process_world_response_reset_done(agent_addr, game_status, agent_new_state) - elif agent_status is AgentStatus.Quitting: - if game_status is GameStatus.OK: - self.logger.debug(f"Agent {agent_addr} removed successfuly from the world") - else: - self.logger.warning(f"Error when removing Agent {agent_addr} from the world") - self._remove_player(agent_addr) - elif agent_status in [AgentStatus.Ready, AgentStatus.Playing, AgentStatus.PlayingActive]: - output_message_dict = self._process_world_response_step(agent_addr, game_status, agent_new_state) - elif agent_status in [AgentStatus.FinishedBlocked, AgentStatus.FinishedGameLost, AgentStatus.FinishedGoalReached, AgentStatus.FinishedMaxSteps]: # This if does not make sense. Put together with the previous (sebas) - output_message_dict = self._process_world_response_step(agent_addr, game_status, agent_new_state) - else: - self.logger.error(f"Unsupported value '{agent_status}'!") - - msg_json = self.convert_msg_dict_to_json(output_message_dict) - return msg_json - except KeyError as e : - self.logger.error(f"Agent {agent_addr} not found! {e}") - - def _process_world_response_created(self, agent_addr:tuple, game_status:GameStatus, new_agent_game_state:GameState)->dict: - """ - Handles reply to Action.JoinGame for agent based on the response of the AIDojo World - """ - # is agent correctly started in the world - if game_status is GameStatus.CREATED: - observation = self._initialize_new_player(agent_addr, new_agent_game_state) - agent_name, agent_role = self.agents[agent_addr] - output_message_dict = { - "to_agent": agent_addr, - "status": str(game_status), - "observation": observation_as_dict(observation), - "message": { - "message": f"Welcome {agent_name}, registred as {agent_role}", - "max_steps": self._steps_limit_per_role[agent_role], - "goal_description": self._goal_description_per_role[agent_role], - "actions": [str(a) for a in ActionType], - "configuration_hash": self._CONFIG_FILE_HASH - }, - } - else: - # remove traces of agent from the game - self._remove_player(agent_addr) - output_message_dict = { - "to_agent": agent_addr, - "status": str(game_status), - "message": f"Error when initializing the agent {agent_name}({agent_role})", - } - return output_message_dict - - def _process_world_response_reset_done(self, agent_addr:tuple, game_status:GameStatus, agent_new_state:GameState)->dict: - """ - Handles reply to Action.JoinGame for agent based on the response of the AIDojo World - """ - if game_status is GameStatus.RESET_DONE: - self._reset_requests[agent_addr] = False - self._agent_steps[agent_addr] = 0 - self._agent_states[agent_addr] = agent_new_state - self._agent_rewards.pop(agent_addr, None) - if self._steps_limit_per_role[self.agents[agent_addr][1]]: - # This agent can force episode end (has timeout and goal defined) - self._agent_statuses[agent_addr] = AgentStatus.PlayingActive - else: - # This agent can NOT force episode end (does NOT timeout or goal defined) - self._agent_statuses[agent_addr] = AgentStatus.Playing - output_message_dict = self._create_response_to_reset_game_action(agent_addr) - else: - # remove traces of agent from the game - agent_name, agent_role = self.agents - self._remove_player(agent_addr) - output_message_dict = { - "to_agent": agent_addr, - "status": str(game_status), - "message": f"Error when resetting the agent {agent_name} ({agent_role})", - } - return output_message_dict - - def _process_world_response_step(self, agent_addr:tuple, game_status:GameStatus, agent_new_state:GameState)->dict: - """ - Handles reply for agent based on the response of the AIDojo World. Covers followinf ActionTypes: - - ActionType.ScanNetwork - - ActionType.FindServices - - ActionType.FindData - - ActionType.ExfiltrateData - - ActionType.ExploitService - - ActionType.BlockIP - """ - if game_status is GameStatus.OK: - if not self.episode_end: - # increase the action counter - self._agent_steps[agent_addr] += 1 - self.logger.info(f"Agent {agent_addr} ({self.agents[agent_addr]}) did #steps: {self._agent_steps[agent_addr]}") - # register the new state - self._agent_states[agent_addr] = agent_new_state - # load the action which lead to the new state - last_action = self._agent_last_action[agent_addr] - # check timeout - if self._max_steps_reached(agent_addr): - self._agent_statuses[agent_addr] = AgentStatus.FinishedMaxSteps - # check detection - if self._check_detection(agent_addr, last_action): - self._agent_statuses[agent_addr] = AgentStatus.FinishedBlocked - # check goal - if self._goal_reached(agent_addr): - self._agent_statuses[agent_addr] = AgentStatus.FinishedGoalReached - # add reward for taking a step - reward = self._world._rewards["step"] - - obs_info = {} - end_reason = None - if self._agent_statuses[agent_addr] is AgentStatus.FinishedGoalReached: - self._assign_end_rewards() - reward += self._agent_rewards[agent_addr] - end_reason = "goal_reached" - obs_info = {'end_reason': "goal_reached"} - elif self._agent_statuses[agent_addr] is AgentStatus.FinishedMaxSteps: - self._assign_end_rewards() - reward += self._agent_rewards[agent_addr] - obs_info = {"end_reason": "max_steps"} - end_reason = "max_steps" - elif self._agent_statuses[agent_addr] is AgentStatus.FinishedBlocked: - self._assign_end_rewards() - reward += self._agent_rewards[agent_addr] - obs_info = {"end_reason": "blocked"} - - # record step in trajecory - self._add_step_to_trajectory(agent_addr, last_action, reward,self._agent_states[agent_addr], end_reason) - new_observation = Observation(self._agent_states[agent_addr], reward, self.episode_end, info=obs_info) - - self._agent_observations[agent_addr] = new_observation - - output_message_dict = { - "to_agent": agent_addr, - "observation": observation_as_dict(new_observation), - "status": str(GameStatus.OK), - } - else: - self._assign_end_rewards() - output_message_dict = self._generate_episode_end_message(agent_addr) - else: - output_message_dict = { - "to_agent": agent_addr, - "status": str(game_status), - "message": f"Error when playing action {last_action}", - } - return output_message_dict - -__version__ = "v0.2.2" - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=f"NetSecGame Coordinator Server version {__version__}. Author: Sebastian Garcia, sebastian.garcia@agents.fel.cvut.cz", - usage="%(prog)s [options]", - ) - parser.add_argument( - "-v", - "--verbose", - help="Verbosity level. This shows more info about the results.", - action="store", - required=False, - type=int, - ) - parser.add_argument( - "-c", - "--configfile", - help="Configuration file.", - action="store", - required=False, - type=str, - default="coordinator.conf", - ) - parser.add_argument( - "-t", - "--task_config", - help="Task configuration file.", - action="store", - required=False, - type=str, - default="env/netsecenv_conf.yaml", - ) - parser.add_argument( - "-l", - "--debug_level", - help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL", - action="store", - required=False, - type=str, - default="DEBUG", - ) - - args = parser.parse_args() - print(args) - # Set the logging - log_filename = Path("coordinator.log") - if not log_filename.parent.exists(): - os.makedirs(log_filename.parent) - - # Convert the logging level in the args to the level to use - pass_level = get_logging_level(args.debug_level) - - logging.basicConfig( - filename=log_filename, - filemode="w", - format="%(asctime)s %(name)s %(levelname)s %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=pass_level, - ) - - # load config for coordinator - with open(args.configfile, "r") as jfile: - confjson = json.load(jfile) - - host = confjson.get("host", None) - port = confjson.get("port", None) - world_type = confjson.get('world_type', 'netsecgame') - - # prioritize task config from CLI - if args.task_config: - task_config_file = args.task_config - else: - # Try to use task config from coordinator.conf - task_config_file = confjson.get("task_config", None) - if task_config_file is None: - raise KeyError("Task configuration must be provided to start the coordinator! Use -h for more details.") - # Create AI Dojo - ai_dojo = AIDojo(host, port, task_config_file, world_type) - # Run it! - ai_dojo.run() \ No newline at end of file diff --git a/env/global_defender.py b/env/global_defender.py deleted file mode 100644 index aba2da41..00000000 --- a/env/global_defender.py +++ /dev/null @@ -1,85 +0,0 @@ -# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz -from itertools import groupby -from .game_components import ActionType, Action -from random import random - - -# The probability of detecting an action is defined by the following dictionary -DEFAULT_DETECTION_PROBS = { - ActionType.ScanNetwork: 0.05, - ActionType.FindServices: 0.075, - ActionType.ExploitService: 0.1, - ActionType.FindData: 0.025, - ActionType.ExfiltrateData: 0.025, - ActionType.BlockIP: 0.01 -} - -# Ratios of action types in the time window (TW) for each action type. The ratio should be higher than the defined value to trigger a detection check -TW_TYPE_RATIOS_THRESHOLD = { - ActionType.ScanNetwork: 0.25, - ActionType.FindServices: 0.3, - ActionType.ExploitService: 0.25, - ActionType.FindData: 0.5, - ActionType.ExfiltrateData: 0.25, - ActionType.BlockIP: 1 -} - -# Thresholds for consecutive actions of the same type in the TW. Only if the threshold is crossed, the detection check is triggered -TW_CONSECUTIVE_TYPE_THRESHOLD = { - ActionType.ScanNetwork: 2, - ActionType.FindServices: 3, - ActionType.ExfiltrateData: 2 -} - -# Thresholds for repeated actions in the episode. Only if the threshold is crossed, the detection check is triggered -EPISODE_REPEATED_ACTION_THRESHOLD = { - ActionType.ExploitService: 2, - ActionType.FindData: 2, -} - -def stochastic(action_type:ActionType)->bool: - """ - Simple random detection based on predefied probability and ActionType - """ - roll = random() - if roll < DEFAULT_DETECTION_PROBS[action_type]: - return True - else: - return False - -def stochastic_with_threshold(action: Action, episode_actions:list, tw_size:int=5)-> bool: - """ - Only detect based on set probabilities if pre-defined thresholds are crossed. - """ - # extend the episode with the latest action - # We need to copy the list before the copying, so we avoid modifying it when it is returned. Modifycation of passed list is the default behavior in Python - temp_episode_actions = episode_actions.copy() - temp_episode_actions.append(action.as_dict) - if len(temp_episode_actions) >= tw_size: - last_n_actions = temp_episode_actions[-tw_size:] - last_n_action_types = [action['type'] for action in last_n_actions] - # compute ratio of action type in the TW - tw_ratio = last_n_action_types.count(str(action.type))/tw_size - # Count how many times this exact (parametrized) action was played in episode - repeats_in_episode = temp_episode_actions.count(action.as_dict) - # compute the highest consecutive number of action type in TW - max_consecutive_action_type = max(sum(1 for item in grouped if item == str(action.type)) - for _, grouped in groupby(last_n_action_types)) - - if action.type in TW_CONSECUTIVE_TYPE_THRESHOLD.keys(): - # ScanNetwork, FindServices, ExfiltrateData - if tw_ratio < TW_TYPE_RATIOS_THRESHOLD[action.type] and max_consecutive_action_type < TW_CONSECUTIVE_TYPE_THRESHOLD[action.type]: - return False - else: - return stochastic(action.type) - elif action.type in EPISODE_REPEATED_ACTION_THRESHOLD.keys(): - # FindData, Exploit service - if tw_ratio < TW_TYPE_RATIOS_THRESHOLD[action.type] and repeats_in_episode < EPISODE_REPEATED_ACTION_THRESHOLD[action.type]: - return False - else: - return stochastic(action.type) - else: #Other actions - Do not detect - return False - - else: - return False diff --git a/env/scenarios/tiny_scenario_configuration.py b/env/scenarios/tiny_scenario_configuration.py deleted file mode 100644 index c3e9f4dd..00000000 --- a/env/scenarios/tiny_scenario_configuration.py +++ /dev/null @@ -1,256 +0,0 @@ -import cyst.api.configuration as cyst_cfg -from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity - -''' -------------------------------------------------------------------------------------------------------------------- -A template for local password authentication. -''' -local_password_auth = cyst_cfg.AuthenticationProviderConfig( - provider_type=AuthenticationProviderType.LOCAL, - token_type=AuthenticationTokenType.PASSWORD, - token_security=AuthenticationTokenSecurity.SEALED, - timeout=30 -) - -''' -------------------------------------------------------------------------------------------------------------------- -Server 1: -- SMB/File sharing (It is vulnerable to some remote exploit) -- Remote Desktop -- Can go to router and internet - -- the only windows server. It does not connect to the AD -- access schemes for remote desktop and file sharing are kept separate, but can be integrated into one if needed -''' -smb_server = cyst_cfg.NodeConfig( - active_services=[], - passive_services=[ - cyst_cfg.PassiveServiceConfig( - type="lanman server", - owner="Local system", - version="10.0.19041", - local=False, - private_data=[ - cyst_cfg.DataConfig( - owner="User1", - description="DataFromServer1" - ) - ], - access_level=cyst_cfg.AccessLevel.LIMITED, - authentication_providers=[], - access_schemes=[ - cyst_cfg.AccessSchemeConfig( - authentication_providers=["windows login"], - authorization_domain=cyst_cfg.AuthorizationDomainConfig( - type=cyst_cfg.AuthorizationDomainType.LOCAL, - authorizations=[ - cyst_cfg.AuthorizationConfig("User1", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("User2", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("User3", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("User4", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("User5", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("Administrator", cyst_cfg.AccessLevel.ELEVATED) - ] - ) - ) - ] - ), - cyst_cfg.PassiveServiceConfig( - type="windows login", - owner="Administrator", - version="10.0.19041", - local=True, - access_level=cyst_cfg.AccessLevel.ELEVATED, - authentication_providers=[local_password_auth("windows login")] - ), - cyst_cfg.PassiveServiceConfig( - type="powershell", - owner="Local system", - version="10.0.19041", - local=True, - access_level=cyst_cfg.AccessLevel.LIMITED - ) - ], - traffic_processors=[], - interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.1.2"), cyst_cfg.IPNetwork("192.168.1.0/24"))], - shell="powershell", - id="smb_server" -) - - -''' -------------------------------------------------------------------------------------------------------------------- -Client 1 - -- Remote Desktop -- Accounts --- Local admin --- User1 -- Can go to server 1, 2, 3, router and internet -- Has the attacker -''' -client_1 = cyst_cfg.NodeConfig( - active_services=[ - cyst_cfg.ActiveServiceConfig( - type="scripted_actor", - name="attacker", - owner="attacker", - access_level=cyst_cfg.AccessLevel.LIMITED, - id="attacker_service" - ) - ], - passive_services=[ - cyst_cfg.PassiveServiceConfig( - type="remote desktop service", - owner="Local system", - version="10.0.19041", - local=False, - access_level=cyst_cfg.AccessLevel.ELEVATED, - parameters=[ - (cyst_cfg.ServiceParameter.ENABLE_SESSION, True), - (cyst_cfg.ServiceParameter.SESSION_ACCESS_LEVEL, cyst_cfg.AccessLevel.LIMITED) - ], - authentication_providers=[local_password_auth("client_1_windows_login")], - access_schemes=[ - cyst_cfg.AccessSchemeConfig( - authentication_providers=["client_1_windows_login"], - authorization_domain=cyst_cfg.AuthorizationDomainConfig( - type=cyst_cfg.AuthorizationDomainType.LOCAL, - authorizations=[ - cyst_cfg.AuthorizationConfig("User1", cyst_cfg.AccessLevel.LIMITED), - cyst_cfg.AuthorizationConfig("Administrator", cyst_cfg.AccessLevel.ELEVATED) - ] - ) - ) - ] - ), - cyst_cfg.PassiveServiceConfig( - type="powershell", - owner="Local system", - version="10.0.19041", - local=True, - access_level=cyst_cfg.AccessLevel.LIMITED - ), - cyst_cfg.PassiveServiceConfig( - type="can_attack_start_here", - owner="Local system", - version="1", - local=True, - access_level=cyst_cfg.AccessLevel.LIMITED - ) - ], - traffic_processors=[], - interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.2.2"), cyst_cfg.IPNetwork("192.168.2.0/24"))], - shell="powershell", - id="client_1" -) - -router1 = cyst_cfg.RouterConfig( - interfaces=[ - cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.1.1"), cyst_cfg.IPNetwork("192.168.1.0/24"), index=2), - cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.2.1"), cyst_cfg.IPNetwork("192.168.2.0/24"), index=3), - ], - routing_table=[ - # Push everything not-infrastructure to the internet - cyst_cfg.RouteConfig(cyst_cfg.IPNetwork("0.0.0.0/0"), 10) - ], - # Firewall FORWARD policy specifies inter-network routes that are enabled - # Firewall INPUT policy specifies who can connect directly to the router. In this scenario, everyone can. - traffic_processors=[ - cyst_cfg.FirewallConfig( - default_policy=cyst_cfg.FirewallPolicy.DENY, - chains=[ - cyst_cfg.FirewallChainConfig( - type=cyst_cfg.FirewallChainType.INPUT, - policy=cyst_cfg.FirewallPolicy.DENY, - rules=[ - cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.1.0/24"), cyst_cfg.IPNetwork("192.168.1.1/32"), "*", cyst_cfg.FirewallPolicy.ALLOW), - cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.2.0/24"), cyst_cfg.IPNetwork("192.168.2.1/32"), "*", cyst_cfg.FirewallPolicy.ALLOW) - ] - ), - cyst_cfg.FirewallChainConfig( - type=cyst_cfg.FirewallChainType.FORWARD, - policy=cyst_cfg.FirewallPolicy.DENY, - rules=[ - # Client 1 can go to server 1 - cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.2.2/32"), cyst_cfg.IPNetwork("192.168.1.2/32"), "*", cyst_cfg.FirewallPolicy.ALLOW), - ] - ) - ] - ) - ], - id="router1" -) - -''' -------------------------------------------------------------------------------------------------------------------- -Internet - -- Represented as a router outside the scenario network 192.168.0.0/16 -''' -internet = cyst_cfg.RouterConfig( - interfaces=[ - cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.193"), cyst_cfg.IPNetwork("213.47.23.192/26"), index=0) - ], - routing_table=[ - cyst_cfg.RouteConfig(cyst_cfg.IPNetwork("192.168.0.0/16"), 0) - ], - traffic_processors=[], - id="internet" -) - -''' -------------------------------------------------------------------------------------------------------------------- -Outside node - -- A machine that sits in the internet, controlled by the attacker, used for data exfiltration. -''' -outside_node = cyst_cfg.NodeConfig( - active_services=[], - passive_services=[ - cyst_cfg.PassiveServiceConfig( - type="bash", - owner="root", - version="5.0.0", - local=True, - access_level=cyst_cfg.AccessLevel.LIMITED - ), - cyst_cfg.PassiveServiceConfig( - type="listener", - owner="attacker", - version="1.0.0", - local=False, - access_level=cyst_cfg.AccessLevel.ELEVATED - ) - ], - traffic_processors=[], - interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.195/26"))], - shell="bash", - id="outside_node" -) - -''' -------------------------------------------------------------------------------------------------------------------- -Connections -''' -connections = [ - cyst_cfg.ConnectionConfig("smb_server", 0, "router1", 0), - cyst_cfg.ConnectionConfig("client_1", 0, "router1", 5), - cyst_cfg.ConnectionConfig("internet", 0, "router1", 10), - cyst_cfg.ConnectionConfig("internet", 1, "outside_node", 0) -] - -''' -------------------------------------------------------------------------------------------------------------------- -Exploits -- There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed... -''' -exploits = [ - cyst_cfg.ExploitConfig( - services=[ - cyst_cfg.VulnerableServiceConfig( - name="lanman server", - min_version="10.0.19041", - max_version="10.0.19041" - ) - ], - locality=cyst_cfg.ExploitLocality.REMOTE, - category=cyst_cfg.ExploitCategory.DATA_MANIPULATION, - id="smb_exploit" - ) -] - -configuration_objects = [smb_server, client_1, router1, internet, outside_node, *connections, *exploits] \ No newline at end of file diff --git a/env/worlds/aidojo_world.py b/env/worlds/aidojo_world.py deleted file mode 100644 index 4650710e..00000000 --- a/env/worlds/aidojo_world.py +++ /dev/null @@ -1,107 +0,0 @@ -# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz -# Template of world to be used in AI Dojo -import sys -import os -import asyncio - -sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) -import logging -from utils.utils import ConfigParser -from env.game_components import GameState, Action, GameStatus, ActionType - -""" -Basic class for worlds to be used in the AI Dojo. -Every world (environment) used in AI Dojo should extend this class and implement -all its methods to be compatible with the game server and game coordinator. -""" -class AIDojoWorld(object): - def __init__(self, task_config_file:str,action_queue:asyncio.Queue, response_queue:asyncio.Queue, world_name:str="BasicAIDojoWorld")->None: - self.task_config = ConfigParser(task_config_file) - self.logger = logging.getLogger(world_name) - self._action_queue = action_queue - self._response_queue = response_queue - self._world_name = world_name - - @property - def world_name(self)->str: - return self._world_name - - def step(self, current_state:GameState, action:Action, agent_id:tuple)-> GameState: - """ - Executes given action in a current state of the environment and produces new GameState. - """ - raise NotImplementedError - - def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState: - """ - Produces a GameState based on the view of the world. - """ - raise NotImplementedError - - def reset()->None: - """ - Resets the world to its initial state. - """ - raise NotImplementedError - - def update_goal_descriptions(self, goal_description:str)->str: - """ - Takes the existing goal description (text) and updates it with respect to the world. - """ - raise NotImplementedError - - def update_goal_dict(self, goal_dict:dict)->dict: - """ - Takes the existing goal dict and updates it with respect to the world. - """ - raise NotImplementedError - - async def handle_incoming_action(self)->None: - """ - Asynchronously handles incoming actions from agents and processes them accordingly. - - This method continuously listens for actions from the `_action_queue`, processes them based on their type, - and sends the appropriate response to the `_response_queue`. It handles different types of actions such as - joining a game, quitting a game, and resetting the game. For other actions, it updates the game state by - calling the `step` method. - - Raises: - asyncio.CancelledError: If the task is cancelled, it logs the termination message. - - Action Types: - - ActionType.JoinGame: Creates a new game state and sends a CREATED status. - - ActionType.QuitGame: Sends an OK status with an empty game state. - - ActionType.ResetGame: Resets the world if the agent is "world", otherwise resets the game state and sends a RESET_DONE status. - - Other: Updates the game state using the `step` method and sends an OK status. - - Logging: - - Logs the start of the task. - - Logs received actions and game states from agents. - - Logs the messages being sent to agents. - - Logs termination due to `asyncio.CancelledError`. - """ - try: - self.logger.info(f"\tStaring {self.world_name} task.") - while True: - agent_id, action, game_state = await self._action_queue.get() - self.logger.debug(f"Received from{agent_id}: {action}, {game_state}.") - match action.type: - case ActionType.JoinGame: - msg = (agent_id, (self.create_state_from_view(game_state), GameStatus.CREATED)) - case ActionType.QuitGame: - msg = (agent_id, (GameState(),GameStatus.OK)) - case ActionType.ResetGame: - if agent_id == "world": #reset the world - self.reset() - continue - else: - msg = (agent_id, (self.create_state_from_view(game_state), GameStatus.RESET_DONE)) - case _: - new_state = self.step(game_state, action,agent_id) - msg = (agent_id, (new_state, GameStatus.OK)) - # new_state = self.step(state, action, agent_id) - self.logger.debug(f"Sending to {agent_id}: {msg}") - await self._response_queue.put(msg) - await asyncio.sleep(0) - except asyncio.CancelledError: - self.logger.info(f"\t{self.world_name} Terminating by CancelledError") \ No newline at end of file diff --git a/env/worlds/cyst_wrapper.py b/env/worlds/cyst_wrapper.py deleted file mode 100644 index c28f5797..00000000 --- a/env/worlds/cyst_wrapper.py +++ /dev/null @@ -1,54 +0,0 @@ -# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz - -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -import game_components as components -from worlds.aidojo_world import AIDojoWorld - -class CYSTWrapper(AIDojoWorld): - """ - Class for connection CYST with the coordinator of AI Dojo - """ - def __init__(self, task_config_file, world_name="CYST") -> None: - super().__init__(task_config_file, world_name) - self.logger.info("Initializing CYST environment") - - - def step(self, current_state:components.GameState, action:components.Action, agent_id:tuple)-> components.GameState: - """ - Executes given action in a current state of the environment and produces new GameState. - """ - raise NotImplementedError - - def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->components.GameState: - """ - Produces a GameState based on the view of the world. - """ - raise NotImplementedError - - def reset()->None: - """ - Resets the world to its initial state. - """ - raise NotImplementedError - - def update_goal_descriptions(self, goal_description:str)->str: - """ - Takes the existing goal description (text) and updates it with respect to the world. - """ - raise NotImplementedError - - def update_goal_dict(self, goal_dict:dict)->dict: - """ - Takes the existing goal dict and updates it with respect to the world. - """ - raise NotImplementedError - - -if __name__ == "__main__": - cyst_wrapper = CYSTWrapper("env/netsecenv_conf.yaml") - objects = cyst_wrapper.task_config.get_scenario() - print(objects) - #e = Environment.create().configure(target, attacker, router, exploit1, connection1, connection2) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f4a78157 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["tests*"] + +[project] +name = "AIDojoGameCoordinator" +version = "0.1.0" +description = "A package for coordinating AI-driven network simulation games." +readme = "README.md" +license = { file = "LICENSE" } +authors = [ + { name = "Ondrej Lukas", email = "ondrej.lukas@aic.fel.cvut.cz" }, + { name = "Sebastian Garcia", email = "sebastian.garcia@agents.fel.cvut.cz" }, + { name = "Maria Rigaki", email = "maria.rigaki@aic.fel.cvut.cz" } +] +dependencies = [ + "aiohttp==3.11.8", + "attrs==23.2.0", + "beartype==0.19.0", + "cachetools==5.5.0", + "casefy==0.1.7", + "cyst==0.3.4", + "dictionaries==0.0.2", + "Faker==23.2.1", + "Jinja2==3.1.4", + "jsonlines==4.0.0", + "jsonpickle==3.3.0", + "kaleido==0.2.1", + "MarkupSafe==3.0.2", + "matplotlib==3.9.1", + "netaddr==1.3.0", + "networkx==3.4.2", + "numpy==1.26.4", + "pandas==2.2.2", + "plotly==5.22.0", + "pyserde==0.21.0", + "python-dateutil==2.8.2", + "PyYAML==6.0.1", + "redis==3.5.3", + "requests==2.32.3", + "scikit-learn==1.5.1", + "scipy==1.14.0", + "tenacity==8.5.0", + "typing-inspect==0.9.0", + "typing_extensions==4.12.2" +] +requires-python = ">=3.12" + +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] \ No newline at end of file diff --git a/tests/manual/three_nets/test_three_net_scenario.py b/tests/manual/three_nets/test_three_net_scenario.py index eff14236..d1bec3ea 100644 --- a/tests/manual/three_nets/test_three_net_scenario.py +++ b/tests/manual/three_nets/test_three_net_scenario.py @@ -5,7 +5,7 @@ PATH = path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) )) sys.path.append(path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) ))) from NetSecGameAgents.agents import base_agent -from env.game_components import Action, ActionType, IP, Network, Service, Data +from AIDojoCoordinator.game_components import Action, ActionType, IP, Network, Service, Data if __name__ == "__main__": diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 19b049e1..348288f7 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -3,8 +3,9 @@ # run all unit tests, -n *5 means distribute tests on 5 different process # -s to see print statements as they are executed -python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace +#python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace +python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace # Coordinator tesst #python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace diff --git a/tests/test_actions.py b/tests/test_actions.py index 50c81a54..443c80bf 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -5,8 +5,8 @@ import sys from os import path sys.path.append( path.dirname(path.dirname( path.abspath(__file__) ) )) -from env.worlds.network_security_game import NetworkSecurityEnvironment -import env.game_components as components +from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment +import AIDojoCoordinator.game_components as components import pytest # Fixture are used to hold the current state and the environment diff --git a/tests/test_components.py b/tests/test_components.py index a30988f5..140d38f4 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -2,11 +2,8 @@ Tests related to the game components in the Network Security Game Environment Author: Maria Rigaki - maria.rigaki@fel.cvut.cz """ -import sys import json -from os import path -sys.path.append( path.dirname(path.dirname( path.abspath(__file__) ) )) -from env.game_components import ActionType, Action, IP, Data, Network, Service, GameState, AgentInfo +from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service, GameState, AgentInfo class TestComponentsIP: """ @@ -197,7 +194,7 @@ def test_create_find_data(self): """ Test the creation of the FindData action """ - action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"),"target_host":IP("192.168.12.12")}) + action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"),"target_host":IP("192.168.12.12")}) assert action.type == ActionType.FindData assert action.parameters["target_host"] == IP("192.168.12.12") assert action.parameters["source_host"] == IP("192.168.12.12") @@ -206,14 +203,14 @@ def test_create_find_data_str(self): """ Test the string representation of the FindData action """ - action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")}) + action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")}) assert str(action) == "Action " def test_create_find_data_repr(self): """ Test the repr of the FindData action """ - action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")}) + action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")}) assert repr(action) == "Action " def test_action_find_services(self): @@ -221,7 +218,7 @@ def test_action_find_services(self): Test the creation of the FindServices action """ action = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"), "target_host":IP("192.168.12.12")}) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("192.168.12.12")}) assert action.type == ActionType.FindServices assert action.parameters["target_host"] == IP("192.168.12.12") assert action.parameters["source_host"] == IP("192.168.12.11") @@ -231,7 +228,7 @@ def test_action_scan_network(self): Test the creation of the ScanNetwork action """ action = Action(action_type=ActionType.ScanNetwork, - params={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)}) + parameters={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)}) assert action.type == ActionType.ScanNetwork assert action.parameters["target_network"] == Network("172.16.1.12", 24) assert action.parameters["source_host"] == IP("192.168.12.11") @@ -241,7 +238,7 @@ def test_action_exploit_services(self): Test the creation of the ExploitService action """ action = Action(action_type=ActionType.ExploitService, - params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.12"), + parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.12"), "target_service":Service("ssh", "passive", "0.23", False)}) assert action.type == ActionType.ExploitService assert action.parameters["target_host"] == IP("172.16.1.12") @@ -255,19 +252,19 @@ def test_action_equal(self): Test that two actions with the same parameters are equal """ action = Action(action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")}) + parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")}) action2 = Action(action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}) + parameters={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}) assert action == action2 - def test_action_equal_params_order(self): + def test_action_equal_parameters_order(self): """ Test that two actions with the same parameters are equal """ action = Action(action_type=ActionType.ExploitService, - params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11"),"target_service": Service("ssh", "passive", "0.23", False)}) + parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11"),"target_service": Service("ssh", "passive", "0.23", False)}) action2 = Action(action_type=ActionType.ExploitService, - params={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}) + parameters={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}) assert action == action2 def test_action_not_equal_different_target(self): @@ -275,9 +272,9 @@ def test_action_not_equal_different_target(self): Test that two actions with different parameters are not equal """ action = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) action2 = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"), "target_host":IP("172.15.1.22")}) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.15.1.22")}) assert action != action2 def test_action_not_equal_different_source(self): @@ -285,9 +282,9 @@ def test_action_not_equal_different_source(self): Test that two actions with different parameters are not equal """ action = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.12"), "target_host":IP("172.16.1.22")}) + parameters={"source_host":IP("192.168.12.12"), "target_host":IP("172.16.1.22")}) action2 = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) assert action != action2 def test_action_not_equal_different_action_type(self): @@ -295,27 +292,27 @@ def test_action_not_equal_different_action_type(self): Test that two actions with different parameters are not equal """ action = Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}) + parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}) action2 = Action(action_type=ActionType.FindData, - params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}) + parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}) assert action != action2 def test_action_hash(self): action = Action( action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")} + parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")} ) action2 = Action( action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")} + parameters={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")} ) action3 = Action( action_type=ActionType.FindServices, - params={"target_host":IP("172.16.13.48"), "source_host":IP("192.168.12.11")} + parameters={"target_host":IP("172.16.13.48"), "source_host":IP("192.168.12.11")} ) action4 = Action( action_type=ActionType.FindData, - params={"target_host":IP("172.16.1.25"), "source_host":IP("192.168.12.11")} + parameters={"target_host":IP("172.16.1.25"), "source_host":IP("192.168.12.11")} ) assert hash(action) == hash(action2) assert hash(action) != hash(action3) @@ -324,31 +321,31 @@ def test_action_hash(self): def test_action_set_member(self): action_set = set() action_set.add(Action(action_type=ActionType.FindServices, - params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})) + parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})) action_set.add(Action(action_type=ActionType.FindData, - params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")})) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")})) action_set.add(Action(action_type=ActionType.ExploitService, - params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})) + parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})) action_set.add(Action(action_type=ActionType.ScanNetwork, - params={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)})) - action_set.add(Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"), + parameters={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)})) + action_set.add(Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")})) - assert Action(action_type=ActionType.FindServices, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) in action_set - assert Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}) in action_set - assert Action(action_type=ActionType.ExploitService, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})in action_set - #reverse params order - assert Action(action_type=ActionType.ExploitService, params={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.24"), "source_host":IP("192.168.12.11")})in action_set - assert Action(action_type=ActionType.ScanNetwork, params={"target_network":Network("172.16.1.12", 24), "source_host":IP("192.168.12.11")}) in action_set - assert Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}) in action_set - #reverse params orders - assert Action(action_type=ActionType.ExfiltrateData, params={"source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.3"), "data":Data("User2", "PublicKey")}) in action_set + assert Action(action_type=ActionType.FindServices, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) in action_set + assert Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}) in action_set + assert Action(action_type=ActionType.ExploitService, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})in action_set + #reverse parameters order + assert Action(action_type=ActionType.ExploitService, parameters={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.24"), "source_host":IP("192.168.12.11")})in action_set + assert Action(action_type=ActionType.ScanNetwork, parameters={"target_network":Network("172.16.1.12", 24), "source_host":IP("192.168.12.11")}) in action_set + assert Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}) in action_set + #reverse parameters orders + assert Action(action_type=ActionType.ExfiltrateData, parameters={"source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.3"), "data":Data("User2", "PublicKey")}) in action_set - def test_action_as_json(self): + def test_action_to_json(self): # Scan Network action = Action(action_type=ActionType.ScanNetwork, - params={"target_network":Network("172.16.1.12", 24)}) - action_json = action.as_json() + parameters={"target_network":Network("172.16.1.12", 24)}) + action_json = action.to_json() try: data = json.loads(action_json) except ValueError: @@ -359,8 +356,8 @@ def test_action_as_json(self): # Find services action = Action(action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22")}) - action_json = action.as_json() + parameters={"target_host":IP("172.16.1.22")}) + action_json = action.to_json() try: data = json.loads(action_json) except ValueError: @@ -371,8 +368,8 @@ def test_action_as_json(self): # Find Data action = Action(action_type=ActionType.FindData, - params={"target_host":IP("172.16.1.22")}) - action_json = action.as_json() + parameters={"target_host":IP("172.16.1.22")}) + action_json = action.to_json() try: data = json.loads(action_json) except ValueError: @@ -383,8 +380,8 @@ def test_action_as_json(self): # Exploit Service action = Action(action_type=ActionType.ExploitService, - params={"target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)}) - action_json = action.as_json() + parameters={"target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)}) + action_json = action.to_json() try: data = json.loads(action_json) except ValueError: @@ -395,9 +392,9 @@ def test_action_as_json(self): "target_service":{"name":"ssh", "type":"passive", "version":"0.23", "is_local":False}}) in data.items() # Exfiltrate Data - action = Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"), + action = Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey", size=42, type="pub")}) - action_json = action.as_json() + action_json = action.to_json() try: data = json.loads(action_json) except ValueError: @@ -410,74 +407,74 @@ def test_action_as_json(self): def test_action_scan_network_serialization(self): action = Action(action_type=ActionType.ScanNetwork, - params={"target_network":Network("172.16.1.12", 24),"source_host": IP("172.16.1.2") }) - action_json = action.as_json() + parameters={"target_network":Network("172.16.1.12", 24),"source_host": IP("172.16.1.2") }) + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_find_services_serialization(self): action = Action(action_type=ActionType.FindServices, - params={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")}) - action_json = action.as_json() + parameters={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")}) + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_find_data_serialization(self): action = Action(action_type=ActionType.FindData, - params={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")}) - action_json = action.as_json() + parameters={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")}) + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_exploit_service_serialization(self): action = Action(action_type=ActionType.ExploitService, - params={"source_host": IP("172.16.1.2"), + parameters={"source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)}) - action_json = action.as_json() + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_exfiltrate_serialization(self): - action = Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"), + action = Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}) - action_json = action.as_json() + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_exfiltrate_join_game(self): action = Action( action_type=ActionType.JoinGame, - params={ + parameters={ "agent_info": AgentInfo(name="TestingAgent", role="attacker"), } ) - action_json = action.as_json() + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_exfiltrate_reset_game(self): action = Action( action_type=ActionType.ResetGame, - params={} + parameters={} ) - action_json = action.as_json() + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_exfiltrate_quit_game(self): action = Action( action_type=ActionType.QuitGame, - params={} + parameters={} ) - action_json = action.as_json() + action_json = action.to_json() new_action = Action.from_json(action_json) assert action == new_action def test_action_to_dict_scan_network(self): action = Action( action_type=ActionType.ScanNetwork, - params={ + parameters={ "target_network":Network("172.16.1.12", 24), "source_host": IP("172.16.1.2") } @@ -485,14 +482,14 @@ def test_action_to_dict_scan_network(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["target_network"] == "172.16.1.12/24" - assert action_dict["params"]["source_host"] == "172.16.1.2" + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["target_network"] == {'ip': '172.16.1.12', 'mask': 24} + assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'} def test_action_to_dict_find_services(self): action = Action( action_type=ActionType.FindServices, - params={ + parameters={ "target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2") } @@ -500,14 +497,14 @@ def test_action_to_dict_find_services(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["target_host"] == "172.16.1.22" - assert action_dict["params"]["source_host"] == "172.16.1.2" + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.22'} + assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'} def test_action_to_dict_find_data(self): action = Action( action_type=ActionType.FindData, - params={ + parameters={ "target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2") } @@ -515,14 +512,14 @@ def test_action_to_dict_find_data(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["target_host"] == "172.16.1.22" - assert action_dict["params"]["source_host"] == "172.16.1.2" + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.22'} + assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'} def test_action_to_dict_exploit_service(self): action = Action( action_type=ActionType.ExploitService, - params={ + parameters={ "source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False) @@ -531,18 +528,18 @@ def test_action_to_dict_exploit_service(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["target_host"] == "172.16.1.24" - assert action_dict["params"]["source_host"] == "172.16.1.2" - assert action_dict["params"]["target_service"]["name"] == "ssh" - assert action_dict["params"]["target_service"]["type"] == "passive" - assert action_dict["params"]["target_service"]["version"] == "0.23" - assert action_dict["params"]["target_service"]["is_local"] is False + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.24'} + assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'} + assert action_dict["parameters"]["target_service"]["name"] == "ssh" + assert action_dict["parameters"]["target_service"]["type"] == "passive" + assert action_dict["parameters"]["target_service"]["version"] == "0.23" + assert action_dict["parameters"]["target_service"]["is_local"] is False def test_action_to_dict_exfiltrate_data(self): action = Action( action_type=ActionType.ExfiltrateData, - params={ + parameters={ "target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey") @@ -551,16 +548,16 @@ def test_action_to_dict_exfiltrate_data(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["target_host"] == "172.16.1.3" - assert action_dict["params"]["source_host"] == "172.16.1.2" - assert action_dict["params"]["data"]["owner"] == "User2" - assert action_dict["params"]["data"]["id"] == "PublicKey" + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.3'} + assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'} + assert action_dict["parameters"]["data"]["owner"] == "User2" + assert action_dict["parameters"]["data"]["id"] == "PublicKey" def test_action_to_dict_join_game(self): action = Action( action_type=ActionType.JoinGame, - params={ + parameters={ "agent_info": AgentInfo(name="TestingAgent", role="attacker"), "source_host": IP("172.16.1.2") } @@ -568,31 +565,31 @@ def test_action_to_dict_join_game(self): action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert action_dict["params"]["agent_info"]["name"] == "TestingAgent" - assert action_dict["params"]["agent_info"]["role"] == "attacker" + assert action_dict["action_type"] == str(action.type) + assert action_dict["parameters"]["agent_info"]["name"] == "TestingAgent" + assert action_dict["parameters"]["agent_info"]["role"] == "attacker" def test_action_to_dict_reset_game(self): action = Action( action_type=ActionType.ResetGame, - params={} + parameters={} ) action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert len(action_dict["params"]) == 0 + assert action_dict["action_type"] == str(action.type) + assert len(action_dict["parameters"]) == 0 def test_action_to_dict_quit_game(self): action = Action( action_type=ActionType.QuitGame, - params={} + parameters={} ) action_dict = action.as_dict new_action = Action.from_dict(action_dict) assert action == new_action - assert action_dict["type"] == str(action.type) - assert len(action_dict["params"]) == 0 + assert action_dict["action_type"] == str(action.type) + assert len(action_dict["parameters"]) == 0 class TestGameState: """ diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 63a36533..965f4e16 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -236,8 +236,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock -from coordinator import Coordinator, AgentStatus, Action, ActionType -from env.game_components import AgentInfo, Network, IP +from AIDojoCoordinator.coordinator import Coordinator, AgentStatus, Action, ActionType +from AIDojoCoordinator.game_components import AgentInfo, Network, IP CONFIG_FILE = "tests/netsecenv-task-for-testing.yaml" ALLOWED_ROLES = ["Attacker", "Defender", "Benign"] diff --git a/tests/test_game_coordinator.py b/tests/test_game_coordinator.py new file mode 100644 index 00000000..f2f75b66 --- /dev/null +++ b/tests/test_game_coordinator.py @@ -0,0 +1,133 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock, Mock +from AIDojoCoordinator.coordinator import AgentServer, GameCoordinator + +def test_game_coordinator_initialization(): + """ + Test that the GameCoordinator is initialized correctly with the expected properties. + """ + # Test input + game_host = "localhost" + game_port = 8000 + service_host = "localhost" + service_port = 8080 + allowed_roles = ["Attacker", "Defender", "Benign"] + task_config_file = "test_config.json" + + # Create an instance of GameCoordinator + coordinator = GameCoordinator( + game_host=game_host, + game_port=game_port, + service_host=service_host, + service_port=service_port, + allowed_roles=allowed_roles, + task_config_file=task_config_file, + ) + + # Assertions for basic initialization + assert coordinator.host == game_host, "GameCoordinator host should be set correctly." + assert coordinator.port == game_port, "GameCoordinator port should be set correctly." + assert coordinator._service_host == service_host, "Service host should be set correctly." + assert coordinator._service_port == service_port, "Service port should be set correctly." + assert coordinator.ALLOWED_ROLES == allowed_roles, "Allowed roles should match the input." + assert coordinator._task_config_file == task_config_file, "Task config file should match the input." + + # Assertions for events and locks + assert isinstance(coordinator.shutdown_flag, asyncio.Event), "shutdown_flag should be an asyncio.Event." + assert isinstance(coordinator._reset_event, asyncio.Event), "reset_event should be an asyncio.Event." + assert isinstance(coordinator._episode_end_event, asyncio.Event), "episode_end_event should be an asyncio.Event." + assert isinstance(coordinator._reset_lock, asyncio.Lock), "reset_lock should be an asyncio.Lock." + assert isinstance(coordinator._agents_lock, asyncio.Lock), "agents_lock should be an asyncio.Lock." + assert isinstance(coordinator._semaphore, asyncio.Semaphore), "semaphore should be an asyncio.Semaphore." + + # Assertions for agent-related data structures + assert isinstance(coordinator._agent_action_queue, asyncio.Queue), "agent_action_queue should be an asyncio.Queue." + assert isinstance(coordinator._agent_response_queues, dict), "agent_response_queues should be a dictionary." + assert isinstance(coordinator.agents, dict), "agents should be a dictionary." + assert isinstance(coordinator._agent_steps, dict), "agent_steps should be a dictionary." + assert isinstance(coordinator._reset_requests, dict), "reset_requests should be a dictionary." + assert isinstance(coordinator._agent_status, dict), "agent_status should be a dictionary." + assert isinstance(coordinator._agent_observations, dict), "agent_observations should be a dictionary." + assert isinstance(coordinator._agent_rewards, dict), "agent_rewards should be a dictionary." + assert isinstance(coordinator._agent_trajectories, dict), "agent_trajectories should be a dictionary." + + # Assertions for tasks + assert isinstance(coordinator._tasks, set), "tasks should be a set." + + # Assertions for configuration + assert coordinator._cyst_objects is None, "cyst_objects should be None at initialization." + assert coordinator._cyst_object_string is None, "cyst_object_string should be None at initialization." + + # Assertions for logging + assert coordinator.logger.name == "AIDojo-GameCoordinator", "Logger should be initialized with the correct name." + + +def test_starting_positions(): + """Test that starting positions are correctly initialized for each role.""" + coordinator = GameCoordinator( + game_host="localhost", + game_port=8000, + service_host="localhost", + service_port=8080, + allowed_roles=["Attacker", "Defender", "Benign"], + ) + coordinator.task_config = Mock() # Mock the task_config + coordinator.task_config.get_start_position.side_effect = lambda agent_role: {"x": 0, "y": 0} + + starting_positions = coordinator._get_starting_position_per_role() + + assert starting_positions["Attacker"] == {"x": 0, "y": 0} + assert starting_positions["Defender"] == {"x": 0, "y": 0} + assert starting_positions["Benign"] == {"x": 0, "y": 0} + + + +# Agent Server +def test_agent_server_initialization(): + """ + Test that the AgentServer is initialized correctly with the expected attributes. + """ + # Test inputs + actions_queue = asyncio.Queue() + agent_response_queues = {} + max_connections = 5 + + # Create an instance of AgentServer + server = AgentServer(actions_queue, agent_response_queues, max_connections) + + # Assertions for basic attributes + assert server.actions_queue is actions_queue, "AgentServer's actions_queue should be set correctly." + assert server.answers_queues is agent_response_queues, "AgentServer's answers_queues should be set correctly." + assert server.max_connections == max_connections, "AgentServer's max_connections should match the input." + assert server.current_connections == 0, "AgentServer's current_connections should be initialized to 0." + + # Assertions for logging + assert server.logger.name == "AIDojo-AgentServer", "Logger should be initialized with the correct name." + + +@pytest.mark.asyncio +async def test_handle_new_agent_max_connections(): + """Test that a new agent connection is rejected when max_connections is reached.""" + # Test setup + actions_queue = asyncio.Queue() + agent_response_queues = {} + max_connections = 1 + server = AgentServer(actions_queue, agent_response_queues, max_connections) + + # Mock reader and writer + reader_mock = AsyncMock() + writer_mock = Mock() + writer_mock.get_extra_info.return_value = ("127.0.0.1", 12345) + + # Simulate max connections + server.current_connections = max_connections + + # Run handle_new_agent + await server.handle_new_agent(reader_mock, writer_mock) + + # Assertions + assert server.current_connections == max_connections, "Connection count should remain unchanged." + assert ("127.0.0.1", 12345) not in agent_response_queues, "Queue should not be created for rejected agent." + writer_mock.write.assert_not_called(), "No data should be sent to the rejected agent." + writer_mock.close.assert_called_once(), "Connection should be closed for the rejected agent."