Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions env/worlds/network_security_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None:
self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs.
self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs.
self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects.
self._fw_blocks = {}
self._data_content = {} #content of each datapoint from self._data
# All exploits in the environment
self._exploits = {}
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None:
# Make a copy of data placements so it is possible to reset to it when episode ends
self._data_original = copy.deepcopy(self._data)
self._data_content_original = copy.deepcopy(self._data_content)
self._firewall_original = copy.deepcopy(self._firewall)

self._actions_played = []
self.logger.info("Environment initialization finished")
Expand Down Expand Up @@ -457,6 +459,16 @@ def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->set:
else:
self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.")
return data

def _get_known_blocks_in_host(self, host_ip:str, controlled_hosts:set)->set:
known_blocks = set()
if host_ip in controlled_hosts: #only return data if the agent controls the host
if host_ip in self._ip_to_hostname:
if host_ip in self._fw_blocks:
known_blocks = self._fw_blocks[host_ip]
else:
self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.")
return known_blocks

def _get_data_content(self, host_ip:str, data_id:str)->str:
"""
Expand Down Expand Up @@ -579,6 +591,13 @@ def _execute_find_data_action(self, current:components.GameState, action:compone
next_data[action.parameters["target_host"]] = new_data
else:
next_data[action.parameters["target_host"]] = next_data[action.parameters["target_host"]].union(new_data)
# ADD KNOWN FW BLOCKS
new_blocks = self._get_known_blocks_in_host(action.parameters["target_host"], current.controlled_hosts)
if len(new_blocks) > 0:
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = new_blocks
else:
next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocks)
else:
self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping")
else:
Expand Down Expand Up @@ -680,41 +699,48 @@ def _execute_block_ip_action(self, current_state:components.GameState, action:co
- Add the rule to the FW list
- Update the state
"""
blocked_host = action.parameters['blocked_host']

next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state)
self.logger.info(f"\t\tBlockIP {action.parameters['target_host']}")
# Is the src in the controlled hosts?
if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts:
# Is the target in the controlled hosts?
if "target_host" in action.parameters.keys() and action.parameters["target_host"] in current_state.controlled_hosts:
# For now there is only one FW in the main router, but this should change in the future.
# This means we ignore the 'target_host' that would be the router where this is applied.

# Stop the blocked host to connect _to_ any other IP
try:
self._firewall[blocked_host] = set()
self.logger.debug(f"Removing all allowed connections from {blocked_host}")
except KeyError:
# The blocked_host host was not in the list
pass
# Stop the other hosts to connect _to the blocked_host_
for host in self._firewall.keys():
try:
self._firewall[host].remove(blocked_host)
self.logger.debug(f"Removing {blocked_host} from allowed connections from {host}")
except KeyError:
# The blocked_host host was not in the list
pass
# Update the state of blocked ips. It is a dict with key target_host and a set with blocked hosts inside
new_blocked = set()
# Store the blocked host IP in the set of blocked hosts
new_blocked.add(action.parameters["blocked_host"])
if len(new_blocked) > 0:
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = new_blocked
if self._firewall_check(action.parameters["source_host"], action.parameters["target_host"]):
if action.parameters["target_host"] != action.parameters['blocked_host']:
self.logger.info(f"\t\tBlockConnection {action.parameters['target_host']} <-> {action.parameters['blocked_host']}")
try:
#remove connection target_host -> blocked_host
self._firewall[action.parameters["target_host"]].discard(action.parameters["blocked_host"])
self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['target_host']}' -> {action.parameters['blocked_host']}")
except KeyError:
pass
try:
#remove blocked_host -> target_host
self._firewall[action.parameters["blocked_host"]].discard(action.parameters["target_host"])
self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['blocked_host']}' -> {action.parameters['target_host']}")
except KeyError:
pass

#Update the FW_Rules visible to agents
if action.parameters["target_host"] not in self._fw_blocks.keys():
self._fw_blocks[action.parameters["target_host"]] = set()
self._fw_blocks[action.parameters["target_host"]].add(action.parameters["blocked_host"])
if action.parameters["blocked_host"] not in self._fw_blocks.keys():
self._fw_blocks[action.parameters["blocked_host"]] = set()
self._fw_blocks[action.parameters["blocked_host"]].add(action.parameters["target_host"])

# update the state
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = set()
if action.parameters["blocked_host"] not in next_blocked.keys():
next_blocked[action.parameters["blocked_host"]] = set()
next_blocked[action.parameters["target_host"]].add(action.parameters["blocked_host"])
next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"])
else:
next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocked)
self.logger.info(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'")
else:
self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW")
else:
self.logger.info(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'")
else:
Expand Down Expand Up @@ -870,6 +896,8 @@ def reset(self)->None:
self._data = copy.deepcopy(self._data_original)
# reset self._data_content to orignal state
self._data_content_original = copy.deepcopy(self._data_content_original)
self._firewall = copy.deepcopy(self._firewall_original)
self._fw_blocks = {}


self._actions_played = []
Expand Down
59 changes: 59 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def env_obs_found_data2(env_obs_exploited_service2):
new_state = env.step(state=state, action=action, agent_id=None)
return (env, new_state)

@pytest.fixture
def env_obs_blocked_connection(env_obs_exploited_service):
"After blocking"
env, state = env_obs_exploited_service
source_host = components.IP('192.168.2.2')
target_host = components.IP('192.168.1.3')
blocked_host = components.IP('192.168.2.2')
parameters = {
"target_host":target_host,
"source_host":source_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
return (env, new_state)

class TestActionsNoDefender:
def test_scan_network_not_exist(self, env_obs):
"""
Expand Down Expand Up @@ -303,3 +318,47 @@ def test_exploit_service_witout_find_service_in_host(self, env_obs_scan):
new_state = env.step(state=state, action=action, agent_id=None)
assert state == new_state
assert components.IP('192.168.1.3') not in new_state.known_services

def test_block_ip_same_host(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
target_host = components.IP('192.168.2.2')
blocked_host = components.IP("1.1.1.1")
parameters = {
"target_host":target_host,
"source_host":target_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host in new_state.known_blocks.keys()
assert blocked_host in new_state.known_blocks[target_host]
assert target_host in env._fw_blocks.keys()
assert blocked_host in env._fw_blocks[target_host]

def test_block_ip_same_different_source(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
source_host = components.IP('192.168.2.2')
target_host = components.IP("192.168.1.3")
blocked_host = components.IP("1.1.1.1")
parameters = {
"target_host":target_host,
"source_host":source_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host in new_state.known_blocks.keys()
assert blocked_host in new_state.known_blocks[target_host]
assert target_host in env._fw_blocks.keys()
assert blocked_host in env._fw_blocks[target_host]


def test_block_ip_self_block(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
target_host = components.IP('192.168.2.2')
parameters = {
"target_host":target_host,
"source_host":components.IP('192.168.2.2'),
"blocked_host": target_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host not in new_state.known_blocks.keys()
assert target_host not in env._fw_blocks.keys()
Loading