diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index 523987c4..69cf9977 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -4,11 +4,10 @@ import asyncio from datetime import datetime import signal -from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus +from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig from AIDojoCoordinator.global_defender import GlobalDefender from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser import os - from aiohttp import ClientSession from cyst.api.environment.environment import Environment @@ -52,7 +51,7 @@ async def send_data_to_agent(writer, data: str) -> None: try: while True: # Step 1: Read data from the agent - data = await reader.read(500) + data = await reader.read(ProtocolConfig.BUFFER_SIZE) if not data: self.logger.info(f"Agent {addr} disconnected.") quit_message = Action(ActionType.QuitGame, parameters={}).to_json() @@ -64,14 +63,15 @@ async def send_data_to_agent(writer, data: str) -> None: # Step 2: Forward the message to the Coordinator await self.actions_queue.put((addr, raw_message)) - # await asyncio.sleep(0) + # await asyncio.sleep(0)w # Step 3: Get a matching response from the answers queue response_queue = self.answers_queues[addr] response = await response_queue.get() self.logger.info(f"Sending response to agent {addr}: {response}") # Step 4: Send the response to the agent - writer.write(bytes(str(response).encode())) + response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE + writer.write(response) await writer.drain() except asyncio.CancelledError: self.logger.debug("Terminating by KeyboardInterrupt") @@ -382,7 +382,7 @@ async def run_game(self): self._spawn_task(self._process_quit_game_action, agent_addr) case ActionType.ResetGame: self.logger.debug(f"Start processing of ActionType.ResetGame by {agent_addr}") - self._spawn_task(self._process_reset_game_action, agent_addr) + self._spawn_task(self._process_reset_game_action, agent_addr, action) case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService: self.logger.debug(f"Start processing of {action.type} by {agent_addr}") self._spawn_task(self._process_game_action, agent_addr, action) @@ -468,7 +468,7 @@ async def _process_quit_game_action(self, agent_addr: tuple)->None: finally: self.logger.debug(f"Cleaning up after QuitGame for {agent_addr}.") - async def _process_reset_game_action(self, agent_addr: tuple)->None: + async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Action)->None: """ Method for processing Action of type ActionType.ResetGame Inputs: @@ -499,6 +499,10 @@ async def _process_reset_game_action(self, agent_addr: tuple)->None: "configuration_hash": self._CONFIG_FILE_HASH }, } + # extend the message with last trajectory + if "request_trajectory" in reset_action.parameters and reset_action.parameters["request_trajectory"]: + output_message_dict["message"]["last_trajectory"] = self._agent_trajectories[agent_addr] + self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) response_msg_json = self.convert_msg_dict_to_json(output_message_dict) await self._agent_response_queues[agent_addr].put(response_msg_json) @@ -548,9 +552,9 @@ async def _process_game_action(self, agent_addr: tuple, action:Action)->None: async with self._episode_rewards_condition: await self._episode_rewards_condition.wait() # append step to the trajectory if needed - if self.task_config.get_store_trajectories() or self._global_defender: - async with self._agents_lock: - self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None) + + async with self._agents_lock: + self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None) # add information to 'info' field if needed info = {} if self._agent_status[agent_addr] not in [AgentStatus.Playing, AgentStatus.PlayingWithTimeout]: @@ -645,8 +649,8 @@ async def _reset_game(self): self._agent_status[agent] = AgentStatus.PlayingWithTimeout else: self._agent_status[agent] = AgentStatus.Playing - if self.task_config.get_store_trajectories() or self._global_defender: - self._agent_trajectories[agent] = self._reset_trajectory(agent) + # if self.task_config.get_store_trajectories() or self._global_defender: + # self._agent_trajectories[agent] = self._reset_trajectory(agent) self._reset_event.clear() # notify all waiting agents async with self._reset_done_condition: @@ -670,8 +674,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout else: self._agent_status[agent_addr] = AgentStatus.Playing - if self.task_config.get_store_trajectories() or self._global_defender: - self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) + self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr) self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}") return Observation(self._agent_states[agent_addr], 0, False, {}) diff --git a/AIDojoCoordinator/docs/Components.md b/AIDojoCoordinator/docs/Components.md index f1da6d14..1ac63276 100644 --- a/AIDojoCoordinator/docs/Components.md +++ b/AIDojoCoordinator/docs/Components.md @@ -53,7 +53,7 @@ Actions are the objects sent by the agents to the environment. Each action is ev In all cases, when an agent sends an action to AIDojo, it is given a response. ### Action format -The Action class is defined in `env.game_components.py`. It has two basic parts: +The Action class is defined in `game_components.py`. It has two basic parts: 1. ActionType:Enum 2. parameters:dict @@ -61,7 +61,7 @@ ActionType is unique Enum that determines what kind of action is agent playing. ### List of actions - **JoinGame**, params={`agent_info`:AgentInfo(\, \)}: Used to register agent in a game with a given \. - **QuitGame**, params={}: Used for termination of agent's interaction. -- **ResetGame**, params={}: Used for requesting reset of the game to it's initial position. +- **ResetGame**, params={`request_trajectory`:`bool`}: Used for requesting reset of the game to it's initial position. If `request_trajectory = True`, the coordinator will send back the complete trajectory of the previous run in the next message. --- - **ScanNetwork**, params{`source_host`:\, `target_network`:\}: Scans the given \ from a specified source host. Discovers ALL hosts in a network that are accessible from \. If successful, returns set of discovered \ objects. - **FindServices**, params={`source_host`:\, `target_host`:\}: Used to discover ALL services running in the `target_host` if the host is accessible from `source_host`. If successful, returns a set of all discovered \ objects. diff --git a/AIDojoCoordinator/game_components.py b/AIDojoCoordinator/game_components.py index 2c4acab2..bab5326c 100755 --- a/AIDojoCoordinator/game_components.py +++ b/AIDojoCoordinator/game_components.py @@ -237,6 +237,8 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": params[k] = Data.from_dict(v) case "agent_info": params[k] = AgentInfo.from_dict(v) + case "request_trajectory": + params[k] = bool(v) case _: raise ValueError(f"Unsupported value in {k}: {v}") return cls(action_type=action_type, parameters=params) @@ -470,4 +472,9 @@ def from_string(cls, name): try: return cls[name] except KeyError: - raise ValueError(f"Invalid AgentStatus: {name}") \ No newline at end of file + raise ValueError(f"Invalid AgentStatus: {name}") + +@dataclass(frozen=True) +class ProtocolConfig: + END_OF_MESSAGE = b"EOF" + BUFFER_SIZE = 8192 \ No newline at end of file