diff --git a/AIDojoCoordinator/utils/utils.py b/AIDojoCoordinator/utils/utils.py index 4bf042f4..a1e5510e 100644 --- a/AIDojoCoordinator/utils/utils.py +++ b/AIDojoCoordinator/utils/utils.py @@ -178,7 +178,7 @@ def read_agents_known_data(self, type_agent: str, type_data: str) -> dict: known_data[known_data_host] = set() for datum in data: if not isinstance(datum, list) and datum.lower() == "random": - known_data[known_data_host] = "random" + known_data[known_data_host].add("random") else: known_data_content_str_user = datum[0] known_data_content_str_data = datum[1] @@ -219,15 +219,21 @@ def read_agents_known_services(self, type_agent: str, type_data: str) -> dict: # Check the host is a good ip _ = netaddr.IPAddress(ip) known_services_host = IP(ip) - if data.lower() == "random": - known_services[known_services_host] = "random" - name = data[0] - type = data[1] - version = data[2] - is_local = data[3] - - known_services[known_services_host] = Service(name, type, version, is_local) - + known_services[known_services_host] = [] + for service in data: # process each item in the list + if isinstance(service, list): # Service defined as list + name = service[0] + type = service[1] + version = service[2] + is_local = service[3] + known_services[known_services_host].append(Service(name, type, version, is_local)) + elif isinstance(service, str): # keyword + if service.lower() == "random": + known_services[known_services_host].append("random") + else: + logging.warning(f"Unsupported values in agent known_services{ip}:{service}") + else: + logging.warning(f"Unsupported values in agent known_services{ip}:{service}") except (ValueError, netaddr.AddrFormatError): known_services = {} return known_services diff --git a/AIDojoCoordinator/worlds/NSEGameCoordinator.py b/AIDojoCoordinator/worlds/NSEGameCoordinator.py index 3410b116..d95aa40f 100644 --- a/AIDojoCoordinator/worlds/NSEGameCoordinator.py +++ b/AIDojoCoordinator/worlds/NSEGameCoordinator.py @@ -5,10 +5,11 @@ import random import numpy as np import copy -from faker import Faker -from pathlib import Path import netaddr, re import json +from faker import Faker +from pathlib import Path +from typing import Iterable from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service from AIDojoCoordinator.coordinator import GameCoordinator @@ -54,21 +55,14 @@ def _initialize(self)->None: self._data_original = copy.deepcopy(self._data) self._firewall_original = copy.deepcopy(self._firewall) self.logger.info("Environment initialization finished") - - def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState: + + def _get_controlled_hosts_from_view(self, view_controlled_hosts:Iterable)->set: """ - Builds a GameState from given view. - If there is a keyword 'random' used, it is replaced by a valid option at random. - - Currently, we artificially extend the knonw_networks with +- 1 in the third octet. + Parses view and translates all keywords. Produces set of controlled host (IP) """ - self.logger.info(f'Generating state from view:{view}') - # re-map all networks based on current mapping in self._network_mapping - known_networks = set([self._network_mapping[net] for net in view["known_networks"]]) - controlled_hosts = set() # controlled_hosts - for host in view['controlled_hosts']: + for host in view_controlled_hosts: if isinstance(host, IP): controlled_hosts.add(self._ip_mapping[host]) self.logger.debug(f'\tThe attacker has control of host {self._ip_mapping[host]}.') @@ -85,11 +79,60 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga controlled_hosts = controlled_hosts.union(self._get_all_local_ips()) else: self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}") - # re-map all known based on current mapping in self._ip_mapping + return controlled_hosts + + def _get_services_from_view(self, view_known_services:dict)->dict: + known_services ={} + for ip, service_list in view_known_services.items(): + if self._ip_mapping[ip] not in known_services: + known_services[self._ip_mapping[ip]] = set() + for s in service_list: + if isinstance(s, Service): + known_services[self._ip_mapping[ip]].add(s) + elif isinstance(s, str): + if s == "random": # randomly select the service + self.logger.info(f"\tSelecting service randomly in {self._ip_mapping[ip]}") + # select candidates that are not explicitly listed + service_candidates = [s for s in self._services[self._ip_to_hostname[ip]] if s not in known_services[self._ip_mapping[ip]]] + # randomly select from candidates + known_services[self._ip_mapping[ip]].add(random.choice(service_candidates)) + return known_services + + def _get_data_from_view(self, view_known_data:dict)->dict: + known_data = {} + for ip, data_list in view_known_data.items(): + if self._ip_mapping[ip] not in known_data: + known_data[self._ip_mapping[ip]] = set() + for datum in data_list: + if isinstance(datum, Data): + known_data[self._ip_mapping[ip]].add(datum) + elif isinstance(datum, str): + if datum == "random": # randomly select the data + self.logger.info(f"\tSelecting data randomly in {self._ip_mapping[ip]}") + # select candidates that are not explicitly listed + data_candidates = [d for d in self._data[self._ip_to_hostname[ip]] if d not in known_data[self._ip_mapping[ip]]] + if len(data_candidates) > 0: + # randomly select from candidates + known_data[self._ip_mapping[ip]].add(random.choice(data_candidates)) + else: + self.logger.warning("\tNo available data. Skipping") + return known_data + + def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState: + """ + Builds a GameState from given view. + If there is a keyword 'random' used, it is replaced by a valid option at random. + + Currently, we artificially extend the knonw_networks with +- 1 in the third octet. + """ + self.logger.info(f'Generating state from view:{view}') + # re-map all networks based on current mapping in self._network_mapping + known_networks = set([self._network_mapping[net] for net in view["known_networks"]]) + # parse controlled hosts + controlled_hosts = self._get_controlled_hosts_from_view(view["controlled_hosts"]) known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]]) # Add all controlled hosts to known_hosts known_hosts = known_hosts.union(controlled_hosts) - if add_neighboring_nets: # Extend the known networks with the neighbouring networks # This is to solve in the env (and not in the agent) the problem @@ -113,12 +156,10 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga known_networks.add(ip) #return value back to the original net_obj.value += 256 - known_services ={} - for ip, service_list in view["known_services"]: - known_services[self._ip_mapping[ip]] = service_list - known_data = {} - for ip, data_list in view["known_data"]: - known_data[self._ip_mapping[ip]] = data_list + # parse known services + known_services = self._get_services_from_view(view["known_services"]) + # parse known data + known_data = self._get_data_from_view(view["known_data"]) game_state = GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks) self.logger.info(f"Generated GameState:{game_state}") return game_state