diff --git a/env/__init__.py b/AIDojoCoordinator/__init__.py
similarity index 100%
rename from env/__init__.py
rename to AIDojoCoordinator/__init__.py
diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py
new file mode 100644
index 00000000..657292ce
--- /dev/null
+++ b/AIDojoCoordinator/coordinator.py
@@ -0,0 +1,844 @@
+import jsonlines
+import logging
+import json
+import asyncio
+from datetime import datetime
+import signal
+from AIDojoCoordinator.game_components import Action, Observation, ActionType, GameStatus, GameState, AgentStatus
+from AIDojoCoordinator.global_defender import GlobalDefender
+from AIDojoCoordinator.utils.utils import observation_as_dict, get_str_hash, ConfigParser
+import os
+
+from aiohttp import ClientSession
+from cyst.api.environment.environment import Environment
+
+class AgentServer(asyncio.Protocol):
+ """
+ Class used for serving the agents when conneting to the game run by th GameCoordinator.
+ """
+ def __init__(self, actions_queue, agent_response_queues, max_connections):
+ self.actions_queue = actions_queue
+ self.answers_queues = agent_response_queues
+ self.max_connections = max_connections
+ self.current_connections = 0
+ self.logger = logging.getLogger("AIDojo-AgentServer")
+
+ async def handle_new_agent(self, reader, writer):
+ async def send_data_to_agent(writer, data: str) -> None:
+ """
+ Send the world to the agent
+ """
+ writer.write(bytes(str(data).encode()))
+
+ # Check if the maximum number of connections has been reached
+ if self.current_connections >= self.max_connections:
+ self.logger.info(
+ f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}"
+ )
+ writer.close()
+ return
+
+ # Increment the count of current connections
+ self.current_connections += 1
+
+ # Handle the new agent
+ addr = writer.get_extra_info("peername")
+ self.logger.info(f"New agent connected: {addr}")
+ # Ensure a queue exists for this agent
+ if addr not in self.answers_queues:
+ self.answers_queues[addr] = asyncio.Queue(maxsize=2)
+ self.logger.info(f"Created queue for agent {addr}")
+
+ try:
+ while True:
+ # Step 1: Read data from the agent
+ data = await reader.read(500)
+ if not data:
+ self.logger.info(f"Agent {addr} disconnected.")
+ quit_message = Action(ActionType.QuitGame, parameters={}).to_json()
+ await self.actions_queue.put((addr, quit_message))
+ break
+
+ raw_message = data.decode().strip()
+ self.logger.debug(f"Handler received from {addr}: {raw_message}")
+
+ # Step 2: Forward the message to the Coordinator
+ await self.actions_queue.put((addr, raw_message))
+ # await asyncio.sleep(0)
+ # Step 3: Get a matching response from the answers queue
+ response_queue = self.answers_queues[addr]
+ response = await response_queue.get()
+ self.logger.info(f"Sending response to agent {addr}: {response}")
+
+ # Step 4: Send the response to the agent
+ writer.write(bytes(str(response).encode()))
+ await writer.drain()
+ except asyncio.CancelledError:
+ self.logger.debug("Terminating by KeyboardInterrupt")
+ raise
+ finally:
+ # Decrement the count of current connections
+ self.current_connections -= 1
+ if addr in self.answers_queues:
+ self.answers_queues.pop(addr)
+ self.logger.info(f"Removed queue for agent {addr}")
+ else:
+ self.logger.warning(f"Queue for agent {addr} not found during cleanup.")
+ writer.close()
+ return
+
+ async def __call__(self, reader, writer):
+ await self.handle_new_agent(reader, writer)
+
+class GameCoordinator:
+ """
+ Class for creation, and management of agent interactions in AI Dojo.
+ """
+ def __init__(self, game_host: str, game_port: int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"], task_config_file:str=None) -> None:
+ self.host = game_host
+ self.port = game_port
+ self.logger = logging.getLogger("AIDojo-GameCoordinator")
+
+ self._tasks = set()
+ self.shutdown_flag = asyncio.Event()
+ self._reset_event = asyncio.Event()
+ self._episode_end_event = asyncio.Event()
+ self._episode_rewards_condition = asyncio.Condition()
+ self._reset_done_condition = asyncio.Condition()
+ self._reset_lock = asyncio.Lock()
+ self._agents_lock = asyncio.Lock()
+ self._semaphore = asyncio.Semaphore(2)
+
+ # for accessing configuration remotely
+ self._service_host = service_host
+ self._service_port = service_port
+ # for reading configuration locally
+ self._task_config_file = task_config_file
+ self.logger = logging.getLogger("AIDojo-GameCoordinator")
+ self.ALLOWED_ROLES = allowed_roles
+ self._cyst_objects = None
+ self._cyst_object_string = None
+
+ # prepare agent communication
+ self._agent_action_queue = asyncio.Queue()
+ self._agent_response_queues = {}
+
+ # agent information
+ self.agents = {}
+ # step counter per agent_addr (int)
+ self._agent_steps = {}
+ # reset request per agent_addr (bool)
+ self._reset_requests = {}
+ self._agent_status = {}
+ self._episode_ends = {}
+ self._agent_observations = {}
+ # starting per agent_addr (dict)
+ self._agent_starting_position = {}
+ # current state per agent_addr (GameState)
+ self._agent_states = {}
+ # last action played by agent (Action)
+ self._agent_last_action = {}
+ # agent status dict {agent_addr: int}
+ self._agent_rewards = {}
+ # trajectories per agent_addr
+ self._agent_trajectories = {}
+
+ def _spawn_task(self, coroutine, *args, **kwargs)->asyncio.Task:
+ "Helper function to make sure all tasks are registered for proper termination"
+ task = asyncio.create_task(coroutine(*args, **kwargs))
+ self._tasks.add(task)
+ def remove_task(t):
+ self._tasks.discard(t)
+ task.add_done_callback(remove_task) # Remove task when done
+ return task
+
+ async def shutdown_signal_handler(self):
+ """Handle shutdown signals."""
+ self.logger.info("Shutdown signal received. Setting shutdown flag.")
+ self.shutdown_flag.set()
+
+ async def create_agent_queue(self, agent_addr:tuple)->None:
+ """
+ Creates a queue for the given agent address if it doesn't already exist.
+ """
+ if agent_addr not in self._agent_response_queues:
+ self._agent_response_queues[agent_addr] = asyncio.Queue()
+ self.logger.info(f"Created queue for agent {agent_addr}. {len(self._agent_response_queues)} queues in total.")
+
+ def convert_msg_dict_to_json(self, msg_dict:dict)->str:
+ """
+ Helper function to create text-base messge from a dictionary. Used in the Agent-Game communication.
+ """
+ try:
+ # Convert message into string representation
+ output_message = json.dumps(msg_dict)
+ except Exception as e:
+ self.logger.error(f"Error when converting msg to JSON:{e}")
+ raise e
+ # Send to anwer_queue
+ return output_message
+
+ def run(self)->None:
+ """
+ Wrapper for ayncio run function. Starts all tasks in AIDojo
+ """
+ try:
+ asyncio.run(self.start_tasks())
+ except Exception as e:
+ self.logger.error(f"Unexpected error: {e}")
+ finally:
+ self.logger.info(f"{__class__.__name__} has exited.")
+
+ async def _fetch_initialization_objects(self):
+ """Send a REST request to MAIN and fetch initialization objects of CYST simulator."""
+ async with ClientSession() as session:
+ try:
+ async with session.get(f"http://{self._service_host}:{self._service_port}/cyst_init_objects") as response:
+ if response.status == 200:
+ response = await response.json()
+ self.logger.debug(response)
+ env = Environment.create()
+ self._CONFIG_FILE_HASH = get_str_hash(response)
+ self._cyst_objects = env.configuration.general.load_configuration(response)
+ self.logger.debug(f"Initialization objects received:{self._cyst_objects}")
+ #self.task_config = ConfigParser(config_dict=response["task_configuration"])
+ else:
+ self.logger.error(f"Failed to fetch initialization objects. Status: {response.status}")
+ except Exception as e:
+ self.logger.error(f"Error fetching initialization objects: {e}")
+
+ def _load_initialization_objects(self)->None:
+ """
+ Loads task configuration from a local file.
+ """
+ self.task_config = ConfigParser(self._task_config_file)
+ self._cyst_objects = self.task_config.get_scenario()
+ self._CONFIG_FILE_HASH = get_str_hash(str(self._cyst_objects))
+
+ def _get_starting_position_per_role(self)->dict:
+ """
+ Method for finding starting position for each agent role in the game.
+ """
+ starting_positions = {}
+ for agent_role in self.ALLOWED_ROLES:
+ try:
+ starting_positions[agent_role] = self.task_config.get_start_position(agent_role=agent_role)
+ self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}")
+ except KeyError:
+ starting_positions[agent_role] = {}
+ return starting_positions
+
+ def _get_win_condition_per_role(self)-> dict:
+ """
+ Method for finding wininng conditions for each agent role in the game.
+ """
+ win_conditions = {}
+ for agent_role in self.ALLOWED_ROLES:
+ try:
+ win_conditions[agent_role] = self.task_config.get_win_conditions(agent_role=agent_role)
+ except KeyError:
+ win_conditions[agent_role] = {}
+ self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}")
+ return win_conditions
+
+ def _get_goal_description_per_role(self)->dict:
+ """
+ Method for finding goal description for each agent role in the game.
+ """
+ goal_descriptions ={}
+ for agent_role in self.ALLOWED_ROLES:
+ try:
+ goal_descriptions[agent_role] = self.task_config.get_goal_description(agent_role=agent_role)
+ except KeyError:
+ goal_descriptions[agent_role] = ""
+ self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}")
+ return goal_descriptions
+
+ def _get_max_steps_per_role(self)->dict:
+ """
+ Method for finding max amount of steps in 1 episode for each agent role in the game.
+ """
+ max_steps = {role:self.task_config.get_max_steps(role) for role in self.ALLOWED_ROLES}
+ return max_steps
+
+ async def start_tcp_server(self):
+ """
+ Starts TPC sever for the agent communication.
+ """
+ try:
+ self.logger.info("Starting the server listening for agents")
+ server = await asyncio.start_server(
+ AgentServer(
+ self._agent_action_queue,
+ self._agent_response_queues,
+ max_connections=2
+ ),
+ self.host,
+ self.port
+ )
+ addrs = ", ".join(str(sock.getsockname()) for sock in server.sockets)
+ self.logger.info(f"\tServing on {addrs}")
+ while not self.shutdown_flag.is_set():
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ self.logger.debug("\tStopping TCP server task.")
+ except Exception as e:
+ self.logger.error(f"TCP server failed: {e}")
+ finally:
+ server.close()
+ await server.wait_closed()
+ self.logger.info("\tTCP server task stopped")
+
+ async def start_tasks(self):
+ """
+ High level funciton to start all the other asynchronous tasks.
+ - Reads the conf of the coordinator
+ - Creates queues
+ - Start the main part of the coordinator
+ - Start a server that listens for agents
+ """
+ loop = asyncio.get_running_loop()
+
+ # Set up signal handlers for graceful shutdown
+ loop.add_signal_handler(
+ signal.SIGINT, lambda: asyncio.create_task(self.shutdown_signal_handler())
+ )
+ loop.add_signal_handler(
+ signal.SIGTERM, lambda: asyncio.create_task(self.shutdown_signal_handler())
+ )
+
+
+ # initialize the game objects
+ if self._service_host: #get the task config using REST API
+ self.logger.info(f"Fetching task configuration from {self._service_host}:{self._service_port}")
+ await self._fetch_initialization_objects()
+ elif self._task_config_file: # load task config locally from a file
+ self.logger.info(f"Loading task configuration from file: {self._task_config_file}")
+ self._load_initialization_objects()
+ else:
+ raise ValueError("Task configuration not specified")
+
+
+ # Read configuration
+ self._starting_positions_per_role = self._get_starting_position_per_role()
+ self._win_conditions_per_role = self._get_win_condition_per_role()
+ self._goal_description_per_role = self._get_goal_description_per_role()
+ self._steps_limit_per_role = self._get_max_steps_per_role()
+ self.logger.debug(f"Timeouts set to:{self._steps_limit_per_role}")
+ if self.task_config.get_use_global_defender():
+ self._global_defender = GlobalDefender()
+ else:
+ self._global_defender = None
+ self._use_dynamic_ips = self.task_config.get_use_dynamic_addresses()
+ self._rewards = self.task_config.get_rewards(["step", "sucess", "fail"])
+ self.logger.debug(f"Rewards set to:{self._rewards}")
+
+ # start server for agent communication
+ self._spawn_task(self.start_tcp_server)
+
+ # start episode rewards task
+ self._spawn_task(self._assign_rewards_episode_end)
+
+ # start episode rewards task
+ self._spawn_task(self._reset_game)
+
+ # start action processing task
+ self._spawn_task(self.run_game)
+
+ while not self.shutdown_flag.is_set():
+ # just wait until user terminates
+ await asyncio.sleep(1)
+ self.logger.debug("Final cleanup started")
+ # make sure there are no running tasks left
+ for task in self._tasks:
+ task.cancel() # Cancel each active task
+ await asyncio.gather(*self._tasks, return_exceptions=True) # Wait for all tasks to finish
+ self.logger.info("All tasks shut down.")
+
+ async def run_game(self):
+ """
+ Task responsible for reading messages from the agent queue and processing them based on the ActionType.
+ """
+ while not self.shutdown_flag.is_set():
+ # Read message from the queue
+ agent_addr, message = await self._agent_action_queue.get()
+ if message is not None:
+ self.logger.info(f"Coordinator received: {message}.")
+ try: # Convert message to Action
+ action = Action.from_json(message)
+ self.logger.debug(f"\tConverted to: {action}.")
+ except Exception as e:
+ self.logger.error(
+ f"Error when converting msg to Action using Action.from_json():{e}, {message}"
+ )
+ match action.type: # process action based on its type
+ case ActionType.JoinGame:
+ self.logger.debug(f"Start processing of ActionType.JoinGame by {agent_addr}")
+ self.logger.debug(f"{action.type}, {action.type.value}, {action.type == ActionType.JoinGame}")
+ self._spawn_task(self._process_join_game_action, agent_addr, action)
+ case ActionType.QuitGame:
+ self.logger.debug(f"Start processing of ActionType.QuitGame by {agent_addr}")
+ self._spawn_task(self._process_quit_game_action, agent_addr)
+ case ActionType.ResetGame:
+ self.logger.debug(f"Start processing of ActionType.ResetGame by {agent_addr}")
+ self._spawn_task(self._process_reset_game_action, agent_addr)
+ case ActionType.ExfiltrateData | ActionType.FindData | ActionType.ScanNetwork | ActionType.FindServices | ActionType.ExploitService:
+ self.logger.debug(f"Start processing of {action.type} by {agent_addr}")
+ self._spawn_task(self._process_game_action, agent_addr, action)
+ case _:
+ self.logger.warning(f"Unsupported action type: {action}!")
+ self.logger.info("\tAction processing task stopped.")
+
+ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None:
+ """
+ Method for processing Action of type ActionType.JoinGame
+ Inputs:
+ - agent_addr (tuple)
+ - JoinGame Action
+ Outputs: None (Method stores reposnse in the agent's response queue)
+ """
+ try:
+ async with self._semaphore:
+ self.logger.info(f"New Join request by {agent_addr}.")
+ if agent_addr not in self.agents:
+ agent_name = action.parameters["agent_info"].name
+ agent_role = action.parameters["agent_info"].role
+ if agent_role in self.ALLOWED_ROLES:
+ # add agent to the world
+ new_agent_game_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role])
+ if new_agent_game_state: # successful registration
+ async with self._agents_lock:
+ self.agents[agent_addr] = (agent_name, agent_role)
+ observation = self._initialize_new_player(agent_addr, new_agent_game_state)
+ self._agent_observations[agent_addr] = observation
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "status": str(GameStatus.CREATED),
+ "observation": observation_as_dict(observation),
+ "message": {
+ "message": f"Welcome {agent_name}, registred as {agent_role}",
+ "max_steps": self._steps_limit_per_role[agent_role],
+ "goal_description": self._goal_description_per_role[agent_role],
+ "actions": [str(a) for a in ActionType],
+ "configuration_hash": self._CONFIG_FILE_HASH
+ },
+ }
+ await self._agent_response_queues[agent_addr].put(self.convert_msg_dict_to_json(output_message_dict))
+ else:
+ self.logger.info(
+ f"\tError in registration, unknown agent role: {agent_role}!"
+ )
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "status": str(GameStatus.BAD_REQUEST),
+ "message": f"Incorrect agent_role {agent_role}",
+ }
+ response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
+ await self._agent_response_queues[agent_addr].put(response_msg_json)
+ else:
+ self.logger.info("\tError in registration, agent already exists!")
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "status": str(GameStatus.BAD_REQUEST),
+ "message": "Agent already exists.",
+ }
+ response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
+ await self._agent_response_queues[agent_addr].put(response_msg_json)
+ except asyncio.CancelledError:
+ self.logger.debug(f"Proccessing JoinAction of agent {agent_addr} interrupted")
+ raise # Ensure the exception propagates
+ finally:
+ self.logger.debug(f"Cleaning up after JoinGame for {agent_addr}.")
+
+ async def _process_quit_game_action(self, agent_addr: tuple)->None:
+ """
+ Method for processing Action of type ActionType.QuitGame
+ Inputs:
+ - agent_addr (tuple)
+ Outputs: None
+ """
+ try:
+ await self.remove_agent(agent_addr, self._agent_states[agent_addr])
+ agent_info = await self._remove_agent_from_game(agent_addr)
+ self.logger.info(f"Agent {agent_addr} removed from the game. {agent_info}")
+ except asyncio.CancelledError:
+ self.logger.debug(f"Proccessing QuitAction of agent {agent_addr} interrupted")
+ raise # Ensure the exception propagates
+ finally:
+ self.logger.debug(f"Cleaning up after QuitGame for {agent_addr}.")
+
+ async def _process_reset_game_action(self, agent_addr: tuple)->None:
+ """
+ Method for processing Action of type ActionType.ResetGame
+ Inputs:
+ - agent_addr (tuple)
+ Outputs: None
+ """
+ self.logger.debug("Beginning the _process_reset_game_action.")
+ async with self._reset_lock:
+ # add reset request for this agent
+ self._reset_requests[agent_addr] = True
+ if all(self._reset_requests.values()):
+ # all agents want reset - reset the world
+ self.logger.debug(f"All agents requested reset, setting the event")
+ self._reset_event.set()
+
+ # wait until reset is done
+ async with self._reset_done_condition:
+ await self._reset_done_condition.wait()
+ async with self._agents_lock:
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "status": str(GameStatus.RESET_DONE),
+ "observation": observation_as_dict(self._agent_observations[agent_addr]),
+ "message": {
+ "message": "Resetting Game and starting again.",
+ "max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]],
+ "goal_description": self._goal_description_per_role[self.agents[agent_addr][1]],
+ "configuration_hash": self._CONFIG_FILE_HASH
+ },
+ }
+ response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
+ await self._agent_response_queues[agent_addr].put(response_msg_json)
+
+ async def _process_game_action(self, agent_addr: tuple, action:Action)->None:
+ if self._episode_ends[agent_addr]:
+ self.logger.warning(f"Agent {agent_addr}({self.agents[agent_addr]}) is attempting to play action {action} after the end of the episode!")
+ # agent can't play any more actions in the game
+ current_observation = self._agent_observations[agent_addr]
+ reward = self._agent_rewards[agent_addr]
+ end_reason = str(self._agent_status[agent_addr])
+ new_observation = Observation(
+ current_observation.state,
+ reward=reward,
+ end=True,
+ info={'end_reason': end_reason, "info":"Episode ended. Request reset for starting new episode."})
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "observation": observation_as_dict(new_observation),
+ "status": str(GameStatus.FORBIDDEN),
+ }
+ else:
+ async with self._agents_lock:
+ self._agent_last_action[agent_addr] = action
+ self._agent_steps[agent_addr] += 1
+ # wait for the new state from the world
+ new_state = await self.step(agent_id=agent_addr, agent_state=self._agent_states[agent_addr], action=action)
+
+ # update agent's values
+ async with self._agents_lock:
+ # store new state of the agent
+ self._agent_states[agent_addr] = new_state
+
+ # store new state of the agent using the new state
+ self._agent_status[agent_addr] = self._update_agent_status(agent_addr)
+
+ # add reward for step (other rewards are added at the end of the episode)
+ self._agent_rewards[agent_addr] = self._rewards["step"]
+
+ # check if the episode ends for this agent
+ self._episode_ends[agent_addr] = self._update_agent_episode_end(agent_addr)
+
+ # check if the episode ends
+ if all(self._episode_ends.values()):
+ self._episode_end_event.set()
+ if self._episode_ends[agent_addr]:
+ # episode ended for this agent - wait for the others to finish
+ async with self._episode_rewards_condition:
+ await self._episode_rewards_condition.wait()
+ # append step to the trajectory if needed
+ if self.task_config.get_store_trajectories() or self._global_defender:
+ async with self._agents_lock:
+ self._add_step_to_trajectory(agent_addr, action, self._agent_rewards[agent_addr], new_state,end_reason=None)
+ # add information to 'info' field if needed
+ info = {}
+ if self._agent_status[agent_addr] not in [AgentStatus.Playing, AgentStatus.PlayingWithTimeout]:
+ info["end_reason"] = str(self._agent_status[agent_addr])
+ new_observation = Observation(self._agent_states[agent_addr], self._agent_rewards[agent_addr], self._episode_ends[agent_addr], info=info)
+ self._agent_observations[agent_addr] = new_observation
+ output_message_dict = {
+ "to_agent": agent_addr,
+ "observation": observation_as_dict(new_observation),
+ "status": str(GameStatus.OK),
+ }
+ response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
+ await self._agent_response_queues[agent_addr].put(response_msg_json)
+
+ async def _assign_rewards_episode_end(self):
+ """Task that waits for all agents to finish and assigns rewards."""
+ self.logger.debug("Starting task for episode end reward assigning.")
+ while not self.shutdown_flag.is_set():
+ # wait until episode is finished by all agents
+ done, pending = await asyncio.wait(
+ [asyncio.create_task(self._episode_end_event.wait()),
+ asyncio.create_task(self.shutdown_flag.wait())],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ # Check if shutdown_flag was set
+ if self.shutdown_flag.is_set():
+ self.logger.debug("\tExiting reward assignment task.")
+ break
+ self.logger.info("Episode finished. Assigning final rewards to agents.")
+ async with self._agents_lock:
+ attackers = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "attacker"]
+ defenders = [a for a,(_, a_role) in self.agents.items() if a_role.lower() == "defender"]
+ successful_attack = False
+ # award attackers
+ for agent in attackers:
+ self.logger.debug(f"Processing reward for agent {agent}")
+ if self._agent_status[agent] is AgentStatus.Success:
+ self._agent_rewards[agent] += self._rewards["sucess"]
+ successful_attack = True
+ else:
+ self._agent_rewards[agent] += self._rewards["fail"]
+
+ # award defenders
+ for agent in defenders:
+ self.logger.debug(f"Processing reward for agent {agent}")
+ if not successful_attack:
+ self._agent_rewards[agent] += self._rewards["sucess"]
+ self._agent_status[agent] = AgentStatus.Success
+ else:
+ self._agent_rewards[agent] += self._rewards["fail"]
+ self._agent_status[agent] = AgentStatus.Fail
+ # TODO Add penalty for False positives
+ # clear the episode end event
+ self._episode_end_event.clear()
+ # notify all waiting agents
+ async with self._episode_rewards_condition:
+ self._episode_rewards_condition.notify_all()
+ self.logger.info("\tReward assignment task stopped.")
+
+ async def _reset_game(self):
+ """Task that waits for all agents to request resets"""
+ self.logger.debug("Starting task for game reset handelling.")
+ while not self.shutdown_flag.is_set():
+ # wait until episode is finished by all agents
+ done, pending = await asyncio.wait(
+ [asyncio.create_task(self._reset_event.wait()),
+ asyncio.create_task(self.shutdown_flag.wait())],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ # Check if shutdown_flag was set
+ if self.shutdown_flag.is_set():
+ self.logger.debug("\tExiting reset_game task.")
+ break
+ # wait until episode is finished by all agents
+ self.logger.info("Resetting game to initial state.")
+ await self.reset()
+ for agent in self.agents:
+ if self.task_config.get_store_trajectories() or self._global_defender:
+ async with self._agents_lock:
+ self._store_trajectory_to_file(agent)
+ self.logger.debug(f"Resetting agent {agent}")
+ new_state = await self.reset_agent(agent, self.agents[agent][1], self._agent_starting_position[agent])
+ new_observation = Observation(self._agent_states[agent], 0, False, {})
+ async with self._agents_lock:
+ self._agent_states[agent] = new_state
+ self._agent_observations[agent] = new_observation
+ self._episode_ends[agent] = False
+ self._reset_requests[agent] = False
+ self._agent_rewards[agent] = 0
+ self._agent_steps[agent] = 0
+ if self.agents[agent][1].lower() == "attacker":
+ self._agent_status[agent] = AgentStatus.PlayingWithTimeout
+ else:
+ self._agent_status[agent] = AgentStatus.Playing
+ if self.task_config.get_store_trajectories() or self._global_defender:
+ self._agent_trajectories[agent] = self._reset_trajectory(agent)
+ self._reset_event.clear()
+ # notify all waiting agents
+ async with self._reset_done_condition:
+ self._reset_done_condition.notify_all()
+ self.logger.info("\tReset game task stopped.")
+
+ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation:
+ """
+ Method to initialize new player upon joining the game.
+ Returns initial observation for the agent based on the agent's role
+ """
+ self.logger.info(f"\tInitializing new player{agent_addr}")
+ agent_name, agent_role = self.agents[agent_addr]
+ self._agent_steps[agent_addr] = 0
+ self._reset_requests[agent_addr] = False
+ self._episode_ends[agent_addr] = False
+ self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
+ self._agent_states[agent_addr] = agent_current_state
+ self._agent_rewards[agent_addr] = 0
+ if agent_role.lower() == "attacker":
+ self._agent_status[agent_addr] = AgentStatus.PlayingWithTimeout
+ else:
+ self._agent_status[agent_addr] = AgentStatus.Playing
+ if self.task_config.get_store_trajectories() or self._global_defender:
+ self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
+ self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
+ return Observation(self._agent_states[agent_addr], 0, False, {})
+
+ async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
+ """
+ Domain specific method of the environment. Creates the initial state of the agent.
+ """
+ raise NotImplementedError
+
+ async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool:
+ """
+ Domain specific method of the environment. Creates the initial state of the agent.
+ """
+ raise NotImplementedError
+
+ async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
+ raise NotImplementedError
+
+ async def _remove_agent_from_game(self, agent_addr):
+ """
+ Removes player from the game. Should be called AFTER QuitGame action was processed by the world.
+ """
+ self.logger.info(f"Removing player {agent_addr} from the GameCoordinator")
+ agent_info = {}
+ async with self._agents_lock:
+ if agent_addr in self.agents:
+ agent_info["state"] = self._agent_states.pop(agent_addr)
+ agent_info["num_steps"] = self._agent_steps.pop(agent_addr)
+ agent_info["agent_status"] = self._agent_status.pop(agent_addr)
+ async with self._reset_lock:
+ agent_info["reset_request"] = self._reset_requests.pop(agent_addr)
+ # check if this agent was not preventing reset
+ if any(self._reset_requests.values()):
+ self._reset_event.set()
+ agent_info["episode_end"] = self._episode_ends.pop(agent_addr)
+ #check if this agent was not preventing episode end
+ if all(self._episode_ends.values()):
+ if len(self.agents) > 0:
+ self._episode_end_event.set()
+ agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None)
+ agent_info["agent_info"] = self.agents.pop(agent_addr)
+ self.logger.debug(f"\t{agent_info}")
+ else:
+ self.logger.info(f"\t Player {agent_addr} not present in the game!")
+ return agent_info
+
+ async def step(self, agent_id:tuple, agent_state:GameState, action:Action):
+ raise NotImplementedError
+
+ async def reset(self):
+ return NotImplemented
+
+ def goal_check(self, agent_addr:tuple)->bool:
+ """
+ Check if the goal conditons were satisfied in a given game state
+ """
+ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
+ """
+ Helper function for checking if a goal dictionary condition is satisfied
+ """
+ # check if we have all IPs that should have some values (are keys in goal_dict)
+ if goal_dict.keys() <= known_dict.keys():
+ try:
+ # Check if values (sets) for EACH key (host) in goal_dict are subsets of known_dict, keep matching_keys
+ matching_keys = [host for host in goal_dict.keys() if goal_dict[host]<= known_dict[host]]
+ # Check we have the amount of mathing keys as in the goal_dict
+ if len(matching_keys) == len(goal_dict.keys()):
+ return True
+ except KeyError:
+ # some keys are missing in the known_dict
+ return False
+ return False
+ self.logger.debug(f"Checking goal for agent {agent_addr}.")
+ goal_conditions = self._win_conditions_per_role[self.agents[agent_addr][1]]
+ state = self._agent_states[agent_addr]
+ # For each part of the state of the game, check if the conditions are met
+ goal_reached = {}
+ goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks)
+ goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts)
+ goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts)
+ goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services)
+ goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data)
+ goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks)
+ self.logger.debug(f"\t{goal_reached}")
+ return all(goal_reached.values())
+
+ def is_detected(self, agent:tuple)->bool:
+ if self._global_defender:
+ detection = self._global_defender.stochastic_with_threshold(self._agent_last_action[agent], self._agent_trajectories[agent]["trajectory"]["actions"])
+ self.logger.debug(f"Global Detection result: {detection}")
+ return detection
+ else:
+ # No global defender
+ return False
+
+ def is_timeout(self, agent:tuple)->bool:
+ timeout_reached = False
+ if self._steps_limit_per_role[self.agents[agent][1]]:
+ if self._agent_steps[agent] >= self._steps_limit_per_role[self.agents[agent][1]]:
+ timeout_reached = True
+ return timeout_reached
+
+ def _update_agent_status(self, agent:tuple)->AgentStatus:
+ """
+ Update the status of an agent based on reaching the goal, timeout or detection.
+ """
+ # read current status of the agent
+ next_status = self._agent_status[agent]
+ if self.goal_check(agent):
+ # Goal has been reached
+ self.logger.info(f"Agent {agent}{self.agents[agent]} reached the goal!")
+ next_status = AgentStatus.Success
+ elif self.is_detected(agent):
+ # Detection by Global Defender
+ self.logger.info(f"Agent {agent}{self.agents[agent]} detected by GlobalDefender!")
+ next_status = AgentStatus.Fail
+ elif self.is_timeout(agent):
+ # Timout Reached
+ self.logger.info(f"Agent {agent}{self.agents[agent]} reached timeout ({self._agent_steps[agent]} steps).")
+ next_status = AgentStatus.TimeoutReached
+ return next_status
+
+ def _update_agent_episode_end(self, agent:tuple)->bool:
+ episode_end = False
+ if self._agent_status[agent] in [AgentStatus.Success, AgentStatus.Fail, AgentStatus.TimeoutReached]:
+ # agent reached goal, timeout or was detected
+ episode_end = True
+ # check if there are any agents playing with timeout
+ elif all(
+ status != AgentStatus.PlayingWithTimeout
+ for status in self._agent_status.values()
+ ):
+ # all attackers have finised - terminate episode
+ self.logger.info(f"Stopping episode for {agent} because the is no ACTIVE agent playing.")
+ episode_end = True
+ return episode_end
+
+ def _reset_trajectory(self, agent_addr:tuple)->dict:
+ agent_name, agent_role = self.agents[agent_addr]
+ self.logger.debug(f"Resetting trajectory of {agent_addr}")
+ return {
+ "trajectory":{
+ "states":[self._agent_states[agent_addr].as_dict],
+ "actions":[],
+ "rewards":[],
+ },
+ "end_reason":None,
+ "agent_role":agent_role,
+ "agent_name":agent_name
+ }
+
+ def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str=None)-> None:
+ """
+ Method for adding one step to the agent trajectory.
+ """
+ if agent_addr in self._agent_trajectories:
+ self.logger.debug(f"Adding step to trajectory of {agent_addr}")
+ self._agent_trajectories[agent_addr]["trajectory"]["actions"].append(action.as_dict)
+ self._agent_trajectories[agent_addr]["trajectory"]["rewards"].append(reward)
+ self._agent_trajectories[agent_addr]["trajectory"]["states"].append(next_state.as_dict)
+ if end_reason:
+ self._agent_trajectories[agent_addr]["end_reason"] = end_reason
+
+ def _store_trajectory_to_file(self, agent_addr:tuple, location="./trajectories")-> None:
+ self.logger.debug(f"Storing Trajectory of {agent_addr}in file")
+ if agent_addr in self._agent_trajectories:
+ agent_name, agent_role = self.agents[agent_addr]
+ filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl")
+ with jsonlines.open(filename, "a") as writer:
+ writer.write(self._agent_trajectories[agent_addr])
+ self.logger.info(f"Trajectory of {agent_addr} strored in {filename}")
\ No newline at end of file
diff --git a/docs/Architecture.md b/AIDojoCoordinator/docs/Architecture.md
similarity index 100%
rename from docs/Architecture.md
rename to AIDojoCoordinator/docs/Architecture.md
diff --git a/docs/Components.md b/AIDojoCoordinator/docs/Components.md
similarity index 100%
rename from docs/Components.md
rename to AIDojoCoordinator/docs/Components.md
diff --git a/docs/Coordinator.md b/AIDojoCoordinator/docs/Coordinator.md
similarity index 83%
rename from docs/Coordinator.md
rename to AIDojoCoordinator/docs/Coordinator.md
index 6ac5aadc..bd7af93b 100644
--- a/docs/Coordinator.md
+++ b/AIDojoCoordinator/docs/Coordinator.md
@@ -12,11 +12,8 @@ Coordinator is the centerpiece of the game orchestration. It provides an interfa
## Connction to other game components
Coordinator, having the role of the middle man in all communication between the agent and the world uses several queues for massing passing and handelling.
-1. `Actions queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs.
-2. `Answer queue` is a separeate queue **per agent** in which the results of the actions are send to the agent.
-3. `World action queue` is a queue used for sending the acions from coordinator to the AI Dojo world
-4. `World response queue` is a channel used for wolrd -> coordinator communicaiton (responses to the agents' action)
-
+1. `Action queue` is a queue in which the agents submit their actions. It provides N:1 communication channel in which the coordinator receives the inputs.
+2. `Answer queues` is a separeate queue **per agent** in which the results of the actions are send to the agent.
## Main components of the coordinator
@@ -45,3 +42,7 @@ Coordinator, having the role of the middle man in all communication between the
`self._agent_statuses`: status of each agent. One of AgentStatus
`self._agent_rewards`: dictionary of final reward of each agent in the current episod. Only agent's which can't participate in the ongoing episode are listed.
`self._agent_trajectories`: complete trajectories for each agent in the ongoing episode
+
+
+## Episode
+The episode starts with sufficient amount of agents registering in the game. Each agent role has a maximum allowed number of steps defined in the task configuration. An episode ends if all agents reach the goal
\ No newline at end of file
diff --git a/docs/Trajectory_analysis.md b/AIDojoCoordinator/docs/Trajectory_analysis.md
similarity index 100%
rename from docs/Trajectory_analysis.md
rename to AIDojoCoordinator/docs/Trajectory_analysis.md
diff --git a/docs/figures/architecture_diagram.jpg b/AIDojoCoordinator/docs/figures/architecture_diagram.jpg
similarity index 100%
rename from docs/figures/architecture_diagram.jpg
rename to AIDojoCoordinator/docs/figures/architecture_diagram.jpg
diff --git a/docs/figures/message_passing_coordinator.jpg b/AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg
similarity index 100%
rename from docs/figures/message_passing_coordinator.jpg
rename to AIDojoCoordinator/docs/figures/message_passing_coordinator.jpg
diff --git a/docs/figures/scenarios/scenario 1_small.png b/AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png
similarity index 100%
rename from docs/figures/scenarios/scenario 1_small.png
rename to AIDojoCoordinator/docs/figures/scenarios/scenario 1_small.png
diff --git a/docs/figures/scenarios/scenario_1.png b/AIDojoCoordinator/docs/figures/scenarios/scenario_1.png
similarity index 100%
rename from docs/figures/scenarios/scenario_1.png
rename to AIDojoCoordinator/docs/figures/scenarios/scenario_1.png
diff --git a/docs/figures/scenarios/scenario_1_tiny.png b/AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png
similarity index 100%
rename from docs/figures/scenarios/scenario_1_tiny.png
rename to AIDojoCoordinator/docs/figures/scenarios/scenario_1_tiny.png
diff --git a/docs/figures/scenarios/three_nets.png b/AIDojoCoordinator/docs/figures/scenarios/three_nets.png
similarity index 100%
rename from docs/figures/scenarios/three_nets.png
rename to AIDojoCoordinator/docs/figures/scenarios/three_nets.png
diff --git a/env/game_components.py b/AIDojoCoordinator/game_components.py
similarity index 60%
rename from env/game_components.py
rename to AIDojoCoordinator/game_components.py
index 1c1a8c05..b82f94aa 100755
--- a/env/game_components.py
+++ b/AIDojoCoordinator/game_components.py
@@ -1,6 +1,7 @@
# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
# Library of helpful functions and objects to play the net sec game
-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, asdict
+from typing import Dict, Any
import dataclasses
from collections import namedtuple
import json
@@ -16,9 +17,13 @@ class Service():
Service represents the service object in the NetSecGame
"""
name: str
- type: str
- version: str
- is_local: bool
+ type: str = "unknown"
+ version: str = "unknown"
+ is_local: bool = True
+
+ @classmethod
+ def from_dict(cls, data: dict):
+ return cls(**data)
"""
IP represents the ip address object in the NetSecGame
@@ -42,6 +47,11 @@ def __post_init__(self):
def __repr__(self):
return self.ip
+ def __eq__(self, other):
+ if not isinstance(other, IP):
+ return NotImplemented
+ return self.ip == other.ip
+
def is_private(self):
"""
Return if the IP is private or not
@@ -54,6 +64,12 @@ def is_private(self):
if self.ip != 'external':
return True
return False
+ @classmethod
+ def from_dict(cls, data: dict):
+ return cls(**data)
+
+ def __hash__(self):
+ return hash(self.ip)
@dataclass(frozen=True, eq=True)
class Network():
@@ -96,6 +112,10 @@ def is_private(self):
except ipaddress.AddressValueError:
# If we are dealing with strings, assume they are local networks
return True
+
+ @classmethod
+ def from_dict(cls, data: dict):
+ return cls(**data)
"""
Data represents the data object in the NetSecGame
@@ -114,66 +134,48 @@ class Data():
def __hash__(self) -> int:
return hash((self.owner, self.id, self.size, self.type))
+ @classmethod
+ def from_dict(cls, data: dict):
+ return cls(**data)
+
@enum.unique
class ActionType(enum.Enum):
- """
- ActionType represents generic action for attacker in the game. Each transition has a default probability
- of success and probability of detection (if the defender is present).
- Currently 5 action types are implemented:
- - ScanNetwork
- - FindServices
- - FindData
- - ExploitService
- - ExfiltrateData
- - BlockIP
- - JoinGame
- - QuitGame
- """
-
- #override the __new__ method to enable multiple parameters
- def __new__(cls, *args, **kwargs):
- value = len(cls.__members__) + 1
- obj = object.__new__(cls)
- obj._value_ = value
- return obj
+ ScanNetwork = "ScanNetwork"
+ FindServices = "FindServices"
+ FindData = "FindData"
+ ExploitService = "ExploitService"
+ ExfiltrateData = "ExfiltrateData"
+ BlockIP = "BlockIP"
+ JoinGame = "JoinGame"
+ QuitGame = "QuitGame"
+ ResetGame = "ResetGame"
+
+ def to_string(self):
+ """Convert enum to string."""
+ return self.value
+
+ def __eq__(self, other):
+ # Compare with another ActionType
+ if isinstance(other, ActionType):
+ return self.value == other.value
+ # Compare with a string
+ elif isinstance(other, str):
+ return self.value == other
+ return False
- def __init__(self, default_success_p: float):
- self.default_success_p = default_success_p
+ def __hash__(self):
+ # Use the hash of the value for consistent behavior
+ return hash(self.value)
@classmethod
- def from_string(cls, string:str):
- match string:
- case "ActionType.ExploitService":
- return ActionType.ExploitService
- case "ActionType.ScanNetwork":
- return ActionType.ScanNetwork
- case "ActionType.FindServices":
- return ActionType.FindServices
- case "ActionType.FindData":
- return ActionType.FindData
- case "ActionType.ExfiltrateData":
- return ActionType.ExfiltrateData
- case "ActionType.BlockIP":
- return ActionType.BlockIP
- case "ActionType.JoinGame":
- return ActionType.JoinGame
- case "ActionType.ResetGame":
- return ActionType.ResetGame
- case "ActionType.QuitGame":
- return ActionType.QuitGame
- case _:
- raise ValueError("Uknown Action Type")
-
- #ActionTypes
- ScanNetwork = 0.9
- FindServices = 0.9
- FindData = 0.8
- ExploitService = 0.7
- ExfiltrateData = 0.8
- BlockIP = 1
- JoinGame = 1
- QuitGame = 1
- ResetGame = 1
+ def from_string(cls, name):
+ """Convert string to enum, stripping 'ActionType.' if present."""
+ if name.startswith("ActionType."):
+ name = name.split("ActionType.")[1]
+ try:
+ return cls[name]
+ except KeyError:
+ raise ValueError(f"Invalid ActionType: {name}")
@dataclass(frozen=True, eq=True, order=True)
class AgentInfo():
@@ -186,132 +188,83 @@ class AgentInfo():
def __repr__(self):
return f"{self.name}({self.role})"
-#Actions
-class Action():
+
+ @classmethod
+ def from_dict(cls, data: dict):
+ return cls(**data)
+
+@dataclass(frozen=True)
+class Action:
"""
- Actions are composed of the action type (see ActionTupe) and additional parameters listed in dictionary
- - ScanNetwork {"target_network": Network object, "source_host": IP object}
- - FindServices {"target_host": IP object, "source_host": IP object,}
- - FindData {"target_host": IP object, "source_host": IP object}
- - ExploitService {"target_host": IP object, "target_service": Service object, "source_host": IP object}
- - ExfiltrateData {"target_host": IP object, "source_host": IP object, "data": Data object}
- - BlockIP("target_host": IP object, "source_host": IP object, "blocked_host": IP object)
+ Immutable dataclass representing an Action.
"""
- def __init__(self, action_type: ActionType, params: dict={}) -> None:
- self._type = action_type
- self._parameters = params
+ action_type: ActionType
+ parameters: Dict[str, Any] = field(default_factory=dict)
@property
- def type(self) -> ActionType:
- return self._type
- @property
- def parameters(self)->dict:
- return self._parameters
-
- @property
- def as_dict(self)->dict:
+ def as_dict(self) -> Dict[str, Any]:
+ """Return a dictionary representation of the Action."""
params = {}
- for k,v in self.parameters.items():
- if isinstance(v, Service):
- params[k] = vars(v)
- elif isinstance(v, Data):
- params[k] = vars(v)
- elif isinstance(v, AgentInfo):
- params[k] = vars(v)
+ for k, v in self.parameters.items():
+ if hasattr(v, '__dict__'): # Handle custom objects like Service, Data, AgentInfo
+ params[k] = asdict(v)
else:
params[k] = str(v)
- return {"type": str(self.type), "params": params}
+ return {"action_type": str(self.action_type), "parameters": params}
+ @property
+ def type(self):
+ return self.action_type
+
+ def to_json(self) -> str:
+ """Serialize the Action to a JSON string."""
+ return json.dumps(self.as_dict)
+
@classmethod
- def from_dict(cls, data_dict:dict):
- action_type = ActionType.from_string(data_dict["type"])
+ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action":
+ """Create an Action from a dictionary."""
+ action_type = ActionType.from_string(data_dict["action_type"])
params = {}
- for k,v in data_dict["params"].items():
+ for k, v in data_dict["parameters"].items():
match k:
- case "source_host":
- params[k] = IP(v)
- case "target_host":
- params[k] = IP(v)
- case "blocked_host":
- params[k] = IP(v)
+ case "source_host" | "target_host" | "blocked_host":
+ params[k] = IP.from_dict(v)
case "target_network":
- net,mask = v.split("/")
- params[k] = Network(net ,int(mask))
+ params[k] = Network.from_dict(v)
case "target_service":
- params[k] = Service(**v)
+ params[k] = Service.from_dict(v)
case "data":
- params[k] = Data(**v)
+ params[k] = Data.from_dict(v)
case "agent_info":
- params[k] = AgentInfo(**v)
+ params[k] = AgentInfo.from_dict(v)
case _:
- raise ValueError(f"Unsupported Value in {k}:{v}")
- action = Action(action_type=action_type, params=params)
- return action
-
- def __repr__(self) -> str:
- return f"Action <{self._type}|{self._parameters}>"
+ raise ValueError(f"Unsupported value in {k}: {v}")
+ return cls(action_type=action_type, parameters=params)
- def __str__(self) -> str:
- return f"Action <{self._type}|{self._parameters}>"
+ @classmethod
+ def from_json(cls, json_string: str) -> "Action":
+ """Create an Action from a JSON string."""
+ data_dict = json.loads(json_string)
+ return cls.from_dict(data_dict)
- def __eq__(self, __o: object) -> bool:
- if isinstance(__o, Action):
- return self._type == __o.type and self.parameters == __o.parameters
- return False
+ def __repr__(self) -> str:
+ return f"Action <{self.action_type}|{self.parameters}>"
- def __hash__(self) -> int:
- sorted_params = sorted(self._parameters.items(), key= lambda x: x[0])
- sorted_params = [f"{x}{str(y)}" for x,y in sorted_params]
- return hash(self._type) + hash("".join(sorted_params))
-
- def as_json(self)->str:
- ret_dict = {"action_type":str(self.type)}
- ret_dict["parameters"] = {k:dataclasses.asdict(v) for k,v in self.parameters.items()}
- return json.dumps(ret_dict)
+ def __str__(self) -> str:
+ return f"Action <{self.action_type}|{self.parameters}>"
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Action):
+ return NotImplemented
+ return (
+ self.action_type == other.action_type and
+ self.parameters == other.parameters
+ )
-
-
- @classmethod
- def from_json(cls, json_string:str):
- """
- Classmethod to ccreate Action object from json string representation
- """
- parameters_dict = json.loads(json_string)
- action_type = ActionType.from_string(parameters_dict["action_type"])
- parameters = {}
- parameters_dict = parameters_dict["parameters"]
- match action_type:
- case ActionType.ScanNetwork:
- parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]),"target_network": Network(parameters_dict["target_network"]["ip"], parameters_dict["target_network"]["mask"])}
- case ActionType.FindServices:
- parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]), "target_host": IP(parameters_dict["target_host"]["ip"])}
- case ActionType.FindData:
- parameters = {"source_host": IP(parameters_dict["source_host"]["ip"]), "target_host": IP(parameters_dict["target_host"]["ip"])}
- case ActionType.ExploitService:
- parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]),
- "target_service": Service(parameters_dict["target_service"]["name"],
- parameters_dict["target_service"]["type"],
- parameters_dict["target_service"]["version"],
- parameters_dict["target_service"]["is_local"]),
- "source_host": IP(parameters_dict["source_host"]["ip"])}
- case ActionType.ExfiltrateData:
- parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]),
- "source_host": IP(parameters_dict["source_host"]["ip"]),
- "data": Data(parameters_dict["data"]["owner"],parameters_dict["data"]["id"])}
- case ActionType.BlockIP:
- parameters = {"target_host": IP(parameters_dict["target_host"]["ip"]),
- "source_host": IP(parameters_dict["source_host"]["ip"]),
- "blocked_host": IP(parameters_dict["blocked_host"]["ip"])}
- case ActionType.JoinGame:
- parameters = {"agent_info":AgentInfo(parameters_dict["agent_info"]["name"], parameters_dict["agent_info"]["role"])}
- case ActionType.QuitGame:
- parameters = {}
- case ActionType.ResetGame:
- parameters = {}
- case _:
- raise ValueError(f"Unknown Action type:{action_type}")
- action = Action(action_type=action_type, params=parameters)
- return action
+ def __hash__(self) -> int:
+ # Convert parameters to a sorted tuple of key-value pairs for consistency
+ sorted_params = tuple(sorted((k, hash(v)) for k, v in self.parameters.items()))
+ return hash((self.action_type, sorted_params))
@dataclass(frozen=True)
class GameState():
@@ -481,14 +434,40 @@ def from_string(cls, string:str):
return GameStatus.RESET_DONE
def __repr__(self) -> str:
return str(self)
-if __name__ == "__main__":
- pass
- #data1 = Data(owner="test", id="test_data", content="content", type="db")
- #data2 = Data(owner="test", id="test_data", content="content", type="db")
- # print(data)
- # print(data.size)
-
- # s = set()
- # s.add(data)
- # s.add( Data("test", "test_data", content="new_content", type="db"))
- # print(s)
\ No newline at end of file
+
+
+@enum.unique
+class AgentStatus(enum.Enum):
+ Playing = "Playing"
+ PlayingWithTimeout = "PlayingWithTimeout"
+ TimeoutReached = "TimeoutReached"
+ ResetRequested = "ResetRequested"
+ Success = "Success"
+ Fail = "Fail"
+
+ def to_string(self):
+ """Convert enum to string."""
+ return self.value
+
+ def __eq__(self, other):
+ # Compare with another ActionType
+ if isinstance(other, AgentStatus):
+ return self.value == other.value
+ # Compare with a string
+ elif isinstance(other, str):
+ return self.value == other
+ return False
+
+ def __hash__(self):
+ # Use the hash of the value for consistent behavior
+ return hash(self.value)
+
+ @classmethod
+ def from_string(cls, name):
+ """Convert string to enum, stripping 'AgentStatus.' if present."""
+ if name.startswith("AgentStatus."):
+ name = name.split("AgentStatus.")[1]
+ try:
+ return cls[name]
+ except KeyError:
+ raise ValueError(f"Invalid AgentStatus: {name}")
\ No newline at end of file
diff --git a/AIDojoCoordinator/global_defender.py b/AIDojoCoordinator/global_defender.py
new file mode 100644
index 00000000..5937ce58
--- /dev/null
+++ b/AIDojoCoordinator/global_defender.py
@@ -0,0 +1,89 @@
+# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
+from itertools import groupby
+from AIDojoCoordinator.game_components import ActionType, Action
+from random import random
+
+
+class GlobalDefender:
+
+ def __init__(self):
+
+ # The probability of detecting an action is defined by the following dictionary
+ self._DEFAULT_DETECTION_PROBS = {
+ ActionType.ScanNetwork: 0.05,
+ ActionType.FindServices: 0.075,
+ ActionType.ExploitService: 0.1,
+ ActionType.FindData: 0.025,
+ ActionType.ExfiltrateData: 0.025,
+ ActionType.BlockIP: 0.01
+ }
+
+
+ # Ratios of action types in the time window (TW) for each action type. The ratio should be higher than the defined value to trigger a detection check
+ self._TW_TYPE_RATIOS_THRESHOLD = {
+ ActionType.ScanNetwork: 0.25,
+ ActionType.FindServices: 0.3,
+ ActionType.ExploitService: 0.25,
+ ActionType.FindData: 0.5,
+ ActionType.ExfiltrateData: 0.25,
+ ActionType.BlockIP: 1
+ }
+
+ # Thresholds for consecutive actions of the same type in the TW. Only if the threshold is crossed, the detection check is triggered
+ self._TW_CONSECUTIVE_TYPE_THRESHOLD = {
+ ActionType.ScanNetwork: 2,
+ ActionType.FindServices: 3,
+ ActionType.ExfiltrateData: 2
+ }
+
+ # Thresholds for repeated actions in the episode. Only if the threshold is crossed, the detection check is triggered
+ self._EPISODE_REPEATED_ACTION_THRESHOLD = {
+ ActionType.ExploitService: 2,
+ ActionType.FindData: 2,
+ }
+
+ def stochastic(self, action_type:ActionType)->bool:
+ """
+ Simple random detection based on predefied probability and ActionType
+ """
+ roll = random()
+ if roll < self._DEFAULT_DETECTION_PROBS[action_type]:
+ return True
+ else:
+ return False
+
+ def stochastic_with_threshold(self, action: Action, episode_actions:list, tw_size:int=5)-> bool:
+ """
+ Only detect based on set probabilities if pre-defined thresholds are crossed.
+ """
+ # extend the episode with the latest action
+ # We need to copy the list before the copying, so we avoid modifying it when it is returned. Modifycation of passed list is the default behavior in Python
+ temp_episode_actions = episode_actions.copy()
+ temp_episode_actions.append(action.as_dict)
+ if len(temp_episode_actions) >= tw_size:
+ last_n_actions = temp_episode_actions[-tw_size:]
+ last_n_action_types = [action['type'] for action in last_n_actions]
+ # compute ratio of action type in the TW
+ tw_ratio = last_n_action_types.count(str(action.type))/tw_size
+ # Count how many times this exact (parametrized) action was played in episode
+ repeats_in_episode = temp_episode_actions.count(action.as_dict)
+ # compute the highest consecutive number of action type in TW
+ max_consecutive_action_type = max(sum(1 for item in grouped if item == str(action.type))
+ for _, grouped in groupby(last_n_action_types))
+
+ if action.type in self._TW_CONSECUTIVE_TYPE_THRESHOLD.keys():
+ # ScanNetwork, FindServices, ExfiltrateData
+ if tw_ratio < self._TW_TYPE_RATIOS_THRESHOLD[action.type] and max_consecutive_action_type < self._TW_CONSECUTIVE_TYPE_THRESHOLD[action.type]:
+ return False
+ else:
+ return self.stochastic(action.type)
+ elif action.type in self._EPISODE_REPEATED_ACTION_THRESHOLD.keys():
+ # FindData, Exploit service
+ if tw_ratio < self._TW_TYPE_RATIOS_THRESHOLD[action.type] and repeats_in_episode < self._EPISODE_REPEATED_ACTION_THRESHOLD[action.type]:
+ return False
+ else:
+ return self.stochastic(action.type)
+ else: #Other actions - Do not detect
+ return False
+ else:
+ return False
diff --git a/env/netsecenv_conf.yaml b/AIDojoCoordinator/netsecenv_conf.yaml
similarity index 98%
rename from env/netsecenv_conf.yaml
rename to AIDojoCoordinator/netsecenv_conf.yaml
index e24a5424..ba93ae9c 100644
--- a/env/netsecenv_conf.yaml
+++ b/AIDojoCoordinator/netsecenv_conf.yaml
@@ -4,7 +4,7 @@ coordinator:
# Define the conf for all the 'attacker' agents.
agents:
Attacker:
- max_steps: 50
+ max_steps: 20
goal:
description: "Exfiltrate data to '213.47.23.195'"
is_any_part_of_goal_random: True
@@ -102,9 +102,10 @@ env:
use_dynamic_addresses: False
use_firewall: True
save_trajectories: False
- goal_reward: 100
- detection_reward: -5
- step_reward: -1
+ rewards:
+ win: 100
+ step: -1
+ loss: -10
actions:
scan_network:
prob_success: 1.0
diff --git a/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml b/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml
new file mode 100644
index 00000000..92217768
--- /dev/null
+++ b/AIDojoCoordinator/netsecevn_conf_cyst_integration.yaml
@@ -0,0 +1,120 @@
+# Configuration file for the NetSecGame environment
+
+coordinator:
+ # Define the conf for all the 'attacker' agents.
+ agents:
+ Attacker:
+ max_steps: 10
+ goal:
+ description: "Take control of the host '192.168.0.3'"
+ is_any_part_of_goal_random: True
+ known_networks: []
+ #known_networks: [192.168.1.0/24, 192.168.3.0/24]
+ known_hosts: []
+ #known_hosts: [192.168.1.1, 192.168.1.2]
+ controlled_hosts: ['192.168.0.3']
+ #controlled_hosts: [213.47.23.195, 192.168.1.3]
+ # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local'
+ known_services: {}
+ #known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]}
+ # In data, put the target host that must have the data and which data in format user,data
+ # Example to fix the data in one host
+ known_data: {}
+ # Example to fix two data in one host
+ #known_data: {213.47.23.195: [[User1,DataFromServer1], [User5,DataFromServer5]]}
+ # Example to fix the data in two host
+ #known_data: {213.47.23.195: [User1,DataFromServer1], 192.168.3.1: [User3,Data3FromServer3]}
+ # Example to ask a random data in a specific server. Putting 'random' in the data, forces the env to randomly choose where the goal data is
+ # known_data: {213.47.23.195: [random]}
+ known_blocks: {}
+ # Example of known blocks. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3
+ # known_blocks: {192.168.2.2: {192.168.1.3}}
+ start_position:
+ known_networks: []
+ known_hosts: []
+ # The attacker must always at least control the CC if the goal is to exfiltrate there
+ # Example of fixing the starting point of the agent in a local host
+ controlled_hosts: [random]
+ # Example of asking a random position to start the agent
+ # controlled_hosts: [213.47.23.195, random]
+ # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local'
+ known_services: {}
+ # known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]}
+ # Same format as before
+ known_data: {}
+ known_blocks: {}
+ # Example of known blocks to start with. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3
+ # known_blocks: {192.168.2.2: {192.168.1.3}}
+
+ Defender:
+ goal:
+ description: "Block all attackers"
+ is_any_part_of_goal_random: False
+ known_networks: []
+ # Example
+ #known_networks: [192.168.1.0/24, 192.168.3.0/24]
+ known_hosts: []
+ # Example
+ #known_hosts: [192.168.1.1, 192.168.1.2]
+ controlled_hosts: []
+ # Example
+ #controlled_hosts: [213.47.23.195, 192.168.1.3]
+ # Services are defined as a target host where the service must be, and then a description in the form 'name,type,version,is_local'
+ known_services: {}
+ # Example
+ #known_services: {192.168.1.3: [Local system, lanman server, 10.0.19041, False], 192.168.1.4: [Other system, SMB server, 21.2.39421, False]}
+ # In data, put the target host that must have the data and which data in format user,data
+ # Example to fix the data in one host
+ known_data: {}
+ # Example to fix two data in one host
+ #known_data: {213.47.23.195: [[User1,DataFromServer1], [User5,DataFromServer5]]}
+ # Example to fix the data in two host
+ #known_data: {213.47.23.195: [User1,DataFromServer1], 192.168.3.1: [User3,Data3FromServer3]}
+ # Example to ask a random data in a specific server. Putting 'random' in the data, forces the env to randomly choose where the goal data is
+ # known_data: {213.47.23.195: [random]}
+ known_blocks: {192.168.0.3: 'all_attackers'}
+ # Example of known blocks. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3
+ # known_blocks: {192.168.2.2: {192.168.1.3}}
+ # You can also use the wildcard string 'all_routers', and 'all_attackers', to mean that all the controlled hosts of all the attackers should be in this list in order to win
+
+ start_position:
+ # should be empty for defender - will be extracted from controlled hosts
+ known_networks: []
+ # should be empty for defender - will be extracted from controlled hosts
+ known_hosts: []
+ # list of controlled hosts, wildard "all_local" can be used to include all local IPs
+ controlled_hosts: [all_local]
+ known_services: {}
+ known_data: {}
+ # Blocked IPs
+ blocked_ips: {}
+ known_blocks: {}
+ # Example of known blocks to start with. In the host 192.168.2.2, block all connections coming or going to 192.168.1.3
+ # known_blocks: {192.168.2.2: {192.168.1.3}}
+
+env:
+ # random means to choose the seed in a random way, so it is not fixed
+ random_seed: 'random'
+ # Or you can fix the seed
+ # random_seed: 42
+ scenario: 'scenario1_tiny'
+ use_global_defender: False
+ use_dynamic_addresses: False
+ use_firewall: True
+ save_trajectories: False
+ goal_reward: 100
+ detection_reward: -5
+ step_reward: -1
+ actions:
+ scan_network:
+ prob_success: 1.0
+ find_services:
+ prob_success: 1.0
+ exploit_service:
+ prob_success: 1.0
+ find_data:
+ prob_success: 1.0
+ exfiltrate_data:
+ prob_success: 1.0
+ block_ip:
+ prob_success: 1.0
\ No newline at end of file
diff --git a/env/scenarios/__init__.py b/AIDojoCoordinator/scenarios/__init__.py
similarity index 100%
rename from env/scenarios/__init__.py
rename to AIDojoCoordinator/scenarios/__init__.py
diff --git a/env/scenarios/scenario_configuration.py b/AIDojoCoordinator/scenarios/scenario_configuration.py
similarity index 96%
rename from env/scenarios/scenario_configuration.py
rename to AIDojoCoordinator/scenarios/scenario_configuration.py
index e2dbdcc6..d8fedb80 100644
--- a/env/scenarios/scenario_configuration.py
+++ b/AIDojoCoordinator/scenarios/scenario_configuration.py
@@ -1,9 +1,9 @@
# This file defines the hosts and their characteristics, the services they run, the users they have and their security levels, the data they have, and in the router/FW all the rules of which host can access what
import cyst.api.configuration as cyst_cfg
-#from cyst.api.configuration import *
from cyst.api.configuration.network.elements import RouteConfig
from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity
-
+from cyst.api.configuration import ExploitConfig, VulnerableServiceConfig
+from cyst.api.logic.exploit import ExploitLocality, ExploitCategory
''' --------------------------------------------------------------------------------------------------------------------
A template for local password authentication.
'''
@@ -28,7 +28,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="microsoft-ds",
+ name="microsoft-ds",
owner="Local system",
version="10.0.19041",
local=False,
@@ -66,7 +66,7 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -94,7 +94,7 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="windows login",
+ name="windows login",
owner="Administrator",
version="10.0.19041",
local=True,
@@ -102,7 +102,7 @@
authentication_providers=[local_password_auth("windows login")]
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
@@ -130,7 +130,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -156,7 +156,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="postgresql",
+ name="postgresql",
owner="postgresql",
version="14.3.0",
private_data=[
@@ -168,7 +168,7 @@
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -195,7 +195,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="http",
+ name="http",
owner="lighttpd",
version="1.4.54",
local=False,
@@ -211,7 +211,7 @@
access_schemes=[]
),
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -237,7 +237,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -261,7 +261,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -282,7 +282,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -306,7 +306,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -327,7 +327,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -362,7 +362,7 @@
],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -386,14 +386,14 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -419,7 +419,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -443,14 +443,14 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -476,7 +476,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -498,14 +498,14 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -531,7 +531,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -553,14 +553,14 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -586,7 +586,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -604,7 +604,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -703,14 +703,14 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="listener",
+ name="listener",
owner="attacker",
version="1.0.0",
local=False,
@@ -746,16 +746,16 @@
- There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed...
'''
exploits = [
- cyst_cfg.ExploitConfig(
+ ExploitConfig(
services=[
- cyst_cfg.VulnerableServiceConfig(
- name="microsoft-ds",
- min_version="10.0.19041",
+ VulnerableServiceConfig(
+ service="microsoft-ds",
+ min_version="10.0. 19041",
max_version="10.0.19041"
)
],
- locality=cyst_cfg.ExploitLocality.REMOTE,
- category=cyst_cfg.ExploitCategory.DATA_MANIPULATION,
+ locality=ExploitLocality.REMOTE,
+ category=ExploitCategory.DATA_MANIPULATION,
id="smb_exploit"
)
]
diff --git a/env/scenarios/smaller_scenario_configuration.py b/AIDojoCoordinator/scenarios/smaller_scenario_configuration.py
similarity index 95%
rename from env/scenarios/smaller_scenario_configuration.py
rename to AIDojoCoordinator/scenarios/smaller_scenario_configuration.py
index a4f6d725..215b4aa7 100644
--- a/env/scenarios/smaller_scenario_configuration.py
+++ b/AIDojoCoordinator/scenarios/smaller_scenario_configuration.py
@@ -1,8 +1,10 @@
# This file defines the hosts and their characteristics, the services they run, the users they have and their security levels, the data they have, and in the router/FW all the rules of which host can access what
import cyst.api.configuration as cyst_cfg
+#from cyst.api.configuration import *
from cyst.api.configuration.network.elements import RouteConfig
from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity
-
+from cyst.api.configuration import ExploitConfig, VulnerableServiceConfig
+from cyst.api.logic.exploit import ExploitLocality, ExploitCategory
''' --------------------------------------------------------------------------------------------------------------------
A template for local password authentication.
'''
@@ -21,12 +23,13 @@
- the only windows server. It does not connect to the AD
- access schemes for remote desktop and file sharing are kept separate, but can be integrated into one if needed
+- Service types should be derived from nmap services https://svn.nmap.org/nmap/nmap-services
'''
smb_server = cyst_cfg.NodeConfig(
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="lanman server",
+ name="microsoft-ds",
owner="Local system",
version="10.0.19041",
local=False,
@@ -64,7 +67,7 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="remote desktop service",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -92,7 +95,7 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="windows login",
+ name="windows login",
owner="Administrator",
version="10.0.19041",
local=True,
@@ -100,7 +103,7 @@
authentication_providers=[local_password_auth("windows login")]
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
@@ -128,7 +131,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="openssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -154,7 +157,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="postgresql",
+ name="postgresql",
owner="postgresql",
version="14.3.0",
private_data=[
@@ -166,7 +169,7 @@
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -193,7 +196,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="lighttpd",
+ name="http",
owner="lighttpd",
version="1.4.54",
local=False,
@@ -209,7 +212,7 @@
access_schemes=[]
),
cyst_cfg.PassiveServiceConfig(
- type="openssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -235,7 +238,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -259,7 +262,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="openssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -280,7 +283,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -304,7 +307,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="openssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -325,7 +328,7 @@
)]
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -360,7 +363,7 @@
],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="remote desktop service",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -384,14 +387,14 @@
]
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -490,14 +493,14 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED
),
cyst_cfg.PassiveServiceConfig(
- type="listener",
+ name="listener",
owner="attacker",
version="1.0.0",
local=False,
@@ -505,7 +508,7 @@
)
],
traffic_processors=[],
- interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.195/26"))],
+ interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.192/26"))],
shell="bash",
id="outside_node"
)
@@ -520,7 +523,7 @@
cyst_cfg.ConnectionConfig("other_server_1", 0, "router1", 3),
cyst_cfg.ConnectionConfig("other_server_2", 0, "router1", 4),
cyst_cfg.ConnectionConfig("client_1", 0, "router1", 5),
- cyst_cfg.ConnectionConfig("internet", 0, "router1", 6),
+ cyst_cfg.ConnectionConfig("internet", 0, "router1", 10),
cyst_cfg.ConnectionConfig("internet", 1, "outside_node", 0)
]
@@ -529,16 +532,16 @@
- There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed...
'''
exploits = [
- cyst_cfg.ExploitConfig(
+ ExploitConfig(
services=[
- cyst_cfg.VulnerableServiceConfig(
- name="lanman server",
- min_version="10.0.19041",
+ VulnerableServiceConfig(
+ service="microsoft-ds",
+ min_version="10.0. 19041",
max_version="10.0.19041"
)
],
- locality=cyst_cfg.ExploitLocality.REMOTE,
- category=cyst_cfg.ExploitCategory.DATA_MANIPULATION,
+ locality=ExploitLocality.REMOTE,
+ category=ExploitCategory.DATA_MANIPULATION,
id="smb_exploit"
)
]
diff --git a/env/scenarios/test_scenario_configuration.py b/AIDojoCoordinator/scenarios/test_scenario_configuration.py
similarity index 100%
rename from env/scenarios/test_scenario_configuration.py
rename to AIDojoCoordinator/scenarios/test_scenario_configuration.py
diff --git a/env/scenarios/three_net_scenario.py b/AIDojoCoordinator/scenarios/three_net_scenario.py
similarity index 97%
rename from env/scenarios/three_net_scenario.py
rename to AIDojoCoordinator/scenarios/three_net_scenario.py
index d4a14f00..b4bd6f95 100644
--- a/env/scenarios/three_net_scenario.py
+++ b/AIDojoCoordinator/scenarios/three_net_scenario.py
@@ -34,7 +34,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="microsoft-ds",
+ name="microsoft-ds",
owner="Local system",
version="10.0.19041",
local=False,
@@ -75,7 +75,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -118,7 +118,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="windows login",
+ name="windows login",
owner="Administrator",
version="10.0.19041",
local=True,
@@ -126,7 +126,7 @@
authentication_providers=[local_password_auth("windows login")],
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
@@ -157,7 +157,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="http",
+ name="http",
owner="lighttpd",
version="1.4.54",
local=False,
@@ -170,7 +170,7 @@
access_schemes=[],
),
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -213,7 +213,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -241,7 +241,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -271,7 +271,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -305,7 +305,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="active-directory",
+ name="active-directory",
owner="Local system",
version="10.0.19041",
local=False,
@@ -329,7 +329,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="windows login",
+ name="windows login",
owner="Administrator",
version="10.0.19041",
local=True,
@@ -363,7 +363,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -406,7 +406,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="postgresql",
+ name="postgresql",
owner="postgresql",
version="14.3.0",
private_data=[
@@ -416,7 +416,7 @@
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -455,7 +455,7 @@
],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -486,14 +486,14 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -523,7 +523,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ms-wbt-server",
+ name="ms-wbt-server",
owner="Local system",
version="10.0.19041",
local=False,
@@ -554,14 +554,14 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="powershell",
+ name="powershell",
owner="Local system",
version="10.0.19041",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -591,7 +591,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -622,14 +622,14 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -659,7 +659,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="ssh",
+ name="ssh",
owner="openssh",
version="8.1.0",
local=False,
@@ -690,14 +690,14 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -726,7 +726,7 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
@@ -750,7 +750,7 @@
],
),
cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
+ name="can_attack_start_here",
owner="Local system",
version="1",
local=True,
@@ -1122,14 +1122,14 @@
active_services=[],
passive_services=[
cyst_cfg.PassiveServiceConfig(
- type="bash",
+ name="bash",
owner="root",
version="5.0.0",
local=True,
access_level=cyst_cfg.AccessLevel.LIMITED,
),
cyst_cfg.PassiveServiceConfig(
- type="listener",
+ name="listener",
owner="attacker",
version="1.0.0",
local=False,
@@ -1173,7 +1173,7 @@
cyst_cfg.ExploitConfig(
services=[
cyst_cfg.VulnerableServiceConfig(
- name="microsoft-ds", min_version="10.0.19041", max_version="10.0.19041"
+ service="microsoft-ds", min_version="10.0.19041", max_version="10.0.19041"
)
],
locality=cyst_cfg.ExploitLocality.REMOTE,
diff --git a/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py b/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py
new file mode 100644
index 00000000..afc3af09
--- /dev/null
+++ b/AIDojoCoordinator/scenarios/tiny_scenario_configuration.py
@@ -0,0 +1,122 @@
+import cyst.api.configuration as cyst_cfg
+
+
+target = cyst_cfg.NodeConfig(
+ active_services=[],
+ passive_services=[
+ cyst_cfg.PassiveServiceConfig(
+ name="bash",
+ owner="root",
+ version="8.1.0",
+ access_level=cyst_cfg.AccessLevel.LIMITED,
+ local=True,
+ ),
+ cyst_cfg.PassiveServiceConfig(
+ name="lighttpd",
+ owner="www",
+ version="1.4.62",
+ access_level=cyst_cfg.AccessLevel.LIMITED,
+ local=False,
+ )
+ ],
+ shell="bash",
+ traffic_processors=[],
+ interfaces=[],
+ name="target"
+)
+
+attacker_service = cyst_cfg.ActiveServiceConfig(
+ type="netsecenv_agent",
+ name="attacker",
+ owner="attacker",
+ access_level=cyst_cfg.AccessLevel.LIMITED,
+ ref="attacker_service"
+)
+
+attacker = cyst_cfg.NodeConfig(
+ active_services=[attacker_service()],
+ passive_services=[],
+ interfaces=[],
+ shell="",
+ traffic_processors=[],
+ name="attacker_node"
+)
+
+attacker2 = cyst_cfg.NodeConfig(
+ active_services=[attacker_service()],
+ passive_services=[],
+ interfaces=[],
+ shell="",
+ traffic_processors=[],
+ name="attacker_node_2"
+)
+
+router = cyst_cfg.RouterConfig(
+ interfaces=[
+ cyst_cfg.InterfaceConfig(
+ ip=cyst_cfg.IPAddress("192.168.0.1"),
+ net=cyst_cfg.IPNetwork("192.168.0.1/24"),
+ index=0
+ ),
+ cyst_cfg.InterfaceConfig(
+ ip=cyst_cfg.IPAddress("192.168.0.1"),
+ net=cyst_cfg.IPNetwork("192.168.0.1/24"),
+ index=1
+ ),
+ cyst_cfg.InterfaceConfig(
+ ip=cyst_cfg.IPAddress("192.168.0.1"),
+ net=cyst_cfg.IPNetwork("192.168.0.1/24"),
+ index=2
+ )
+ ],
+ traffic_processors=[
+ cyst_cfg.FirewallConfig(
+ default_policy=cyst_cfg.FirewallPolicy.ALLOW,
+ chains=[
+ cyst_cfg.FirewallChainConfig(
+ type=cyst_cfg.FirewallChainType.FORWARD,
+ policy=cyst_cfg.FirewallPolicy.ALLOW,
+ rules=[]
+ )
+ ]
+ )
+ ],
+ id="router"
+)
+
+exploit1 = cyst_cfg.ExploitConfig(
+ services=[
+ cyst_cfg.VulnerableServiceConfig(
+ service="lighttpd",
+ min_version="1.4.62",
+ max_version="1.4.62"
+ )
+ ],
+ locality=cyst_cfg.ExploitLocality.REMOTE,
+ category=cyst_cfg.ExploitCategory.CODE_EXECUTION,
+ id="http_exploit"
+)
+
+connection1 = cyst_cfg.ConnectionConfig(
+ src_ref=target,
+ src_port=-1,
+ dst_ref=router,
+ dst_port=0
+)
+
+connection2 = cyst_cfg.ConnectionConfig(
+ src_ref=attacker,
+ src_port=-1,
+ dst_ref=router,
+ dst_port=1
+)
+
+connection3 = cyst_cfg.ConnectionConfig(
+ src_ref=attacker2,
+ src_port=-1,
+ dst_ref=router,
+ dst_port=2
+)
+
+
+configuration_objects = [target, attacker_service, attacker, attacker2, router, exploit1, connection2, connection1, connection3]
diff --git a/env/worlds/__init__.py b/AIDojoCoordinator/utils/__init__.py
similarity index 100%
rename from env/worlds/__init__.py
rename to AIDojoCoordinator/utils/__init__.py
diff --git a/utils/action_plots.r b/AIDojoCoordinator/utils/action_plots.r
similarity index 100%
rename from utils/action_plots.r
rename to AIDojoCoordinator/utils/action_plots.r
diff --git a/utils/action_plots_readme.md b/AIDojoCoordinator/utils/action_plots_readme.md
similarity index 100%
rename from utils/action_plots_readme.md
rename to AIDojoCoordinator/utils/action_plots_readme.md
diff --git a/utils/actions_parser.py b/AIDojoCoordinator/utils/actions_parser.py
similarity index 100%
rename from utils/actions_parser.py
rename to AIDojoCoordinator/utils/actions_parser.py
diff --git a/utils/gamaplay_graphs.py b/AIDojoCoordinator/utils/gamaplay_graphs.py
similarity index 60%
rename from utils/gamaplay_graphs.py
rename to AIDojoCoordinator/utils/gamaplay_graphs.py
index c42e7947..a2f6f80e 100644
--- a/utils/gamaplay_graphs.py
+++ b/AIDojoCoordinator/utils/gamaplay_graphs.py
@@ -1,17 +1,16 @@
from trajectory_analysis import read_json
import numpy as np
-import sys
import os
import utils
import argparse
import matplotlib.pyplot as plt
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
-from env.game_components import GameState, Action
+from AIDojoCoordinator.game_components import GameState, Action
class TrajectoryGraph:
def __init__(self)->None:
self._checkpoints = {}
+ self._checkpoint_size = {}
self._checkpoint_edges = {}
self._checkpoint_simple_edges = {}
self._wins_per_checkpoint = {}
@@ -54,6 +53,7 @@ def add_checkpoint(self, trajectories:list, end_reason=None)->None:
wins = []
edges = {}
simple_edges = {}
+ self._checkpoint_size[self.num_checkpoints] = len(trajectories)
for play in trajectories:
if end_reason and play["end_reason"] not in end_reason:
continue
@@ -63,11 +63,13 @@ def add_checkpoint(self, trajectories:list, end_reason=None)->None:
wins.append(1)
else:
wins.append(0)
+ # get the id of the first state
state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0]))
- #print(f'Trajectory len: {len(play["trajectory"]["actions"])}')
- for i in range(1, len(play["trajectory"]["actions"])):
+ # iterate over the trajectory
+ assert len(play["trajectory"]["states"]) == len(play["trajectory"]["actions"]) +1
+ for i in range(1, len(play["trajectory"]["states"])):
next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i]))
- action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i])))
+ action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i-1])))
# fullgraph
if (state_id, next_state_id, action_id) not in edges:
edges[state_id, next_state_id, action_id] = 0
@@ -155,19 +157,95 @@ def get_graph_structure_progress(self)->dict:
return super_graph
def get_graph_structure_probabilistic_progress(self)->dict:
-
+ # collect all edeges from all checkpoints
all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
+ # prepare data straucture for the probabiliites per edge
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
- for i, edge_list in self._checkpoint_edges.items():
- total_out_edges_use = {}
- for (src, _, _), frequency in edge_list.items():
- if src not in total_out_edges_use:
- total_out_edges_use[src] = 0
- total_out_edges_use[src] += frequency
- for (src,dst,edge), value in edge_list.items():
- super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src]
+ for i, edges in self._checkpoint_edges.items():
+ total_edge_count_cp = sum(edges.values())
+ # print(f"Processing timestamp {i}")
+ # # total_out_edges_use = {}
+ # # for (src, _, _), frequency in edge_list.items():
+ # # if src not in total_out_edges_use:
+ # # total_out_edges_use[src] = 0
+ # # total_out_edges_use[src] += frequency
+ # # for (src,dst,edge), value in edge_list.items():
+ # # super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src]
+ # src_nodes = set([x for (x, _, _ ) in edge_list.keys()])
+ # print(f"\t{len(src_nodes)} source nodes")
+ # num_outgoing_edges = {}
+ # for node in src_nodes:
+ # num_outgoing_edges[node] = sum([v for k,v in edge_list.items() if k[0] == node])
+ # print(f"\t{num_outgoing_edges[node]}")
+ # for (src,dst,action), occurence in edge_list.items():
+ # print(f"\tedge:{(src,dst,self._id_to_action[action])}, occurence={occurence}, prob={occurence/num_outgoing_edges[src]}")
+ # super_graph[(src,dst,action)][i] = occurence/num_outgoing_edges[src]
+ for edge, occurence in edges.items():
+ super_graph[edge][i] = occurence/total_edge_count_cp
return super_graph
+ def calculate_source_node_likelihoods(self) -> dict:
+ """
+ Calculates the likelihood of each edge originating from its source node
+ in each checkpoint.
+
+ Returns:
+ source_likelihoods (dict): A nested dictionary where the outer keys are checkpoint numbers,
+ and the inner dictionaries map edges to their likelihoods.
+ """
+ all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
+ # prepare data straucture for the probabiliites per edge
+ source_likelihoods = {key:np.zeros(self.num_checkpoints) for key in all_edges}
+
+ for checkpoint, edges in self._checkpoint_edges.items():
+ # Map to store total occurrences of edges for each source node
+ source_totals = {}
+
+ # Calculate total occurrences for each source node
+ for (source, destination, action), count in edges.items():
+ if source not in source_totals:
+ source_totals[source] = 0
+ source_totals[source] += count
+
+ # Calculate likelihoods for each edge
+
+ for (source, destination, action), count in edges.items():
+ if source_totals[source] > 0:
+ source_likelihoods[(source, destination, action)][checkpoint] = count / source_totals[source]
+ return source_likelihoods
+
+ def calculate_edge_play_likelihoods(self) -> dict:
+ """
+ Calculates the likelihood of each edge being present in a play for each checkpoint.
+ Returns:
+ play_likelihoods (dict): A nested dictionary where the outer keys are checkpoint numbers,
+ and the inner dictionaries map edges to their likelihoods.
+ """
+ all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
+ # prepare data straucture for the probabiliites per edge
+ play_likelihoods = {key:np.zeros(self.num_checkpoints) for key in all_edges}
+ for checkpoint, edges in self._checkpoint_edges.items():
+ # Get the total number of trajectories (plays) in this checkpoint
+ total_plays = self._checkpoint_size.get(checkpoint, 0)
+ if total_plays == 0:
+ continue # Skip checkpoints with no trajectories
+
+ # Calculate likelihood for each edge
+ for edge, count in edges.items():
+ play_likelihoods[edge][checkpoint] = min(count / total_plays, 1)
+
+ return play_likelihoods
+ def get_graph_entropies(self)->dict:
+ def compute_entropy(probs, epsilon_value=1e-10):
+ probs = np.array(probs)
+ normalized_probs = probs / np.sum(probs)
+ entropy = -np.sum(normalized_probs * np.log(normalized_probs +epsilon_value)) # Avoid log(0)
+ return entropy
+ probabilistic_graph = self.get_graph_structure_probabilistic_progress
+ edge_entropy = {}
+ for e, probs in probabilistic_graph:
+ edge_entropy[e] = compute_entropy(probs)
+
def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple:
edges = {}
nodes_timestamps = {}
@@ -255,55 +333,27 @@ def get_graph_modificiation(edge_list1, edge_list2):
# parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
# parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False)
- parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False)
+ parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=2000, required=False)
args = parser.parse_args()
- # trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
- # trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
- # states = {}
- # actions = {}
-
- # graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
- # graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
-
- # state_to_id = {v:k for k,v in states.items()}
- # action_to_id = {v:k for k,v in states.items()}
-
- # print(f"Trajectory 1: {args.t1}")
- # print(f"WR={t1_wr_mean}±{t1_wr_std}")
- # get_graph_stats(graph_t1, state_to_id, action_to_id)
- # print(f"Trajectory 2: {args.t2}")
- # print(f"WR={t2_wr_mean}±{t2_wr_std}")
- # get_graph_stats(graph_t2, state_to_id, action_to_id)
-
- # a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
- # print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
- # # print("positions of same states:")
- # # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
- # # print(g1_timestaps[node], g2_timestaps[node])
- # # print("-----------------------")
- # tg_no_blocks = TrajectoryGraph()
-
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories))
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories))
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories))
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories))
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories))
- # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories))
-
- # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories))
- # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories))
- # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories))
- # tg_no_blocks.plot_graph_stats_progress()
-
+
tg_blocks = TrajectoryGraph()
- tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories))
- tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories))
- tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories))
- tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-2000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-4000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-6000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-8000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-10000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-12000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-14000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-16000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-18000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-20000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-22000.jsonl",max_lines=args.n_trajectories))
+ tg_blocks.add_checkpoint(read_json("./trajectories/2025-01-16_experimentsarsa_005-episodes-24000.jsonl",max_lines=args.n_trajectories))
tg_blocks.plot_graph_stats_progress()
- super_graph = tg_blocks.get_graph_structure_probabilistic_progress()
- print(len(super_graph))
- edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0]
- print(len(edges_present_everycheckpoint))
\ No newline at end of file
+ #edge_probabilies_per_cp = tg_blocks.get_graph_structure_probabilistic_progress()
+ #edge_probabilies_per_cp = tg_blocks.calculate_source_node_likelihoods()
+ edge_play_likelihoods = tg_blocks.calculate_edge_play_likelihoods()
+ for k,v in edge_play_likelihoods.items():
+ print(f"{k}, {','.join(map(str, v.tolist()))}")
\ No newline at end of file
diff --git a/utils/log_parser.py b/AIDojoCoordinator/utils/log_parser.py
similarity index 100%
rename from utils/log_parser.py
rename to AIDojoCoordinator/utils/log_parser.py
diff --git a/utils/trajectory_analysis.py b/AIDojoCoordinator/utils/trajectory_analysis.py
similarity index 99%
rename from utils/trajectory_analysis.py
rename to AIDojoCoordinator/utils/trajectory_analysis.py
index f6034f42..1bed8d3d 100644
--- a/utils/trajectory_analysis.py
+++ b/AIDojoCoordinator/utils/trajectory_analysis.py
@@ -1,6 +1,5 @@
import jsonlines
import numpy as np
-import sys
import os
import utils
import matplotlib.pyplot as plt
@@ -10,10 +9,7 @@
import plotly.graph_objects as go
from sklearn.preprocessing import StandardScaler
-
-
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
-from env.game_components import GameState, Action, ActionType
+from AIDojoCoordinator.game_components import GameState, Action, ActionType
diff --git a/utils/utils.py b/AIDojoCoordinator/utils/utils.py
similarity index 86%
rename from utils/utils.py
rename to AIDojoCoordinator/utils/utils.py
index 7a53d276..b43a9509 100644
--- a/utils/utils.py
+++ b/AIDojoCoordinator/utils/utils.py
@@ -1,23 +1,21 @@
# Utility functions for then env and for the agents
# Author: Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz
# Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz
-#import configparser
+
import yaml
-import sys
-from os import path
# This is used so the agent can see the environment and game components
-sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
-from env.scenarios import scenario_configuration
-from env.scenarios import smaller_scenario_configuration
-from env.scenarios import tiny_scenario_configuration
-from env.scenarios import three_net_scenario
-from env.game_components import IP, Data, Network, Service, GameState, Action, Observation
+from AIDojoCoordinator.scenarios import scenario_configuration
+from AIDojoCoordinator.scenarios import smaller_scenario_configuration
+from AIDojoCoordinator.scenarios import tiny_scenario_configuration
+from AIDojoCoordinator.scenarios import three_net_scenario
+from AIDojoCoordinator.game_components import IP, Data, Network, Service, GameState, Action, Observation
import netaddr
import logging
import csv
from random import randint
import json
import hashlib
+from cyst.api.configuration.network.node import NodeConfig
def get_file_hash(filepath, hash_func='sha256', chunk_size=4096):
"""
@@ -31,6 +29,14 @@ def get_file_hash(filepath, hash_func='sha256', chunk_size=4096):
chunk = file.read(chunk_size)
return hash_algorithm.hexdigest()
+def get_str_hash(string, hash_func='sha256', chunk_size=4096):
+ """
+ Computes hash of a given file.
+ """
+ hash_algorithm = hashlib.new(hash_func)
+ hash_algorithm.update(string.encode('utf-8'))
+ return hash_algorithm.hexdigest()
+
def read_replay_buffer_from_csv(csvfile:str)->list:
"""
Function to read steps from a CSV file
@@ -112,14 +118,19 @@ class ConfigParser():
"""
Class to deal with the configuration file
"""
- def __init__(self, task_config_file):
+ def __init__(self, task_config_file:str=None, config_dict:dict=None):
"""
- Init the class
+ Initializes the configuration parser. Required either path to a confgiuration file or a dict with configuraitons.
"""
self.logger = logging.getLogger('configparser')
- self.read_config_file(task_config_file)
+ if task_config_file:
+ self.read_config_file(task_config_file)
+ elif config_dict:
+ self.config = config_dict
+ else:
+ self.logger.error("You must provide either the configuration file or a dictionary with the configuration!")
- def read_config_file(self, conf_file_name):
+ def read_config_file(self, conf_file_name:str):
"""
reads configuration file
"""
@@ -182,18 +193,6 @@ def read_agents_known_blocks(self, type_agent: str, type_data: str) -> dict:
known_blocks[target_host] = block_list
else:
raise ValueError(f"Unsupported value in 'known_blocks': {known_blocks_conf}")
- # try:
- # # Check the host is a good ip
- # _ = netaddr.IPAddress(target_host)
- # target_host_ip = IP(target_host)
- # for known_blocked_host in dict_blocked_hosts.values():
- # known_blocked_host_ip = IP(known_blocked_host)
- # known_blocks[target_host_ip].append(known_blocked_host_ip)
- # except (ValueError, netaddr.AddrFormatError):
- # if target_host.lower() == "all_routers":
- # known_blocks["all_routers"] = dict_blocked_hosts
- # except (ValueError):
- # known_blocks = {}
return known_blocks
def read_agents_known_services(self, type_agent: str, type_data: str) -> dict:
@@ -276,7 +275,7 @@ def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> dict:
self.logger.error(f'Configuration problem with the controlled hosts: {e}')
return controlled_hosts
- def get_player_win_conditions(self, type_of_player):
+ def get_player_win_conditions(self, type_of_player:str):
"""
Get the goal of the player
type_of_player: Can be 'attackers' or 'defenders'
@@ -312,7 +311,7 @@ def get_player_win_conditions(self, type_of_player):
return player_goal
- def get_player_start_position(self, type_of_player):
+ def get_player_start_position(self, type_of_player:str):
"""
Generate the starting position of an attacking agent
type_of_player: Can be 'attackers' or 'defenders'
@@ -341,7 +340,7 @@ def get_player_start_position(self, type_of_player):
return player_start_position
- def get_start_position(self, agent_role):
+ def get_start_position(self, agent_role:str):
match agent_role:
case "Attacker":
return self.get_player_start_position(agent_role)
@@ -391,7 +390,6 @@ def get_max_steps(self, role=str)->int:
self.logger.warning(f"Unsupported value in 'coordinator.agents.{role}.max_steps': {e}. Setting value to default=None (no step limit)")
return max_steps
-
def get_goal_description(self, agent_role)->dict:
"""
Get goal description per role
@@ -412,46 +410,17 @@ def get_goal_description(self, agent_role)->dict:
case _:
raise ValueError(f"Unsupported agent role: {agent_role}")
return description
-
- def get_goal_reward(self)->float:
- """
- Reads what is the reward for reaching the goal.
- default: 100
- """
- try:
- goal_reward = self.config['env']['goal_reward']
- return float(goal_reward)
- except KeyError:
- return 100
- except ValueError:
- return 100
-
- def get_detection_reward(self)->float:
- """
- Reads what is the reward for detection.
- default: -50
- """
- try:
- detection_reward = self.config['env']['detection_reward']
- return float(detection_reward)
- except KeyError:
- return -50
- except ValueError:
- return -50
-
- def get_step_reward(self)->float:
- """
- Reads what is the reward for detection.
- default: -1
- """
- try:
- step_reward = self.config['env']['step_reward']
- return float(step_reward)
- except KeyError:
- return -1
- except ValueError:
- return -1
+ def get_rewards(self, reward_names:list, default_value=0)->dict:
+ "Reads configuration for rewards for cases listed in 'rewards_names'"
+ rewards = {}
+ for name in reward_names:
+ try:
+ rewards[name] = self.config['env']["rewards"][name]
+ except KeyError:
+ self.logger.warning(f"No reward value found for '{name}'. Usinng default reward({name})={default_value}")
+ rewards[name] = default_value
+ return rewards
def get_use_dynamic_addresses(self)->bool:
"""
@@ -525,10 +494,10 @@ def get_use_firewall(self)->bool:
def get_use_global_defender(self)->bool:
try:
- use_firewall = self.config['env']['use_global_defender']
+ use_global_defender = self.config['env']['use_global_defender']
except KeyError:
- use_firewall = False
- return use_firewall
+ use_global_defender = False
+ return use_global_defender
def get_logging_level(debug_level):
"""
@@ -545,6 +514,23 @@ def get_logging_level(debug_level):
level = log_levels.get(debug_level.upper(), logging.ERROR)
return level
+def get_starting_position_from_cyst_config(cyst_objects):
+ starting_positions = {}
+ for obj in cyst_objects:
+ if isinstance(obj, NodeConfig):
+ for active_service in obj.active_services:
+ if active_service.type == "netsecenv_agent":
+ print(f"startig processing {obj.id}.{active_service.name}")
+ hosts = set()
+ networks = set()
+ for interface in obj.interfaces:
+ hosts.add(IP(str(interface.ip)))
+ net_ip, net_mask = str(interface.net).split("/")
+ networks.add(Network(net_ip,int(net_mask)))
+ starting_positions[f"{obj.id}.{active_service.name}"] = {"known_hosts":hosts, "known_networks":networks}
+ return starting_positions
+
+
if __name__ == "__main__":
state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)},
known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")},
diff --git a/AIDojoCoordinator/worlds/CYSTCoordinator.py b/AIDojoCoordinator/worlds/CYSTCoordinator.py
new file mode 100644
index 00000000..f9d65c74
--- /dev/null
+++ b/AIDojoCoordinator/worlds/CYSTCoordinator.py
@@ -0,0 +1,266 @@
+# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
+
+import os
+import requests
+import json
+import copy
+import ast
+import logging
+import argparse
+from pathlib import Path
+from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Service
+from AIDojoCoordinator.coordinator import GameCoordinator
+
+from AIDojoCoordinator.utils.utils import get_starting_position_from_cyst_config, get_logging_level
+
+class CYSTCoordinator(GameCoordinator):
+
+ def __init__(self, game_host:str, game_port:int, service_host:str, service_port:int, allowed_roles=["Attacker", "Defender", "Benign"]):
+ super().__init__(game_host, game_port, service_host, service_port, allowed_roles)
+ self._id_to_cystid = {}
+ self._cystid_to_id = {}
+ self._known_agent_roles = {}
+ self._last_state_per_agent = {}
+ self._last_action_per_agent = {}
+ self._last_msg_per_agent = {}
+ self._starting_positions = None
+ self._availabe_cyst_agents = None
+
+ def get_cyst_id(self, agent_role:str):
+ """
+ Returns ID of the CYST agent based on the agent's role.
+ """
+ try:
+ cyst_id = self._availabe_cyst_agents[agent_role].pop()
+ except KeyError:
+ cyst_id = None
+ return cyst_id
+
+ async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
+ self.logger.debug(f"Registering agent {agent_id} in the world.")
+ agent_role = "Attacker"
+ if not self._starting_positions:
+ self._starting_positions = get_starting_position_from_cyst_config(self._cyst_objects)
+ self._availabe_cyst_agents = {"Attacker":set([k for k in self._starting_positions.keys()])}
+ async with self._agents_lock:
+ cyst_id = self.get_cyst_id(agent_role)
+ if cyst_id:
+ self._cystid_to_id[cyst_id] = agent_id
+ self._id_to_cystid[agent_id] = cyst_id
+ self._known_agent_roles[agent_id] = agent_role
+ kh = self._starting_positions[cyst_id]["known_hosts"]
+ kn = self._starting_positions[cyst_id]["known_networks"]
+ return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn)
+ else:
+ return None
+
+ async def remove_agent(self, agent_id, agent_state:GameState)->bool:
+ print(f"Removing agent {agent_id} from the CYST World")
+ async with self._agents_lock:
+ try:
+ agent_role = self._known_agent_roles[agent_id]
+ cyst_id = self._id_to_cystid[agent_id]
+ # remove agent's IDs
+ self._id_to_cystid.pop(agent_id)
+ self._cystid_to_id.pop(cyst_id)
+ # make cyst_agent avaiable again
+ self._availabe_cyst_agents[agent_role].add(cyst_id)
+ return True
+ except KeyError:
+ self.logger.error(f"Unknown agent ID: {agent_id}!")
+ return False
+
+ async def step(self, agent_id:tuple, agent_state:GameState, action:Action)->GameState:
+ self.logger.info(f"Processing {action} from {agent_id}({self._id_to_cystid[agent_id]})")
+ next_state = None
+ match action.type:
+ case ActionType.ScanNetwork:
+ next_state = await self._execute_scan_network_action(agent_id, agent_state, action)
+ case ActionType.FindServices:
+ next_state = await self._execute_find_services_action(agent_id, agent_state, action)
+ case ActionType.FindData:
+ next_state = await self._execute_find_data_action(agent_id, agent_state, action)
+ case ActionType.ExploitService:
+ next_state = await self._execute_exploit_service_action(agent_id, agent_state, action)
+ case ActionType.ExfiltrateData:
+ next_state = await self._execute_exfiltrate_data_action(agent_id, agent_state, action)
+ case ActionType.BlockIP:
+ next_state = await self._execute_block_ip_action(agent_id, agent_state, action)
+ case _:
+ raise ValueError(f"Unknown Action type or other error: '{action.type}'")
+ return next_state
+
+ async def _cyst_request(self, cyst_id:str, msg:dict)->tuple:
+ url = f"http://localhost:8282/execute/{cyst_id}/" # Replace with your server's URL
+ data = msg # The JSON data you want to send
+ self.logger.info(f"Sedning request {msg} to {url}")
+ try:
+ # Send the POST request with JSON data
+ response = requests.post(url, json=data)
+
+ # Print the response from the server
+ self.logger.debug(f'Status code:{response.status_code}')
+ self.logger.debug(f'Response body:{response.text}')
+ return int(response.status_code), json.loads(response.text)
+ except requests.exceptions.RequestException as e:
+ print(f'An error occurred: {e}')
+
+ async def _execute_scan_network_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ action_dict = {
+ "action":"dojo:scan_network",
+ "params":
+ {
+ "dst_ip":str(action.parameters["source_host"]),
+ "dst_service":"",
+ "to_network":str(action.parameters["target_network"])
+ }
+ }
+ cyst_rsp_status, cyst_rsp_data = await self._cyst_request(self._id_to_cystid[agent_id], action_dict)
+ extended_kh = copy.deepcopy(agent_state.known_hosts)
+ extended_kn = copy.deepcopy(agent_state.known_networks)
+ extended_ch = copy.deepcopy(agent_state.controlled_hosts)
+ extended_ks = copy.deepcopy(agent_state.known_services)
+ extended_kd = copy.deepcopy(agent_state.known_data)
+ extended_kb = copy.deepcopy(agent_state.known_blocks)
+
+ if cyst_rsp_status == 200:
+ self.logger.debug("Valid response from CYST")
+ data = ast.literal_eval(cyst_rsp_data["result"][1]["content"])
+ for ip_str in data:
+ ip = IP(ip_str)
+ self.logger.debug(f"Adding {ip} to known_hosts")
+ extended_kh.add(ip)
+ return GameState(extended_ch, extended_kh, extended_ks, extended_kd, extended_kn, extended_kb)
+
+ async def _execute_find_services_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ action_dict = {
+ "action":"dojo:find_services",
+ "params":
+ {
+ "dst_ip":str(action.parameters["target_host"]),
+ "dst_service":""
+ }
+ }
+ cyst_rsp_status, cyst_rsp_data = await self._cyst_request(self._id_to_cystid[agent_id], action_dict)
+ extended_kh = copy.deepcopy(agent_state.known_hosts)
+ extended_kn = copy.deepcopy(agent_state.known_networks)
+ extended_ch = copy.deepcopy(agent_state.controlled_hosts)
+ extended_ks = copy.deepcopy(agent_state.known_services)
+ extended_kd = copy.deepcopy(agent_state.known_data)
+ extended_kb = copy.deepcopy(agent_state.known_blocks)
+
+ if cyst_rsp_status == 200:
+ self.logger.debug("Valid response from CYST")
+ data = ast.literal_eval(cyst_rsp_data["result"][1]["content"])
+ self.logger.warning(data)
+ for item in data:
+ ip = IP(item["ip"])
+ # Add IP in case it was discovered by the scan
+ extended_kh.add(ip)
+ if len(item["services"]) > 0:
+ if ip not in extended_ks.keys():
+ extended_ks[ip] = set()
+ for service_dict in item["services"]:
+ service = Service.from_dict(service_dict)
+ extended_ks[ip].add(service)
+ return GameState(extended_ch, extended_kh, extended_ks, extended_kd, extended_kn, extended_kb)
+
+ async def _execute_find_data_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ raise NotImplementedError
+
+ async def _execute_exploit_service_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ raise NotImplementedError
+
+ async def _execute_exfiltrate_data_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ raise NotImplementedError
+
+ async def _execute_block_ip_action(self, agent_id:tuple, agent_state: GameState, action:Action)->GameState:
+ raise NotImplementedError
+
+ async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
+ cyst_id = self._id_to_cystid[agent_id]
+ kh = self._starting_positions[cyst_id]["known_hosts"]
+ kn = self._starting_positions[cyst_id]["known_networks"]
+ return GameState(controlled_hosts=kh, known_hosts=kh, known_networks=kn)
+
+ async def reset(self)->bool:
+ return True
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="CYST-NetSecGame Coordinator Server Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz",
+ usage="%(prog)s [options]",
+ )
+
+ parser.add_argument(
+ "-l",
+ "--debug_level",
+ help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL",
+ action="store",
+ required=False,
+ type=str,
+ default="DEBUG",
+ )
+
+ parser.add_argument(
+ "-gh",
+ "--game_host",
+ help="host where to run the game server",
+ action="store",
+ required=False,
+ type=str,
+ default="127.0.0.1",
+ )
+
+ parser.add_argument(
+ "-gp",
+ "--game_port",
+ help="Port where to run the game server",
+ action="store",
+ required=False,
+ type=int,
+ default="9000",
+ )
+
+ parser.add_argument(
+ "-sh",
+ "--service_host",
+ help="Host where to run the config server",
+ action="store",
+ required=False,
+ type=str,
+ default="127.0.0.1",
+ )
+
+ parser.add_argument(
+ "-sp",
+ "--service_port",
+ help="Port where to listen for cyst config",
+ action="store",
+ required=False,
+ type=int,
+ default="9009",
+ )
+
+
+ args = parser.parse_args()
+ print(args)
+ # Set the logging
+ log_filename = Path("CYST_coordinator.log")
+ if not log_filename.parent.exists():
+ os.makedirs(log_filename.parent)
+
+ # Convert the logging level in the args to the level to use
+ pass_level = get_logging_level(args.debug_level)
+
+ logging.basicConfig(
+ filename=log_filename,
+ filemode="w",
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=pass_level,
+ )
+
+ game_server = CYSTCoordinator(args.game_host, args.game_port, args.service_host , args.service_port)
+ # Run it!
+ game_server.run()
\ No newline at end of file
diff --git a/env/worlds/network_security_game.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py
old mode 100755
new mode 100644
similarity index 76%
rename from env/worlds/network_security_game.py
rename to AIDojoCoordinator/worlds/NSEGameCoordinator.py
index 1e39b232..f5ab0c5a
--- a/env/worlds/network_security_game.py
+++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py
@@ -1,94 +1,126 @@
-#Authors
-# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
-# Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz
-
-import netaddr
-import env.game_components as gc
+# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
+import os
+import logging
+import argparse
import random
-import copy
-from cyst.api.configuration import NodeConfig, RouterConfig, ConnectionConfig, ExploitConfig, FirewallPolicy
import numpy as np
+import copy
from faker import Faker
-from env.worlds.aidojo_world import AIDojoWorld
-
-class NetworkSecurityEnvironment(AIDojoWorld):
- """
- Class to manage the whole network security game
- It uses some Cyst libraries for the network topology
- It presents a env environment to play
- """
- def __init__(self, task_config_file, action_queue, response_queue, world_name="NetSecEnv") -> None:
- super().__init__(task_config_file, action_queue, response_queue, world_name)
- self.logger.info("Initializing NetSetGame environment")
- # Prepare data structures for all environment components (to be filled in self._process_cyst_config())
+from pathlib import Path
+import netaddr
+
+from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service
+from AIDojoCoordinator.coordinator import GameCoordinator
+from cyst.api.configuration import NodeConfig, RouterConfig, ConnectionConfig, ExploitConfig, FirewallPolicy
+
+from AIDojoCoordinator.utils.utils import get_logging_level
+
+class NSGCoordinator(GameCoordinator):
+
+ def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=42):
+ super().__init__(game_host, game_port, service_host=None, service_port=None, allowed_roles=allowed_roles, task_config_file=task_config)
+
+ # Internal data structure of the NSG
self._ip_to_hostname = {} # Mapping of `IP`:`host_name`(str) of all nodes in the environment
self._networks = {} # A `dict` of the networks present in the environment. Keys: `Network` objects, values `set` of `IP` objects.
self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs.
self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs.
self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects.
self._fw_blocks = {}
- self._data_content = {} #content of each datapoint from self._data
# All exploits in the environment
self._exploits = {}
# A list of all the hosts where the attacker can start in a random start
self.hosts_to_start = []
self._network_mapping = {}
self._ip_mapping = {}
- # Load CYST configuration
- self._process_cyst_config(self.task_config.get_scenario())
- # Set the seed
- seed = self.task_config.get_seed('env')
+
np.random.seed(seed)
random.seed(seed)
self._seed = seed
self.logger.info(f'Setting env seed to {seed}')
-
- # Set rewards for goal/detection/step
- self._rewards = {
- "goal": self.task_config.get_goal_reward(),
- "detection": self.task_config.get_detection_reward(),
- "step": self.task_config.get_step_reward()
- }
- self.logger.info(f"\tSetting rewards - {self._rewards}")
-
- # Set the default parameters of all actionss
- # if the values of the actions were updated in the configuration file
- gc.ActionType.ScanNetwork.default_success_p = self.task_config.read_env_action_data('scan_network')
- gc.ActionType.FindServices.default_success_p = self.task_config.read_env_action_data('find_services')
- gc.ActionType.ExploitService.default_success_p = self.task_config.read_env_action_data('exploit_service')
- gc.ActionType.FindData.default_success_p = self.task_config.read_env_action_data('find_data')
- gc.ActionType.ExfiltrateData.default_success_p = self.task_config.read_env_action_data('exfiltrate_data')
- gc.ActionType.BlockIP.default_success_p = self.task_config.read_env_action_data('block_ip')
-
- # At this point all 'random' values should be assigned to something
- # Check if dynamic network and ip adddresses are required
- if self.task_config.get_use_dynamic_addresses():
+ def _initialize(self)->None:
+ # Load CYST configuration
+ self._process_cyst_config(self._cyst_objects)
+ # Check if dynamic network and ip adddresses are required
+ if self._use_dynamic_ips:
self.logger.info("Dynamic change of the IP and network addresses enabled")
self._faker_object = Faker()
- Faker.seed(seed)
- self._episode_replay_buffer = None
-
- # Make a copy of data placements so it is possible to reset to it when episode ends
+ Faker.seed(self._seed)
+ # store initial values for parts which are modified during the game
self._data_original = copy.deepcopy(self._data)
- self._data_content_original = copy.deepcopy(self._data_content)
self._firewall_original = copy.deepcopy(self._firewall)
-
- self._actions_played = []
self.logger.info("Environment initialization finished")
- @property
- def seed(self)->int:
+ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState:
"""
- Can be used by agents to use the same random seed as the environment
+ Builds a GameState from given view.
+ If there is a keyword 'random' used, it is replaced by a valid option at random.
+
+ Currently, we artificially extend the knonw_networks with +- 1 in the third octet.
"""
- return self._seed
-
- @property
- def num_actions(self)->int:
- return len(self.get_all_actions())
-
+ self.logger.info(f'Generating state from view:{view}')
+ # re-map all networks based on current mapping in self._network_mapping
+ known_networks = set([self._network_mapping[net] for net in view["known_networks"]])
+
+ controlled_hosts = set()
+ # controlled_hosts
+ for host in view['controlled_hosts']:
+ if isinstance(host, IP):
+ controlled_hosts.add(self._ip_mapping[host])
+ self.logger.debug(f'\tThe attacker has control of host {self._ip_mapping[host]}.')
+ elif host == 'random':
+ # Random start
+ self.logger.debug('\tAdding random starting position of agent')
+ self.logger.debug(f'\t\tChoosing from {self.hosts_to_start}')
+ selected = random.choice(self.hosts_to_start)
+ controlled_hosts.add(selected)
+ self.logger.debug(f'\t\tMaking agent start in {selected}')
+ elif host == "all_local":
+ # all local ips
+ self.logger.debug('\t\tAdding all local hosts to agent')
+ controlled_hosts = controlled_hosts.union(self._get_all_local_ips())
+ else:
+ self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}")
+ # re-map all known based on current mapping in self._ip_mapping
+ known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]])
+ # Add all controlled hosts to known_hosts
+ known_hosts = known_hosts.union(controlled_hosts)
+
+ if add_neighboring_nets:
+ # Extend the known networks with the neighbouring networks
+ # This is to solve in the env (and not in the agent) the problem
+ # of not knowing other networks appart from the one the agent is in
+ # This is wrong and should be done by the agent, not here
+ # TODO remove this!
+ for controlled_host in controlled_hosts:
+ for net in self._get_networks_from_host(controlled_host): #TODO
+ net_obj = netaddr.IPNetwork(str(net))
+ if net_obj.ip.is_ipv4_private_use(): #TODO
+ known_networks.add(net)
+ net_obj.value += 256
+ if net_obj.ip.is_ipv4_private_use():
+ ip = Network(str(net_obj.ip), net_obj.prefixlen)
+ self.logger.debug(f'\tAdding {ip} to agent')
+ known_networks.add(ip)
+ net_obj.value -= 2*256
+ if net_obj.ip.is_ipv4_private_use():
+ ip = Network(str(net_obj.ip), net_obj.prefixlen)
+ self.logger.debug(f'\tAdding {ip} to agent')
+ known_networks.add(ip)
+ #return value back to the original
+ net_obj.value += 256
+ known_services ={}
+ for ip, service_list in view["known_services"]:
+ known_services[self._ip_mapping[ip]] = service_list
+ known_data = {}
+ for ip, data_list in view["known_data"]:
+ known_data[self._ip_mapping[ip]] = data_list
+ game_state = GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks)
+ self.logger.info(f"Generated GameState:{game_state}")
+ return game_state
+
def _process_cyst_config(self, configuration_objects:list)-> None:
"""
Process the cyst configuration file
@@ -122,8 +154,11 @@ def process_node_config(node_obj:NodeConfig) -> None:
self.logger.info(f"\t\tProcessing interfaces in node '{node_obj.id}'")
for interface in node_obj.interfaces:
net_ip, net_mask = str(interface.net).split("/")
- net = gc.Network(net_ip,int(net_mask))
- ip = gc.IP(str(interface.ip))
+ net = Network(net_ip,int(net_mask))
+ ip = IP(str(interface.ip))
+ if len(node_obj.active_services)>0:
+ self.logger.info(f"\tAdding as potential start point")
+ self.hosts_to_start.append(ip)
self._ip_to_hostname[ip] = node_obj.id
if net not in self._networks:
self._networks[net] = []
@@ -136,27 +171,26 @@ def process_node_config(node_obj:NodeConfig) -> None:
for service in node_obj.passive_services:
# Check if it is a candidate for random start
# Becareful, it will add all the IPs for this node
- if service.type == "can_attack_start_here":
- self.hosts_to_start.append(gc.IP(str(interface.ip)))
+ if service.name == "can_attack_start_here":
+ self.hosts_to_start.append(IP(str(interface.ip)))
continue
if node_obj.id not in self._services:
self._services[node_obj.id] = []
- self._services[node_obj.id].append(gc.Service(service.type, "passive", service.version, service.local))
+ self._services[node_obj.id].append(Service(service.name, "passive", service.version, service.local))
#data
- self.logger.info(f"\t\t\tProcessing data in node '{node_obj.id}':'{service.type}' service")
+ self.logger.info(f"\t\t\tProcessing data in node '{node_obj.id}':'{service.name}' service")
try:
for data in service.private_data:
if node_obj.id not in self._data:
self._data[node_obj.id] = set()
- datapoint = gc.Data(data.owner, data.description)
+ datapoint = Data(data.owner, data.description)
self._data[node_obj.id].add(datapoint)
# add content
self._data_content[node_obj.id, datapoint.id] = f"Content of {datapoint.id}"
except AttributeError:
pass
#service does not contain any data
-
def process_router_config(router_obj:RouterConfig)->None:
self.logger.info(f"\tProcessing config of router '{router_obj.id}'")
# Process a router
@@ -173,8 +207,8 @@ def process_router_config(router_obj:RouterConfig)->None:
self.logger.info(f"\t\tProcessing interfaces in router '{router_obj.id}'")
for interface in r.interfaces:
net_ip, net_mask = str(interface.net).split("/")
- net = gc.Network(net_ip,int(net_mask))
- ip = gc.IP(str(interface.ip))
+ net = Network(net_ip,int(net_mask))
+ ip = IP(str(interface.ip))
self._ip_to_hostname[ip] = router_obj.id
if net not in self._networks:
self._networks[net] = []
@@ -271,7 +305,7 @@ def _create_new_network_mapping(self)->tuple:
if netaddr.IPNetwork(str(net)).ip.is_ipv4_private_use():
private_nets.append(net)
else:
- mapping_nets[net] = gc.Network(fake.ipv4_public(), net.mask)
+ mapping_nets[net] = Network(fake.ipv4_public(), net.mask)
# for private networks, we want to keep the distances among them
private_nets_sorted = sorted(private_nets)
@@ -282,7 +316,7 @@ def _create_new_network_mapping(self)->tuple:
# find the new lowest networks
new_base = netaddr.IPNetwork(f"{fake.ipv4_private()}/{private_nets_sorted[0].mask}")
# store its new mapping
- mapping_nets[private_nets[0]] = gc.Network(str(new_base.network), private_nets_sorted[0].mask)
+ mapping_nets[private_nets[0]] = Network(str(new_base.network), private_nets_sorted[0].mask)
base = netaddr.IPNetwork(str(private_nets_sorted[0]))
is_private_net_checks = []
for i in range(1,len(private_nets_sorted)):
@@ -294,7 +328,7 @@ def _create_new_network_mapping(self)->tuple:
# evaluate if its still a private network
is_private_net_checks.append(new_net_addr.is_ipv4_private_use())
# store the new mapping
- mapping_nets[private_nets_sorted[i]] = gc.Network(str(new_net_addr), private_nets_sorted[i].mask)
+ mapping_nets[private_nets_sorted[i]] = Network(str(new_net_addr), private_nets_sorted[i].mask)
if False not in is_private_net_checks: # verify that ALL new networks are still in the private ranges
valid_valid_network_mapping = True
except IndexError as e:
@@ -312,7 +346,7 @@ def _create_new_network_mapping(self)->tuple:
# remove broadcast and network ip from the list
random.shuffle(ip_list)
for i,ip in enumerate(ips):
- mapping_ips[ip] = gc.IP(str(ip_list[i]))
+ mapping_ips[ip] = IP(str(ip_list[i]))
# Always add random, in case random is selected for ips
mapping_ips['random'] = 'random'
self.logger.info(f"Mapping IPs done:{mapping_ips}")
@@ -326,22 +360,6 @@ def _create_new_network_mapping(self)->tuple:
new_self_networks[mapping_nets[net]].add(mapping_ips[ip])
self._networks = new_self_networks
- # Harpo says that here there is a problem that firewall.items() do not return an ip that can be used in the mapping
- # His solution is: (check)
- """
- new_self_firewall = {}
- for ip, dst_ips in self._firewall.items():
- if ip not in mapping_ips:
- self.logger.debug(f"IP {ip} not found in mapping_ips")
- continue # Skip this IP if it's not found in the mapping
-
- new_self_firewall[mapping_ips[ip]] = set()
-
- for dst_ip in dst_ips:
- new_self_firewall[mapping_ips[ip]].add(mapping_ips[dst_ip])
- self._firewall = new_self_firewall
- """
-
#self._firewall
new_self_firewall = {}
for ip, dst_ips in self._firewall.items():
@@ -431,12 +449,12 @@ def _get_data_content(self, host_ip:str, data_id:str)->str:
if (hostname, data_id) in self._data_content:
content = self._data_content[hostname,data_id]
else:
- self.logger.info(f"\tData '{data_id}' not found in host '{hostname}'({host_ip})")
+ self.logger.debug(f"\tData '{data_id}' not found in host '{hostname}'({host_ip})")
else:
self.logger.debug("Data content not found because target IP does not exists.")
return content
- def _execute_action(self, current_state:gc.GameState, action:gc.Action, agent_id)-> gc.GameState:
+ def _execute_action(self, current_state:GameState, action:Action)-> GameState:
"""
Execute the action and update the values in the state
Before this function it was checked if the action was successful
@@ -449,23 +467,23 @@ def _execute_action(self, current_state:gc.GameState, action:gc.Action, agent_id
"""
next_state = None
match action.type:
- case gc.ActionType.ScanNetwork:
+ case ActionType.ScanNetwork:
next_state = self._execute_scan_network_action(current_state, action)
- case gc.ActionType.FindServices:
+ case ActionType.FindServices:
next_state = self._execute_find_services_action(current_state, action)
- case gc.ActionType.FindData:
+ case ActionType.FindData:
next_state = self._execute_find_data_action(current_state, action)
- case gc.ActionType.ExploitService:
+ case ActionType.ExploitService:
next_state = self._execute_exploit_service_action(current_state, action)
- case gc.ActionType.ExfiltrateData:
+ case ActionType.ExfiltrateData:
next_state = self._execute_exfiltrate_data_action(current_state, action)
- case gc.ActionType.BlockIP:
+ case ActionType.BlockIP:
next_state = self._execute_block_ip_action(current_state, action)
case _:
raise ValueError(f"Unknown Action type or other error: '{action.type}'")
return next_state
- def _state_parts_deep_copy(self, current:gc.GameState)->tuple:
+ def _state_parts_deep_copy(self, current:GameState)->tuple:
next_nets = copy.deepcopy(current.known_networks)
next_known_h = copy.deepcopy(current.known_hosts)
next_controlled_h = copy.deepcopy(current.controlled_hosts)
@@ -474,7 +492,7 @@ def _state_parts_deep_copy(self, current:gc.GameState)->tuple:
next_blocked = copy.deepcopy(current.known_blocks)
return next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked
- def _firewall_check(self, src_ip:gc.IP, dst_ip:gc.IP)->bool:
+ def _firewall_check(self, src_ip:IP, dst_ip:IP)->bool:
"""Checks if firewall allows connection from 'src_ip to ''dst_ip'"""
try:
connection_allowed = dst_ip in self._firewall[src_ip]
@@ -482,12 +500,12 @@ def _firewall_check(self, src_ip:gc.IP, dst_ip:gc.IP)->bool:
connection_allowed = False
return connection_allowed
- def _execute_scan_network_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_scan_network_action(self, current_state:GameState, action:Action)->GameState:
"""
Executes the ScanNetwork action in the environment
"""
next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state)
- self.logger.info(f"\t\tScanning {action.parameters['target_network']}")
+ self.logger.debug(f"\t\tScanning {action.parameters['target_network']}")
if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts:
new_ips = set()
for ip in self._ip_to_hostname.keys(): #check if IP exists
@@ -500,15 +518,15 @@ def _execute_scan_network_action(self, current_state:gc.GameState, action:gc.Act
self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {ip} blocked by FW. Skipping")
next_known_h = next_known_h.union(new_ips)
else:
- self.logger.info(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_find_services_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_find_services_action(self, current_state:GameState, action:Action)->GameState:
"""
Executes the FindServices action in the environment
"""
next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state)
- self.logger.info(f"\t\tSearching for services in {action.parameters['target_host']}")
+ self.logger.debug(f"\t\tSearching for services in {action.parameters['target_host']}")
if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts:
if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']):
found_services = self._get_services_from_host(action.parameters["target_host"], current_state.controlled_hosts)
@@ -525,14 +543,14 @@ def _execute_find_services_action(self, current_state:gc.GameState, action:gc.Ac
self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping")
else:
self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_find_data_action(self, current:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_find_data_action(self, current:GameState, action:Action)->GameState:
"""
Executes the FindData action in the environment
"""
next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current)
- self.logger.info(f"\t\tSearching for data in {action.parameters['target_host']}")
+ self.logger.debug(f"\t\tSearching for data in {action.parameters['target_host']}")
if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current.controlled_hosts:
if self._firewall_check(action.parameters["source_host"], action.parameters['target_host']):
new_data = self._get_data_in_host(action.parameters["target_host"], current.controlled_hosts)
@@ -553,9 +571,9 @@ def _execute_find_data_action(self, current:gc.GameState, action:gc.Action)->gc.
self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping")
else:
self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_exfiltrate_data_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_exfiltrate_data_action(self, current_state:GameState, action:Action)->GameState:
"""
Executes the ExfiltrateData action in the environment
"""
@@ -597,9 +615,9 @@ def _execute_exfiltrate_data_action(self, current_state:gc.GameState, action:gc.
self.logger.debug("\t\t\tCan not exfiltrate. Source host is not controlled.")
else:
self.logger.debug("\t\t\tCan not exfiltrate. Target host is not controlled.")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_exploit_service_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_exploit_service_action(self, current_state:GameState, action:Action)->GameState:
"""
Executes the ExploitService action in the environment
"""
@@ -634,9 +652,9 @@ def _execute_exploit_service_action(self, current_state:gc.GameState, action:gc.
self.logger.debug("\t\t\tCan not exploit. Target host does not exist.")
else:
self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_block_ip_action(self, current_state:gc.GameState, action:gc.Action)->gc.GameState:
+ def _execute_block_ip_action(self, current_state:GameState, action:Action)->GameState:
"""
Executes the BlockIP action
- The action has BlockIP("target_host": IP object, "source_host": IP object, "blocked_host": IP object)
@@ -689,14 +707,14 @@ def _execute_block_ip_action(self, current_state:gc.GameState, action:gc.Action)
next_blocked[action.parameters["target_host"]].add(action.parameters["blocked_host"])
next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"])
else:
- self.logger.info(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'")
+ self.logger.debug(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'")
else:
self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW")
else:
- self.logger.info(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'")
+ self.logger.debug(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'")
else:
- self.logger.info(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
- return gc.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ self.logger.debug(f"\t\t\t Invalid source_host:'{action.parameters['source_host']}'")
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
def _get_all_local_ips(self)->set:
local_ips = set()
@@ -706,179 +724,105 @@ def _get_all_local_ips(self)->set:
local_ips.add(self._ip_mapping[ip])
self.logger.info(f"\t\t\tLocal ips: {local_ips}")
return local_ips
-
- def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->gc.GameState:
- """
- Builds a GameState from given view.
- If there is a keyword 'random' used, it is replaced by a valid option at random.
-
- Currently, we artificially extend the knonw_networks with +- 1 in the third octet.
- """
- self.logger.info(f'Generating state from view:{view}')
- # re-map all networks based on current mapping in self._network_mapping
- known_networks = set([self._network_mapping[net] for net in view["known_networks"]])
-
-
- controlled_hosts = set()
- # controlled_hosts
- for host in view['controlled_hosts']:
- if isinstance(host, gc.IP):
- controlled_hosts.add(self._ip_mapping[host])
- self.logger.info(f'\tThe attacker has control of host {self._ip_mapping[host]}.')
- elif host == 'random':
- # Random start
- self.logger.info('\tAdding random starting position of agent')
- self.logger.info(f'\t\tChoosing from {self.hosts_to_start}')
- selected = random.choice(self.hosts_to_start)
- controlled_hosts.add(selected)
- self.logger.info(f'\t\tMaking agent start in {selected}')
- elif host == "all_local":
- # all local ips
- self.logger.info('\t\tAdding all local hosts to agent')
- controlled_hosts = controlled_hosts.union(self._get_all_local_ips())
- else:
- self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}")
- # re-map all known based on current mapping in self._ip_mapping
- known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]])
- # Add all controlled hosts to known_hosts
- known_hosts = known_hosts.union(controlled_hosts)
-
- if add_neighboring_nets:
- # Extend the known networks with the neighbouring networks
- # This is to solve in the env (and not in the agent) the problem
- # of not knowing other networks appart from the one the agent is in
- # This is wrong and should be done by the agent, not here
- # TODO remove this!
- for controlled_host in controlled_hosts:
- for net in self._get_networks_from_host(controlled_host): #TODO
- net_obj = netaddr.IPNetwork(str(net))
- if net_obj.ip.is_ipv4_private_use(): #TODO
- known_networks.add(net)
- net_obj.value += 256
- if net_obj.ip.is_ipv4_private_use():
- ip = gc.Network(str(net_obj.ip), net_obj.prefixlen)
- self.logger.info(f'\tAdding {ip} to agent')
- known_networks.add(ip)
- net_obj.value -= 2*256
- if net_obj.ip.is_ipv4_private_use():
- ip = gc.Network(str(net_obj.ip), net_obj.prefixlen)
- self.logger.info(f'\tAdding {ip} to agent')
- known_networks.add(ip)
- #return value back to the original
- net_obj.value += 256
- known_services ={}
- for ip, service_list in view["known_services"]:
- known_services[self._ip_mapping[ip]] = service_list
- known_data = {}
- for ip, data_list in view["known_data"]:
- known_data[self._ip_mapping[ip]] = data_list
- game_state = gc.GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks)
- self.logger.info(f"Generated GameState:{game_state}")
+
+ async def register_agent(self, agent_id, agent_role, agent_initial_view)->GameState:
+ if len(self._networks) == 0:
+ self._initialize()
+ game_state = self._create_state_from_view(agent_initial_view)
return game_state
+
+ async def remove_agent(self, agent_id, agent_state)->bool:
+ # No action is required
+ return True
+
+ async def step(self, agent_id, agent_state, action)->GameState:
+ return self._execute_action(agent_state, action)
+
+ async def reset_agent(self, agent_id, agent_role, agent_initial_view)->GameState:
+ game_state = self._create_state_from_view(agent_initial_view)
+ return game_state
- def update_goal_dict(self, goal_dict:dict)->dict:
- """
- Updates goal dict based on the current values
- in self._network_mapping and self._ip_mapping.
- """
- new_dict = {
- "known_networks":set(),
- "known_hosts":set(),
- "controlled_hosts":set(),
- "known_services": {},
- "known_data": {},
- "known_blocks": {}
- }
- for net in goal_dict["known_networks"]:
- if net in self._network_mapping:
- new_dict["known_networks"].add(self._network_mapping[net])
- else:
- # Unknown net, do not map
- new_dict["known_networks"].add(net)
- for host in goal_dict["known_hosts"]:
- if host in self._ip_mapping:
- new_dict["known_hosts"].add(self._ip_mapping[host])
- else:
- # Unknown IP, do not map
- new_dict["known_hosts"].add(host)
- for host in goal_dict["controlled_hosts"]:
- if host in self._ip_mapping:
- new_dict["controlled_hosts"].add(self._ip_mapping[host])
- else:
- # Unknown IP, do not map
- new_dict["controlled_hosts"].add(host)
- for host, items in goal_dict["known_services"].items():
- if host in self._ip_mapping:
- new_dict["known_services"][self._ip_mapping[host]] = items
- else:
- # Unknown IP, do not map
- new_dict["known_services"][host] = items
- for host, items in goal_dict["known_data"].items():
- if host in self._ip_mapping:
- new_dict["known_data"][self._ip_mapping[host]] = items
- else:
- # Unknown IP, do not map
- new_dict["known_data"][host] = items
- for host, items in goal_dict["known_blocks"].items():
- if host in self._ip_mapping:
- new_dict["known_blocks"][self._ip_mapping[host]] = items
- else:
- # Unknown IP, do not map
- new_dict["known_blocks"][host] = items
- return new_dict
-
- def update_goal_descriptions(self, goal_description:str)->str:
- new_description = goal_description
- for ip in self._ip_mapping:
- new_description = new_description.replace(str(ip), str(self._ip_mapping[ip]))
- return new_description
-
- def reset(self)->None:
+ async def reset(self)->bool:
"""
Function to reset the state of the game
and prepare for a new episode
"""
# write all steps in the episode replay buffer in the file
- self.logger.info('--- Reseting env to its initial state ---')
+ self.logger.info('--- Reseting NSG Environment to its initial state ---')
# change IPs if needed
if self.task_config.get_use_dynamic_addresses():
self._create_new_network_mapping()
# reset self._data to orignal state
self._data = copy.deepcopy(self._data_original)
# reset self._data_content to orignal state
- self._data_content_original = copy.deepcopy(self._data_content_original)
self._firewall = copy.deepcopy(self._firewall_original)
self._fw_blocks = {}
-
-
- self._actions_played = []
-
- def step(self, state:gc.GameState, action:gc.Action, agent_id:tuple)-> gc.GameState:
- """
- Take a step in the environment given an action
- in: action
- out: observation of the state of the env
- """
- self.logger.info(f"Agent {agent_id}. Action: {action}")
- # Reward for taking an action
- reward = self._rewards["step"]
-
- # 1. Perform the action
- self._actions_played.append(action)
- if random.random() <= action.type.default_success_p:
- next_state = self._execute_action(state, action, agent_id)
- else:
- self.logger.info("\tAction NOT sucessful")
- next_state = state
-
-
- # Make the state we just got into, our current state
- current_state = state
- self.logger.info(f'New state: {next_state} ')
-
-
- # Save the transition to the episode replay buffer if there is any
- if self._episode_replay_buffer is not None:
- self._episode_replay_buffer.append((current_state, action, reward, next_state))
- # Return an observation
- return next_state
\ No newline at end of file
+ return True
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="NetSecGame Coordinator Server Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz",
+ usage="%(prog)s [options]",
+ )
+
+ parser.add_argument(
+ "-l",
+ "--debug_level",
+ help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL",
+ action="store",
+ required=False,
+ type=str,
+ default="INFO",
+ )
+
+ parser.add_argument(
+ "-gh",
+ "--game_host",
+ help="host where to run the game server",
+ action="store",
+ required=False,
+ type=str,
+ default="127.0.0.1",
+ )
+
+ parser.add_argument(
+ "-gp",
+ "--game_port",
+ help="Port where to run the game server",
+ action="store",
+ required=False,
+ type=int,
+ default="9000",
+ )
+
+ parser.add_argument(
+ "-c",
+ "--task_config",
+ help="File with the task configuration",
+ action="store",
+ required=True,
+ type=str,
+ default="netsecenv_conf.yaml",
+ )
+
+ args = parser.parse_args()
+ print(args)
+ # Set the logging
+ log_filename = Path("NSG_coordinator.log")
+ if not log_filename.parent.exists():
+ os.makedirs(log_filename.parent)
+
+ # Convert the logging level in the args to the level to use
+ pass_level = get_logging_level(args.debug_level)
+
+ logging.basicConfig(
+ filename=log_filename,
+ filemode="w",
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=pass_level,
+ )
+
+ game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config)
+ # Run it!
+ game_server.run()
\ No newline at end of file
diff --git a/env/worlds/network_security_game_real_world.py b/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py
similarity index 62%
rename from env/worlds/network_security_game_real_world.py
rename to AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py
index d7276510..763326ab 100644
--- a/env/worlds/network_security_game_real_world.py
+++ b/AIDojoCoordinator/worlds/NSGRealWorldCoordinator.py
@@ -2,22 +2,49 @@
# Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
# Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz
-import env.game_components as components
-from env.worlds.network_security_game import NetworkSecurityEnvironment
import subprocess
import xml.etree.ElementTree as ElementTree
+import logging
+import argparse
+import os
+from pathlib import Path
+from AIDojoCoordinator.utils.utils import get_logging_level
+from AIDojoCoordinator.game_components import GameState, Action, ActionType, Service,IP
+from AIDojoCoordinator.worlds.NSEGameCoordinator import NSGCoordinator
-class NetworkSecurityEnvironmentRealWorld(NetworkSecurityEnvironment):
- """
- Class to manage the whole network security game in the real world (current network)
- It uses some Cyst libraries for the network topology
- It presents a env environment to play
- """
- def __init__(self, task_config_file, world_name="NetSecEnvRealWorld") -> None:
- super().__init__(task_config_file, world_name)
+class NSERealWorldGameCoordinator(NSGCoordinator):
+
+ def _execute_action(self, current_state:GameState, action:Action)-> GameState:
+ """
+ Execute the action and update the values in the state
+ Before this function it was checked if the action was successful
+ So in here all actions were already successful.
- def _execute_scan_network_action_real_world(self, current_state:components.GameState, action:components.Action)->components.GameState:
+ - actions_type: Define if the action is simulated in netsecenv or in the real world
+ - agent_id: is the name or type of agent that requested the action
+
+ Returns: A new GameState
+ """
+ next_state = None
+ match action.type:
+ case ActionType.ScanNetwork:
+ next_state = self._execute_scan_network_action_real_world(current_state, action)
+ case ActionType.FindServices:
+ next_state = self._execute_find_services_action_real_world(current_state, action)
+ case ActionType.FindData:
+ next_state = self._execute_find_data_action(current_state, action)
+ case ActionType.ExploitService:
+ next_state = self._execute_exploit_service_action(current_state, action)
+ case ActionType.ExfiltrateData:
+ next_state = self._execute_exfiltrate_data_action(current_state, action)
+ case ActionType.BlockIP:
+ next_state = self._execute_block_ip_action(current_state, action)
+ case _:
+ raise ValueError(f"Unknown Action type or other error: '{action.type}'")
+ return next_state
+
+ def _execute_scan_network_action_real_world(self, current_state:GameState, action:Action)->GameState:
"""
Executes the ScanNetwork action in the the real world
"""
@@ -38,7 +65,7 @@ def _execute_scan_network_action_real_world(self, current_state:components.GameS
status = ""
ip_elem = host.find('./address[@addrtype="ipv4"]')
if ip_elem is not None:
- ip = components.IP(str(ip_elem.get('addr')))
+ ip = IP(str(ip_elem.get('addr')))
else:
ip = ""
@@ -53,9 +80,9 @@ def _execute_scan_network_action_real_world(self, current_state:components.GameS
self.logger.debug(f"\t\t\tAdding {ip} to new_ips. {status}, {mac_address}, {vendor}")
new_ips.add(ip)
next_known_h = next_known_h.union(new_ips)
- return components.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_find_services_action_real_world(self, current_state:components.GameState, action:components.Action)->components.GameState:
+ def _execute_find_services_action_real_world(self, current_state:GameState, action:Action)->GameState:
"""
Executes the FindServices action in the real world
"""
@@ -85,7 +112,7 @@ def _execute_find_services_action_real_world(self, current_state:components.Game
service_elem = port.find('./service[@name]')
service_name = service_elem.get('name') if service_elem is not None else "Unknown"
service_fullname = f'{port_id}/{protocol}/{service_name}'
- service = components.Service(name=service_fullname, type=service_name, version='', is_local=False)
+ service = Service(name=service_fullname, type=service_name, version='', is_local=False)
found_services.add(service)
next_services[action.parameters["target_host"]] = found_services
@@ -95,66 +122,72 @@ def _execute_find_services_action_real_world(self, current_state:components.Game
self.logger.info(f"\t\tAdding {action.parameters['target_host']} to known_hosts")
next_known_h.add(action.parameters["target_host"])
next_nets = next_nets.union({net for net, values in self._networks.items() if action.parameters["target_host"] in values})
- return components.GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
+ return GameState(next_controlled_h, next_known_h, next_services, next_data, next_nets, next_blocked)
- def _execute_action(self, current_state:components.GameState, action:components.Action, agent_id)-> components.GameState:
- """
- Execute the action and update the values in the state
- Before this function it was checked if the action was successful
- So in here all actions were already successful.
-
- - actions_type: Define if the action is simulated in netsecenv or in the real world
- - agent_id: is the name or type of agent that requested the action
-
- Returns: A new GameState
- """
- next_state = None
- match action.type:
- case components.ActionType.ScanNetwork:
- next_state = self._execute_scan_network_action_real_world(current_state, action)
- case components.ActionType.FindServices:
- next_state = self._execute_find_services_action_real_world(current_state, action)
- case components.ActionType.FindData:
- # This Action type is not implemente in real world - use the simualtion
- next_state = self._execute_find_data_action(current_state, action)
- case components.ActionType.ExploitService:
- # This Action type is not implemente in real world - use the simualtion
- next_state = self._execute_exploit_service_action(current_state, action)
- case components.ActionType.ExfiltrateData:
- # This Action type is not implemente in real world - use the simualtion
- next_state = self._execute_exfiltrate_data_action(current_state, action)
- case components.ActionType.BlockIP:
- # This Action type is not implemente in real world - use the simualtion
- next_state = self._execute_block_ip_action(current_state, action)
- case _:
- raise ValueError(f"Unknown Action type or other error: '{action.type}'")
- return next_state
-
- def step(self, state:components.GameState, action:components.Action, agent_id:tuple)-> components.GameState:
- """
- Take a step in the environment given an action
- in: action
- out: observation of the state of the env
- """
- self.logger.info(f"Agent {agent_id}. Action: {action}")
- # Reward for taking an action
- reward = self._rewards["step"]
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="NetSecGame Coordinator Server (Real World) Author: Ondrej Lukas ondrej.lukas@aic.fel.cvut.cz; sebastian.garcia@agents.fel.cvut.cz",
+ usage="%(prog)s [options]",
+ )
+
+ parser.add_argument(
+ "-l",
+ "--debug_level",
+ help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL",
+ action="store",
+ required=False,
+ type=str,
+ default="INFO",
+ )
+
+ parser.add_argument(
+ "-gh",
+ "--game_host",
+ help="host where to run the game server",
+ action="store",
+ required=False,
+ type=str,
+ default="127.0.0.1",
+ )
+
+ parser.add_argument(
+ "-gp",
+ "--game_port",
+ help="Port where to run the game server",
+ action="store",
+ required=False,
+ type=int,
+ default="9000",
+ )
- # 1. Perform the action
- self._actions_played.append(action)
-
- # No randomness in action success - we are playing in real world
- next_state = self._execute_action(state, action, agent_id)
-
+ parser.add_argument(
+ "-c",
+ "--task_config",
+ help="File with the task configuration",
+ action="store",
+ required=True,
+ type=str,
+ default="netsecenv_conf.yaml",
+ )
-
- # Make the state we just got into, our current state
- current_state = state
- self.logger.info(f'New state: {next_state} ')
+ args = parser.parse_args()
+ print(args)
+ # Set the logging
+ log_filename = Path("NSG_real_world_coordinator.log")
+ if not log_filename.parent.exists():
+ os.makedirs(log_filename.parent)
+ # Convert the logging level in the args to the level to use
+ pass_level = get_logging_level(args.debug_level)
- # Save the transition to the episode replay buffer if there is any
- if self._episode_replay_buffer is not None:
- self._episode_replay_buffer.append((current_state, action, reward, next_state))
- # Return an observation
- return next_state
\ No newline at end of file
+ logging.basicConfig(
+ filename=log_filename,
+ filemode="w",
+ format="%(asctime)s %(name)s %(levelname)s %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=pass_level,
+ )
+
+ game_server = NSERealWorldGameCoordinator(args.game_host, args.game_port, args.task_config)
+ # Run it!
+ game_server.run()
\ No newline at end of file
diff --git a/utils/__init__.py b/AIDojoCoordinator/worlds/__init__.py
similarity index 100%
rename from utils/__init__.py
rename to AIDojoCoordinator/worlds/__init__.py
diff --git a/NetSecGameAgents b/NetSecGameAgents
index d15cf5c6..5c966fd7 160000
--- a/NetSecGameAgents
+++ b/NetSecGameAgents
@@ -1 +1 @@
-Subproject commit d15cf5c663c76479a7ce9845e56e46c28016c80d
+Subproject commit 5c966fd7f8560a2632dd4279392440219b4bb486
diff --git a/README.md b/README.md
index 619ddd48..c944e5c7 100755
--- a/README.md
+++ b/README.md
@@ -19,17 +19,17 @@ python -m venv ai-dojo-venv-
source ai-dojo-venv/bin/activate
```
-- Install the requirements with
+- Install using pip by running following in the **root** directory
```bash
-python3 -m pip install -r requirements.txt
+pip install -e .
```
- If you use conda use
```bash
conda create --name aidojo python==3.10
conda activate aidojo
-python3 -m pip install -r requirements.txt
+pip install -e .
```
## Architecture
@@ -149,8 +149,13 @@ This approach ensures that only repeated or excessive behavior is flagged, reduc
### Starting the game
-The environment should be created before starting the agents. The properties of the environment can be defined in a YAML file. The game server can be started by running:
-```python3 coordinator.py```
+The environment should be created before starting the agents. The properties of the game, the task and the topology can be either read from a local file or via REST request to the GameDashboard.
+
+#### To start the game with local configuration file
+```python3 -m AIDojoCoordinator.worlds.NSEGameCoordinator --task_config=```
+
+#### To start the game with remotely defined configuration
+```python3 -m AIDojoCoordinator.worlds.CYSTCoordinator --service_host= --service_port= ```
When created, the environment:
1. reads the configuration file
diff --git a/coordinator.conf b/coordinator.conf
deleted file mode 100644
index 4758d9b6..00000000
--- a/coordinator.conf
+++ /dev/null
@@ -1,7 +0,0 @@
-{
-"host": "127.0.0.1",
-"port": 9000,
-"start_reward": 0,
-"max_steps": 1500,
-"world_type": "netsecenv"
-}
diff --git a/coordinator.py b/coordinator.py
deleted file mode 100644
index 31dc475b..00000000
--- a/coordinator.py
+++ /dev/null
@@ -1,950 +0,0 @@
-#!/usr/bin/env python
-# Server for the Aidojo project, coordinator
-# Author: sebastian garcia, sebastian.garcia@agents.fel.cvut.cz
-# Author: Ondrej Lukas, ondrej.lukas@aic.fel.cvut.cz
-import jsonlines
-import argparse
-import logging
-import json
-import asyncio
-import enum
-from datetime import datetime
-from env.worlds.network_security_game import NetworkSecurityEnvironment
-from env.worlds.network_security_game_real_world import NetworkSecurityEnvironmentRealWorld
-from env.worlds.aidojo_world import AIDojoWorld
-from env.game_components import Action, Observation, ActionType, GameStatus, GameState
-from utils.utils import observation_as_dict, get_logging_level, get_file_hash
-from pathlib import Path
-import os
-import signal
-from env.global_defender import stochastic_with_threshold
-from utils.utils import ConfigParser
-
-@enum.unique
-class AgentStatus(enum.Enum):
- """
- Class representing the current status for each agent connected to the coordinator
- """
- JoinRequested = 0
- Ready = 1
- Playing = 2
- PlayingActive = 3
- FinishedMaxSteps = 4
- FinishedBlocked = 5
- FinishedGoalReached = 6
- FinishedGameLost = 7
- ResetRequested = 8
- Quitting = 9
-
-
-class AIDojo:
- def __init__(self, host: str, port: int, net_sec_config: str, world_type) -> None:
- self.host = host
- self.port = port
- self.logger = logging.getLogger("AIDojo-main")
- self._agent_action_queue = asyncio.Queue()
- self._agent_response_queues = {}
- self._coordinator = Coordinator(
- self._agent_action_queue,
- self._agent_response_queues,
- net_sec_config,
- allowed_roles=["Attacker", "Defender", "Benign"],
- world_type = world_type,
- )
-
- async def create_agent_queue(self, addr):
- """
- Create a queue for the given agent address if it doesn't already exist.
- """
- if addr not in self._agent_response_queues:
- self._agent_response_queues[addr] = asyncio.Queue()
- self.logger.info(f"Created queue for agent {addr}. {len(self._agent_response_queues)} queues in total.")
-
- def run(self)->None:
- """
- Wrapper for ayncio run function. Starts all tasks in AIDojo
- """
- asyncio.run(self.start_tasks())
-
- async def start_tasks(self):
- """
- High level funciton to start all the other asynchronous tasks.
- - Reads the conf of the coordinator
- - Creates queues
- - Start the main part of the coordinator
- - Start a server that listens for agents
- """
- self.logger.info("Starting all tasks")
- loop = asyncio.get_running_loop()
-
- self.logger.info("Starting Coordinator taks")
- coordinator_task = asyncio.create_task(self._coordinator.run())
-
- self.logger.info("Starting the server listening for agents")
- running_server = await asyncio.start_server(
- ConnectionLimitProtocol(
- self._agent_action_queue,
- self._agent_response_queues,
- max_connections=2
- ),
- self.host,
- self.port
- )
- addrs = ", ".join(str(sock.getsockname()) for sock in running_server.sockets)
- self.logger.info(f"\tServing on {addrs}")
-
- # prepare the stopping event for keyboard interrupt
- stop = loop.create_future()
-
- # register the signal handler to the stopping event
- loop.add_signal_handler(signal.SIGINT, stop.set_result, None)
-
- await stop # Event that triggers stopping the AIDojo
- # Stop the server
- self.logger.info("Initializing server shutdown")
- running_server.close()
- await running_server.wait_closed()
- self.logger.info("\tServer stopped")
- # Stop coordinator taks
- self.logger.info("Initializing coordinator shutdown")
- coordinator_task.cancel()
- await asyncio.gather(coordinator_task, return_exceptions=True)
- self.logger.info("\tCoordinator stopped")
- # Everything stopped correctly, terminate
- self.logger.info("AIDojo terminating")
-
-class ConnectionLimitProtocol(asyncio.Protocol):
- def __init__(self, actions_queue, agent_response_queues, max_connections):
- self.actions_queue = actions_queue
- self.answers_queues = agent_response_queues
- self.max_connections = max_connections
- self.current_connections = 0
- self.logger = logging.getLogger("AIDojo-Server")
- self._stop = False
-
- def close(self)->None:
- self.logger.info(
- "Stopping server"
- )
- self._stop = True
-
- async def handle_new_agent(self, reader, writer):
- async def send_data_to_agent(writer, data: str) -> None:
- """
- Send the world to the agent
- """
- writer.write(bytes(str(data).encode()))
-
- # Check if the maximum number of connections has been reached
- if self.current_connections >= self.max_connections:
- self.logger.info(
- f"Max connections reached. Rejecting new connection from {writer.get_extra_info('peername')}"
- )
- writer.close()
- return
-
- # Increment the count of current connections
- self.current_connections += 1
-
- # Handle the new agent
- addr = writer.get_extra_info("peername")
- self.logger.info(f"New agent connected: {addr}")
- # Ensure a queue exists for this agent
- if addr not in self.answers_queues:
- self.answers_queues[addr] = asyncio.Queue(maxsize=2)
- self.logger.info(f"Created queue for agent {addr}")
-
- try:
- while not self._stop:
- # Step 1: Read data from the agent
- data = await reader.read(500)
- if not data:
- self.logger.info(f"Agent {addr} disconnected.")
- quit_message = Action(ActionType.QuitGame, params={}).as_json()
- await self.actions_queue.put((addr, quit_message))
- break
-
- raw_message = data.decode().strip()
- self.logger.debug(f"Handler received from {addr}: {raw_message}")
-
- # Step 2: Forward the message to the Coordinator
- await self.actions_queue.put((addr, raw_message))
- await asyncio.sleep(0)
- # Step 3: Get a matching response from the answers queue
- response_queue = self.answers_queues[addr]
- response = await response_queue.get()
- self.logger.info(f"Sending response to agent {addr}: {response}")
-
- # Step 4: Send the response to the agent
- writer.write(bytes(str(response).encode()))
- await writer.drain()
- except KeyboardInterrupt:
- self.logger.debug("Terminating by KeyboardInterrupt")
- raise SystemExit
- except Exception as e:
- self.logger.error(f"Exception in handle_new_agent(): {e}")
- finally:
- # Decrement the count of current connections
- self.current_connections -= 1
- if addr in self.answers_queues:
- self.answers_queues.pop(addr)
- self.logger.info(f"Removed queue for agent {addr}")
- else:
- self.logger.warning(f"Queue for agent {addr} not found during cleanup.")
- writer.close()
- return
-
- async def __call__(self, reader, writer):
- await self.handle_new_agent(reader, writer)
-
-class Coordinator:
- def __init__(self, actions_queue, answers_queues, net_sec_config, allowed_roles, world_type="netsecenv"):
- # communication channels for asyncio
- # agents -> coordinator
- self._actions_queue = actions_queue
- # coordinator -> agent (separate queue per agent)
- self._answers_queues = answers_queues
- # coordinator -> world
- self._world_action_queue = asyncio.Queue()
- # world -> coordinator
- self._world_response_queue = asyncio.Queue()
-
- self.task_config = ConfigParser(net_sec_config)
- self.ALLOWED_ROLES = allowed_roles
- self.logger = logging.getLogger("AIDojo-Coordinator")
-
- # world definition
- match world_type:
- case "netsecenv":
- self._world = NetworkSecurityEnvironment(net_sec_config,self._world_action_queue, self._world_response_queue)
- case "netsecenv-real-world":
- self._world = NetworkSecurityEnvironmentRealWorld(net_sec_config, self._world_action_queue, self._world_response_queue)
- case _:
- self._world = AIDojoWorld(net_sec_config, self._world_action_queue, self._world_response_queue)
- self.world_type = world_type
- self._CONFIG_FILE_HASH = get_file_hash(net_sec_config)
- self._starting_positions_per_role = self._get_starting_position_per_role()
- self._win_conditions_per_role = self._get_win_condition_per_role()
- self._goal_description_per_role = self._get_goal_description_per_role()
- self._steps_limit_per_role = self._get_max_steps_per_role()
- self._use_global_defender = self.task_config.get_use_global_defender()
-
- # player information
- self.agents = {}
- # step counter per agent_addr (int)
- self._agent_steps = {}
- # reset request per agent_addr (bool)
- self._reset_requests = {}
- self._agent_observations = {}
- # starting per agent_addr (dict)
- self._agent_starting_position = {}
- # current state per agent_addr (GameState)
- self._agent_states = {}
- # last action played by agent (Action)
- self._agent_last_action = {}
- # agent status dict {agent_addr: AgentStatus}
- self._agent_statuses = {}
- # agent status dict {agent_addr: int}
- self._agent_rewards = {}
- # trajectories per agent_addr
- self._agent_trajectories = {}
-
- @property
- def episode_end(self)->bool:
- # Episode ends ONLY IF all agents with defined max_steps reached the end fo the episode
- exists_active_player = any(status is AgentStatus.PlayingActive for status in self._agent_statuses.values())
- self.logger.debug(f"End evaluation: {self._agent_statuses.items()} - Episode end:{not exists_active_player}")
- return not exists_active_player
-
- @property
- def config_file_hash(self):
- return self._CONFIG_FILE_HASH
-
- def convert_msg_dict_to_json(self, msg_dict)->str:
- try:
- # Convert message into string representation
- output_message = json.dumps(msg_dict)
- except Exception as e:
- self.logger.error(f"Error when converting msg to Json:{e}")
- raise e
- # Send to anwer_queue
- return output_message
-
- async def run(self):
- """
- Main method to be run for coordinating the agent's interaction with the game engine.
- - Reads messages from action queue
- - processes actions based on their type
- - Forwards actions in the game engine
- - Evaluates states and checks for goal
- - Forwards responses the agents
- """
- try:
- self.logger.info("Main coordinator started.")
- # Start World Response handler task
- world_response_task = asyncio.create_task(self._handle_world_responses())
- world_processing_task = asyncio.create_task(self._world.handle_incoming_action())
- while True:
- # Read message from the queue
- agent_addr, message = await self._actions_queue.get()
- if message is not None:
- self.logger.debug(f"Coordinator received: {message}.")
- try: # Convert message to Action
- action = Action.from_json(message)
- except Exception as e:
- self.logger.error(
- f"Error when converting msg to Action using Action.from_json():{e}, {message}"
- )
- match action.type: # process action based on its type
- case ActionType.JoinGame:
- self.logger.debug(f"Start processing of ActionType.JoinGame by {agent_addr}")
- await self._process_join_game_action(agent_addr, action)
- case ActionType.QuitGame:
- self.logger.info(f"Coordinator received from QUIT message from agent {agent_addr}")
- # update agent status
- self._agent_statuses[agent_addr] = AgentStatus.Quitting
- # forward the message to the world
- await self._world_action_queue.put((agent_addr, action, self._agent_states[agent_addr]))
- case ActionType.ResetGame:
- self._reset_requests[agent_addr] = True
- self._agent_statuses[agent_addr] = AgentStatus.ResetRequested
- self.logger.info(f"Coordinator received from RESET request from agent {agent_addr} ({self.agents[agent_addr]})")
- if all(self._reset_requests.values()):
- # should we discard the queue here?
- self.logger.info("All active agents requested reset")
- # send WORLD reset request to the world
- await self._world_action_queue.put(("world", Action(ActionType.ResetGame, params={}), None))
- # send request for each of the agents (to get new initial state)
- for agent in self._reset_requests:
- await self._world_action_queue.put((agent, Action(ActionType.ResetGame, params={}), self._agent_starting_position[agent]))
- else:
- self.logger.info("\t Waiting for other agents to request reset")
- case _:
- self.logger.debug(f"Agent {self.agents[agent_addr]}, played Action {action}.")
- # actions in the game
- await self._process_generic_action(agent_addr, action)
- await asyncio.sleep(0)
- except asyncio.CancelledError:
- world_response_task.cancel()
- world_processing_task.cancel()
- asyncio.gather(world_processing_task, world_response_task, return_exceptions=True)
- self.logger.info("\tTerminating by CancelledError")
- except Exception as e:
- self.logger.error(f"Exception in Class coordinator(): {e}")
- raise e
-
- def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation:
- """
- Method to initialize new player upon joining the game.
- Returns initial observation for the agent based on the agent's role
- """
- self.logger.info(f"\tInitializing new player{agent_addr}")
- agent_name, agent_role = self.agents[agent_addr]
- self._agent_steps[agent_addr] = 0
- self._reset_requests[agent_addr] = False
- self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
- self._agent_statuses[agent_addr] = AgentStatus.PlayingActive if agent_role == "Attacker" else AgentStatus.Playing
- self._agent_states[agent_addr] = agent_current_state
-
-
- #self._agent_states[agent_addr] = self._world.create_state_from_view(self._agent_starting_position[agent_addr])
-
- # if self._steps_limit_per_role[agent_role]:
- # # This agent can force episode end (has timeout and goal defined)
- # self._agent_statuses[agent_addr] = AgentStatus.PlayingActive
- # else:
- # # This agent can NOT force episode end (does NOT timeout or goal defined)
- # self._agent_statuses[agent_addr] = AgentStatus.Playing
-
- if self.task_config.get_store_trajectories() or self._use_global_defender:
- self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
- self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
- # Initializeing the player, also sets the end_episode to false for the Observation of the agent
- end_episode = False
- return Observation(self._agent_states[agent_addr], 0, end_episode, {})
-
- def _remove_player(self, agent_addr:tuple)->dict:
- """
- Removes player from the game. Should be called AFTER QuitGame action was processed by the world.
- """
- self.logger.info(f"Removing player {agent_addr} from the Coordinator")
- agent_info = {}
- if agent_addr in self.agents:
- agent_info["state"] = self._agent_states.pop(agent_addr)
- agent_info["status"] = self._agent_statuses.pop(agent_addr)
- agent_info["num_steps"] = self._agent_steps.pop(agent_addr)
- agent_info["reset_request"] = self._reset_requests.pop(agent_addr)
- agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None)
- agent_info["agent_info"] = self.agents.pop(agent_addr)
- self.logger.debug(f"\t{agent_info}")
- else:
- self.logger.info(f"\t Player {agent_addr} not present in the game!")
- return agent_info
-
- def _get_starting_position_per_role(self)->dict:
- """
- Method for finding starting position for each agent role in the game.
- """
- starting_positions = {}
- for agent_role in self.ALLOWED_ROLES:
- try:
- starting_positions[agent_role] = self.task_config.get_start_position(agent_role=agent_role)
- self.logger.info(f"Starting position for role '{agent_role}': {starting_positions[agent_role]}")
- except KeyError:
- starting_positions[agent_role] = {}
- return starting_positions
-
- def _get_win_condition_per_role(self)-> dict:
- """
- Method for finding wininng conditions for each agent role in the game.
- """
- win_conditions = {}
- for agent_role in self.ALLOWED_ROLES:
- try:
- win_conditions[agent_role] = self._world.update_goal_dict(
- self.task_config.get_win_conditions(agent_role=agent_role)
- )
- except KeyError:
- win_conditions[agent_role] = {}
- self.logger.info(f"Win condition for role '{agent_role}': {win_conditions[agent_role]}")
- return win_conditions
-
- def _get_goal_description_per_role(self)->dict:
- """
- Method for finding goal description for each agent role in the game.
- """
- goal_descriptions ={}
- for agent_role in self.ALLOWED_ROLES:
- try:
- goal_descriptions[agent_role] = self._world.update_goal_descriptions(
- self.task_config.get_goal_description(agent_role=agent_role)
- )
- except KeyError:
- goal_descriptions[agent_role] = ""
- self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}")
- return goal_descriptions
-
- def _get_max_steps_per_role(self)->dict:
- """
- Method for finding max amount of steps in 1 episode for each agent role in the game.
- """
- max_steps = {}
- for agent_role in self.ALLOWED_ROLES:
- try:
- max_steps[agent_role] = self.task_config.get_max_steps(agent_role)
- except KeyError:
- max_steps[agent_role] = None
- self.logger.info(f"Max steps in episode for '{agent_role}': {max_steps[agent_role]}")
- return max_steps
-
- async def _process_join_game_action(self, agent_addr: tuple, action: Action)->None:
- """ "
- Method for processing Action of type ActionType.JoinGame
- """
- self.logger.info(f"New Join request by {agent_addr}.")
- if agent_addr not in self.agents:
- agent_name = action.parameters["agent_info"].name
- agent_role = action.parameters["agent_info"].role
- if agent_role in self.ALLOWED_ROLES:
- self.agents[agent_addr] = (agent_name, agent_role)
- self._agent_statuses[agent_addr] = AgentStatus.JoinRequested
- self.logger.debug(f"Sending JoinRequest by {agent_addr} to the world_action_queue")
- await self._world_action_queue.put((agent_addr, action, self._starting_positions_per_role[agent_role]))
- else:
- self.logger.info(
- f"\tError in registration, unknown agent role: {agent_role}!"
- )
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(GameStatus.BAD_REQUEST),
- "message": f"Incorrect agent_role {agent_role}",
- }
- response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
- await self._answers_queues[agent_addr].put(response_msg_json)
- else:
- self.logger.info("\tError in registration, agent already exists!")
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(GameStatus.BAD_REQUEST),
- "message": "Agent already exists.",
- }
- response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
- await self._answers_queues[agent_addr].put(response_msg_json)
-
- def _create_response_to_reset_game_action(self, agent_addr: tuple) -> dict:
- """ "
- Method for generatating answers to Action of type ActionType.ResetGame after all agents requested reset
- """
- self.logger.info(
- f"Coordinator responding to RESET request from agent {agent_addr}"
- )
- # store trajectory in file if needed
- if self.task_config.get_store_trajectories():
- self._store_trajectory_to_file(agent_addr)
- new_observation = Observation(self._agent_states[agent_addr], 0, self.episode_end, {})
- # reset trajectory
- self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(GameStatus.RESET_DONE),
- "observation": observation_as_dict(new_observation),
- "message": {
- "message": "Resetting Game and starting again.",
- "max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]],
- "goal_description": self._goal_description_per_role[self.agents[agent_addr][1]],
- "configuration_hash": self._CONFIG_FILE_HASH
- },
- }
- return output_message_dict
-
- def _add_step_to_trajectory(self, agent_addr:tuple, action:Action, reward:float, next_state:GameState, end_reason:str)-> None:
- """
- Method for adding one step to the agent trajectory.
- """
- if agent_addr in self._agent_trajectories:
- self.logger.debug(f"Adding step to trajectory of {agent_addr}")
- self._agent_trajectories[agent_addr]["trajectory"]["actions"].append(action.as_dict)
- self._agent_trajectories[agent_addr]["trajectory"]["rewards"].append(reward)
- self._agent_trajectories[agent_addr]["trajectory"]["states"].append(next_state.as_dict)
- if end_reason:
- self._agent_trajectories[agent_addr]["end_reason"] = end_reason
-
- def _store_trajectory_to_file(self, agent_addr:tuple, location="./trajectories")-> None:
- self.logger.debug(f"Storing Trajectory of {agent_addr}in file")
- if agent_addr in self._agent_trajectories:
- agent_name, agent_role = self.agents[agent_addr]
- filename = os.path.join(location, f"{datetime.now():%Y-%m-%d}_{agent_name}_{agent_role}.jsonl")
- with jsonlines.open(filename, "a") as writer:
- writer.write(self._agent_trajectories[agent_addr])
- self.logger.info(f"Trajectory of {agent_addr} strored in {filename}")
-
- def _reset_trajectory(self,agent_addr)->dict:
- agent_name, agent_role = self.agents[agent_addr]
- self.logger.debug(f"Resetting trajectory of {agent_addr}")
- return {
- "trajectory":{
- "states":[self._agent_states[agent_addr].as_dict],
- "actions":[],
- "rewards":[],
- },
- "end_reason":None,
- "agent_role":agent_role,
- "agent_name":agent_name
- }
-
- async def _process_generic_action(self, agent_addr: tuple, action: Action) ->None:
- """
- Method processing the Actions relevant to the environment
- """
- self.logger.info(f"Processing {action} from {agent_addr}")
- if not self.episode_end:
- self._agent_last_action[agent_addr] = action
- await self._world_action_queue.put((agent_addr, action, self._agent_states[agent_addr]))
- else:
- # Episode finished, just send back the rewards and final episode info
- self._assign_end_rewards()
- self.logger.info(f"{self.episode_end}, {self._agent_statuses[agent_addr]}")
- output_message_dict = self._generate_episode_end_message(agent_addr)
- response_msg_json = self.convert_msg_dict_to_json(output_message_dict)
- await self._answers_queues[agent_addr].put(response_msg_json)
-
- def _generate_episode_end_message(self, agent_addr:tuple)->dict:
- """
- Method for generating response when agent attemps to make a step after episode ended.
- """
- # There is a case when a defender agent connects first and is alone that it doesnt receive an observation because the game may not have started.
- current_observation = self._agent_observations.get(agent_addr, Observation(self._agent_states[agent_addr], 0, True, {}))
- reward = self._agent_rewards[agent_addr]
- end_reason = str(self._agent_statuses[agent_addr])
- new_observation = Observation(
- current_observation.state,
- reward=reward,
- end=True,
- info={'end_reason': end_reason, "info":"Episode ended. Request reset for starting new episode."})
- output_message_dict = {
- "to_agent": agent_addr,
- "observation": observation_as_dict(new_observation),
- "status": str(GameStatus.FORBIDDEN),
- }
- return output_message_dict
-
- def _goal_reached(self, agent_addr:tuple)->bool:
- """
- Determines if and agent reached a goal state
- """
- self.logger.info(f"Goal check for {agent_addr}({self.agents[agent_addr][1]})")
- agents_state = self._agent_states[agent_addr]
- agent_role = self.agents[agent_addr][1]
- win_condition = self._world.update_goal_dict(self._win_conditions_per_role[agent_role])
- goal_check = self._check_goal(agents_state, win_condition)
- if goal_check:
- self.logger.info("\tGoal reached!")
- else:
- self.logger.info("\tGoal not reached!")
- return goal_check
-
- def _check_goal(self, state:GameState, goal_conditions:dict)->bool:
- """
- Check if the goal conditons were satisfied in a given game state
- """
- def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
- """
- Helper function for checking if a goal dictionary condition is satisfied
- """
- # check if we have all IPs that should have some values (are keys in goal_dict)
- if goal_dict.keys() <= known_dict.keys():
- try:
- # Check if values (sets) for EACH key (host) in goal_dict are subsets of known_dict, keep matching_keys
- matching_keys = [host for host in goal_dict.keys() if goal_dict[host]<= known_dict[host]]
- # Check we have the amount of mathing keys as in the goal_dict
- if len(matching_keys) == len(goal_dict.keys()):
- return True
- except KeyError:
- # some keys are missing in the known_dict
- return False
- return False
-
- # For each part of the state of the game, check if the conditions are met
- goal_reached = {}
- goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks)
- goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts)
- goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts)
- goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services)
- goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data)
- goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks)
- self.logger.debug(f"\t{goal_reached}")
- return all(goal_reached.values())
-
- def _check_detection(self, agent_addr:tuple, last_action:Action)->bool:
- self.logger.info(f"Detection check for {agent_addr}({self.agents[agent_addr][1]})")
- detection = False
- if last_action:
- if self._use_global_defender:
- self.logger.warning("Global defender - ONLY use for backward compatibility!")
- episode_actions = None
- if (agent_addr in self._agent_trajectories and "trajectory" in self._agent_trajectories[agent_addr] and "actions" in self._agent_trajectories[agent_addr]["trajectory"]):
- episode_actions = self._agent_trajectories[agent_addr]["trajectory"]["actions"]
- detection = stochastic_with_threshold(last_action, episode_actions)
- if detection:
- self.logger.info("\tDetected!")
- else:
- self.logger.info("\tNot detected!")
- return detection
-
- def _max_steps_reached(self, agent_addr:tuple) ->bool:
- """
- Checks if the agent reached the max allowed steps. Only applies to role 'Attacker'
- """
- self.logger.debug(f"Checking timout for {self.agents[agent_addr]}")
- agent_role = self.agents[agent_addr][1]
- if self._steps_limit_per_role[agent_role]:
- if self._agent_steps[agent_addr] >= self._steps_limit_per_role[agent_role]:
- self.logger.info(f"Timeout reached by {self.agents[agent_addr]}!")
- return True
- else:
- self.logger.debug(f"No max steps defined for role {agent_role}")
- return False
-
- def _assign_end_rewards(self)->None:
- """
- Method which assings rewards to each agent which has finished playing
- """
- self.logger.debug("Assigning rewards")
- is_episode_over = self.episode_end
- for agent, status in self._agent_statuses.items():
- if agent not in self._agent_rewards.keys(): # reward has not been assigned yet
- agent_name, agent_role = self.agents[agent]
- if agent_role == "Attacker":
- match status:
- case AgentStatus.FinishedGoalReached:
- self._agent_rewards[agent] = self._world._rewards["goal"]
- case AgentStatus.FinishedMaxSteps:
- self._agent_rewards[agent] = 0
- case AgentStatus.FinishedBlocked:
- self._agent_rewards[agent] = self._world._rewards["detection"]
- self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
- elif agent_role == "Defender":
- if self._agent_statuses[agent] is AgentStatus.FinishedMaxSteps: #defender was responsible for the end
- raise NotImplementedError
- self._agent_rewards[agent] = 0
- else:
- if is_episode_over: #only assign defender's reward when episode ends
- sucessful_attacks = list(self._agent_statuses.values()).count("goal_reached")
- if sucessful_attacks > 0:
- self._agent_rewards[agent] = sucessful_attacks*self._world._rewards["detection"]
- self._agent_statuses[agent] = "game_lost"
- else: #no successful attacker
- self._agent_rewards[agent] = self._world._rewards["goal"]
- self._agent_statuses[agent] = "goal_reached"
- self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
- else:
- if is_episode_over:
- self._agent_rewards[agent] = 0
- self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
-
- async def _handle_world_responses(self)-> None:
- """
- Continuously processes responses from the AIDojo World, evaluates them and sends messages to agents
- """
- try:
- self.logger.info("\tStarting task to handle AIDojo World responses")
- while True:
- try:
- # Get a response from the World Response Queue
- agent_id, response = await self._world_response_queue.get()
- self.logger.info(f"Received response for agent {agent_id}: {response}")
-
- # Processing of the response
- response_msg_json = self._process_world_response(agent_id, response)
- # Notify the agent if there is message
- if len(response_msg_json) > 2: # we have NON EMPTY JSON (len('{}') = 2)
- self.logger.info(f"Generated response for agent {agent_id} ({self.agents[agent_id]}): {response_msg_json}")
- await self._answers_queues[agent_id].put(response_msg_json)
- self.logger.info(f"Placed response in answers queue for agent {agent_id}")
- else:
- self.logger.info(f"Empty response for agent {agent_id}: {response_msg_json}. Skipping")
- await asyncio.sleep(0)
- except Exception as e:
- self.logger.error(f"Error handling world response: {e}")
- except asyncio.CancelledError:
- self.logger.info("\tTerminating by CancelledError")
-
- def _process_world_response(self, agent_addr:tuple, response:tuple)-> str:
- """
- Method for generation of messages to the agent based on the world response
- """
- agent_new_state, game_status = response
- output_message_dict = {}
- try:
- agent_status = self._agent_statuses[agent_addr]
- if agent_status is AgentStatus.JoinRequested:
- output_message_dict = self._process_world_response_created(agent_addr, game_status, agent_new_state)
- elif agent_status is AgentStatus.ResetRequested:
- output_message_dict = self._process_world_response_reset_done(agent_addr, game_status, agent_new_state)
- elif agent_status is AgentStatus.Quitting:
- if game_status is GameStatus.OK:
- self.logger.debug(f"Agent {agent_addr} removed successfuly from the world")
- else:
- self.logger.warning(f"Error when removing Agent {agent_addr} from the world")
- self._remove_player(agent_addr)
- elif agent_status in [AgentStatus.Ready, AgentStatus.Playing, AgentStatus.PlayingActive]:
- output_message_dict = self._process_world_response_step(agent_addr, game_status, agent_new_state)
- elif agent_status in [AgentStatus.FinishedBlocked, AgentStatus.FinishedGameLost, AgentStatus.FinishedGoalReached, AgentStatus.FinishedMaxSteps]: # This if does not make sense. Put together with the previous (sebas)
- output_message_dict = self._process_world_response_step(agent_addr, game_status, agent_new_state)
- else:
- self.logger.error(f"Unsupported value '{agent_status}'!")
-
- msg_json = self.convert_msg_dict_to_json(output_message_dict)
- return msg_json
- except KeyError as e :
- self.logger.error(f"Agent {agent_addr} not found! {e}")
-
- def _process_world_response_created(self, agent_addr:tuple, game_status:GameStatus, new_agent_game_state:GameState)->dict:
- """
- Handles reply to Action.JoinGame for agent based on the response of the AIDojo World
- """
- # is agent correctly started in the world
- if game_status is GameStatus.CREATED:
- observation = self._initialize_new_player(agent_addr, new_agent_game_state)
- agent_name, agent_role = self.agents[agent_addr]
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(game_status),
- "observation": observation_as_dict(observation),
- "message": {
- "message": f"Welcome {agent_name}, registred as {agent_role}",
- "max_steps": self._steps_limit_per_role[agent_role],
- "goal_description": self._goal_description_per_role[agent_role],
- "actions": [str(a) for a in ActionType],
- "configuration_hash": self._CONFIG_FILE_HASH
- },
- }
- else:
- # remove traces of agent from the game
- self._remove_player(agent_addr)
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(game_status),
- "message": f"Error when initializing the agent {agent_name}({agent_role})",
- }
- return output_message_dict
-
- def _process_world_response_reset_done(self, agent_addr:tuple, game_status:GameStatus, agent_new_state:GameState)->dict:
- """
- Handles reply to Action.JoinGame for agent based on the response of the AIDojo World
- """
- if game_status is GameStatus.RESET_DONE:
- self._reset_requests[agent_addr] = False
- self._agent_steps[agent_addr] = 0
- self._agent_states[agent_addr] = agent_new_state
- self._agent_rewards.pop(agent_addr, None)
- if self._steps_limit_per_role[self.agents[agent_addr][1]]:
- # This agent can force episode end (has timeout and goal defined)
- self._agent_statuses[agent_addr] = AgentStatus.PlayingActive
- else:
- # This agent can NOT force episode end (does NOT timeout or goal defined)
- self._agent_statuses[agent_addr] = AgentStatus.Playing
- output_message_dict = self._create_response_to_reset_game_action(agent_addr)
- else:
- # remove traces of agent from the game
- agent_name, agent_role = self.agents
- self._remove_player(agent_addr)
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(game_status),
- "message": f"Error when resetting the agent {agent_name} ({agent_role})",
- }
- return output_message_dict
-
- def _process_world_response_step(self, agent_addr:tuple, game_status:GameStatus, agent_new_state:GameState)->dict:
- """
- Handles reply for agent based on the response of the AIDojo World. Covers followinf ActionTypes:
- - ActionType.ScanNetwork
- - ActionType.FindServices
- - ActionType.FindData
- - ActionType.ExfiltrateData
- - ActionType.ExploitService
- - ActionType.BlockIP
- """
- if game_status is GameStatus.OK:
- if not self.episode_end:
- # increase the action counter
- self._agent_steps[agent_addr] += 1
- self.logger.info(f"Agent {agent_addr} ({self.agents[agent_addr]}) did #steps: {self._agent_steps[agent_addr]}")
- # register the new state
- self._agent_states[agent_addr] = agent_new_state
- # load the action which lead to the new state
- last_action = self._agent_last_action[agent_addr]
- # check timeout
- if self._max_steps_reached(agent_addr):
- self._agent_statuses[agent_addr] = AgentStatus.FinishedMaxSteps
- # check detection
- if self._check_detection(agent_addr, last_action):
- self._agent_statuses[agent_addr] = AgentStatus.FinishedBlocked
- # check goal
- if self._goal_reached(agent_addr):
- self._agent_statuses[agent_addr] = AgentStatus.FinishedGoalReached
- # add reward for taking a step
- reward = self._world._rewards["step"]
-
- obs_info = {}
- end_reason = None
- if self._agent_statuses[agent_addr] is AgentStatus.FinishedGoalReached:
- self._assign_end_rewards()
- reward += self._agent_rewards[agent_addr]
- end_reason = "goal_reached"
- obs_info = {'end_reason': "goal_reached"}
- elif self._agent_statuses[agent_addr] is AgentStatus.FinishedMaxSteps:
- self._assign_end_rewards()
- reward += self._agent_rewards[agent_addr]
- obs_info = {"end_reason": "max_steps"}
- end_reason = "max_steps"
- elif self._agent_statuses[agent_addr] is AgentStatus.FinishedBlocked:
- self._assign_end_rewards()
- reward += self._agent_rewards[agent_addr]
- obs_info = {"end_reason": "blocked"}
-
- # record step in trajecory
- self._add_step_to_trajectory(agent_addr, last_action, reward,self._agent_states[agent_addr], end_reason)
- new_observation = Observation(self._agent_states[agent_addr], reward, self.episode_end, info=obs_info)
-
- self._agent_observations[agent_addr] = new_observation
-
- output_message_dict = {
- "to_agent": agent_addr,
- "observation": observation_as_dict(new_observation),
- "status": str(GameStatus.OK),
- }
- else:
- self._assign_end_rewards()
- output_message_dict = self._generate_episode_end_message(agent_addr)
- else:
- output_message_dict = {
- "to_agent": agent_addr,
- "status": str(game_status),
- "message": f"Error when playing action {last_action}",
- }
- return output_message_dict
-
-__version__ = "v0.2.2"
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description=f"NetSecGame Coordinator Server version {__version__}. Author: Sebastian Garcia, sebastian.garcia@agents.fel.cvut.cz",
- usage="%(prog)s [options]",
- )
- parser.add_argument(
- "-v",
- "--verbose",
- help="Verbosity level. This shows more info about the results.",
- action="store",
- required=False,
- type=int,
- )
- parser.add_argument(
- "-c",
- "--configfile",
- help="Configuration file.",
- action="store",
- required=False,
- type=str,
- default="coordinator.conf",
- )
- parser.add_argument(
- "-t",
- "--task_config",
- help="Task configuration file.",
- action="store",
- required=False,
- type=str,
- default="env/netsecenv_conf.yaml",
- )
- parser.add_argument(
- "-l",
- "--debug_level",
- help="Define the debug level for the logs. DEBUG, INFO, WARNING, ERROR, CRITICAL",
- action="store",
- required=False,
- type=str,
- default="DEBUG",
- )
-
- args = parser.parse_args()
- print(args)
- # Set the logging
- log_filename = Path("coordinator.log")
- if not log_filename.parent.exists():
- os.makedirs(log_filename.parent)
-
- # Convert the logging level in the args to the level to use
- pass_level = get_logging_level(args.debug_level)
-
- logging.basicConfig(
- filename=log_filename,
- filemode="w",
- format="%(asctime)s %(name)s %(levelname)s %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- level=pass_level,
- )
-
- # load config for coordinator
- with open(args.configfile, "r") as jfile:
- confjson = json.load(jfile)
-
- host = confjson.get("host", None)
- port = confjson.get("port", None)
- world_type = confjson.get('world_type', 'netsecgame')
-
- # prioritize task config from CLI
- if args.task_config:
- task_config_file = args.task_config
- else:
- # Try to use task config from coordinator.conf
- task_config_file = confjson.get("task_config", None)
- if task_config_file is None:
- raise KeyError("Task configuration must be provided to start the coordinator! Use -h for more details.")
- # Create AI Dojo
- ai_dojo = AIDojo(host, port, task_config_file, world_type)
- # Run it!
- ai_dojo.run()
\ No newline at end of file
diff --git a/env/global_defender.py b/env/global_defender.py
deleted file mode 100644
index aba2da41..00000000
--- a/env/global_defender.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Author: Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
-from itertools import groupby
-from .game_components import ActionType, Action
-from random import random
-
-
-# The probability of detecting an action is defined by the following dictionary
-DEFAULT_DETECTION_PROBS = {
- ActionType.ScanNetwork: 0.05,
- ActionType.FindServices: 0.075,
- ActionType.ExploitService: 0.1,
- ActionType.FindData: 0.025,
- ActionType.ExfiltrateData: 0.025,
- ActionType.BlockIP: 0.01
-}
-
-# Ratios of action types in the time window (TW) for each action type. The ratio should be higher than the defined value to trigger a detection check
-TW_TYPE_RATIOS_THRESHOLD = {
- ActionType.ScanNetwork: 0.25,
- ActionType.FindServices: 0.3,
- ActionType.ExploitService: 0.25,
- ActionType.FindData: 0.5,
- ActionType.ExfiltrateData: 0.25,
- ActionType.BlockIP: 1
-}
-
-# Thresholds for consecutive actions of the same type in the TW. Only if the threshold is crossed, the detection check is triggered
-TW_CONSECUTIVE_TYPE_THRESHOLD = {
- ActionType.ScanNetwork: 2,
- ActionType.FindServices: 3,
- ActionType.ExfiltrateData: 2
-}
-
-# Thresholds for repeated actions in the episode. Only if the threshold is crossed, the detection check is triggered
-EPISODE_REPEATED_ACTION_THRESHOLD = {
- ActionType.ExploitService: 2,
- ActionType.FindData: 2,
-}
-
-def stochastic(action_type:ActionType)->bool:
- """
- Simple random detection based on predefied probability and ActionType
- """
- roll = random()
- if roll < DEFAULT_DETECTION_PROBS[action_type]:
- return True
- else:
- return False
-
-def stochastic_with_threshold(action: Action, episode_actions:list, tw_size:int=5)-> bool:
- """
- Only detect based on set probabilities if pre-defined thresholds are crossed.
- """
- # extend the episode with the latest action
- # We need to copy the list before the copying, so we avoid modifying it when it is returned. Modifycation of passed list is the default behavior in Python
- temp_episode_actions = episode_actions.copy()
- temp_episode_actions.append(action.as_dict)
- if len(temp_episode_actions) >= tw_size:
- last_n_actions = temp_episode_actions[-tw_size:]
- last_n_action_types = [action['type'] for action in last_n_actions]
- # compute ratio of action type in the TW
- tw_ratio = last_n_action_types.count(str(action.type))/tw_size
- # Count how many times this exact (parametrized) action was played in episode
- repeats_in_episode = temp_episode_actions.count(action.as_dict)
- # compute the highest consecutive number of action type in TW
- max_consecutive_action_type = max(sum(1 for item in grouped if item == str(action.type))
- for _, grouped in groupby(last_n_action_types))
-
- if action.type in TW_CONSECUTIVE_TYPE_THRESHOLD.keys():
- # ScanNetwork, FindServices, ExfiltrateData
- if tw_ratio < TW_TYPE_RATIOS_THRESHOLD[action.type] and max_consecutive_action_type < TW_CONSECUTIVE_TYPE_THRESHOLD[action.type]:
- return False
- else:
- return stochastic(action.type)
- elif action.type in EPISODE_REPEATED_ACTION_THRESHOLD.keys():
- # FindData, Exploit service
- if tw_ratio < TW_TYPE_RATIOS_THRESHOLD[action.type] and repeats_in_episode < EPISODE_REPEATED_ACTION_THRESHOLD[action.type]:
- return False
- else:
- return stochastic(action.type)
- else: #Other actions - Do not detect
- return False
-
- else:
- return False
diff --git a/env/scenarios/tiny_scenario_configuration.py b/env/scenarios/tiny_scenario_configuration.py
deleted file mode 100644
index c3e9f4dd..00000000
--- a/env/scenarios/tiny_scenario_configuration.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import cyst.api.configuration as cyst_cfg
-from cyst.api.logic.access import AuthenticationProviderType, AuthenticationTokenType, AuthenticationTokenSecurity
-
-''' --------------------------------------------------------------------------------------------------------------------
-A template for local password authentication.
-'''
-local_password_auth = cyst_cfg.AuthenticationProviderConfig(
- provider_type=AuthenticationProviderType.LOCAL,
- token_type=AuthenticationTokenType.PASSWORD,
- token_security=AuthenticationTokenSecurity.SEALED,
- timeout=30
-)
-
-''' --------------------------------------------------------------------------------------------------------------------
-Server 1:
-- SMB/File sharing (It is vulnerable to some remote exploit)
-- Remote Desktop
-- Can go to router and internet
-
-- the only windows server. It does not connect to the AD
-- access schemes for remote desktop and file sharing are kept separate, but can be integrated into one if needed
-'''
-smb_server = cyst_cfg.NodeConfig(
- active_services=[],
- passive_services=[
- cyst_cfg.PassiveServiceConfig(
- type="lanman server",
- owner="Local system",
- version="10.0.19041",
- local=False,
- private_data=[
- cyst_cfg.DataConfig(
- owner="User1",
- description="DataFromServer1"
- )
- ],
- access_level=cyst_cfg.AccessLevel.LIMITED,
- authentication_providers=[],
- access_schemes=[
- cyst_cfg.AccessSchemeConfig(
- authentication_providers=["windows login"],
- authorization_domain=cyst_cfg.AuthorizationDomainConfig(
- type=cyst_cfg.AuthorizationDomainType.LOCAL,
- authorizations=[
- cyst_cfg.AuthorizationConfig("User1", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("User2", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("User3", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("User4", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("User5", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("Administrator", cyst_cfg.AccessLevel.ELEVATED)
- ]
- )
- )
- ]
- ),
- cyst_cfg.PassiveServiceConfig(
- type="windows login",
- owner="Administrator",
- version="10.0.19041",
- local=True,
- access_level=cyst_cfg.AccessLevel.ELEVATED,
- authentication_providers=[local_password_auth("windows login")]
- ),
- cyst_cfg.PassiveServiceConfig(
- type="powershell",
- owner="Local system",
- version="10.0.19041",
- local=True,
- access_level=cyst_cfg.AccessLevel.LIMITED
- )
- ],
- traffic_processors=[],
- interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.1.2"), cyst_cfg.IPNetwork("192.168.1.0/24"))],
- shell="powershell",
- id="smb_server"
-)
-
-
-''' --------------------------------------------------------------------------------------------------------------------
-Client 1
-
-- Remote Desktop
-- Accounts
--- Local admin
--- User1
-- Can go to server 1, 2, 3, router and internet
-- Has the attacker
-'''
-client_1 = cyst_cfg.NodeConfig(
- active_services=[
- cyst_cfg.ActiveServiceConfig(
- type="scripted_actor",
- name="attacker",
- owner="attacker",
- access_level=cyst_cfg.AccessLevel.LIMITED,
- id="attacker_service"
- )
- ],
- passive_services=[
- cyst_cfg.PassiveServiceConfig(
- type="remote desktop service",
- owner="Local system",
- version="10.0.19041",
- local=False,
- access_level=cyst_cfg.AccessLevel.ELEVATED,
- parameters=[
- (cyst_cfg.ServiceParameter.ENABLE_SESSION, True),
- (cyst_cfg.ServiceParameter.SESSION_ACCESS_LEVEL, cyst_cfg.AccessLevel.LIMITED)
- ],
- authentication_providers=[local_password_auth("client_1_windows_login")],
- access_schemes=[
- cyst_cfg.AccessSchemeConfig(
- authentication_providers=["client_1_windows_login"],
- authorization_domain=cyst_cfg.AuthorizationDomainConfig(
- type=cyst_cfg.AuthorizationDomainType.LOCAL,
- authorizations=[
- cyst_cfg.AuthorizationConfig("User1", cyst_cfg.AccessLevel.LIMITED),
- cyst_cfg.AuthorizationConfig("Administrator", cyst_cfg.AccessLevel.ELEVATED)
- ]
- )
- )
- ]
- ),
- cyst_cfg.PassiveServiceConfig(
- type="powershell",
- owner="Local system",
- version="10.0.19041",
- local=True,
- access_level=cyst_cfg.AccessLevel.LIMITED
- ),
- cyst_cfg.PassiveServiceConfig(
- type="can_attack_start_here",
- owner="Local system",
- version="1",
- local=True,
- access_level=cyst_cfg.AccessLevel.LIMITED
- )
- ],
- traffic_processors=[],
- interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.2.2"), cyst_cfg.IPNetwork("192.168.2.0/24"))],
- shell="powershell",
- id="client_1"
-)
-
-router1 = cyst_cfg.RouterConfig(
- interfaces=[
- cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.1.1"), cyst_cfg.IPNetwork("192.168.1.0/24"), index=2),
- cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("192.168.2.1"), cyst_cfg.IPNetwork("192.168.2.0/24"), index=3),
- ],
- routing_table=[
- # Push everything not-infrastructure to the internet
- cyst_cfg.RouteConfig(cyst_cfg.IPNetwork("0.0.0.0/0"), 10)
- ],
- # Firewall FORWARD policy specifies inter-network routes that are enabled
- # Firewall INPUT policy specifies who can connect directly to the router. In this scenario, everyone can.
- traffic_processors=[
- cyst_cfg.FirewallConfig(
- default_policy=cyst_cfg.FirewallPolicy.DENY,
- chains=[
- cyst_cfg.FirewallChainConfig(
- type=cyst_cfg.FirewallChainType.INPUT,
- policy=cyst_cfg.FirewallPolicy.DENY,
- rules=[
- cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.1.0/24"), cyst_cfg.IPNetwork("192.168.1.1/32"), "*", cyst_cfg.FirewallPolicy.ALLOW),
- cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.2.0/24"), cyst_cfg.IPNetwork("192.168.2.1/32"), "*", cyst_cfg.FirewallPolicy.ALLOW)
- ]
- ),
- cyst_cfg.FirewallChainConfig(
- type=cyst_cfg.FirewallChainType.FORWARD,
- policy=cyst_cfg.FirewallPolicy.DENY,
- rules=[
- # Client 1 can go to server 1
- cyst_cfg.FirewallRule(cyst_cfg.IPNetwork("192.168.2.2/32"), cyst_cfg.IPNetwork("192.168.1.2/32"), "*", cyst_cfg.FirewallPolicy.ALLOW),
- ]
- )
- ]
- )
- ],
- id="router1"
-)
-
-''' --------------------------------------------------------------------------------------------------------------------
-Internet
-
-- Represented as a router outside the scenario network 192.168.0.0/16
-'''
-internet = cyst_cfg.RouterConfig(
- interfaces=[
- cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.193"), cyst_cfg.IPNetwork("213.47.23.192/26"), index=0)
- ],
- routing_table=[
- cyst_cfg.RouteConfig(cyst_cfg.IPNetwork("192.168.0.0/16"), 0)
- ],
- traffic_processors=[],
- id="internet"
-)
-
-''' --------------------------------------------------------------------------------------------------------------------
-Outside node
-
-- A machine that sits in the internet, controlled by the attacker, used for data exfiltration.
-'''
-outside_node = cyst_cfg.NodeConfig(
- active_services=[],
- passive_services=[
- cyst_cfg.PassiveServiceConfig(
- type="bash",
- owner="root",
- version="5.0.0",
- local=True,
- access_level=cyst_cfg.AccessLevel.LIMITED
- ),
- cyst_cfg.PassiveServiceConfig(
- type="listener",
- owner="attacker",
- version="1.0.0",
- local=False,
- access_level=cyst_cfg.AccessLevel.ELEVATED
- )
- ],
- traffic_processors=[],
- interfaces=[cyst_cfg.InterfaceConfig(cyst_cfg.IPAddress("213.47.23.195"), cyst_cfg.IPNetwork("213.47.23.195/26"))],
- shell="bash",
- id="outside_node"
-)
-
-''' --------------------------------------------------------------------------------------------------------------------
-Connections
-'''
-connections = [
- cyst_cfg.ConnectionConfig("smb_server", 0, "router1", 0),
- cyst_cfg.ConnectionConfig("client_1", 0, "router1", 5),
- cyst_cfg.ConnectionConfig("internet", 0, "router1", 10),
- cyst_cfg.ConnectionConfig("internet", 1, "outside_node", 0)
-]
-
-''' --------------------------------------------------------------------------------------------------------------------
-Exploits
-- There exists only one for windows lanman server (SMB) and enables data exfiltration. Add others as needed...
-'''
-exploits = [
- cyst_cfg.ExploitConfig(
- services=[
- cyst_cfg.VulnerableServiceConfig(
- name="lanman server",
- min_version="10.0.19041",
- max_version="10.0.19041"
- )
- ],
- locality=cyst_cfg.ExploitLocality.REMOTE,
- category=cyst_cfg.ExploitCategory.DATA_MANIPULATION,
- id="smb_exploit"
- )
-]
-
-configuration_objects = [smb_server, client_1, router1, internet, outside_node, *connections, *exploits]
\ No newline at end of file
diff --git a/env/worlds/aidojo_world.py b/env/worlds/aidojo_world.py
deleted file mode 100644
index 4650710e..00000000
--- a/env/worlds/aidojo_world.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
-# Template of world to be used in AI Dojo
-import sys
-import os
-import asyncio
-
-sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
-import logging
-from utils.utils import ConfigParser
-from env.game_components import GameState, Action, GameStatus, ActionType
-
-"""
-Basic class for worlds to be used in the AI Dojo.
-Every world (environment) used in AI Dojo should extend this class and implement
-all its methods to be compatible with the game server and game coordinator.
-"""
-class AIDojoWorld(object):
- def __init__(self, task_config_file:str,action_queue:asyncio.Queue, response_queue:asyncio.Queue, world_name:str="BasicAIDojoWorld")->None:
- self.task_config = ConfigParser(task_config_file)
- self.logger = logging.getLogger(world_name)
- self._action_queue = action_queue
- self._response_queue = response_queue
- self._world_name = world_name
-
- @property
- def world_name(self)->str:
- return self._world_name
-
- def step(self, current_state:GameState, action:Action, agent_id:tuple)-> GameState:
- """
- Executes given action in a current state of the environment and produces new GameState.
- """
- raise NotImplementedError
-
- def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState:
- """
- Produces a GameState based on the view of the world.
- """
- raise NotImplementedError
-
- def reset()->None:
- """
- Resets the world to its initial state.
- """
- raise NotImplementedError
-
- def update_goal_descriptions(self, goal_description:str)->str:
- """
- Takes the existing goal description (text) and updates it with respect to the world.
- """
- raise NotImplementedError
-
- def update_goal_dict(self, goal_dict:dict)->dict:
- """
- Takes the existing goal dict and updates it with respect to the world.
- """
- raise NotImplementedError
-
- async def handle_incoming_action(self)->None:
- """
- Asynchronously handles incoming actions from agents and processes them accordingly.
-
- This method continuously listens for actions from the `_action_queue`, processes them based on their type,
- and sends the appropriate response to the `_response_queue`. It handles different types of actions such as
- joining a game, quitting a game, and resetting the game. For other actions, it updates the game state by
- calling the `step` method.
-
- Raises:
- asyncio.CancelledError: If the task is cancelled, it logs the termination message.
-
- Action Types:
- - ActionType.JoinGame: Creates a new game state and sends a CREATED status.
- - ActionType.QuitGame: Sends an OK status with an empty game state.
- - ActionType.ResetGame: Resets the world if the agent is "world", otherwise resets the game state and sends a RESET_DONE status.
- - Other: Updates the game state using the `step` method and sends an OK status.
-
- Logging:
- - Logs the start of the task.
- - Logs received actions and game states from agents.
- - Logs the messages being sent to agents.
- - Logs termination due to `asyncio.CancelledError`.
- """
- try:
- self.logger.info(f"\tStaring {self.world_name} task.")
- while True:
- agent_id, action, game_state = await self._action_queue.get()
- self.logger.debug(f"Received from{agent_id}: {action}, {game_state}.")
- match action.type:
- case ActionType.JoinGame:
- msg = (agent_id, (self.create_state_from_view(game_state), GameStatus.CREATED))
- case ActionType.QuitGame:
- msg = (agent_id, (GameState(),GameStatus.OK))
- case ActionType.ResetGame:
- if agent_id == "world": #reset the world
- self.reset()
- continue
- else:
- msg = (agent_id, (self.create_state_from_view(game_state), GameStatus.RESET_DONE))
- case _:
- new_state = self.step(game_state, action,agent_id)
- msg = (agent_id, (new_state, GameStatus.OK))
- # new_state = self.step(state, action, agent_id)
- self.logger.debug(f"Sending to {agent_id}: {msg}")
- await self._response_queue.put(msg)
- await asyncio.sleep(0)
- except asyncio.CancelledError:
- self.logger.info(f"\t{self.world_name} Terminating by CancelledError")
\ No newline at end of file
diff --git a/env/worlds/cyst_wrapper.py b/env/worlds/cyst_wrapper.py
deleted file mode 100644
index c28f5797..00000000
--- a/env/worlds/cyst_wrapper.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Author Ondrej Lukas - ondrej.lukas@aic.fel.cvut.cz
-
-import sys
-import os
-
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-import game_components as components
-from worlds.aidojo_world import AIDojoWorld
-
-class CYSTWrapper(AIDojoWorld):
- """
- Class for connection CYST with the coordinator of AI Dojo
- """
- def __init__(self, task_config_file, world_name="CYST") -> None:
- super().__init__(task_config_file, world_name)
- self.logger.info("Initializing CYST environment")
-
-
- def step(self, current_state:components.GameState, action:components.Action, agent_id:tuple)-> components.GameState:
- """
- Executes given action in a current state of the environment and produces new GameState.
- """
- raise NotImplementedError
-
- def create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->components.GameState:
- """
- Produces a GameState based on the view of the world.
- """
- raise NotImplementedError
-
- def reset()->None:
- """
- Resets the world to its initial state.
- """
- raise NotImplementedError
-
- def update_goal_descriptions(self, goal_description:str)->str:
- """
- Takes the existing goal description (text) and updates it with respect to the world.
- """
- raise NotImplementedError
-
- def update_goal_dict(self, goal_dict:dict)->dict:
- """
- Takes the existing goal dict and updates it with respect to the world.
- """
- raise NotImplementedError
-
-
-if __name__ == "__main__":
- cyst_wrapper = CYSTWrapper("env/netsecenv_conf.yaml")
- objects = cyst_wrapper.task_config.get_scenario()
- print(objects)
- #e = Environment.create().configure(target, attacker, router, exploit1, connection1, connection2)
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..f4a78157
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,57 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+where = ["."]
+exclude = ["tests*"]
+
+[project]
+name = "AIDojoGameCoordinator"
+version = "0.1.0"
+description = "A package for coordinating AI-driven network simulation games."
+readme = "README.md"
+license = { file = "LICENSE" }
+authors = [
+ { name = "Ondrej Lukas", email = "ondrej.lukas@aic.fel.cvut.cz" },
+ { name = "Sebastian Garcia", email = "sebastian.garcia@agents.fel.cvut.cz" },
+ { name = "Maria Rigaki", email = "maria.rigaki@aic.fel.cvut.cz" }
+]
+dependencies = [
+ "aiohttp==3.11.8",
+ "attrs==23.2.0",
+ "beartype==0.19.0",
+ "cachetools==5.5.0",
+ "casefy==0.1.7",
+ "cyst==0.3.4",
+ "dictionaries==0.0.2",
+ "Faker==23.2.1",
+ "Jinja2==3.1.4",
+ "jsonlines==4.0.0",
+ "jsonpickle==3.3.0",
+ "kaleido==0.2.1",
+ "MarkupSafe==3.0.2",
+ "matplotlib==3.9.1",
+ "netaddr==1.3.0",
+ "networkx==3.4.2",
+ "numpy==1.26.4",
+ "pandas==2.2.2",
+ "plotly==5.22.0",
+ "pyserde==0.21.0",
+ "python-dateutil==2.8.2",
+ "PyYAML==6.0.1",
+ "redis==3.5.3",
+ "requests==2.32.3",
+ "scikit-learn==1.5.1",
+ "scipy==1.14.0",
+ "tenacity==8.5.0",
+ "typing-inspect==0.9.0",
+ "typing_extensions==4.12.2"
+]
+requires-python = ">=3.12"
+
+[project.optional-dependencies]
+dev = [
+ "pytest",
+ "ruff",
+]
\ No newline at end of file
diff --git a/tests/manual/three_nets/test_three_net_scenario.py b/tests/manual/three_nets/test_three_net_scenario.py
index eff14236..d1bec3ea 100644
--- a/tests/manual/three_nets/test_three_net_scenario.py
+++ b/tests/manual/three_nets/test_three_net_scenario.py
@@ -5,7 +5,7 @@
PATH = path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) ))
sys.path.append(path.dirname( path.dirname( path.dirname( path.dirname( path.abspath(__file__) ) ) )))
from NetSecGameAgents.agents import base_agent
-from env.game_components import Action, ActionType, IP, Network, Service, Data
+from AIDojoCoordinator.game_components import Action, ActionType, IP, Network, Service, Data
if __name__ == "__main__":
diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh
index 19b049e1..348288f7 100755
--- a/tests/run_all_tests.sh
+++ b/tests/run_all_tests.sh
@@ -3,8 +3,9 @@
# run all unit tests, -n *5 means distribute tests on 5 different process
# -s to see print statements as they are executed
-python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace
+#python3 -m pytest tests/test_actions.py -p no:warnings -vvvv -s --full-trace
python3 -m pytest tests/test_components.py -p no:warnings -vvvv -s --full-trace
+python3 -m pytest tests/test_game_coordinator.py -p no:warnings -vvvv -s --full-trace
# Coordinator tesst
#python3 -m pytest tests/test_coordinator.py -p no:warnings -vvvv -s --full-trace
diff --git a/tests/test_actions.py b/tests/test_actions.py
index 50c81a54..443c80bf 100644
--- a/tests/test_actions.py
+++ b/tests/test_actions.py
@@ -5,8 +5,8 @@
import sys
from os import path
sys.path.append( path.dirname(path.dirname( path.abspath(__file__) ) ))
-from env.worlds.network_security_game import NetworkSecurityEnvironment
-import env.game_components as components
+from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
+import AIDojoCoordinator.game_components as components
import pytest
# Fixture are used to hold the current state and the environment
diff --git a/tests/test_components.py b/tests/test_components.py
index a30988f5..140d38f4 100644
--- a/tests/test_components.py
+++ b/tests/test_components.py
@@ -2,11 +2,8 @@
Tests related to the game components in the Network Security Game Environment
Author: Maria Rigaki - maria.rigaki@fel.cvut.cz
"""
-import sys
import json
-from os import path
-sys.path.append( path.dirname(path.dirname( path.abspath(__file__) ) ))
-from env.game_components import ActionType, Action, IP, Data, Network, Service, GameState, AgentInfo
+from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service, GameState, AgentInfo
class TestComponentsIP:
"""
@@ -197,7 +194,7 @@ def test_create_find_data(self):
"""
Test the creation of the FindData action
"""
- action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"),"target_host":IP("192.168.12.12")})
+ action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"),"target_host":IP("192.168.12.12")})
assert action.type == ActionType.FindData
assert action.parameters["target_host"] == IP("192.168.12.12")
assert action.parameters["source_host"] == IP("192.168.12.12")
@@ -206,14 +203,14 @@ def test_create_find_data_str(self):
"""
Test the string representation of the FindData action
"""
- action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")})
+ action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")})
assert str(action) == "Action "
def test_create_find_data_repr(self):
"""
Test the repr of the FindData action
"""
- action = Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")})
+ action = Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.12"), "target_host":IP("192.168.12.12")})
assert repr(action) == "Action "
def test_action_find_services(self):
@@ -221,7 +218,7 @@ def test_action_find_services(self):
Test the creation of the FindServices action
"""
action = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("192.168.12.12")})
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("192.168.12.12")})
assert action.type == ActionType.FindServices
assert action.parameters["target_host"] == IP("192.168.12.12")
assert action.parameters["source_host"] == IP("192.168.12.11")
@@ -231,7 +228,7 @@ def test_action_scan_network(self):
Test the creation of the ScanNetwork action
"""
action = Action(action_type=ActionType.ScanNetwork,
- params={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)})
+ parameters={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)})
assert action.type == ActionType.ScanNetwork
assert action.parameters["target_network"] == Network("172.16.1.12", 24)
assert action.parameters["source_host"] == IP("192.168.12.11")
@@ -241,7 +238,7 @@ def test_action_exploit_services(self):
Test the creation of the ExploitService action
"""
action = Action(action_type=ActionType.ExploitService,
- params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.12"),
+ parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.12"),
"target_service":Service("ssh", "passive", "0.23", False)})
assert action.type == ActionType.ExploitService
assert action.parameters["target_host"] == IP("172.16.1.12")
@@ -255,19 +252,19 @@ def test_action_equal(self):
Test that two actions with the same parameters are equal
"""
action = Action(action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")})
+ parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")})
action2 = Action(action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")})
+ parameters={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")})
assert action == action2
- def test_action_equal_params_order(self):
+ def test_action_equal_parameters_order(self):
"""
Test that two actions with the same parameters are equal
"""
action = Action(action_type=ActionType.ExploitService,
- params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11"),"target_service": Service("ssh", "passive", "0.23", False)})
+ parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11"),"target_service": Service("ssh", "passive", "0.23", False)})
action2 = Action(action_type=ActionType.ExploitService,
- params={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")})
+ parameters={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")})
assert action == action2
def test_action_not_equal_different_target(self):
@@ -275,9 +272,9 @@ def test_action_not_equal_different_target(self):
Test that two actions with different parameters are not equal
"""
action = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")})
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")})
action2 = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("172.15.1.22")})
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.15.1.22")})
assert action != action2
def test_action_not_equal_different_source(self):
@@ -285,9 +282,9 @@ def test_action_not_equal_different_source(self):
Test that two actions with different parameters are not equal
"""
action = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.12"), "target_host":IP("172.16.1.22")})
+ parameters={"source_host":IP("192.168.12.12"), "target_host":IP("172.16.1.22")})
action2 = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")})
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")})
assert action != action2
def test_action_not_equal_different_action_type(self):
@@ -295,27 +292,27 @@ def test_action_not_equal_different_action_type(self):
Test that two actions with different parameters are not equal
"""
action = Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})
+ parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})
action2 = Action(action_type=ActionType.FindData,
- params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})
+ parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")})
assert action != action2
def test_action_hash(self):
action = Action(
action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")}
+ parameters={"target_host":IP("172.16.1.22"),"source_host":IP("192.168.12.11")}
)
action2 = Action(
action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}
+ parameters={"target_host":IP("172.16.1.22"), "source_host":IP("192.168.12.11")}
)
action3 = Action(
action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.13.48"), "source_host":IP("192.168.12.11")}
+ parameters={"target_host":IP("172.16.13.48"), "source_host":IP("192.168.12.11")}
)
action4 = Action(
action_type=ActionType.FindData,
- params={"target_host":IP("172.16.1.25"), "source_host":IP("192.168.12.11")}
+ parameters={"target_host":IP("172.16.1.25"), "source_host":IP("192.168.12.11")}
)
assert hash(action) == hash(action2)
assert hash(action) != hash(action3)
@@ -324,31 +321,31 @@ def test_action_hash(self):
def test_action_set_member(self):
action_set = set()
action_set.add(Action(action_type=ActionType.FindServices,
- params={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}))
+ parameters={"source_host":IP("192.168.12.11"),"target_host":IP("172.16.1.22")}))
action_set.add(Action(action_type=ActionType.FindData,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}))
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}))
action_set.add(Action(action_type=ActionType.ExploitService,
- params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)}))
+ parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)}))
action_set.add(Action(action_type=ActionType.ScanNetwork,
- params={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)}))
- action_set.add(Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"),
+ parameters={"source_host":IP("192.168.12.11"), "target_network":Network("172.16.1.12", 24)}))
+ action_set.add(Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"),
"source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}))
- assert Action(action_type=ActionType.FindServices, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) in action_set
- assert Action(action_type=ActionType.FindData, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}) in action_set
- assert Action(action_type=ActionType.ExploitService, params={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})in action_set
- #reverse params order
- assert Action(action_type=ActionType.ExploitService, params={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.24"), "source_host":IP("192.168.12.11")})in action_set
- assert Action(action_type=ActionType.ScanNetwork, params={"target_network":Network("172.16.1.12", 24), "source_host":IP("192.168.12.11")}) in action_set
- assert Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}) in action_set
- #reverse params orders
- assert Action(action_type=ActionType.ExfiltrateData, params={"source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.3"), "data":Data("User2", "PublicKey")}) in action_set
+ assert Action(action_type=ActionType.FindServices, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.22")}) in action_set
+ assert Action(action_type=ActionType.FindData, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24")}) in action_set
+ assert Action(action_type=ActionType.ExploitService, parameters={"source_host":IP("192.168.12.11"), "target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})in action_set
+ #reverse parameters order
+ assert Action(action_type=ActionType.ExploitService, parameters={"target_service": Service("ssh", "passive", "0.23", False), "target_host":IP("172.16.1.24"), "source_host":IP("192.168.12.11")})in action_set
+ assert Action(action_type=ActionType.ScanNetwork, parameters={"target_network":Network("172.16.1.12", 24), "source_host":IP("192.168.12.11")}) in action_set
+ assert Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"), "source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")}) in action_set
+ #reverse parameters orders
+ assert Action(action_type=ActionType.ExfiltrateData, parameters={"source_host": IP("172.16.1.2"), "target_host":IP("172.16.1.3"), "data":Data("User2", "PublicKey")}) in action_set
- def test_action_as_json(self):
+ def test_action_to_json(self):
# Scan Network
action = Action(action_type=ActionType.ScanNetwork,
- params={"target_network":Network("172.16.1.12", 24)})
- action_json = action.as_json()
+ parameters={"target_network":Network("172.16.1.12", 24)})
+ action_json = action.to_json()
try:
data = json.loads(action_json)
except ValueError:
@@ -359,8 +356,8 @@ def test_action_as_json(self):
# Find services
action = Action(action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22")})
- action_json = action.as_json()
+ parameters={"target_host":IP("172.16.1.22")})
+ action_json = action.to_json()
try:
data = json.loads(action_json)
except ValueError:
@@ -371,8 +368,8 @@ def test_action_as_json(self):
# Find Data
action = Action(action_type=ActionType.FindData,
- params={"target_host":IP("172.16.1.22")})
- action_json = action.as_json()
+ parameters={"target_host":IP("172.16.1.22")})
+ action_json = action.to_json()
try:
data = json.loads(action_json)
except ValueError:
@@ -383,8 +380,8 @@ def test_action_as_json(self):
# Exploit Service
action = Action(action_type=ActionType.ExploitService,
- params={"target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})
- action_json = action.as_json()
+ parameters={"target_host":IP("172.16.1.24"), "target_service": Service("ssh", "passive", "0.23", False)})
+ action_json = action.to_json()
try:
data = json.loads(action_json)
except ValueError:
@@ -395,9 +392,9 @@ def test_action_as_json(self):
"target_service":{"name":"ssh", "type":"passive", "version":"0.23", "is_local":False}}) in data.items()
# Exfiltrate Data
- action = Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"),
+ action = Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"),
"source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey", size=42, type="pub")})
- action_json = action.as_json()
+ action_json = action.to_json()
try:
data = json.loads(action_json)
except ValueError:
@@ -410,74 +407,74 @@ def test_action_as_json(self):
def test_action_scan_network_serialization(self):
action = Action(action_type=ActionType.ScanNetwork,
- params={"target_network":Network("172.16.1.12", 24),"source_host": IP("172.16.1.2") })
- action_json = action.as_json()
+ parameters={"target_network":Network("172.16.1.12", 24),"source_host": IP("172.16.1.2") })
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_find_services_serialization(self):
action = Action(action_type=ActionType.FindServices,
- params={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")})
- action_json = action.as_json()
+ parameters={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")})
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_find_data_serialization(self):
action = Action(action_type=ActionType.FindData,
- params={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")})
- action_json = action.as_json()
+ parameters={"target_host":IP("172.16.1.22"), "source_host": IP("172.16.1.2")})
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_exploit_service_serialization(self):
action = Action(action_type=ActionType.ExploitService,
- params={"source_host": IP("172.16.1.2"),
+ parameters={"source_host": IP("172.16.1.2"),
"target_host":IP("172.16.1.24"),
"target_service": Service("ssh", "passive", "0.23", False)})
- action_json = action.as_json()
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_exfiltrate_serialization(self):
- action = Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"),
+ action = Action(action_type=ActionType.ExfiltrateData, parameters={"target_host":IP("172.16.1.3"),
"source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")})
- action_json = action.as_json()
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_exfiltrate_join_game(self):
action = Action(
action_type=ActionType.JoinGame,
- params={
+ parameters={
"agent_info": AgentInfo(name="TestingAgent", role="attacker"),
}
)
- action_json = action.as_json()
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_exfiltrate_reset_game(self):
action = Action(
action_type=ActionType.ResetGame,
- params={}
+ parameters={}
)
- action_json = action.as_json()
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_exfiltrate_quit_game(self):
action = Action(
action_type=ActionType.QuitGame,
- params={}
+ parameters={}
)
- action_json = action.as_json()
+ action_json = action.to_json()
new_action = Action.from_json(action_json)
assert action == new_action
def test_action_to_dict_scan_network(self):
action = Action(
action_type=ActionType.ScanNetwork,
- params={
+ parameters={
"target_network":Network("172.16.1.12", 24),
"source_host": IP("172.16.1.2")
}
@@ -485,14 +482,14 @@ def test_action_to_dict_scan_network(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["target_network"] == "172.16.1.12/24"
- assert action_dict["params"]["source_host"] == "172.16.1.2"
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["target_network"] == {'ip': '172.16.1.12', 'mask': 24}
+ assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'}
def test_action_to_dict_find_services(self):
action = Action(
action_type=ActionType.FindServices,
- params={
+ parameters={
"target_host":IP("172.16.1.22"),
"source_host": IP("172.16.1.2")
}
@@ -500,14 +497,14 @@ def test_action_to_dict_find_services(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["target_host"] == "172.16.1.22"
- assert action_dict["params"]["source_host"] == "172.16.1.2"
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.22'}
+ assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'}
def test_action_to_dict_find_data(self):
action = Action(
action_type=ActionType.FindData,
- params={
+ parameters={
"target_host":IP("172.16.1.22"),
"source_host": IP("172.16.1.2")
}
@@ -515,14 +512,14 @@ def test_action_to_dict_find_data(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["target_host"] == "172.16.1.22"
- assert action_dict["params"]["source_host"] == "172.16.1.2"
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.22'}
+ assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'}
def test_action_to_dict_exploit_service(self):
action = Action(
action_type=ActionType.ExploitService,
- params={
+ parameters={
"source_host": IP("172.16.1.2"),
"target_host":IP("172.16.1.24"),
"target_service": Service("ssh", "passive", "0.23", False)
@@ -531,18 +528,18 @@ def test_action_to_dict_exploit_service(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["target_host"] == "172.16.1.24"
- assert action_dict["params"]["source_host"] == "172.16.1.2"
- assert action_dict["params"]["target_service"]["name"] == "ssh"
- assert action_dict["params"]["target_service"]["type"] == "passive"
- assert action_dict["params"]["target_service"]["version"] == "0.23"
- assert action_dict["params"]["target_service"]["is_local"] is False
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.24'}
+ assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'}
+ assert action_dict["parameters"]["target_service"]["name"] == "ssh"
+ assert action_dict["parameters"]["target_service"]["type"] == "passive"
+ assert action_dict["parameters"]["target_service"]["version"] == "0.23"
+ assert action_dict["parameters"]["target_service"]["is_local"] is False
def test_action_to_dict_exfiltrate_data(self):
action = Action(
action_type=ActionType.ExfiltrateData,
- params={
+ parameters={
"target_host":IP("172.16.1.3"),
"source_host": IP("172.16.1.2"),
"data":Data("User2", "PublicKey")
@@ -551,16 +548,16 @@ def test_action_to_dict_exfiltrate_data(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["target_host"] == "172.16.1.3"
- assert action_dict["params"]["source_host"] == "172.16.1.2"
- assert action_dict["params"]["data"]["owner"] == "User2"
- assert action_dict["params"]["data"]["id"] == "PublicKey"
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["target_host"] == {'ip': '172.16.1.3'}
+ assert action_dict["parameters"]["source_host"] == {'ip': '172.16.1.2'}
+ assert action_dict["parameters"]["data"]["owner"] == "User2"
+ assert action_dict["parameters"]["data"]["id"] == "PublicKey"
def test_action_to_dict_join_game(self):
action = Action(
action_type=ActionType.JoinGame,
- params={
+ parameters={
"agent_info": AgentInfo(name="TestingAgent", role="attacker"),
"source_host": IP("172.16.1.2")
}
@@ -568,31 +565,31 @@ def test_action_to_dict_join_game(self):
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert action_dict["params"]["agent_info"]["name"] == "TestingAgent"
- assert action_dict["params"]["agent_info"]["role"] == "attacker"
+ assert action_dict["action_type"] == str(action.type)
+ assert action_dict["parameters"]["agent_info"]["name"] == "TestingAgent"
+ assert action_dict["parameters"]["agent_info"]["role"] == "attacker"
def test_action_to_dict_reset_game(self):
action = Action(
action_type=ActionType.ResetGame,
- params={}
+ parameters={}
)
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert len(action_dict["params"]) == 0
+ assert action_dict["action_type"] == str(action.type)
+ assert len(action_dict["parameters"]) == 0
def test_action_to_dict_quit_game(self):
action = Action(
action_type=ActionType.QuitGame,
- params={}
+ parameters={}
)
action_dict = action.as_dict
new_action = Action.from_dict(action_dict)
assert action == new_action
- assert action_dict["type"] == str(action.type)
- assert len(action_dict["params"]) == 0
+ assert action_dict["action_type"] == str(action.type)
+ assert len(action_dict["parameters"]) == 0
class TestGameState:
"""
diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py
index 63a36533..965f4e16 100644
--- a/tests/test_coordinator.py
+++ b/tests/test_coordinator.py
@@ -236,8 +236,8 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
-from coordinator import Coordinator, AgentStatus, Action, ActionType
-from env.game_components import AgentInfo, Network, IP
+from AIDojoCoordinator.coordinator import Coordinator, AgentStatus, Action, ActionType
+from AIDojoCoordinator.game_components import AgentInfo, Network, IP
CONFIG_FILE = "tests/netsecenv-task-for-testing.yaml"
ALLOWED_ROLES = ["Attacker", "Defender", "Benign"]
diff --git a/tests/test_game_coordinator.py b/tests/test_game_coordinator.py
new file mode 100644
index 00000000..f2f75b66
--- /dev/null
+++ b/tests/test_game_coordinator.py
@@ -0,0 +1,133 @@
+import asyncio
+import pytest
+from unittest.mock import AsyncMock, Mock
+from AIDojoCoordinator.coordinator import AgentServer, GameCoordinator
+
+def test_game_coordinator_initialization():
+ """
+ Test that the GameCoordinator is initialized correctly with the expected properties.
+ """
+ # Test input
+ game_host = "localhost"
+ game_port = 8000
+ service_host = "localhost"
+ service_port = 8080
+ allowed_roles = ["Attacker", "Defender", "Benign"]
+ task_config_file = "test_config.json"
+
+ # Create an instance of GameCoordinator
+ coordinator = GameCoordinator(
+ game_host=game_host,
+ game_port=game_port,
+ service_host=service_host,
+ service_port=service_port,
+ allowed_roles=allowed_roles,
+ task_config_file=task_config_file,
+ )
+
+ # Assertions for basic initialization
+ assert coordinator.host == game_host, "GameCoordinator host should be set correctly."
+ assert coordinator.port == game_port, "GameCoordinator port should be set correctly."
+ assert coordinator._service_host == service_host, "Service host should be set correctly."
+ assert coordinator._service_port == service_port, "Service port should be set correctly."
+ assert coordinator.ALLOWED_ROLES == allowed_roles, "Allowed roles should match the input."
+ assert coordinator._task_config_file == task_config_file, "Task config file should match the input."
+
+ # Assertions for events and locks
+ assert isinstance(coordinator.shutdown_flag, asyncio.Event), "shutdown_flag should be an asyncio.Event."
+ assert isinstance(coordinator._reset_event, asyncio.Event), "reset_event should be an asyncio.Event."
+ assert isinstance(coordinator._episode_end_event, asyncio.Event), "episode_end_event should be an asyncio.Event."
+ assert isinstance(coordinator._reset_lock, asyncio.Lock), "reset_lock should be an asyncio.Lock."
+ assert isinstance(coordinator._agents_lock, asyncio.Lock), "agents_lock should be an asyncio.Lock."
+ assert isinstance(coordinator._semaphore, asyncio.Semaphore), "semaphore should be an asyncio.Semaphore."
+
+ # Assertions for agent-related data structures
+ assert isinstance(coordinator._agent_action_queue, asyncio.Queue), "agent_action_queue should be an asyncio.Queue."
+ assert isinstance(coordinator._agent_response_queues, dict), "agent_response_queues should be a dictionary."
+ assert isinstance(coordinator.agents, dict), "agents should be a dictionary."
+ assert isinstance(coordinator._agent_steps, dict), "agent_steps should be a dictionary."
+ assert isinstance(coordinator._reset_requests, dict), "reset_requests should be a dictionary."
+ assert isinstance(coordinator._agent_status, dict), "agent_status should be a dictionary."
+ assert isinstance(coordinator._agent_observations, dict), "agent_observations should be a dictionary."
+ assert isinstance(coordinator._agent_rewards, dict), "agent_rewards should be a dictionary."
+ assert isinstance(coordinator._agent_trajectories, dict), "agent_trajectories should be a dictionary."
+
+ # Assertions for tasks
+ assert isinstance(coordinator._tasks, set), "tasks should be a set."
+
+ # Assertions for configuration
+ assert coordinator._cyst_objects is None, "cyst_objects should be None at initialization."
+ assert coordinator._cyst_object_string is None, "cyst_object_string should be None at initialization."
+
+ # Assertions for logging
+ assert coordinator.logger.name == "AIDojo-GameCoordinator", "Logger should be initialized with the correct name."
+
+
+def test_starting_positions():
+ """Test that starting positions are correctly initialized for each role."""
+ coordinator = GameCoordinator(
+ game_host="localhost",
+ game_port=8000,
+ service_host="localhost",
+ service_port=8080,
+ allowed_roles=["Attacker", "Defender", "Benign"],
+ )
+ coordinator.task_config = Mock() # Mock the task_config
+ coordinator.task_config.get_start_position.side_effect = lambda agent_role: {"x": 0, "y": 0}
+
+ starting_positions = coordinator._get_starting_position_per_role()
+
+ assert starting_positions["Attacker"] == {"x": 0, "y": 0}
+ assert starting_positions["Defender"] == {"x": 0, "y": 0}
+ assert starting_positions["Benign"] == {"x": 0, "y": 0}
+
+
+
+# Agent Server
+def test_agent_server_initialization():
+ """
+ Test that the AgentServer is initialized correctly with the expected attributes.
+ """
+ # Test inputs
+ actions_queue = asyncio.Queue()
+ agent_response_queues = {}
+ max_connections = 5
+
+ # Create an instance of AgentServer
+ server = AgentServer(actions_queue, agent_response_queues, max_connections)
+
+ # Assertions for basic attributes
+ assert server.actions_queue is actions_queue, "AgentServer's actions_queue should be set correctly."
+ assert server.answers_queues is agent_response_queues, "AgentServer's answers_queues should be set correctly."
+ assert server.max_connections == max_connections, "AgentServer's max_connections should match the input."
+ assert server.current_connections == 0, "AgentServer's current_connections should be initialized to 0."
+
+ # Assertions for logging
+ assert server.logger.name == "AIDojo-AgentServer", "Logger should be initialized with the correct name."
+
+
+@pytest.mark.asyncio
+async def test_handle_new_agent_max_connections():
+ """Test that a new agent connection is rejected when max_connections is reached."""
+ # Test setup
+ actions_queue = asyncio.Queue()
+ agent_response_queues = {}
+ max_connections = 1
+ server = AgentServer(actions_queue, agent_response_queues, max_connections)
+
+ # Mock reader and writer
+ reader_mock = AsyncMock()
+ writer_mock = Mock()
+ writer_mock.get_extra_info.return_value = ("127.0.0.1", 12345)
+
+ # Simulate max connections
+ server.current_connections = max_connections
+
+ # Run handle_new_agent
+ await server.handle_new_agent(reader_mock, writer_mock)
+
+ # Assertions
+ assert server.current_connections == max_connections, "Connection count should remain unchanged."
+ assert ("127.0.0.1", 12345) not in agent_response_queues, "Queue should not be created for rejected agent."
+ writer_mock.write.assert_not_called(), "No data should be sent to the rejected agent."
+ writer_mock.close.assert_called_once(), "Connection should be closed for the rejected agent."