Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions AIDojoCoordinator/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def read_agents_known_data(self, type_agent: str, type_data: str) -> dict:
known_data[known_data_host] = set()
for datum in data:
if not isinstance(datum, list) and datum.lower() == "random":
known_data[known_data_host] = "random"
known_data[known_data_host].add("random")
else:
known_data_content_str_user = datum[0]
known_data_content_str_data = datum[1]
Expand Down Expand Up @@ -219,15 +219,21 @@ def read_agents_known_services(self, type_agent: str, type_data: str) -> dict:
# Check the host is a good ip
_ = netaddr.IPAddress(ip)
known_services_host = IP(ip)
if data.lower() == "random":
known_services[known_services_host] = "random"
name = data[0]
type = data[1]
version = data[2]
is_local = data[3]

known_services[known_services_host] = Service(name, type, version, is_local)

known_services[known_services_host] = []
for service in data: # process each item in the list
if isinstance(service, list): # Service defined as list
name = service[0]
type = service[1]
version = service[2]
is_local = service[3]
known_services[known_services_host].append(Service(name, type, version, is_local))
elif isinstance(service, str): # keyword
if service.lower() == "random":
known_services[known_services_host].append("random")
else:
logging.warning(f"Unsupported values in agent known_services{ip}:{service}")
else:
logging.warning(f"Unsupported values in agent known_services{ip}:{service}")
except (ValueError, netaddr.AddrFormatError):
known_services = {}
return known_services
Expand Down
83 changes: 62 additions & 21 deletions AIDojoCoordinator/worlds/NSEGameCoordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import random
import numpy as np
import copy
from faker import Faker
from pathlib import Path
import netaddr, re
import json
from faker import Faker
from pathlib import Path
from typing import Iterable

from AIDojoCoordinator.game_components import GameState, Action, ActionType, IP, Network, Data, Service
from AIDojoCoordinator.coordinator import GameCoordinator
Expand Down Expand Up @@ -54,21 +55,14 @@ def _initialize(self)->None:
self._data_original = copy.deepcopy(self._data)
self._firewall_original = copy.deepcopy(self._firewall)
self.logger.info("Environment initialization finished")

def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState:
def _get_controlled_hosts_from_view(self, view_controlled_hosts:Iterable)->set:
"""
Builds a GameState from given view.
If there is a keyword 'random' used, it is replaced by a valid option at random.

Currently, we artificially extend the knonw_networks with +- 1 in the third octet.
Parses view and translates all keywords. Produces set of controlled host (IP)
"""
self.logger.info(f'Generating state from view:{view}')
# re-map all networks based on current mapping in self._network_mapping
known_networks = set([self._network_mapping[net] for net in view["known_networks"]])

controlled_hosts = set()
# controlled_hosts
for host in view['controlled_hosts']:
for host in view_controlled_hosts:
if isinstance(host, IP):
controlled_hosts.add(self._ip_mapping[host])
self.logger.debug(f'\tThe attacker has control of host {self._ip_mapping[host]}.')
Expand All @@ -85,11 +79,60 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga
controlled_hosts = controlled_hosts.union(self._get_all_local_ips())
else:
self.logger.error(f"Unsupported value encountered in start_position['controlled_hosts']: {host}")
# re-map all known based on current mapping in self._ip_mapping
return controlled_hosts

def _get_services_from_view(self, view_known_services:dict)->dict:
known_services ={}
for ip, service_list in view_known_services.items():
if self._ip_mapping[ip] not in known_services:
known_services[self._ip_mapping[ip]] = set()
for s in service_list:
if isinstance(s, Service):
known_services[self._ip_mapping[ip]].add(s)
elif isinstance(s, str):
if s == "random": # randomly select the service
self.logger.info(f"\tSelecting service randomly in {self._ip_mapping[ip]}")
# select candidates that are not explicitly listed
service_candidates = [s for s in self._services[self._ip_to_hostname[ip]] if s not in known_services[self._ip_mapping[ip]]]
# randomly select from candidates
known_services[self._ip_mapping[ip]].add(random.choice(service_candidates))
return known_services

def _get_data_from_view(self, view_known_data:dict)->dict:
known_data = {}
for ip, data_list in view_known_data.items():
if self._ip_mapping[ip] not in known_data:
known_data[self._ip_mapping[ip]] = set()
for datum in data_list:
if isinstance(datum, Data):
known_data[self._ip_mapping[ip]].add(datum)
elif isinstance(datum, str):
if datum == "random": # randomly select the data
self.logger.info(f"\tSelecting data randomly in {self._ip_mapping[ip]}")
# select candidates that are not explicitly listed
data_candidates = [d for d in self._data[self._ip_to_hostname[ip]] if d not in known_data[self._ip_mapping[ip]]]
if len(data_candidates) > 0:
# randomly select from candidates
known_data[self._ip_mapping[ip]].add(random.choice(data_candidates))
else:
self.logger.warning("\tNo available data. Skipping")
return known_data

def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->GameState:
"""
Builds a GameState from given view.
If there is a keyword 'random' used, it is replaced by a valid option at random.

Currently, we artificially extend the knonw_networks with +- 1 in the third octet.
"""
self.logger.info(f'Generating state from view:{view}')
# re-map all networks based on current mapping in self._network_mapping
known_networks = set([self._network_mapping[net] for net in view["known_networks"]])
# parse controlled hosts
controlled_hosts = self._get_controlled_hosts_from_view(view["controlled_hosts"])
known_hosts = set([self._ip_mapping[ip] for ip in view["known_hosts"]])
# Add all controlled hosts to known_hosts
known_hosts = known_hosts.union(controlled_hosts)

if add_neighboring_nets:
# Extend the known networks with the neighbouring networks
# This is to solve in the env (and not in the agent) the problem
Expand All @@ -113,12 +156,10 @@ def _create_state_from_view(self, view:dict, add_neighboring_nets:bool=True)->Ga
known_networks.add(ip)
#return value back to the original
net_obj.value += 256
known_services ={}
for ip, service_list in view["known_services"]:
known_services[self._ip_mapping[ip]] = service_list
known_data = {}
for ip, data_list in view["known_data"]:
known_data[self._ip_mapping[ip]] = data_list
# parse known services
known_services = self._get_services_from_view(view["known_services"])
# parse known data
known_data = self._get_data_from_view(view["known_data"])
game_state = GameState(controlled_hosts, known_hosts, known_services, known_data, known_networks)
self.logger.info(f"Generated GameState:{game_state}")
return game_state
Expand Down