Skip to content

Commit 43b85eb

Browse files
authored
Merge pull request #374 from stratosphereips/ondra-fix-goal-randomization
Ondra fix goal randomization
2 parents 787f69e + fdfcc65 commit 43b85eb

File tree

6 files changed

+327
-176
lines changed

6 files changed

+327
-176
lines changed

AIDojoCoordinator/coordinator.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por
180180
self._agent_starting_position = {}
181181
# current state per agent_addr (GameState)
182182
self._agent_states = {}
183+
# goal state per agent_addr (GameState)
184+
self._agent_goal_states = {}
183185
# last action played by agent (Action)
184186
self._agent_last_action = {}
185187
# False positives per agent (due to added blocks)
@@ -462,11 +464,11 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No
462464
agent_role = action.parameters["agent_info"].role
463465
if agent_role in self.ALLOWED_ROLES:
464466
# add agent to the world
465-
new_agent_game_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role])
467+
new_agent_game_state, new_agent_goal_state = await self.register_agent(agent_addr, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role])
466468
if new_agent_game_state: # successful registration
467469
async with self._agents_lock:
468470
self.agents[agent_addr] = (agent_name, agent_role)
469-
observation = self._initialize_new_player(agent_addr, new_agent_game_state)
471+
observation = self._initialize_new_player(agent_addr, new_agent_game_state, new_agent_goal_state)
470472
self._agent_observations[agent_addr] = observation
471473
#if len(self.agents) == self._min_required_players:
472474
if sum(1 for v in self._agent_status.values() if v == AgentStatus.PlayingWithTimeout) >= self._min_required_players:
@@ -720,10 +722,13 @@ async def _reset_game(self):
720722
async with self._agents_lock:
721723
self._store_trajectory_to_file(agent)
722724
self.logger.debug(f"Resetting agent {agent}")
723-
new_state = await self.reset_agent(agent, self.agents[agent][1], self._agent_starting_position[agent])
725+
agent_role = self.agents[agent][1]
726+
# reset the agent in the world
727+
new_state, new_goal_state = await self.reset_agent(agent, agent_role, self._starting_positions_per_role[agent_role], self._win_conditions_per_role[agent_role])
724728
new_observation = Observation(new_state, 0, False, {})
725729
async with self._agents_lock:
726730
self._agent_states[agent] = new_state
731+
self._agent_goal_states[agent] = new_goal_state
727732
self._agent_observations[agent] = new_observation
728733
self._episode_ends[agent] = False
729734
self._reset_requests[agent] = False
@@ -741,7 +746,7 @@ async def _reset_game(self):
741746
self._reset_done_condition.notify_all()
742747
self.logger.info("\tReset game task stopped.")
743748

744-
def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState) -> Observation:
749+
def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState, agent_current_goal_state:GameState) -> Observation:
745750
"""
746751
Method to initialize new player upon joining the game.
747752
Returns initial observation for the agent based on the agent's role
@@ -753,6 +758,8 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState
753758
self._episode_ends[agent_addr] = False
754759
self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
755760
self._agent_states[agent_addr] = agent_current_state
761+
self._agent_goal_states[agent_addr] = agent_current_goal_state
762+
self._agent_last_action[agent_addr] = None
756763
self._agent_rewards[agent_addr] = 0
757764
self._agent_false_positives[agent_addr] = 0
758765
if agent_role.lower() == "attacker":
@@ -764,7 +771,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState
764771
# create initial observation
765772
return Observation(self._agent_states[agent_addr], 0, False, {})
766773

767-
async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
774+
async def register_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]:
768775
"""
769776
Domain specific method of the environment. Creates the initial state of the agent.
770777
"""
@@ -775,8 +782,8 @@ async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool:
775782
Domain specific method of the environment. Creates the initial state of the agent.
776783
"""
777784
raise NotImplementedError
778-
779-
async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict)->GameState:
785+
786+
async def reset_agent(self, agent_id:tuple, agent_role:str, agent_initial_view:dict, agent_win_condition_view:dict)->tuple[GameState, GameState]:
780787
raise NotImplementedError
781788

782789
async def _remove_agent_from_game(self, agent_addr):
@@ -788,6 +795,7 @@ async def _remove_agent_from_game(self, agent_addr):
788795
async with self._agents_lock:
789796
if agent_addr in self.agents:
790797
agent_info["state"] = self._agent_states.pop(agent_addr)
798+
agent_info["goal_state"] = self._agent_goal_states.pop(agent_addr)
791799
agent_info["num_steps"] = self._agent_steps.pop(agent_addr)
792800
agent_info["agent_status"] = self._agent_status.pop(agent_addr)
793801
agent_info["false_positives"] = self._agent_false_positives.pop(agent_addr)
@@ -816,7 +824,7 @@ async def _remove_agent_from_game(self, agent_addr):
816824
async def step(self, agent_id:tuple, agent_state:GameState, action:Action):
817825
raise NotImplementedError
818826

819-
async def reset(self):
827+
async def reset(self)->bool:
820828
return NotImplemented
821829

822830
def _initialize(self):
@@ -846,16 +854,17 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
846854
return False
847855
return False
848856
self.logger.debug(f"Checking goal for agent {agent_addr}.")
849-
goal_conditions = self._win_conditions_per_role[self.agents[agent_addr][1]]
850857
state = self._agent_states[agent_addr]
851858
# For each part of the state of the game, check if the conditions are met
859+
target_goal_state = self._agent_goal_states[agent_addr]
860+
self.logger.debug(f"\tGoal conditions: {target_goal_state}.")
852861
goal_reached = {}
853-
goal_reached["networks"] = set(goal_conditions["known_networks"]) <= set(state.known_networks)
854-
goal_reached["known_hosts"] = set(goal_conditions["known_hosts"]) <= set(state.known_hosts)
855-
goal_reached["controlled_hosts"] = set(goal_conditions["controlled_hosts"]) <= set(state.controlled_hosts)
856-
goal_reached["services"] = goal_dict_satistfied(goal_conditions["known_services"], state.known_services)
857-
goal_reached["data"] = goal_dict_satistfied(goal_conditions["known_data"], state.known_data)
858-
goal_reached["known_blocks"] = goal_dict_satistfied(goal_conditions["known_blocks"], state.known_blocks)
862+
goal_reached["networks"] = target_goal_state.known_networks <= state.known_networks
863+
goal_reached["known_hosts"] = target_goal_state.known_hosts <= state.known_hosts
864+
goal_reached["controlled_hosts"] = target_goal_state.controlled_hosts <= state.controlled_hosts
865+
goal_reached["services"] = goal_dict_satistfied(target_goal_state.known_services, state.known_services)
866+
goal_reached["data"] = goal_dict_satistfied(target_goal_state.known_data, state.known_data)
867+
goal_reached["known_blocks"] = goal_dict_satistfied(target_goal_state.known_blocks, state.known_blocks)
859868
self.logger.debug(f"\t{goal_reached}")
860869
return all(goal_reached.values())
861870

AIDojoCoordinator/game_components.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,8 @@ def as_graph(self)->tuple:
551551
graph_nodes = {}
552552
node_features = []
553553
controlled = []
554+
edges = []
554555
try:
555-
edges = []
556556
#add known nets
557557
for net in self.known_networks:
558558
graph_nodes[net] = len(graph_nodes)
@@ -738,6 +738,8 @@ def from_string(cls, string:str)->"GameStatus":
738738
return GameStatus.FORBIDDEN
739739
case "GameStatus.RESET_DONE":
740740
return GameStatus.RESET_DONE
741+
case _:
742+
raise ValueError(f"Invalid GameStatus string: {string}")
741743
def __repr__(self) -> str:
742744
"""
743745
Return the string representation of the GameStatus.

AIDojoCoordinator/utils/utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import json
1414
import hashlib
1515
from cyst.api.configuration.network.node import NodeConfig
16+
from typing import Optional
1617

1718
def get_file_hash(filepath, hash_func='sha256', chunk_size=4096):
1819
"""
@@ -111,7 +112,7 @@ def observation_as_dict(observation:Observation)->dict:
111112
}
112113
return observation_dict
113114

114-
def parse_log_content(log_content:str)->list:
115+
def parse_log_content(log_content:str)->Optional[list]:
115116
try:
116117
logs = []
117118
data = json.loads(log_content)
@@ -154,7 +155,7 @@ def read_config_file(self, conf_file_name:str):
154155
self.logger.error(f'Error loading the configuration file{e}')
155156
pass
156157

157-
def read_env_action_data(self, action_name: str) -> dict:
158+
def read_env_action_data(self, action_name: str) -> float:
158159
"""
159160
Generic function to read the known data for any agent and goal of position
160161
"""
@@ -238,7 +239,7 @@ def read_agents_known_services(self, type_agent: str, type_data: str) -> dict:
238239
known_services = {}
239240
return known_services
240241

241-
def read_agents_known_networks(self, type_agent: str, type_data: str) -> dict:
242+
def read_agents_known_networks(self, type_agent: str, type_data: str) -> set:
242243
"""
243244
Generic function to read the known networks for any agent and goal of position
244245
"""
@@ -251,10 +252,10 @@ def read_agents_known_networks(self, type_agent: str, type_data: str) -> dict:
251252
host_part, net_part = net.split('/')
252253
known_networks.add(Network(host_part, int(net_part)))
253254
except (ValueError, TypeError, netaddr.AddrFormatError):
254-
self.logger('Configuration problem with the known networks')
255+
self.logger.error('Configuration problem with the known networks')
255256
return known_networks
256257

257-
def read_agents_known_hosts(self, type_agent: str, type_data: str) -> dict:
258+
def read_agents_known_hosts(self, type_agent: str, type_data: str) -> set:
258259
"""
259260
Generic function to read the known hosts for any agent and goal of position
260261
"""
@@ -274,7 +275,7 @@ def read_agents_known_hosts(self, type_agent: str, type_data: str) -> dict:
274275
self.logger.error(f'Configuration problem with the known hosts: {e}')
275276
return known_hosts
276277

277-
def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> dict:
278+
def read_agents_controlled_hosts(self, type_agent: str, type_data: str) -> set:
278279
"""
279280
Generic function to read the controlled hosts for any agent and goal of position
280281
"""
@@ -395,7 +396,7 @@ def get_win_conditions(self, agent_role):
395396
case _:
396397
raise ValueError(f"Unsupported agent role: {agent_role}")
397398

398-
def get_max_steps(self, role=str)->int:
399+
def get_max_steps(self, role=str)->Optional[int]:
399400
"""
400401
Get the max steps based on agent's role
401402
"""
@@ -409,7 +410,7 @@ def get_max_steps(self, role=str)->int:
409410
self.logger.warning(f"Unsupported value in 'coordinator.agents.{role}.max_steps': {e}. Setting value to default=None (no step limit)")
410411
return max_steps
411412

412-
def get_goal_description(self, agent_role)->dict:
413+
def get_goal_description(self, agent_role)->str:
413414
"""
414415
Get goal description per role
415416
"""
@@ -554,7 +555,7 @@ def get_starting_position_from_cyst_config(cyst_objects):
554555
if isinstance(obj, NodeConfig):
555556
for active_service in obj.active_services:
556557
if active_service.type == "netsecenv_agent":
557-
print(f"startig processing {obj.id}.{active_service.name}")
558+
print(f"starting processing {obj.id}.{active_service.name}")
558559
hosts = set()
559560
networks = set()
560561
for interface in obj.interfaces:

0 commit comments

Comments
 (0)