Skip to content

Commit 696c8fe

Browse files
authored
Merge pull request #256 from stratosphereips/assign-reward-at-episode-end
Assign reward at episode end
2 parents 00ce5be + f23692c commit 696c8fe

File tree

10 files changed

+178
-63
lines changed

10 files changed

+178
-63
lines changed

README.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ env:
134134
## Task configuration
135135
The task configuration part (section `coordinator[agents]`) defines the starting and goal position of the attacker and the type of defender that is used.
136136

137-
### Attacker configuration (`attackers`)
138-
Configuration of the attacking agents. Consists of two parts:
137+
### Attacker configuration (`Attacker`)
138+
Configuration of the attacking agents. Consists of three parts:
139139
1. Goal definition (`goal`) which describes the `GameState` properties that must be fulfilled to award `goal_reward` to the attacker:
140140
- `known_networks:`(set)
141141
- `known_hosts`(set)
@@ -154,11 +154,14 @@ Configuration of the attacking agents. Consists of two parts:
154154
- `known_data`(dict)
155155

156156
The initial network configuration must assign at least **one** controlled host to the attacker in the network. Any item in `controlled_hosts` is copied to `known_hosts`, so there is no need to include these in both sets. `known_networks` is also extended with a set of **all** networks accessible from the `controlled_hosts`
157+
3. Definition of maximum allowed amount of steps:
158+
- `max_steps:`(int)
157159

158160
Example attacker configuration:
159161
```YAML
160162
agents:
161163
Attacker:
164+
max_steps: 100
162165
goal:
163166
randomize_goal_every_episode: False
164167
known_networks: []
@@ -179,7 +182,7 @@ agents:
179182
known_data: {}
180183
known_blocks: {}
181184
```
182-
### Defender configuration (`defenders`)
185+
### Defender configuration (`Defender`)
183186
Currently, the defender **is** a separate agent.
184187

185188
If you want a defender in the game, you must connect a defender agent. For playing without a defender, leave the section empty.

coordinator.py

Lines changed: 111 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
188188
self._starting_positions_per_role = self._get_starting_position_per_role()
189189
self._win_conditions_per_role = self._get_win_condition_per_role()
190190
self._goal_description_per_role = self._get_goal_description_per_role()
191-
self._steps_limit = self._world.task_config.get_max_steps()
191+
self._steps_limit_per_role = self._get_max_steps_per_role()
192192
self._use_global_defender = self._world.task_config.get_use_global_defender()
193193
# player information
194194
self.agents = {}
@@ -201,18 +201,19 @@ def __init__(self, actions_queue, answers_queue, net_sec_config, allowed_roles,
201201
self._agent_starting_position = {}
202202
# current state per agent_addr (GameState)
203203
self._agent_states = {}
204-
# goal reach status per agent_addr (bool)
205-
self._agent_goal_reached = {}
206-
self._agent_episode_ends = {}
207-
self._agent_detected = {}
204+
# agent status dict {agent_addr: string}
205+
self._agent_statuses = {}
206+
# agent status dict {agent_addr: int}
207+
self._agent_rewards = {}
208208
# trajectories per agent_addr
209209
self._agent_trajectories = {}
210210

211211
@property
212212
def episode_end(self)->bool:
213-
# Terminate episode if at least one player wins or reaches the timeout
214-
self.logger.debug(f"End evaluation: {self._agent_episode_ends.values()}")
215-
return all(self._agent_episode_ends.values())
213+
# Episode ends ONLY IF all agents with defined max_steps reached the end fo the episode
214+
exists_active_player = any(status == "playing_active" for status in self._agent_statuses.values())
215+
self.logger.debug(f"End evaluation: {self._agent_statuses.items()} - Episode end:{not exists_active_player}")
216+
return not exists_active_player
216217

217218
@property
218219
def config_file_hash(self):
@@ -273,8 +274,13 @@ async def run(self):
273274
self._reset_requests[agent] = False
274275
self._agent_steps[agent] = 0
275276
self._agent_states[agent] = self._world.create_state_from_view(self._agent_starting_position[agent])
276-
self._agent_goal_reached[agent] = self._goal_reached(agent)
277-
self._agent_episode_ends[agent] = False
277+
self._agent_rewards.pop(agent, None)
278+
if self._steps_limit_per_role[self.agents[agent][1]]:
279+
# This agent can force episode end (has timeout and goal defined)
280+
self._agent_statuses[agent] = "playing_active"
281+
else:
282+
# This agent can NOT force episode end (does NOT timeout or goal defined)
283+
self._agent_statuses[agent] = "playing"
278284
output_message_dict = self._create_response_to_reset_game_action(agent)
279285
msg_json = self.convert_msg_dict_to_json(output_message_dict)
280286
# Send to anwer_queue
@@ -307,9 +313,13 @@ def _initialize_new_player(self, agent_addr:tuple, agent_name:str, agent_role:st
307313
self._reset_requests[agent_addr] = False
308314
self._agent_starting_position[agent_addr] = self._starting_positions_per_role[agent_role]
309315
self._agent_states[agent_addr] = self._world.create_state_from_view(self._agent_starting_position[agent_addr])
310-
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)
311-
self._agent_detected[agent_addr] = self._check_detection(agent_addr, None)
312-
self._agent_episode_ends[agent_addr] = False
316+
317+
if self._steps_limit_per_role[agent_role]:
318+
# This agent can force episode end (has timeout and goal defined)
319+
self._agent_statuses[agent_addr] = "playing_active"
320+
else:
321+
# This agent can NOT force episode end (does NOT timeout or goal defined)
322+
self._agent_statuses[agent_addr] = "playing"
313323
if self._world.task_config.get_store_trajectories() or self._use_global_defender:
314324
self._agent_trajectories[agent_addr] = self._reset_trajectory(agent_addr)
315325
self.logger.info(f"\tAgent {agent_name} ({agent_addr}), registred as {agent_role}")
@@ -323,10 +333,10 @@ def _remove_player(self, agent_addr:tuple)->dict:
323333
agent_info = {}
324334
if agent_addr in self.agents:
325335
agent_info["state"] = self._agent_states.pop(agent_addr)
326-
agent_info["goal_reached"] = self._agent_goal_reached.pop(agent_addr)
336+
agent_info["status"] = self._agent_statuses.pop(agent_addr)
327337
agent_info["num_steps"] = self._agent_steps.pop(agent_addr)
328338
agent_info["reset_request"] = self._reset_requests.pop(agent_addr)
329-
agent_info["episode_end"] = self._agent_episode_ends.pop(agent_addr)
339+
agent_info["end_reward"] = self._agent_rewards.pop(agent_addr, None)
330340
agent_info["agent_info"] = self.agents.pop(agent_addr)
331341
self.logger.debug(f"\t{agent_info}")
332342
else:
@@ -376,6 +386,19 @@ def _get_goal_description_per_role(self)->dict:
376386
self.logger.info(f"Goal description for role '{agent_role}': {goal_descriptions[agent_role]}")
377387
return goal_descriptions
378388

389+
def _get_max_steps_per_role(self)->dict:
390+
"""
391+
Method for finding max amount of steps in 1 episode for each agent role in the game.
392+
"""
393+
max_steps = {}
394+
for agent_role in self.ALLOWED_ROLES:
395+
try:
396+
max_steps[agent_role] = self._world.task_config.get_max_steps(agent_role)
397+
except KeyError:
398+
max_steps[agent_role] = None
399+
self.logger.info(f"Max steps in episode for '{agent_role}': {max_steps[agent_role]}")
400+
return max_steps
401+
379402
def _process_join_game_action(self, agent_addr: tuple, action: Action) -> dict:
380403
""" "
381404
Method for processing Action of type ActionType.JoinGame
@@ -386,14 +409,13 @@ def _process_join_game_action(self, agent_addr: tuple, action: Action) -> dict:
386409
agent_role = action.parameters["agent_info"].role
387410
if agent_role in self.ALLOWED_ROLES:
388411
initial_observation = self._initialize_new_player(agent_addr, agent_name, agent_role)
389-
max_steps = self._world._max_steps if agent_role == "Attacker" else None
390412
output_message_dict = {
391413
"to_agent": agent_addr,
392414
"status": str(GameStatus.CREATED),
393415
"observation": observation_as_dict(initial_observation),
394416
"message": {
395417
"message": f"Welcome {agent_name}, registred as {agent_role}",
396-
"max_steps": max_steps,
418+
"max_steps": self._steps_limit_per_role[agent_role],
397419
"goal_description": self._goal_description_per_role[agent_role],
398420
"num_actions": self._world.num_actions,
399421
"configuration_hash": self._CONFIG_FILE_HASH
@@ -436,8 +458,9 @@ def _create_response_to_reset_game_action(self, agent_addr: tuple) -> dict:
436458
"observation": observation_as_dict(new_observation),
437459
"message": {
438460
"message": "Resetting Game and starting again.",
439-
"max_steps": self._world._max_steps,
440-
"goal_description": self._goal_description_per_role[self.agents[agent_addr][1]]
461+
"max_steps": self._steps_limit_per_role[self.agents[agent_addr][1]],
462+
"goal_description": self._goal_description_per_role[self.agents[agent_addr][1]],
463+
"configuration_hash": self._CONFIG_FILE_HASH
441464
},
442465
}
443466
return output_message_dict
@@ -491,24 +514,34 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
491514
current_state = self._agent_states[agent_addr]
492515
# Build new Observation for the agent
493516
self._agent_states[agent_addr] = self._world.step(current_state, action, agent_addr)
494-
self._agent_goal_reached[agent_addr] = self._goal_reached(agent_addr)
495-
496-
self._agent_detected[agent_addr] = self._check_detection(agent_addr, action)
497-
517+
# check timout
518+
if self._max_steps_reached(agent_addr):
519+
self._agent_statuses[agent_addr] = "max_steps"
520+
# check detection
521+
if self._check_detection(agent_addr, action):
522+
self._agent_statuses[agent_addr] = "blocked"
523+
self._agent_detected[agent_addr] = True
524+
# check goal
525+
if self._goal_reached(agent_addr):
526+
self._agent_statuses[agent_addr] = "goal_reached"
527+
# add reward for taking a step
498528
reward = self._world._rewards["step"]
529+
499530
obs_info = {}
500531
end_reason = None
501-
if self._agent_goal_reached[agent_addr]:
502-
reward += self._world._rewards["goal"]
503-
self._agent_episode_ends[agent_addr] = True
532+
if self._agent_statuses[agent_addr] == "goal_reached":
533+
self._assign_end_rewards()
534+
reward += self._agent_rewards[agent_addr]
504535
end_reason = "goal_reached"
505536
obs_info = {'end_reason': "goal_reached"}
506-
elif self._timeout_reached(agent_addr):
507-
self._agent_episode_ends[agent_addr] = True
537+
elif self._agent_statuses[agent_addr] == "max_steps":
538+
self._assign_end_rewards()
539+
reward += self._agent_rewards[agent_addr]
508540
obs_info = {"end_reason": "max_steps"}
509541
end_reason = "max_steps"
510-
elif self._agent_detected[agent_addr]:
511-
reward += self._world._rewards["detection"]
542+
elif self._agent_statuses[agent_addr] == "blocked":
543+
self._assign_end_rewards()
544+
reward += self._agent_rewards[agent_addr]
512545
self._agent_episode_ends[agent_addr] = True
513546
obs_info = {"end_reason": "max_steps"}
514547

@@ -524,6 +557,7 @@ def _process_generic_action(self, agent_addr: tuple, action: Action) -> dict:
524557
"status": str(GameStatus.OK),
525558
}
526559
else:
560+
self._assign_end_rewards()
527561
self.logger.error(f"{self.episode_end}, {self._agent_episode_ends}")
528562
output_message_dict = self._generate_episode_end_message(agent_addr)
529563
return output_message_dict
@@ -533,15 +567,8 @@ def _generate_episode_end_message(self, agent_addr:tuple)->dict:
533567
Method for generating response when agent attemps to make a step after episode ended.
534568
"""
535569
current_observation = self._agent_observations[agent_addr]
536-
reward = 0 # TODO
537-
end_reason = ""
538-
if self._agent_goal_reached[agent_addr]:
539-
end_reason = "goal_reached"
540-
elif self._timeout_reached(agent_addr):
541-
end_reason = "max_steps"
542-
else:
543-
end_reason = "game_lost"
544-
reward += self._world._rewards["detection"]
570+
reward = self._agent_rewards[agent_addr]
571+
end_reason = self._agent_statuses[agent_addr]
545572
new_observation = Observation(
546573
current_observation.state,
547574
reward=reward,
@@ -586,7 +613,7 @@ def goal_dict_satistfied(goal_dict:dict, known_dict: dict)-> bool:
586613
if len(matching_keys) == len(goal_dict.keys()):
587614
return True
588615
except KeyError:
589-
#some keys are missing in the known_dict
616+
# some keys are missing in the known_dict
590617
return False
591618
return False
592619

@@ -615,18 +642,58 @@ def _check_detection(self, agent_addr:tuple, last_action:Action)->bool:
615642
self.logger.info("\tNot detected!")
616643
return detection
617644

618-
def _timeout_reached(self, agent_addr:tuple) ->bool:
645+
def _max_steps_reached(self, agent_addr:tuple) ->bool:
619646
"""
620647
Checks if the agent reached the max allowed steps. Only applies to role 'Attacker'
621648
"""
622649
self.logger.debug(f"Checking timout for {self.agents[agent_addr]}")
623-
if self.agents[agent_addr][1] == "Attacker":
624-
if self._agent_steps[agent_addr] >= self._steps_limit:
650+
agent_role = self.agents[agent_addr][1]
651+
if self._steps_limit_per_role[agent_role]:
652+
if self._agent_steps[agent_addr] >= self._steps_limit_per_role[agent_role]:
625653
self.logger.info("Timeout reached by {self.agents[agent_addr]}!")
626654
return True
627655
else:
656+
self.logger.debug(f"No max steps defined for role {agent_role}")
628657
return False
629658

659+
def _assign_end_rewards(self)->None:
660+
"""
661+
Method which assings rewards to each agent which has finished playing
662+
"""
663+
is_episode_over = self.episode_end
664+
for agent, status in self._agent_statuses.items():
665+
if agent not in self._agent_rewards.keys(): # reward has not been assigned yet
666+
agent_name, agent_role = self.agents[agent]
667+
if agent_role == "Attacker":
668+
match status:
669+
case "goal_reached":
670+
self._agent_rewards[agent] = self._world._rewards["goal"]
671+
case "max_steps":
672+
self._agent_rewards[agent] = 0
673+
case "blocked":
674+
self._agent_rewards[agent] = self._world._rewards["detection"]
675+
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
676+
elif agent_role == "Defender":
677+
if self._agent_statuses[agent] == "max_steps": #defender was responsible for the end
678+
raise NotImplementedError
679+
self._agent_rewards[agent] = 0
680+
else:
681+
if is_episode_over: #only assign defender's reward when episode ends
682+
sucessful_attacks = list(self._agent_statuses.values).count("goal_reached")
683+
if sucessful_attacks > 0:
684+
self._agent_rewards[agent] = sucessful_attacks*self._world._rewards["detection"]
685+
self._agent_statuses[agent] = "game_lost"
686+
else: #no successful attacker
687+
self._agent_rewards[agent] = self._world._rewards["goal"]
688+
self._agent_statuses[agent] = "goal_reached"
689+
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
690+
else:
691+
if is_episode_over:
692+
self._agent_rewards[agent] = 0
693+
self.logger.info(f"End reward for {agent_name}({agent_role}, status: '{status}') = {self._agent_rewards[agent]}")
694+
695+
696+
630697
__version__ = "v0.2.2"
631698

632699

@@ -668,7 +735,7 @@ def _timeout_reached(self, agent_addr:tuple) ->bool:
668735
action="store",
669736
required=False,
670737
type=str,
671-
default="INFO",
738+
default="WARNING",
672739
)
673740

674741
args = parser.parse_args()

0 commit comments

Comments
 (0)