diff --git a/AIDojoCoordinator/coordinator.py b/AIDojoCoordinator/coordinator.py index 96fbd3f4..6850972b 100644 --- a/AIDojoCoordinator/coordinator.py +++ b/AIDojoCoordinator/coordinator.py @@ -172,6 +172,7 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por self._agent_steps = {} # reset request per agent_addr (bool) self._reset_requests = {} + self._randomize_topology_requests = {} self._agent_status = {} self._episode_ends = {} self._agent_observations = {} @@ -547,8 +548,10 @@ async def _process_reset_game_action(self, agent_addr: tuple, reset_action:Actio """ self.logger.debug("Beginning the _process_reset_game_action.") async with self._reset_lock: - # add reset request for this agent + # add reset request for this agent self._reset_requests[agent_addr] = True + # register if the agent wants to randomize the topology + self._randomize_topology_requests[agent_addr] = reset_action.parameters.get("randomize_topology", False) if all(self._reset_requests.values()): # all agents want reset - reset the world self.logger.debug(f"All agents requested reset, setting the event") @@ -724,6 +727,7 @@ async def _reset_game(self): self._agent_observations[agent] = new_observation self._episode_ends[agent] = False self._reset_requests[agent] = False + self._randomize_topology_requests[agent] = False self._agent_rewards[agent] = 0 self._agent_steps[agent] = 0 self._agent_false_positives[agent] = 0 @@ -788,6 +792,9 @@ async def _remove_agent_from_game(self, agent_addr): agent_info["agent_status"] = self._agent_status.pop(agent_addr) agent_info["false_positives"] = self._agent_false_positives.pop(agent_addr) async with self._reset_lock: + # remove agent from topology reset requests + agent_info["topology_reset_request"] = self._randomize_topology_requests.pop(agent_addr, False) + # remove agent from reset requests agent_info["reset_request"] = self._reset_requests.pop(agent_addr) # check if this agent was not preventing reset if any(self._reset_requests.values()): diff --git a/AIDojoCoordinator/game_components.py b/AIDojoCoordinator/game_components.py index fcde7614..98881f13 100755 --- a/AIDojoCoordinator/game_components.py +++ b/AIDojoCoordinator/game_components.py @@ -397,6 +397,8 @@ def as_dict(self) -> Dict[str, Any]: for k, v in self.parameters.items(): if hasattr(v, '__dict__'): # Handle custom objects like Service, Data, AgentInfo params[k] = asdict(v) + elif isinstance(v, bool): # Handle boolean values + params[k] = v else: params[k] = str(v) return {"action_type": str(self.action_type), "parameters": params} @@ -448,8 +450,11 @@ def from_dict(cls, data_dict: Dict[str, Any]) -> "Action": params[k] = Data.from_dict(v) case "agent_info": params[k] = AgentInfo.from_dict(v) - case "request_trajectory": - params[k] = ast.literal_eval(v) + case "request_trajectory" | "randomize_topology": + if isinstance(v, bool): + params[k] = v + else: + params[k] = ast.literal_eval(v) case _: raise ValueError(f"Unsupported value in {k}: {v}") return cls(action_type=action_type, parameters=params) diff --git a/AIDojoCoordinator/netsecenv_conf.yaml b/AIDojoCoordinator/netsecenv_conf.yaml index 6fd52eb1..0c161d30 100644 --- a/AIDojoCoordinator/netsecenv_conf.yaml +++ b/AIDojoCoordinator/netsecenv_conf.yaml @@ -99,7 +99,7 @@ env: # random_seed: 42 scenario: 'scenario1' use_global_defender: False - use_dynamic_addresses: False + use_dynamic_addresses: True use_firewall: True save_trajectories: False required_players: 1 diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py index dd64106b..4ed71c3b 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -19,7 +19,7 @@ class NSGCoordinator(GameCoordinator): - def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=42): + def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attacker", "Defender", "Benign"], seed=None): super().__init__(game_host, game_port, service_host=None, service_port=None, allowed_roles=allowed_roles, task_config_file=task_config) # Internal data structure of the NSG @@ -44,7 +44,17 @@ def __init__(self, game_host, game_port, task_config:str, allowed_roles=["Attack self._seed = seed self.logger.info(f'Setting env seed to {seed}') - def _initialize(self)->None: + def _initialize(self) -> None: + """ + Initializes the NetSecGame environment. + + Loads the CYST configuration, sets up dynamic IP and network address generation if enabled, + and stores original copies of environment data structures for later resets. Also seeds the + random number generator for reproducibility and logs the completion of initialization. + + Returns: + None + """ # Load CYST configuration self._process_cyst_config(self._cyst_objects) # Check if dynamic network and ip adddresses are required @@ -84,7 +94,16 @@ def _get_controlled_hosts_from_view(self, view_controlled_hosts:Iterable)->set: return controlled_hosts def _get_services_from_view(self, view_known_services:dict)->dict: - known_services ={} + """ + Parses view and translates all keywords. Produces dict of known services {IP: set(Service)} + + Args: + view_known_services (dict): The view containing known services information. + + Returns: + dict: A dictionary mapping IP addresses to sets of known services. + """ + known_services = {} for ip, service_list in view_known_services.items(): if self._ip_mapping[ip] not in known_services: known_services[self._ip_mapping[ip]] = set() @@ -101,6 +120,15 @@ def _get_services_from_view(self, view_known_services:dict)->dict: return known_services def _get_data_from_view(self, view_known_data:dict)->dict: + """ + Parses view and translates all keywords. Produces dict of known data {IP: set(Data)} + + Args: + view_known_data (dict): The view containing known data information. + + Returns: + dict: A dictionary mapping IP addresses to sets of known data. + """ known_data = {} for ip, data_list in view_known_data.items(): if self._ip_mapping[ip] not in known_data: @@ -920,7 +948,11 @@ async def reset(self)->bool: self.logger.info('--- Reseting NSG Environment to its initial state ---') # change IPs if needed if self.task_config.get_use_dynamic_addresses(): - self._create_new_network_mapping() + if all(self._randomize_topology_requests.values()): + self.logger.info("All agents requested reset with randomized topology.") + self._create_new_network_mapping() + else: + self.logger.info("Not all agents requested a topology randomization. Keeping the current one.") # reset self._data to orignal state self._data = copy.deepcopy(self._data_original) # reset self._data_content to orignal state @@ -977,6 +1009,16 @@ async def reset(self)->bool: default="netsecenv_conf.yaml", ) + parser.add_argument( + "-s", + "--seed", + help="Random seed for the environment", + action="store", + required=False, + type=int, + default=42, + ) + args = parser.parse_args() print(args) # Set the logging @@ -994,7 +1036,7 @@ async def reset(self)->bool: datefmt="%Y-%m-%d %H:%M:%S", level=pass_level, ) - - game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config) + + game_server = NSGCoordinator(args.game_host, args.game_port, args.task_config, seed=args.seed) # Run it! game_server.run() \ No newline at end of file diff --git a/tests/components/test_action.py b/tests/components/test_action.py index 16f98b62..b1c75526 100644 --- a/tests/components/test_action.py +++ b/tests/components/test_action.py @@ -426,7 +426,18 @@ def test_action_to_dict_reset_game(self): assert action == new_action assert action_dict["action_type"] == str(action.type) assert len(action_dict["parameters"]) == 0 - + action = Action( + action_type=ActionType.ResetGame, + parameters={"request_trajectory": True, "randomize_topology": False} + ) + action_dict = action.as_dict + new_action = Action.from_dict(action_dict) + assert action == new_action + assert action_dict["action_type"] == str(action.type) + assert len(action_dict["parameters"]) == 2 + assert action_dict["parameters"]["request_trajectory"] is True + assert action_dict["parameters"]["randomize_topology"] is False + def test_action_to_dict_quit_game(self): action = Action( action_type=ActionType.QuitGame,