Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions AIDojoCoordinator/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import asyncio
from datetime import datetime
import signal
from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus
from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus, ProtocolConfig
from AIDojoCoordinator.global_defender import GlobalDefender
from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser
import os

from aiohttp import ClientSession
from cyst.api.environment.environment import Environment

Expand Down Expand Up @@ -52,7 +51,7 @@ async def send_data_to_agent(writer, data: str) -> None:
try:
while True:
# Step 1: Read data from the agent
data = await reader.read(500)
data = await reader.read(ProtocolConfig.BUFFER_SIZE)
if not data:
self.logger.info(f"Agent {addr} disconnected.")
quit_message = Action(ActionType.QuitGame, parameters={}).to_json()
Expand All @@ -64,14 +63,15 @@ async def send_data_to_agent(writer, data: str) -> None:

# Step 2: Forward the message to the Coordinator
await self.actions_queue.put((addr, raw_message))
# await asyncio.sleep(0)
# await asyncio.sleep(0)w
# Step 3: Get a matching response from the answers queue
response_queue = self.answers_queues[addr]
response = await response_queue.get()
self.logger.info(f"Sending response to agent {addr}: {response}")

# Step 4: Send the response to the agent
writer.write(bytes(str(response).encode()))
response = str(response).encode() + ProtocolConfig.END_OF_MESSAGE
writer.write(response)
await writer.drain()
except asyncio.CancelledError:
self.logger.debug("Terminating by KeyboardInterrupt")
Expand Down Expand Up @@ -382,7 +382,7 @@ async def run_game(self):
self._spawn_task(self._process_quit_game_action, agent_addr)
case ActionType.ResetGame:
self.logger.debug(f"Start processing of ActionType.ResetGame by {agent_addr}")
self._spawn_task(self._process_reset_game_action, agent_addr)
self._spawn_task(self._process_reset_game_action, agent_addr, action)
case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService:
self.logger.debug(f"Start processing of {action.type} by {agent_addr}")
self._spawn_task(self._process_game_action, agent_addr, action)
Expand Down Expand Up @@ -468,7 +468,7 @@ async def _process_quit_game_action(self, agent_addr: tuple)->None:
finally:
self.logger.debug(f"Cleaning up after QuitGame for {agent_addr}.")

async def _process_reset_game_action(self, agent_addr: tuple)->None:
async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Action)->None:
"""
Method for processing Action of type ActionType.ResetGame
Inputs:
Expand Down Expand Up @@ -499,6 +499,10 @@ async def _process_reset_game_action(self, agent_addr: tuple)->None:
"configuration_hash": self._CONFIG_FILE_HASH
},
}
# extend the message with last trajectory
if "request_trajectory" in reset_action.parameters and reset_action.parameters["request_trajectory"]:
output_message_dict["message"]["last_trajectory"] = self._agent_trajectories[agent_addr]
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
await self._agent_response_queues[agent_addr].put(response_msg_json)

Expand Down Expand Up @@ -548,9 +552,9 @@ async def _process_game_action(self, agent_addr: tuple, action:Action)->None:
async with self._episode_rewards_condition:
await self._episode_rewards_condition.wait()
# append step to the trajectory if needed
if self.task_config.get_store_trajectories() or self._global_defender:
async with self._agents_lock:
self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None)

async with self._agents_lock:
self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None)
# add information to 'info' field if needed
info = {}
if self._agent_status[agent_addr] not in [AgentStatus.Playing, AgentStatus.PlayingWithTimeout]:
Expand Down Expand Up @@ -645,8 +649,8 @@ async def _reset_game(self):
self._agent_status[agent] = AgentStatus.PlayingWithTimeout
else:
self._agent_status[agent] = AgentStatus.Playing
if self.task_config.get_store_trajectories() or self._global_defender:
self._agent_trajectories[agent] = self._reset_trajectory(agent)
# if self.task_config.get_store_trajectories() or self._global_defender:
# self._agent_trajectories[agent] = self._reset_trajectory(agent)
self._reset_event.clear()
# notify all waiting agents
async with self._reset_done_condition:
Expand All @@ -670,8 +674,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState
self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout
else:
self._agent_status[agent_addr] = AgentStatus.Playing
if self.task_config.get_store_trajectories() or self._global_defender:
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
return Observation(self._agent_states[agent_addr], 0, False, {})

Expand Down
4 changes: 2 additions & 2 deletions AIDojoCoordinator/docs/Components.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ Actions are the objects sent by the agents to the environment. Each action is ev

In all cases, when an agent sends an action to AIDojo, it is given a response.
### Action format
The Action class is defined in `env.game_components.py`. It has two basic parts:
The Action class is defined in `game_components.py`. It has two basic parts:
1. ActionType:Enum
2. parameters:dict

ActionType is unique Enum that determines what kind of action is agent playing. Parameters are passed in a dictionary as follows.
### List of actions
- **JoinGame**, params={`agent_info`:AgentInfo(\<name\>, \<role\>)}: Used to register agent in a game with a given \<role\>.
- **QuitGame**, params={}: Used for termination of agent's interaction.
- **ResetGame**, params={}: Used for requesting reset of the game to it's initial position.
- **ResetGame**, params={`request_trajectory`:`bool`}: Used for requesting reset of the game to it's initial position. If `request_trajectory = True`, the coordinator will send back the complete trajectory of the previous run in the next message.
---
- **ScanNetwork**, params{`source_host`:\<IP\>, `target_network`:\<Network\>}: Scans the given \<Network\> from a specified source host. Discovers ALL hosts in a network that are accessible from \<IP\>. If successful, returns set of discovered \<IP\> objects.
- **FindServices**, params={`source_host`:\<IP\>, `target_host`:\<IP\>}: Used to discover ALL services running in the `target_host` if the host is accessible from `source_host`. If successful, returns a set of all discovered \<Service\> objects.
Expand Down
9 changes: 8 additions & 1 deletion AIDojoCoordinator/game_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action":
params[k] = Data.from_dict(v)
case "agent_info":
params[k] = AgentInfo.from_dict(v)
case "request_trajectory":
params[k] = bool(v)
case _:
raise ValueError(f"Unsupported value in {k}: {v}")
return cls(action_type=action_type, parameters=params)
Expand Down Expand Up @@ -470,4 +472,9 @@ def from_string(cls, name):
try:
return cls[name]
except KeyError:
raise ValueError(f"Invalid AgentStatus: {name}")
raise ValueError(f"Invalid AgentStatus: {name}")

@dataclass(frozen=True)
class ProtocolConfig:
END_OF_MESSAGE = b"EOF"
BUFFER_SIZE = 8192