Skip to content

Commit 5305fd4

Browse files
authored
Merge pull request #97 from stratosphereips/sebas-add-w&b-to-qlearning-attacker
Sebas add w&b to qlearning attacker
2 parents 5e855a1 + 28f3b80 commit 5305fd4

File tree

2 files changed

+143
-99
lines changed

2 files changed

+143
-99
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,4 @@ agents/mlruns*
153153
agents/*/*/mlruns/
154154
agents/*/*/logs
155155
aim/*
156+
wandb/

agents/attackers/q_learning/q_agent.py

Lines changed: 142 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pickle
88
import argparse
99
import logging
10-
import mlflow
10+
import wandb
1111
import subprocess
1212
import time
1313

@@ -185,7 +185,7 @@ def play_game(self, observation, episode_num, testing=False):
185185
parser.add_argument("--logdir", help="Folder to store logs", default=path.join(path.dirname(path.abspath(__file__)), "logs"))
186186
parser.add_argument("--previous_model", help="Load the previous model. If training, it will start from here. If testing, will use to test.", type=str)
187187
parser.add_argument("--testing", help="Test the agent. No train.", default=False, type=bool)
188-
parser.add_argument("--experiment_id", help="Id of the experiment to record into Mlflow.", default='', type=str)
188+
parser.add_argument("--experiment_id", help="Id of the experiment to record into Weights & Biases.", default='', type=str)
189189
parser.add_argument("--store_actions", help="Store actions in the log file q_agents_actions.log.", default=False, type=bool)
190190
parser.add_argument("--store_models_every", help="Store a model to disk every these number of episodes.", default=2000, type=int)
191191
parser.add_argument("--env_conf", help="Configuration file of the env. Only for logging purposes.", required=False, default='./env/netsecenv_conf.yaml', type=str)
@@ -225,71 +225,89 @@ def play_game(self, observation, episode_num, testing=False):
225225

226226

227227
if not args.testing:
228-
# Mlflow experiment name
228+
# Wandb experiment name
229229
experiment_name = "Training and Eval of Q-learning Agent"
230-
mlflow.set_experiment(experiment_name)
231230
elif args.testing:
232231
# Evaluate the agent performance
233232

234-
# Mlflow experiment name
235-
experiment_name = "Testing of Q-learning Agent against defender agent"
236-
mlflow.set_experiment(experiment_name)
233+
# Wandb experiment name
234+
experiment_name = "Testing of Q-learning Agent"
237235

238236

239237
# This code runs for both training and testing. The difference is in the args.testing variable that is passed along
240238
# How it works:
241239
# - Evaluate for several 'episodes' (parameter)
242240
# - Each episode finishes with: steps played, return, win/lose. Store all
243241
# - Each episode compute the avg and std of all.
244-
# - Every X episodes (parameter), report in log and mlflow
245-
# - At the end, report in log and mlflow and console
242+
# - Every X episodes (parameter), report in log and wandb
243+
# - At the end, report in log and wandb and console
246244

247245
# Register the agent
248246
observation = agent.register()
249247

250248
try:
251-
with mlflow.start_run(run_name=experiment_name + f'. ID {args.experiment_id}') as run:
252-
# To keep statistics of each episode
253-
wins = 0
254-
detected = 0
255-
max_steps = 0
256-
num_win_steps = []
257-
num_detected_steps = []
258-
num_max_steps_steps = []
259-
num_detected_returns = []
260-
num_win_returns = []
261-
num_max_steps_returns = []
262-
263-
# Log more things in Mlflow
264-
mlflow.set_tag("experiment_name", experiment_name)
265-
# Log notes or additional information
266-
mlflow.set_tag("notes", "This is an evaluation")
267-
if args.previous_model:
268-
mlflow.set_tag("Previous q-learning model loaded", str(args.previous_model))
269-
mlflow.log_param("alpha", args.alpha)
270-
mlflow.log_param("epsilon_start", args.epsilon_start)
271-
mlflow.log_param("epsilon_end", args.epsilon_end)
272-
mlflow.log_param("epsilon_max_episodes", args.epsilon_max_episodes)
273-
mlflow.log_param("gamma", args.gamma)
274-
mlflow.log_param("Episodes", args.episodes)
275-
mlflow.log_param("Test each", str(args.test_each))
276-
mlflow.log_param("Test for", str(args.test_for))
277-
mlflow.log_param("Testing", str(args.testing))
278-
# Use subprocess.run to get the commit hash
279-
netsecenv_command = "git rev-parse HEAD"
280-
netsecenv_git_result = subprocess.run(netsecenv_command, shell=True, capture_output=True, text=True).stdout
281-
agents_command = "cd NetSecGameAgents; git rev-parse HEAD"
282-
agents_git_result = subprocess.run(agents_command, shell=True, capture_output=True, text=True).stdout
283-
agent._logger.info(f'Using commits. NetSecEnv: {netsecenv_git_result}. Agents: {agents_git_result}')
284-
mlflow.set_tag("NetSecEnv commit", netsecenv_git_result)
285-
mlflow.set_tag("Agents commit", agents_git_result)
286-
# Log the env conf
287-
mlflow.log_artifact(args.env_conf)
288-
agent._logger.info(f'Epsilon Start: {agent.epsilon_start}')
289-
agent._logger.info(f'Epsilon End: {agent.epsilon_end}')
290-
agent._logger.info(f'Epsilon Max Episodes: {agent.epsilon_max_episodes}')
291-
292-
for episode in range(1, args.episodes + 1):
249+
# Initialize wandb
250+
wandb.init(
251+
entity='Stratosphere',
252+
project='UTEP-Collaboration',
253+
group='sebas-qlearning',
254+
name=experiment_name + f'. ID {args.experiment_id}'
255+
)
256+
257+
# To keep statistics of each episode
258+
wins = 0
259+
detected = 0
260+
max_steps = 0
261+
num_win_steps = []
262+
num_detected_steps = []
263+
num_max_steps_steps = []
264+
num_detected_returns = []
265+
num_win_returns = []
266+
num_max_steps_returns = []
267+
268+
# Configure wandb with parameters and tags
269+
wandb.config.update({
270+
"alpha": args.alpha,
271+
"epsilon_start": args.epsilon_start,
272+
"epsilon_end": args.epsilon_end,
273+
"epsilon_max_episodes": args.epsilon_max_episodes,
274+
"gamma": args.gamma,
275+
"episodes": args.episodes,
276+
"test_each": args.test_each,
277+
"test_for": args.test_for,
278+
"testing": args.testing,
279+
"experiment_name": experiment_name,
280+
"notes": "This is an evaluation"
281+
})
282+
283+
if args.previous_model:
284+
wandb.config.update({"previous_model_loaded": str(args.previous_model)})
285+
286+
# Use subprocess.run to get the commit hash
287+
netsecenv_command = "git rev-parse HEAD"
288+
netsecenv_git_result = subprocess.run(netsecenv_command, shell=True, capture_output=True, text=True).stdout
289+
agents_command = "cd NetSecGameAgents; git rev-parse HEAD"
290+
agents_git_result = subprocess.run(agents_command, shell=True, capture_output=True, text=True).stdout
291+
agent._logger.info(f'Using commits. NetSecEnv: {netsecenv_git_result}. Agents: {agents_git_result}')
292+
wandb.config.update({
293+
"netsecenv_commit": netsecenv_git_result.strip(),
294+
"agents_commit": agents_git_result.strip()
295+
})
296+
# Log the env conf
297+
try:
298+
if path.exists(args.env_conf):
299+
wandb.save(args.env_conf, base_path=path.dirname(path.abspath(args.env_conf)))
300+
else:
301+
agent._logger.warning(f"Environment config file not found: {args.env_conf}")
302+
wandb.config.update({"env_conf_path": args.env_conf})
303+
except Exception as e:
304+
agent._logger.warning(f"Could not save env config file: {e}")
305+
wandb.config.update({"env_conf_path": args.env_conf})
306+
agent._logger.info(f'Epsilon Start: {agent.epsilon_start}')
307+
agent._logger.info(f'Epsilon End: {agent.epsilon_end}')
308+
agent._logger.info(f'Epsilon Max Episodes: {agent.epsilon_max_episodes}')
309+
310+
for episode in range(1, args.episodes + 1):
293311
if not early_stop:
294312
# Play 1 episode
295313
observation, num_steps = agent.play_game(observation, testing=args.testing, episode_num=episode)
@@ -333,6 +351,24 @@ def play_game(self, observation, episode_num, testing=False):
333351
eval_average_max_steps_steps = np.mean(num_max_steps_steps)
334352
eval_std_max_steps_steps = np.std(num_max_steps_steps)
335353

354+
# Log results for testing mode every episode
355+
if args.testing:
356+
wandb.log({
357+
"test_avg_win_rate": eval_win_rate,
358+
"test_avg_detection_rate": eval_detection_rate,
359+
"test_avg_returns": eval_average_returns,
360+
"test_std_returns": eval_std_returns,
361+
"test_avg_episode_steps": eval_average_episode_steps,
362+
"test_std_episode_steps": eval_std_episode_steps,
363+
"test_avg_win_steps": eval_average_win_steps,
364+
"test_std_win_steps": eval_std_win_steps,
365+
"test_avg_detected_steps": eval_average_detected_steps,
366+
"test_std_detected_steps": eval_std_detected_steps,
367+
"test_avg_max_steps_steps": eval_average_max_steps_steps,
368+
"test_std_max_steps_steps": eval_std_max_steps_steps,
369+
"current_episode": episode
370+
}, step=episode)
371+
336372
# Now Test, log and report. This happens every X training episodes
337373
if episode % args.test_each == 0 and episode != 0:
338374
# If we are training, every these number of episodes, we need to test for some episodes.
@@ -354,20 +390,22 @@ def play_game(self, observation, episode_num, testing=False):
354390
epsilon={agent.current_epsilon}
355391
'''
356392
agent._logger.info(text)
357-
mlflow.log_metric("eval_avg_win_rate", eval_win_rate, step=episode)
358-
mlflow.log_metric("eval_avg_detection_rate", eval_detection_rate, step=episode)
359-
mlflow.log_metric("eval_avg_returns", eval_average_returns, step=episode)
360-
mlflow.log_metric("eval_std_returns", eval_std_returns, step=episode)
361-
mlflow.log_metric("eval_avg_episode_steps", eval_average_episode_steps, step=episode)
362-
mlflow.log_metric("eval_std_episode_steps", eval_std_episode_steps, step=episode)
363-
mlflow.log_metric("eval_avg_win_steps", eval_average_win_steps, step=episode)
364-
mlflow.log_metric("eval_std_win_steps", eval_std_win_steps, step=episode)
365-
mlflow.log_metric("eval_avg_detected_steps", eval_average_detected_steps, step=episode)
366-
mlflow.log_metric("eval_std_detected_steps", eval_std_detected_steps, step=episode)
367-
mlflow.log_metric("eval_avg_max_steps_steps", eval_average_max_steps_steps, step=episode)
368-
mlflow.log_metric("eval_std_max_steps_steps", eval_std_max_steps_steps, step=episode)
369-
mlflow.log_metric("current_epsilon", agent.current_epsilon, step=episode)
370-
mlflow.log_metric("current_episode", episode, step=episode)
393+
wandb.log({
394+
"eval_avg_win_rate": eval_win_rate,
395+
"eval_avg_detection_rate": eval_detection_rate,
396+
"eval_avg_returns": eval_average_returns,
397+
"eval_std_returns": eval_std_returns,
398+
"eval_avg_episode_steps": eval_average_episode_steps,
399+
"eval_std_episode_steps": eval_std_episode_steps,
400+
"eval_avg_win_steps": eval_average_win_steps,
401+
"eval_std_win_steps": eval_std_win_steps,
402+
"eval_avg_detected_steps": eval_average_detected_steps,
403+
"eval_std_detected_steps": eval_std_detected_steps,
404+
"eval_avg_max_steps_steps": eval_average_max_steps_steps,
405+
"eval_std_max_steps_steps": eval_std_max_steps_steps,
406+
"current_epsilon": agent.current_epsilon,
407+
"current_episode": episode
408+
}, step=episode)
371409

372410
# To keep statistics of testing each episode
373411
test_wins = 0
@@ -441,45 +479,50 @@ def play_game(self, observation, episode_num, testing=False):
441479
'''
442480
agent._logger.info(text)
443481
print(text)
444-
# Store in mlflow
445-
mlflow.log_metric("test_avg_win_rate", test_win_rate, step=episode)
446-
mlflow.log_metric("test_avg_detection_rate", test_detection_rate, step=episode)
447-
mlflow.log_metric("test_avg_returns", test_average_returns, step=episode)
448-
mlflow.log_metric("test_std_returns", test_std_returns, step=episode)
449-
mlflow.log_metric("test_avg_episode_steps", test_average_episode_steps, step=episode)
450-
mlflow.log_metric("test_std_episode_steps", test_std_episode_steps, step=episode)
451-
mlflow.log_metric("test_avg_win_steps", test_average_win_steps, step=episode)
452-
mlflow.log_metric("test_std_win_steps", test_std_win_steps, step=episode)
453-
mlflow.log_metric("test_avg_detected_steps", test_average_detected_steps, step=episode)
454-
mlflow.log_metric("test_std_detected_steps", test_std_detected_steps, step=episode)
455-
mlflow.log_metric("test_avg_max_steps_steps", test_average_max_steps_steps, step=episode)
456-
mlflow.log_metric("test_std_max_steps_steps", test_std_max_steps_steps, step=episode)
457-
mlflow.log_metric("current_epsilon", agent.current_epsilon, step=episode)
458-
mlflow.log_metric("current_episode", episode, step=episode)
482+
# Store in wandb
483+
wandb.log({
484+
"test_avg_win_rate": test_win_rate,
485+
"test_avg_detection_rate": test_detection_rate,
486+
"test_avg_returns": test_average_returns,
487+
"test_std_returns": test_std_returns,
488+
"test_avg_episode_steps": test_average_episode_steps,
489+
"test_std_episode_steps": test_std_episode_steps,
490+
"test_avg_win_steps": test_average_win_steps,
491+
"test_std_win_steps": test_std_win_steps,
492+
"test_avg_detected_steps": test_average_detected_steps,
493+
"test_std_detected_steps": test_std_detected_steps,
494+
"test_avg_max_steps_steps": test_average_max_steps_steps,
495+
"test_std_max_steps_steps": test_std_max_steps_steps,
496+
"test_current_epsilon": agent.current_epsilon,
497+
"test_current_episode": episode
498+
}, step=episode)
459499

460500
if test_win_rate >= args.early_stop_threshold:
461501
agent.logger.info(f'Early stopping. Test win rate: {test_win_rate}. Threshold: {args.early_stop_threshold}')
462502
early_stop = True
463503

464-
465-
# Log the last final episode when it ends
466-
text = f'''Final model performance after {episode} episodes.
467-
Wins={wins},
468-
Detections={detected},
469-
winrate={eval_win_rate:.3f}%,
470-
detection_rate={eval_detection_rate:.3f}%,
471-
average_returns={eval_average_returns:.3f} +- {eval_std_returns:.3f},
472-
average_episode_steps={eval_average_episode_steps:.3f} +- {eval_std_episode_steps:.3f},
473-
average_win_steps={eval_average_win_steps:.3f} +- {eval_std_win_steps:.3f},
474-
average_detected_steps={eval_average_detected_steps:.3f} +- {eval_std_detected_steps:.3f}
475-
average_max_steps_steps={eval_std_max_steps_steps:.3f} +- {eval_std_max_steps_steps:.3f},
476-
epsilon={agent.current_epsilon}
477-
'''
478-
479-
agent._logger.info(text)
480-
print(text)
481-
agent._logger.error("Terminating interaction")
482-
agent.terminate_connection()
504+
505+
# Log the last final episode when it ends
506+
text = f'''Final model performance after {episode} episodes.
507+
Wins={wins},
508+
Detections={detected},
509+
winrate={eval_win_rate:.3f}%,
510+
detection_rate={eval_detection_rate:.3f}%,
511+
average_returns={eval_average_returns:.3f} +- {eval_std_returns:.3f},
512+
average_episode_steps={eval_average_episode_steps:.3f} +- {eval_std_episode_steps:.3f},
513+
average_win_steps={eval_average_win_steps:.3f} +- {eval_std_win_steps:.3f},
514+
average_detected_steps={eval_average_detected_steps:.3f} +- {eval_std_detected_steps:.3f}
515+
average_max_steps_steps={eval_std_max_steps_steps:.3f} +- {eval_std_max_steps_steps:.3f},
516+
epsilon={agent.current_epsilon}
517+
'''
518+
519+
agent._logger.info(text)
520+
print(text)
521+
agent._logger.error("Terminating interaction")
522+
agent.terminate_connection()
523+
524+
# Finish wandb run
525+
wandb.finish()
483526

484527
except KeyboardInterrupt:
485528
# Store the q-table

0 commit comments

Comments
 (0)