@@ -180,6 +180,8 @@ def __init__(self, game_host: str, game_port: int, service_host:str, service_por
180180 self ._agent_starting_position = {}
181181 # current state per agent_addr (GameState)
182182 self ._agent_states = {}
183+ # goal state per agent_addr (GameState)
184+ self ._agent_goal_states = {}
183185 # last action played by agent (Action)
184186 self ._agent_last_action = {}
185187 # False positives per agent (due to added blocks)
@@ -462,11 +464,11 @@ async def _process_join_game_action(self, agent_addr: tuple, action: Action)->No
462464 agent_role = action .parameters ["agent_info" ].role
463465 if agent_role in self .ALLOWED_ROLES :
464466 # add agent to the world
465- new_agent_game_state = await self .register_agent (agent_addr , agent_role , self ._starting_positions_per_role [agent_role ])
467+ new_agent_game_state , new_agent_goal_state = await self .register_agent (agent_addr , agent_role , self ._starting_positions_per_role [ agent_role ], self . _win_conditions_per_role [agent_role ])
466468 if new_agent_game_state : # successful registration
467469 async with self ._agents_lock :
468470 self .agents [agent_addr ] = (agent_name , agent_role )
469- observation = self ._initialize_new_player (agent_addr , new_agent_game_state )
471+ observation = self ._initialize_new_player (agent_addr , new_agent_game_state , new_agent_goal_state )
470472 self ._agent_observations [agent_addr ] = observation
471473 #if len(self.agents) == self._min_required_players:
472474 if sum (1 for v in self ._agent_status .values () if v == AgentStatus .PlayingWithTimeout ) >= self ._min_required_players :
@@ -720,10 +722,13 @@ async def _reset_game(self):
720722 async with self ._agents_lock :
721723 self ._store_trajectory_to_file (agent )
722724 self .logger .debug (f"Resetting agent { agent } " )
723- new_state = await self .reset_agent (agent , self .agents [agent ][1 ], self ._agent_starting_position [agent ])
725+ agent_role = self .agents [agent ][1 ]
726+ # reset the agent in the world
727+ new_state , new_goal_state = await self .reset_agent (agent , agent_role , self ._starting_positions_per_role [agent_role ], self ._win_conditions_per_role [agent_role ])
724728 new_observation = Observation (new_state , 0 , False , {})
725729 async with self ._agents_lock :
726730 self ._agent_states [agent ] = new_state
731+ self ._agent_goal_states [agent ] = new_goal_state
727732 self ._agent_observations [agent ] = new_observation
728733 self ._episode_ends [agent ] = False
729734 self ._reset_requests [agent ] = False
@@ -741,7 +746,7 @@ async def _reset_game(self):
741746 self ._reset_done_condition .notify_all ()
742747 self .logger .info ("\t Reset game task stopped." )
743748
744- def _initialize_new_player (self , agent_addr :tuple , agent_current_state :GameState ) -> Observation :
749+ def _initialize_new_player (self , agent_addr :tuple , agent_current_state :GameState , agent_current_goal_state : GameState ) -> Observation :
745750 """
746751 Method to initialize new player upon joining the game.
747752 Returns initial observation for the agent based on the agent's role
@@ -753,6 +758,8 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState
753758 self ._episode_ends [agent_addr ] = False
754759 self ._agent_starting_position [agent_addr ] = self ._starting_positions_per_role [agent_role ]
755760 self ._agent_states [agent_addr ] = agent_current_state
761+ self ._agent_goal_states [agent_addr ] = agent_current_goal_state
762+ self ._agent_last_action [agent_addr ] = None
756763 self ._agent_rewards [agent_addr ] = 0
757764 self ._agent_false_positives [agent_addr ] = 0
758765 if agent_role .lower () == "attacker" :
@@ -764,7 +771,7 @@ def _initialize_new_player(self, agent_addr:tuple, agent_current_state:GameState
764771 # create initial observation
765772 return Observation (self ._agent_states [agent_addr ], 0 , False , {})
766773
767- async def register_agent (self , agent_id :tuple , agent_role :str , agent_initial_view :dict )-> GameState :
774+ async def register_agent (self , agent_id :tuple , agent_role :str , agent_initial_view :dict , agent_win_condition_view : dict )-> tuple [ GameState , GameState ] :
768775 """
769776 Domain specific method of the environment. Creates the initial state of the agent.
770777 """
@@ -775,8 +782,8 @@ async def remove_agent(self, agent_id:tuple, agent_state:GameState)->bool:
775782 Domain specific method of the environment. Creates the initial state of the agent.
776783 """
777784 raise NotImplementedError
778-
779- async def reset_agent (self , agent_id :tuple , agent_role :str , agent_initial_view :dict )-> GameState :
785+
786+ async def reset_agent (self , agent_id :tuple , agent_role :str , agent_initial_view :dict , agent_win_condition_view : dict )-> tuple [ GameState , GameState ] :
780787 raise NotImplementedError
781788
782789 async def _remove_agent_from_game (self , agent_addr ):
@@ -788,6 +795,7 @@ async def _remove_agent_from_game(self, agent_addr):
788795 async with self ._agents_lock :
789796 if agent_addr in self .agents :
790797 agent_info ["state" ] = self ._agent_states .pop (agent_addr )
798+ agent_info ["goal_state" ] = self ._agent_goal_states .pop (agent_addr )
791799 agent_info ["num_steps" ] = self ._agent_steps .pop (agent_addr )
792800 agent_info ["agent_status" ] = self ._agent_status .pop (agent_addr )
793801 agent_info ["false_positives" ] = self ._agent_false_positives .pop (agent_addr )
@@ -816,7 +824,7 @@ async def _remove_agent_from_game(self, agent_addr):
816824 async def step (self , agent_id :tuple , agent_state :GameState , action :Action ):
817825 raise NotImplementedError
818826
819- async def reset (self ):
827+ async def reset (self )-> bool :
820828 return NotImplemented
821829
822830 def _initialize (self ):
@@ -846,16 +854,17 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
846854 return False
847855 return False
848856 self .logger .debug (f"Checking goal for agent { agent_addr } ." )
849- goal_conditions = self ._win_conditions_per_role [self .agents [agent_addr ][1 ]]
850857 state = self ._agent_states [agent_addr ]
851858 # For each part of the state of the game, check if the conditions are met
859+ target_goal_state = self ._agent_goal_states [agent_addr ]
860+ self .logger .debug (f"\t Goal conditions: { target_goal_state } ." )
852861 goal_reached = {}
853- goal_reached ["networks" ] = set ( goal_conditions [ " known_networks" ]) <= set ( state .known_networks )
854- goal_reached ["known_hosts" ] = set ( goal_conditions [ " known_hosts" ]) <= set ( state .known_hosts )
855- goal_reached ["controlled_hosts" ] = set ( goal_conditions [ " controlled_hosts" ]) <= set ( state .controlled_hosts )
856- goal_reached ["services" ] = goal_dict_satistfied (goal_conditions [ " known_services" ] , state .known_services )
857- goal_reached ["data" ] = goal_dict_satistfied (goal_conditions [ " known_data" ] , state .known_data )
858- goal_reached ["known_blocks" ] = goal_dict_satistfied (goal_conditions [ " known_blocks" ] , state .known_blocks )
862+ goal_reached ["networks" ] = target_goal_state . known_networks <= state .known_networks
863+ goal_reached ["known_hosts" ] = target_goal_state . known_hosts <= state .known_hosts
864+ goal_reached ["controlled_hosts" ] = target_goal_state . controlled_hosts <= state .controlled_hosts
865+ goal_reached ["services" ] = goal_dict_satistfied (target_goal_state . known_services , state .known_services )
866+ goal_reached ["data" ] = goal_dict_satistfied (target_goal_state . known_data , state .known_data )
867+ goal_reached ["known_blocks" ] = goal_dict_satistfied (target_goal_state . known_blocks , state .known_blocks )
859868 self .logger .debug (f"\t { goal_reached } " )
860869 return all (goal_reached .values ())
861870
0 commit comments