diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index dd14ad29..40e771c2 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -1,4 +1,3 @@ -import jsonlines import logging import json import asyncio @@ -7,7 +6,7 @@ from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig from AIDojoCoordinator.global_defender import GlobalDefender -from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser +from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser, store_trajectories_to_jsonl import os from aiohttp import ClientSession from cyst.api.environment.environment import Environment @@ -136,8 +135,44 @@ async def __call__(self, reader, writer): class GameCoordinator: """ Class for creation, and management of agent interactions in AI Dojo. + + Attributes: + host (str): Host address for the game server. + port (int): Port number for the game server. + logger (logging.Logger): Logger for the GameCoordinator. + _tasks (set): Set of active asyncio tasks. + shutdown_flag (asyncio.Event): Event to signal shutdown. + _reset_event (asyncio.Event): Event to signal game reset. + _episode_end_event (asyncio.Event): Event to signal episode end. + _episode_start_event (asyncio.Event): Event to signal episode start. + _episode_rewards_condition (asyncio.Condition): Condition for episode rewards assignment. + _reset_done_condition (asyncio.Condition): Condition for reset completion. + _reset_lock (asyncio.Lock): Lock for reset operations. + _agents_lock (asyncio.Lock): Lock for agent operations. + _service_host (str): Host for remote configuration service. + _service_port (int): Port for remote configuration service. + _task_config_file (str): Path to local task configuration file. + ALLOWED_ROLES (list): List of allowed agent roles. + _cyst_objects: CYST simulator initialization objects. + _cyst_object_string: String representation of CYST objects. + _agent_action_queue (asyncio.Queue): Queue for agent actions. + _agent_response_queues (dict): Mapping of agent addresses to their response queues. + agents (dict): Mapping of agent addresses to their information. + _agent_steps (dict): Step counters per agent address. + _reset_requests (dict): Reset requests per agent address. + _randomize_topology_requests (dict): Topology randomization requests per agent address. + _agent_status (dict): Status of each agent. + _episode_ends (dict): Episode end flags per agent address. + _agent_observations (dict): Observations per agent address. + _agent_starting_position (dict): Starting positions per agent address. + _agent_states (dict): Current states per agent address. + _agent_goal_states (dict): Goal states per agent address. + _agent_last_action (dict): Last actions played by agents. + _agent_false_positives (dict): False positives per agent. + _agent_rewards (dict): Rewards per agent address. + _agent_trajectories (dict): Trajectories per agent address. """ - def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"], task_config_file:str=None) -> None: + def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, task_config_file:str,allowed_roles=["Attacker", "Defender", "Benign"]) -> None: self.host = game_host self.port = game_port self.logger = logging.getLogger("AIDojo-GameCoordinator") @@ -190,8 +225,6 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._agent_rewards = {} # trajectories per agent_addr self._agent_trajectories = {} - # false_positives per agent_addr - self._agent_false_positives = {} def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task: "Helper function to make sure all tasks are registered for proper termination" @@ -317,6 +350,7 @@ async def start_tcp_server(self): """ Starts TPC sever for the agent communication. """ + server = None try: self.logger.info("Starting the server listening for agents") server = await asyncio.start_server( @@ -337,8 +371,9 @@ async def start_tcp_server(self): except Exception as e: self.logger.error(f"TCP server failed: {e}") finally: - server.close() - await server.wait_closed() + if server: + server.close() + await server.wait_closed() self.logger.info("\tTCP server task stopped") async def start_tasks(self): @@ -421,32 +456,33 @@ async def run_game(self): agent_addr, message = await self._agent_action_queue.get() if message is not None: self.logger.info(f"Coordinator received from agent {agent_addr}: {message}.") + try: # Convert message to Action action = Action.from_json(message) self.logger.debug(f"\tConverted to: {action}.") + match action.type: # process action based on its type + case ActionType.JoinGame: + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}") + self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}") + self._spawn_task(self._process_join_game_action, agent_addr, action) + case ActionType.QuitGame: + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}") + self._spawn_task(self._process_quit_game_action, agent_addr) + case ActionType.ResetGame: + self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}") + self._spawn_task(self._process_reset_game_action, agent_addr, action) + case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: + self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") + self._spawn_task(self._process_game_action, agent_addr, action) + case ActionType.BlockIP: + self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") + self._spawn_task(self._process_game_action, agent_addr, action) + case _: + self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!") except Exception as e: self.logger.error( f"Error when converting msg to Action using Action.from_json():{e}, {message}" ) - match action.type: # process action based on its type - case ActionType.JoinGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.JoinGame by {agent_addr}") - self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}") - self._spawn_task(self._process_join_game_action, agent_addr, action) - case ActionType.QuitGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.QuitGame by {agent_addr}") - self._spawn_task(self._process_quit_game_action, agent_addr) - case ActionType.ResetGame: - self.logger.debug(f"About agent {agent_addr}. Start processing of ActionType.ResetGame by {agent_addr}") - self._spawn_task(self._process_reset_game_action, agent_addr, action) - case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: - self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") - self._spawn_task(self._process_game_action, agent_addr, action) - case ActionType.BlockIP: - self.logger.debug(f"About agent {agent_addr}. Start processing of {action.type} by {agent_addr}") - self._spawn_task(self._process_game_action, agent_addr, action) - case _: - self.logger.warning(f"About agent {agent_addr}. Unsupported action type: {action}!") self.logger.info("\tAction processing task stopped.") async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None: @@ -945,7 +981,7 @@ def _reset_trajectory(self, agent_addr:tuple)->dict: "agent_name":agent_name } - def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str=None)-> None: + def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str|None=None)-> None: """ Method for adding one step to the agent trajectory. """ @@ -958,16 +994,17 @@ def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, self._agent_trajectories[agent_addr]["end_reason"] = end_reason def _store_trajectory_to_file(self, agent_addr:tuple, location="./logs/trajectories")-> None: - if not os.path.exists(location): - os.makedirs(location) - self.logger.debug(f"Created directory for storing trajectories: {location}") - self.logger.debug(f"Storing Trajectory of {agent_addr}in file") - if agent_addr in self._agent_trajectories: - agent_name, agent_role = self.agents[agent_addr] - filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl") - with jsonlines.open(filename, "a") as writer: - writer.write(self._agent_trajectories[agent_addr]) - self.logger.info(f"Trajectory of {agent_addr} strored in {filename}") + """ + Method for storing the agent trajectory to a file. + """ + if agent_addr in self.agents: + agent_name, agent_role = self.agents[agent_addr] + filename =f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}" + trajectories = self._agent_trajectories[agent_addr] + store_trajectories_to_jsonl(trajectories, location, filename) + self.logger.info(f"Trajectories of {agent_addr} strored in {os.path.join(location, filename)}.jsonl") + else: + self.logger.warning(f"Agent {agent_addr} not found in agents list, can't store trajectory to file.") def is_agent_benign(self, agent_addr:tuple)->bool: """ diff --git a/AIDojoCoordinator/utils/utils.py b/AIDojoCoordinator/utils/utils.py index c92b030c..face37cf 100644 --- a/AIDojoCoordinator/utils/utils.py +++ b/AIDojoCoordinator/utils/utils.py @@ -9,6 +9,8 @@ import netaddr import logging import csv +import os +import jsonlines from random import randint import json import hashlib @@ -565,6 +567,22 @@ def get_starting_position_from_cyst_config(cyst_objects): starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks} return starting_positions +def store_trajectories_to_jsonl(trajectories:list, dir:str, filename:str)->None: + """ + Store trajectories to a JSONL file. + Args: + trajectories (list): List of trajectory data to store. + dir (str): Directory where the file will be stored. + filename (str): Name of the file (without extension). + """ + # make sure the directory exists + if not os.path.exists(dir): + os.makedirs(dir) + # construct the full file name + filename = os.path.join(dir, f"{filename.rstrip('jsonl')}.jsonl") + # store the trajectories + with jsonlines.open(filename, "a") as writer: + writer.write(trajectories) if __name__ == "__main__": state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)},