Skip to content

Commit 39353ee

Browse files
committed
Change how the library of conceptual maping is called
1 parent 27f7f21 commit 39353ee

File tree

1 file changed

+5
-81
lines changed

1 file changed

+5
-81
lines changed

agents/attackers/conceptual_q_learning/conceptual_q_agent.py

Lines changed: 5 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
import logging
1212
import subprocess
1313
import time
14-
import mlflow
1514
import wandb
1615

1716
from os import path, makedirs
1817
# with the path fixed, we can import now
1918
from AIDojoCoordinator.game_components import Action, Observation, GameState, AgentStatus, ActionType
2019
from NetSecGameAgents.agents.base_agent import BaseAgent
2120
from NetSecGameAgents.agents.agent_utils import state_as_ordered_string, convert_ips_to_concepts, convert_concepts_to_actions, generate_valid_actions_concepts
22-
from concept_mapping_logger import ConceptMappingLogger
21+
from NetSecGameAgents.utils.concept_mapping_logger import ConceptMappingLogger
2322

2423
class QAgent(BaseAgent):
2524

@@ -268,9 +267,8 @@ def play_game(self, concept_observation, episode_num, testing=False):
268267
parser.add_argument("--models_dir", help="Folder to store models", default=path.join(path.dirname(path.abspath(__file__)), "models"))
269268
parser.add_argument("--previous_model", help="Load the previous model. If training, it will start from here. If testing, will use to test.", type=str)
270269
parser.add_argument("--testing", help="Test the agent. No train.", default=False, type=bool)
271-
parser.add_argument("--experiment_id", help="Id of the experiment to record into Mlflow and/or Wandb.", default='', type=str)
270+
parser.add_argument("--experiment_id", help="Id of the experiment to record into Wandb.", default='', type=str)
272271
# Logging platform selection arguments
273-
parser.add_argument("--use_mlflow", help="Enable MLflow logging.", action='store_true')
274272
parser.add_argument("--disable_wandb", help="Disable Wandb logging (enabled by default).", action='store_true')
275273
# Wandb-specific arguments
276274
parser.add_argument("--wandb_project", help="Wandb project name.", default="netsec-conceptual-qlearning", type=str)
@@ -298,14 +296,9 @@ def play_game(self, concept_observation, episode_num, testing=False):
298296
agent.enable_enhanced_logging(verbose=True)
299297

300298
# Set logging platform usage based on flags
301-
# MLflow is disabled by default, wandb is enabled by default
299+
# wandb is enabled by default
302300
args.use_wandb = not args.disable_wandb
303301

304-
# Validate logging platform selection
305-
if not args.use_mlflow and not args.use_wandb:
306-
agent._logger.warning("No logging platform selected. Enabling Wandb by default.")
307-
args.use_wandb = True
308-
309302
# Early stop flag. Used to stop the training if the win rate goes over a threshold.
310303
early_stop = False
311304

@@ -321,17 +314,13 @@ def play_game(self, concept_observation, episode_num, testing=False):
321314
print(message)
322315

323316

324-
# Set mlflow for local tracking
317+
# Set W&B for tracking
325318
if not args.testing:
326319
# Experiment name
327320
experiment_name = "Training and Eval of Conceptual Q-learning Agent"
328-
if args.use_mlflow:
329-
mlflow.set_experiment(experiment_name)
330321
elif args.testing:
331322
# Experiment name
332323
experiment_name = "Testing of Conceptual Q-learning Agent against defender agent"
333-
if args.use_mlflow:
334-
mlflow.set_experiment(experiment_name)
335324

336325
# This code runs for both training and testing.
337326
# How ti works:
@@ -340,7 +329,7 @@ def play_game(self, concept_observation, episode_num, testing=False):
340329
# - Test for --test_for episodes
341330
# - When each episode finishes you have: steps played, return, win/lose.
342331
# - For each episode, store all values and compute the avg and std of each of them
343-
# - Every --test_for episodes and at the end of the testing, report results in log file, mlflow and console.
332+
# - Every --test_for episodes and at the end of the testing, report results in log file, remote log and console.
344333

345334
# Register the agent
346335
# Obsservation is in IPs
@@ -363,12 +352,6 @@ def play_game(self, concept_observation, episode_num, testing=False):
363352
mode=args.wandb_mode
364353
)
365354

366-
# Start MLflow run if enabled
367-
if args.use_mlflow:
368-
mlflow_run = mlflow.start_run(run_name=experiment_name + f'. ID {args.experiment_id}')
369-
else:
370-
mlflow_run = None
371-
372355
try:
373356
# To keep statistics of each episode
374357
wins = 0
@@ -388,29 +371,6 @@ def play_game(self, concept_observation, episode_num, testing=False):
388371
agents_git_result = subprocess.run(agents_command, shell=True, capture_output=True, text=True).stdout
389372
agent._logger.info(f'Using commits. NetSecEnv: {netsecenv_git_result}. Agents: {agents_git_result}')
390373

391-
# Log configuration to MLflow if enabled
392-
if args.use_mlflow:
393-
mlflow.set_tag("experiment_name", experiment_name)
394-
mlflow.set_tag("notes", "This is a training and evaluation of the conceptual Q-learning agent.")
395-
if args.previous_model:
396-
mlflow.set_tag("Previous q-learning model loaded", str(args.previous_model))
397-
mlflow.log_param("alpha", args.alpha)
398-
mlflow.log_param("epsilon_start", args.epsilon_start)
399-
mlflow.log_param("epsilon_end", args.epsilon_end)
400-
mlflow.log_param("epsilon_max_episodes", args.epsilon_max_episodes)
401-
mlflow.log_param("gamma", args.gamma)
402-
mlflow.log_param("Episodes", args.episodes)
403-
mlflow.log_param("Test each", str(args.test_each))
404-
mlflow.log_param("Test for", str(args.test_for))
405-
mlflow.log_param("Testing", str(args.testing))
406-
mlflow.set_tag("NetSecEnv commit", netsecenv_git_result)
407-
mlflow.set_tag("Agents commit", agents_git_result)
408-
# Log the env conf
409-
try:
410-
mlflow.log_artifact(args.env_conf)
411-
except Exception as e:
412-
agent._logger.warning(f"Could not log env config file to MLflow: {e}")
413-
414374
# Log configuration to Wandb if enabled
415375
if args.use_wandb:
416376
wandb.config.update({
@@ -520,23 +480,6 @@ def play_game(self, concept_observation, episode_num, testing=False):
520480
'''
521481
agent._logger.info(text)
522482

523-
# Log evaluation metrics to MLflow if enabled
524-
if args.use_mlflow:
525-
mlflow.log_metric("eval_avg_win_rate", eval_win_rate, step=episode)
526-
mlflow.log_metric("eval_avg_detection_rate", eval_detection_rate, step=episode)
527-
mlflow.log_metric("eval_avg_returns", eval_average_returns, step=episode)
528-
mlflow.log_metric("eval_std_returns", eval_std_returns, step=episode)
529-
mlflow.log_metric("eval_avg_episode_steps", eval_average_episode_steps, step=episode)
530-
mlflow.log_metric("eval_std_episode_steps", eval_std_episode_steps, step=episode)
531-
mlflow.log_metric("eval_avg_win_steps", eval_average_win_steps, step=episode)
532-
mlflow.log_metric("eval_std_win_steps", eval_std_win_steps, step=episode)
533-
mlflow.log_metric("eval_avg_detected_steps", eval_average_detected_steps, step=episode)
534-
mlflow.log_metric("eval_std_detected_steps", eval_std_detected_steps, step=episode)
535-
mlflow.log_metric("eval_avg_max_steps_steps", eval_average_max_steps_steps, step=episode)
536-
mlflow.log_metric("eval_std_max_steps_steps", eval_std_max_steps_steps, step=episode)
537-
mlflow.log_metric("current_epsilon", agent.current_epsilon, step=episode)
538-
mlflow.log_metric("current_episode", episode, step=episode)
539-
540483
# Log evaluation metrics to Wandb if enabled
541484
if args.use_wandb:
542485
wandb.log({
@@ -640,23 +583,6 @@ def play_game(self, concept_observation, episode_num, testing=False):
640583
agent._logger.info(text)
641584
print(text)
642585

643-
# Log test metrics to MLflow if enabled
644-
if args.use_mlflow:
645-
mlflow.log_metric("test_avg_win_rate", test_win_rate, step=episode)
646-
mlflow.log_metric("test_avg_detection_rate", test_detection_rate, step=episode)
647-
mlflow.log_metric("test_avg_returns", test_average_returns, step=episode)
648-
mlflow.log_metric("test_std_returns", test_std_returns, step=episode)
649-
mlflow.log_metric("test_avg_episode_steps", test_average_episode_steps, step=episode)
650-
mlflow.log_metric("test_std_episode_steps", test_std_episode_steps, step=episode)
651-
mlflow.log_metric("test_avg_win_steps", test_average_win_steps, step=episode)
652-
mlflow.log_metric("test_std_win_steps", test_std_win_steps, step=episode)
653-
mlflow.log_metric("test_avg_detected_steps", test_average_detected_steps, step=episode)
654-
mlflow.log_metric("test_std_detected_steps", test_std_detected_steps, step=episode)
655-
mlflow.log_metric("test_avg_max_steps_steps", test_average_max_steps_steps, step=episode)
656-
mlflow.log_metric("test_std_max_steps_steps", test_std_max_steps_steps, step=episode)
657-
mlflow.log_metric("current_epsilon", agent.current_epsilon, step=episode)
658-
mlflow.log_metric("current_episode", episode, step=episode)
659-
660586
# Log test metrics to Wandb if enabled
661587
if args.use_wandb:
662588
wandb.log({
@@ -704,8 +630,6 @@ def play_game(self, concept_observation, episode_num, testing=False):
704630

705631
finally:
706632
# Clean up logging sessions
707-
if args.use_mlflow and mlflow_run:
708-
mlflow.end_run()
709633
if args.use_wandb:
710634
wandb.finish()
711635

0 commit comments

Comments
 (0)