diff --git a/env/worlds/network_security_game.py b/env/worlds/network_security_game.py index b4056b94..c2bd3532 100755 --- a/env/worlds/network_security_game.py +++ b/env/worlds/network_security_game.py @@ -27,6 +27,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None: self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs. self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs. self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects. + self._fw_blocks = {} self._data_content = {} #content of each datapoint from self._data # All exploits in the environment self._exploits = {} @@ -76,6 +77,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None: # Make a copy of data placements so it is possible to reset to it when episode ends self._data_original = copy.deepcopy(self._data) self._data_content_original = copy.deepcopy(self._data_content) + self._firewall_original = copy.deepcopy(self._firewall) self._actions_played = [] self.logger.info("Environment initialization finished") @@ -457,6 +459,16 @@ def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->set: else: self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.") return data + + def _get_known_blocks_in_host(self, host_ip:str, controlled_hosts:set)->set: + known_blocks = set() + if host_ip in controlled_hosts: #only return data if the agent controls the host + if host_ip in self._ip_to_hostname: + if host_ip in self._fw_blocks: + known_blocks = self._fw_blocks[host_ip] + else: + self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.") + return known_blocks def _get_data_content(self, host_ip:str, data_id:str)->str: """ @@ -579,6 +591,13 @@ def _execute_find_data_action(self, current:components.GameState, action:compone next_data[action.parameters["target_host"]] = new_data else: next_data[action.parameters["target_host"]] = next_data[action.parameters["target_host"]].union(new_data) + # ADD KNOWN FW BLOCKS + new_blocks = self._get_known_blocks_in_host(action.parameters["target_host"], current.controlled_hosts) + if len(new_blocks) > 0: + if action.parameters["target_host"] not in next_blocked.keys(): + next_blocked[action.parameters["target_host"]] = new_blocks + else: + next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocks) else: self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping") else: @@ -680,41 +699,48 @@ def _execute_block_ip_action(self, current_state:components.GameState, action:co - Add the rule to the FW list - Update the state """ - blocked_host = action.parameters['blocked_host'] - next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state) - self.logger.info(f"\t\tBlockIP {action.parameters['target_host']}") # Is the src in the controlled hosts? if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts: # Is the target in the controlled hosts? if "target_host" in action.parameters.keys() and action.parameters["target_host"] in current_state.controlled_hosts: # For now there is only one FW in the main router, but this should change in the future. # This means we ignore the 'target_host' that would be the router where this is applied. - - # Stop the blocked host to connect _to_ any other IP - try: - self._firewall[blocked_host] = set() - self.logger.debug(f"Removing all allowed connections from {blocked_host}") - except KeyError: - # The blocked_host host was not in the list - pass - # Stop the other hosts to connect _to the blocked_host_ - for host in self._firewall.keys(): - try: - self._firewall[host].remove(blocked_host) - self.logger.debug(f"Removing {blocked_host} from allowed connections from {host}") - except KeyError: - # The blocked_host host was not in the list - pass - # Update the state of blocked ips. It is a dict with key target_host and a set with blocked hosts inside - new_blocked = set() - # Store the blocked host IP in the set of blocked hosts - new_blocked.add(action.parameters["blocked_host"]) - if len(new_blocked) > 0: - if action.parameters["target_host"] not in next_blocked.keys(): - next_blocked[action.parameters["target_host"]] = new_blocked + if self._firewall_check(action.parameters["source_host"], action.parameters["target_host"]): + if action.parameters["target_host"] != action.parameters['blocked_host']: + self.logger.info(f"\t\tBlockConnection {action.parameters['target_host']} <-> {action.parameters['blocked_host']}") + try: + #remove connection target_host -> blocked_host + self._firewall[action.parameters["target_host"]].discard(action.parameters["blocked_host"]) + self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['target_host']}' -> {action.parameters['blocked_host']}") + except KeyError: + pass + try: + #remove blocked_host -> target_host + self._firewall[action.parameters["blocked_host"]].discard(action.parameters["target_host"]) + self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['blocked_host']}' -> {action.parameters['target_host']}") + except KeyError: + pass + + #Update the FW_Rules visible to agents + if action.parameters["target_host"] not in self._fw_blocks.keys(): + self._fw_blocks[action.parameters["target_host"]] = set() + self._fw_blocks[action.parameters["target_host"]].add(action.parameters["blocked_host"]) + if action.parameters["blocked_host"] not in self._fw_blocks.keys(): + self._fw_blocks[action.parameters["blocked_host"]] = set() + self._fw_blocks[action.parameters["blocked_host"]].add(action.parameters["target_host"]) + + # update the state + if action.parameters["target_host"] not in next_blocked.keys(): + next_blocked[action.parameters["target_host"]] = set() + if action.parameters["blocked_host"] not in next_blocked.keys(): + next_blocked[action.parameters["blocked_host"]] = set() + next_blocked[action.parameters["target_host"]].add(action.parameters["blocked_host"]) + next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"]) else: - next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocked) + self.logger.info(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'") + else: + self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW") else: self.logger.info(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'") else: @@ -870,6 +896,8 @@ def reset(self)->None: self._data = copy.deepcopy(self._data_original) # reset self._data_content to orignal state self._data_content_original = copy.deepcopy(self._data_content_original) + self._firewall = copy.deepcopy(self._firewall_original) + self._fw_blocks = {} self._actions_played = [] diff --git a/tests/test_actions.py b/tests/test_actions.py index c0a3e66d..e3864207 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -92,6 +92,21 @@ def env_obs_found_data2(env_obs_exploited_service2): new_state = env.step(state=state, action=action, agent_id=None) return (env, new_state) +@pytest.fixture +def env_obs_blocked_connection(env_obs_exploited_service): + "After blocking" + env, state = env_obs_exploited_service + source_host = components.IP('192.168.2.2') + target_host = components.IP('192.168.1.3') + blocked_host = components.IP('192.168.2.2') + parameters = { + "target_host":target_host, + "source_host":source_host, + "blocked_host": blocked_host} + action = components.Action(components.ActionType.BlockIP, parameters) + new_state = env.step(state=state, action=action, agent_id=None) + return (env, new_state) + class TestActionsNoDefender: def test_scan_network_not_exist(self, env_obs): """ @@ -303,3 +318,47 @@ def test_exploit_service_witout_find_service_in_host(self, env_obs_scan): new_state = env.step(state=state, action=action, agent_id=None) assert state == new_state assert components.IP('192.168.1.3') not in new_state.known_services + + def test_block_ip_same_host(self, env_obs_exploited_service): + env, state = env_obs_exploited_service + target_host = components.IP('192.168.2.2') + blocked_host = components.IP("1.1.1.1") + parameters = { + "target_host":target_host, + "source_host":target_host, + "blocked_host": blocked_host} + action = components.Action(components.ActionType.BlockIP, parameters) + new_state = env.step(state=state, action=action, agent_id=None) + assert target_host in new_state.known_blocks.keys() + assert blocked_host in new_state.known_blocks[target_host] + assert target_host in env._fw_blocks.keys() + assert blocked_host in env._fw_blocks[target_host] + + def test_block_ip_same_different_source(self, env_obs_exploited_service): + env, state = env_obs_exploited_service + source_host = components.IP('192.168.2.2') + target_host = components.IP("192.168.1.3") + blocked_host = components.IP("1.1.1.1") + parameters = { + "target_host":target_host, + "source_host":source_host, + "blocked_host": blocked_host} + action = components.Action(components.ActionType.BlockIP, parameters) + new_state = env.step(state=state, action=action, agent_id=None) + assert target_host in new_state.known_blocks.keys() + assert blocked_host in new_state.known_blocks[target_host] + assert target_host in env._fw_blocks.keys() + assert blocked_host in env._fw_blocks[target_host] + + + def test_block_ip_self_block(self, env_obs_exploited_service): + env, state = env_obs_exploited_service + target_host = components.IP('192.168.2.2') + parameters = { + "target_host":target_host, + "source_host":components.IP('192.168.2.2'), + "blocked_host": target_host} + action = components.Action(components.ActionType.BlockIP, parameters) + new_state = env.step(state=state, action=action, agent_id=None) + assert target_host not in new_state.known_blocks.keys() + assert target_host not in env._fw_blocks.keys()