Skip to content

Commit cfbee86

Browse files
committed
Random agent compatible with the whitebox version of the env
1 parent 3fa4582 commit cfbee86

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
#Author: Ondrej Lukas, ondrej.lukas@aic.cvut.cz
2+
# This agents just randomnly picks actions. No learning
3+
import logging
4+
import argparse
5+
import numpy as np
6+
import mlflow
7+
from os import path, makedirs
8+
from random import choice
9+
from AIDojoCoordinator.game_components import Action, Observation, AgentStatus
10+
from NetSecGameAgents.agents.action_list_base_agent import ActionListAgent
11+
12+
class RandomWhiteboxAttackerAgent(ActionListAgent):
13+
"""
14+
A random attacker agent that selects actions randomly from the available action space.
15+
"""
16+
17+
def __init__(self, host, port,role, seed) -> None:
18+
super().__init__(host, port, role)
19+
20+
21+
def play_game(self, observation, num_episodes=1):
22+
"""
23+
The main function for the gameplay. Handles agent registration and the main interaction loop.
24+
"""
25+
returns = []
26+
num_steps = 0
27+
for episode in range(num_episodes):
28+
self._logger.info(f"Playing episode {episode}")
29+
episodic_returns = []
30+
while observation and not observation.end:
31+
num_steps += 1
32+
self._logger.debug(f'Observation received:{observation}')
33+
# Store returns in the episode
34+
episodic_returns.append(observation.reward)
35+
36+
# Select the action randomly
37+
action = self.select_action(observation)
38+
observation = self.make_step(action)
39+
# To return
40+
last_observation = observation
41+
self._logger.debug(f'Observation received:{observation}')
42+
returns.append(np.sum(episodic_returns))
43+
self._logger.info(f"Episode {episode} ended with return{np.sum(episodic_returns)}. Mean returns={np.mean(returns)}±{np.std(returns)}")
44+
# Reset the episode
45+
observation = self.request_game_reset()
46+
self._logger.info(f"Final results for {self.__class__.__name__} after {num_episodes} episodes: {np.mean(returns)}±{np.std(returns)}")
47+
# This will be the last observation played before the reset
48+
return (last_observation, num_steps)
49+
50+
def select_action(self, observation:Observation)->Action:
51+
# Get the valid action mask (boolean array) for the current state using the parent class method
52+
action_mask = self.get_valid_action_mask(observation.state)
53+
if not np.any(action_mask):
54+
raise ValueError("No valid actions available to select.")
55+
# Filter the action list to include only valid actions
56+
valid_actions = [action for action, valid in zip(self._action_list, action_mask) if valid]
57+
if not valid_actions:
58+
raise ValueError("No valid actions found after filtering.")
59+
# Randomly choose one of the valid actions
60+
action = choice(valid_actions)
61+
return action
62+
63+
if __name__ == '__main__':
64+
65+
parser = argparse.ArgumentParser()
66+
parser.add_argument("--host", help="Host where the game server is", default="127.0.0.1", action='store', required=False)
67+
parser.add_argument("--port", help="Port where the game server is", default=9000, type=int, action='store', required=False)
68+
parser.add_argument("--episodes", help="Sets number of episodes to play or evaluate", default=100, type=int)
69+
parser.add_argument("--test_each", help="Evaluate performance during testing every this number of episodes.", default=10, type=int)
70+
parser.add_argument("--logdir", help="Folder to store logs", default=path.join(path.dirname(path.abspath(__file__)), "logs"))
71+
parser.add_argument("--evaluate", help="Evaluate the agent and report, instead of playing the game only once.", default=True)
72+
parser.add_argument("--mlflow_url", help="URL for mlflow tracking server. If not provided, mlflow will store locally.", default=None)
73+
args = parser.parse_args()
74+
75+
if not path.exists(args.logdir):
76+
makedirs(args.logdir)
77+
logging.basicConfig(filename=path.join(args.logdir, "random_agent.log"), filemode='w', format='%(asctime)s %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S',level=logging.INFO)
78+
79+
# Create agent
80+
agent = RandomWhiteboxAttackerAgent(args.host, args.port,"Attacker", seed=42)
81+
82+
if not args.evaluate:
83+
# Play the normal game
84+
observation = agent.register()
85+
agent.play_game(observation, args.episodes)
86+
agent._logger.info("Terminating interaction")
87+
agent.terminate_connection()
88+
else:
89+
# Evaluate the agent performance
90+
91+
# How it works:
92+
# - Evaluate for several 'episodes' (parameter)
93+
# - Each episode finishes with: steps played, return, win/lose. Store all
94+
# - Each episode compute the avg and std of all.
95+
# - Every X episodes (parameter), report in log and mlflow
96+
# - At the end, report in log and mlflow and console
97+
98+
# Mlflow experiment name
99+
experiment_name = "Evaluation of Random Attacker Agent"
100+
if args.mlflow_url:
101+
mlflow.set_tracking_uri(args.mlflow_url)
102+
mlflow.set_experiment(experiment_name)
103+
# Register in the game
104+
observation = agent.register()
105+
with mlflow.start_run(run_name=experiment_name) as run:
106+
# To keep statistics of each episode
107+
wins = 0
108+
detected = 0
109+
max_steps = 0
110+
num_win_steps = []
111+
num_detected_steps = []
112+
num_max_steps_steps = []
113+
num_detected_returns = []
114+
num_win_returns = []
115+
num_max_steps_returns = []
116+
117+
# Log more things in Mlflow
118+
mlflow.set_tag("experiment_name", experiment_name)
119+
# Log notes or additional information
120+
mlflow.set_tag("notes", "This is an evaluation")
121+
mlflow.set_tag("episode_number", args.episodes)
122+
#mlflow.log_param("learning_rate", learning_rate)
123+
124+
for episode in range(1, args.episodes + 1):
125+
agent.logger.info(f'Starting the testing for episode {episode}')
126+
print(f'Starting the testing for episode {episode}')
127+
128+
# Play the game for one episode
129+
observation, num_steps = agent.play_game(observation, 1)
130+
131+
state = observation.state
132+
reward = observation.reward
133+
end = observation.end
134+
info = observation.info
135+
136+
if observation.info and observation.info['end_reason'] == AgentStatus.Fail:
137+
detected +=1
138+
num_detected_steps += [num_steps]
139+
num_detected_returns += [reward]
140+
elif observation.info and observation.info['end_reason'] == AgentStatus.Success:
141+
wins += 1
142+
num_win_steps += [num_steps]
143+
num_win_returns += [reward]
144+
elif observation.info and observation.info['end_reason'] == AgentStatus.TimeoutReached:
145+
max_steps += 1
146+
num_max_steps_steps += [num_steps]
147+
num_max_steps_returns += [reward]
148+
149+
# Reset the game
150+
observation = agent.request_game_reset()
151+
152+
eval_win_rate = (wins/episode) * 100
153+
eval_detection_rate = (detected/episode) * 100
154+
eval_average_returns = np.mean(num_detected_returns+num_win_returns+num_max_steps_returns)
155+
eval_std_returns = np.std(num_detected_returns+num_win_returns+num_max_steps_returns)
156+
eval_average_episode_steps = np.mean(num_win_steps+num_detected_steps+num_max_steps_steps)
157+
eval_std_episode_steps = np.std(num_win_steps+num_detected_steps+num_max_steps_steps)
158+
eval_average_win_steps = np.mean(num_win_steps)
159+
eval_std_win_steps = np.std(num_win_steps)
160+
eval_average_detected_steps = np.mean(num_detected_steps)
161+
eval_std_detected_steps = np.std(num_detected_steps)
162+
eval_average_max_steps_steps = np.mean(num_max_steps_steps)
163+
eval_std_max_steps_steps = np.std(num_max_steps_steps)
164+
165+
# Log and report every X episodes
166+
if episode % args.test_each == 0 and episode != 0:
167+
text = f'''Tested after {episode} episodes.
168+
Wins={wins},
169+
Detections={detected},
170+
winrate={eval_win_rate:.3f}%,
171+
detection_rate={eval_detection_rate:.3f}%,
172+
average_returns={eval_average_returns:.3f} +- {eval_std_returns:.3f},
173+
average_episode_steps={eval_average_episode_steps:.3f} +- {eval_std_episode_steps:.3f},
174+
average_win_steps={eval_average_win_steps:.3f} +- {eval_std_win_steps:.3f},
175+
average_detected_steps={eval_average_detected_steps:.3f} +- {eval_std_detected_steps:.3f}
176+
average_max_steps_steps={eval_std_max_steps_steps:.3f} +- {eval_std_max_steps_steps:.3f},
177+
'''
178+
agent.logger.info(text)
179+
# Store in mlflow
180+
mlflow.log_metric("eval_avg_win_rate", eval_win_rate, step=episode)
181+
mlflow.log_metric("eval_avg_detection_rate", eval_detection_rate, step=episode)
182+
mlflow.log_metric("eval_avg_returns", eval_average_returns, step=episode)
183+
mlflow.log_metric("eval_std_returns", eval_std_returns, step=episode)
184+
mlflow.log_metric("eval_avg_episode_steps", eval_average_episode_steps, step=episode)
185+
mlflow.log_metric("eval_std_episode_steps", eval_std_episode_steps, step=episode)
186+
mlflow.log_metric("eval_avg_win_steps", eval_average_win_steps, step=episode)
187+
mlflow.log_metric("eval_std_win_steps", eval_std_win_steps, step=episode)
188+
mlflow.log_metric("eval_avg_detected_steps", eval_average_detected_steps, step=episode)
189+
mlflow.log_metric("eval_std_detected_steps", eval_std_detected_steps, step=episode)
190+
mlflow.log_metric("eval_avg_max_steps_steps", eval_average_max_steps_steps, step=episode)
191+
mlflow.log_metric("eval_std_max_steps_steps", eval_std_max_steps_steps, step=episode)
192+
193+
194+
# Log the last final episode when it ends
195+
text = f'''Episode {episode}. Final eval after {episode} episodes, for {args.episodes} steps.
196+
Wins={wins},
197+
Detections={detected},
198+
winrate={eval_win_rate:.3f}%,
199+
detection_rate={eval_detection_rate:.3f}%,
200+
average_returns={eval_average_returns:.3f} +- {eval_std_returns:.3f},
201+
average_episode_steps={eval_average_episode_steps:.3f} +- {eval_std_episode_steps:.3f},
202+
average_win_steps={eval_average_win_steps:.3f} +- {eval_std_win_steps:.3f},
203+
average_detected_steps={eval_average_detected_steps:.3f} +- {eval_std_detected_steps:.3f}
204+
average_max_steps_steps={eval_std_max_steps_steps:.3f} +- {eval_std_max_steps_steps:.3f},
205+
'''
206+
207+
agent.logger.info(text)
208+
print(text)
209+
agent._logger.info("Terminating interaction")
210+
agent.terminate_connection()
211+
212+
# Print and log the mlflow experiment ID, run ID, and storage location
213+
experiment_id = run.info.experiment_id
214+
run_id = run.info.run_id
215+
storage_location = "locally" if not args.mlflow_url else f"at {args.mlflow_url}"
216+
print(f"MLflow Experiment ID: {experiment_id}")
217+
print(f"MLflow Run ID: {run_id}")
218+
print(f"Experiment saved {storage_location}")
219+
agent._logger.info(f"MLflow Experiment ID: {experiment_id}")
220+
agent._logger.info(f"MLflow Run ID: {run_id}")
221+
agent._logger.info(f"Experiment saved {storage_location}")

0 commit comments

Comments
 (0)