From e2f16c254aee66f6f8dfab8c937fe4fb7ed5d17a Mon Sep 17 00:00:00 2001 From: diegoforni <71357860+diegoforni@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:10:36 -0300 Subject: [PATCH 1/2] Created llm_bert --- .../attackers/llm_bert/llm_action_planner.py | 603 ++++++++++++++++++ agents/attackers/llm_bert/llm_agent_qa.py | 450 +++++++++++++ agents/attackers/llm_bert/prompts.yaml | 70 ++ .../attackers/llm_bert/validate_responses.py | 142 +++++ 4 files changed, 1265 insertions(+) create mode 100644 agents/attackers/llm_bert/llm_action_planner.py create mode 100644 agents/attackers/llm_bert/llm_agent_qa.py create mode 100644 agents/attackers/llm_bert/prompts.yaml create mode 100644 agents/attackers/llm_bert/validate_responses.py diff --git a/agents/attackers/llm_bert/llm_action_planner.py b/agents/attackers/llm_bert/llm_action_planner.py new file mode 100644 index 00000000..f9da8f37 --- /dev/null +++ b/agents/attackers/llm_bert/llm_action_planner.py @@ -0,0 +1,603 @@ +""" +@file llm_action_planner.py + +@brief Implementation of an LLM-based action planner with ModernBERT integration for reactive agent systems. + +This script defines classes and methods to facilitate interaction with language models (LLMs), +manage configuration files, and parse responses from LLM queries. The core functionality includes +planning actions based on observations and memory, parsing LLM responses, and dynamically loading +configuration files using YAML. Stage 2 has been replaced with two ModernBERT calls for action +type classification and masked language modeling. + +""" + +import sys +import os +from os import path +import yaml +import logging +import json +from dotenv import dotenv_values +from openai import OpenAI +from tenacity import retry, stop_after_attempt +import jinja2 +import torch +# Compatibility shim: some Torch builds on Python 3.12 don't support torch.compile (Dynamo). +# Transformers' ModernBERT uses @torch.compile; make it a no-op if unsupported to prevent import-time failure. +try: + _orig_torch_compile = getattr(torch, "compile", None) + if _orig_torch_compile is None: + def _noop_compile(fn=None, **kwargs): + return fn + torch.compile = _noop_compile # type: ignore[attr-defined] + else: + def _safe_compile(fn=None, **kwargs): + try: + return _orig_torch_compile(fn, **kwargs) + except Exception: + return fn + torch.compile = _safe_compile # type: ignore[attr-defined] +except Exception: + pass +import numpy as np +from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM +import ast +import gc + +import re +from collections import Counter + +# Import validate_responses - works when run as module or script +try: + from . import validate_responses +except ImportError: + import validate_responses + +# Add parent directories dynamically +sys.path.append( + path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))) +) +sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) + +from AIDojoCoordinator.game_components import ActionType, Observation +from NetSecGameAgents.agents.llm_utils import create_action_from_response, create_status_from_state + + +class ConfigLoader: + """Class to handle loading YAML configurations.""" + + @staticmethod + def load_config(file_name: str = 'prompts.yaml') -> dict: + possible_paths = [ + path.join(path.dirname(__file__), file_name), + path.join(path.dirname(path.dirname(__file__)), file_name), + path.join(path.dirname(path.dirname(path.dirname(__file__))), file_name), + ] + for yaml_file in possible_paths: + if path.exists(yaml_file): + with open(yaml_file, 'r') as file: + return yaml.safe_load(file) + raise FileNotFoundError(f"{file_name} not found in expected directories.") + + +ACTION_MAPPER = { + "ScanNetwork": ActionType.ScanNetwork, + "ScanServices": ActionType.FindServices, + "FindData": ActionType.FindData, + "ExfiltrateData": ActionType.ExfiltrateData, + "ExploitService": ActionType.ExploitService, +} + + +class ModernBERTActionClassifier: + """Handler for the first ModernBERT model that classifies action types.""" + + def __init__(self, model_path: str): + # Force CPU usage to avoid CUDA compatibility issues + self.device = torch.device("cpu") + # Ensure we can load local directories and not treat them as hub repo ids + self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) + self.model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True) + self.model.to(self.device) + self.model.eval() + + self.label2id = { + "ScanNetwork": 0, + "ScanServices": 1, + "ExploitService": 2, + "FindData": 3, + "ExfiltrateData": 4 + } + + self.id2label = { + 0: "ScanNetwork", + 1: "ScanServices", + 2: "ExploitService", + 3: "FindData", + 4: "ExfiltrateData" + } + + def predict_action_type(self, prompt: str) -> tuple: + """Predict action type from the complete prompt.""" + inputs = self.tokenizer( + prompt, + return_tensors="pt", + max_length=8048, + truncation=True, + padding=True + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) + predicted_class_id = predictions.argmax().item() + confidence = predictions.max().item() + + predicted_action = self.id2label[predicted_class_id] + return predicted_action, confidence + + +class ModernBERTMaskedLM: + """Handler for the second ModernBERT model for masked language modeling - EXACT training implementation.""" + + def __init__(self, model_path: str): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Ensure we can load local directories and not treat them as hub repo ids + self.tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) + self.model = AutoModelForMaskedLM.from_pretrained(model_path, local_files_only=True) + self.model.to(self.device) + self.model.eval() + + # Initialize logger + self.logger = logging.getLogger("ModernBERT-MLM") + + # Ensure special tokens are available - EXACT training setup + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if self.tokenizer.mask_token is None: + self.tokenizer.add_special_tokens({"mask_token": "[MASK]"}) + + # Resize embeddings if tokenizer changed - EXACT training setup + if len(self.tokenizer) != self.model.config.vocab_size: + self.model.resize_token_embeddings(len(self.tokenizer)) + + def fix_format_and_remove_spaces(self, text): + """EXACT training preprocessing - Convert dict strings to JSON AND remove ALL spaces between tokens""" + # First convert dict strings to JSON - EXACT training logic + if isinstance(text, str) and text.startswith("{'"): + try: + dict_obj = ast.literal_eval(text) + text = json.dumps(dict_obj) + except Exception: + pass + + # Convert to string + text = str(text) + + # Remove spaces before all special tokens - EXACT training preprocessing + text = text.replace(" [MASK]", "[MASK]") + text = text.replace(" [PAD]", "[PAD]") + + # Remove spaces between special tokens - EXACT training preprocessing + text = text.replace("[MASK] [PAD]", "[MASK][PAD]") + text = text.replace("[PAD] [MASK]", "[PAD][MASK]") + text = text.replace("[PAD] [PAD]", "[PAD][PAD]") + text = text.replace("[MASK] [MASK]", "[MASK][MASK]") + + return text + + def predict_action_full(self, prompt: str, masked_action: str, max_length: int = 8192, iterations: int = 50) -> tuple: + """ + Single-pass inference: predict all [MASK] tokens in one forward pass. + """ + self.model.eval() + + # Convert dict string to JSON if needed - EXACT training logic + if isinstance(masked_action, str) and masked_action.startswith("{'"): + try: + masked_action = json.dumps(ast.literal_eval(masked_action)) + except Exception: + pass + + masked_action = self.fix_format_and_remove_spaces(masked_action) + + original_text = f"Action: {masked_action}\n\nPrompt: {prompt}" + + # Log the exact masked prompt fed into the MLM + try: + self.logger.info(f"MLM masked prompt (exact input):\n{original_text}") + except Exception: + pass + predictions = [] + raw_predictions = [] + + # Tokenize once and get predictions for all mask positions + inputs = self.tokenizer( + original_text, + return_tensors="pt", + max_length=max_length, + truncation=True, + padding=False # Training used padding=False for inference + ).to(self.device) + + input_ids = inputs["input_ids"][0] + mask_positions = (input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[0] + + if len(mask_positions) > 0: + with torch.no_grad(): + outputs = self.model(**inputs) + logits = outputs.logits[0] # [seq_len, vocab] + + # Collect predictions for each mask position (left-to-right order) + for pos in mask_positions.tolist(): + token_id = int(torch.argmax(logits[pos]).item()) + token_str_raw = self.tokenizer.decode([token_id]) + # Keep raw token for logging, but drop newlines before reinsertion + token_str_clean = token_str_raw.replace("\n", "") + predictions.append(token_str_clean) + raw_predictions.append(token_str_raw) + + # Cleanup tensors + del outputs, logits, inputs + torch.cuda.empty_cache() + gc.collect() + + filled_text = original_text + for tok in predictions: + filled_text = filled_text.replace("[MASK]", tok, 1) + + # Log the predicted tokens for traceability + try: + if raw_predictions: + self.logger.info(f"MLM predicted tokens (raw): {raw_predictions}") + self.logger.info(f"MLM predicted tokens (cleaned): {predictions}") + self.logger.info(f"MLM filled text (raw, includes specials):\n{filled_text}") + except Exception: + pass + + # Extract result (strip prompt section) + if "Action: " in filled_text: + result = filled_text.split("Action: ", 1)[1].strip() + if "\n\nPrompt: " in result: + result = result.split("\n\nPrompt: ")[0].strip() + else: + result = filled_text + + # Keep raw for logging, but return a cleaned JSON string for parsing + raw_result = result + try: + self.logger.info(f"MLM result before cleaning: {raw_result}") + except Exception: + pass + + cleaned_result = raw_result.replace("[PAD]", "").replace("[UNK]", "") + return cleaned_result, predictions + + +class LLMActionPlanner: + def __init__(self, model_name: str, goal: str, memory_len: int = 10, api_url=None, + config: dict = None, use_reasoning: bool = False, use_reflection: bool = False, + use_self_consistency: bool = False, classifier_model_path: str = None, + mlm_model_path: str = None): + self.model = model_name + self.config = config or ConfigLoader.load_config() + self.use_reasoning = use_reasoning + self.use_reflection = use_reflection + self.use_self_consistency = use_self_consistency + self.modernbert_prompts = [] + + + if "gpt" in self.model: + env_config = dotenv_values(".env") + api_key = env_config.get("OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("OPENAI_API_KEY not found. Please set it in a .env file or as an environment variable.") + self.client = OpenAI(api_key=api_key) + else: + self.client = OpenAI(base_url=api_url, api_key="ollama") + + self.memory_len = memory_len + self.logger = logging.getLogger("REACT-agent") + self.update_instructions(goal.lower()) + self.prompts = [] + self.states = [] + self.responses = [] + + # Initialize ModernBERT models + self.action_classifier = None + self.masked_lm = None + + if classifier_model_path: + resolved_classifier = classifier_model_path + if not path.isabs(resolved_classifier): + candidate = path.join(path.dirname(__file__), classifier_model_path) + if path.isdir(candidate): + resolved_classifier = candidate + self.logger.info(f"Loading action classifier from: {resolved_classifier}") + self.action_classifier = ModernBERTActionClassifier(resolved_classifier) + + if mlm_model_path: + resolved_mlm = mlm_model_path + if not path.isabs(resolved_mlm): + candidate = path.join(path.dirname(__file__), mlm_model_path) + if path.isdir(candidate): + resolved_mlm = candidate + self.logger.info(f"Loading masked language model from: {resolved_mlm}") + self.masked_lm = ModernBERTMaskedLM(resolved_mlm) + + # EXACT training data format templates + self.action_templates = { + "ScanNetwork": { + "action": "ScanNetwork", + "parameters": { + "target_network": "[MASK].[MASK].[MASK].[MASK]/[MASK]", + "source_host": "[MASK].[MASK].[MASK].[MASK]" + } + }, + "ScanServices": { + "action": "ScanServices", + "parameters": { + "target_host": "[MASK].[MASK].[MASK].[MASK]", + "source_host": "[MASK].[MASK].[MASK].[MASK]" + } + }, + "ExploitService": { + "action": "ExploitService", + "parameters": { + "target_host": "[MASK].[MASK].[MASK].[MASK]", + "target_service": "[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]", + "source_host": "[MASK].[MASK].[MASK].[MASK]" + } + }, + "FindData": { + "action": "FindData", + "parameters": { + "target_host": "[MASK].[MASK].[MASK].[MASK]", + "source_host": "[MASK].[MASK].[MASK].[MASK]" + } + }, + "ExfiltrateData": { + "action": "ExfiltrateData", + "parameters": { + "target_host": "[MASK].[MASK].[MASK].[MASK]", + "data": { + "owner": "[MASK] [MASK] [MASK]", + "id": "[MASK] [MASK] [MASK] [MASK] [MASK] [MASK]" + }, + "source_host": "[MASK].[MASK].[MASK].[MASK]" + } + } + } + + def get_prompts(self) -> list: + return self.prompts + + def get_responses(self) -> list: + return self.responses + + def get_states(self) -> list: + return self.states + + def get_modernbert_prompts(self) -> list: + return self.modernbert_prompts + + + def update_instructions(self, new_goal: str) -> None: + template = jinja2.Environment().from_string(self.config['prompts']['INSTRUCTIONS_TEMPLATE']) + self.instructions = template.render(goal=new_goal) + + def create_mem_prompt(self, memory_list: list) -> str: + prompt = "" + for memory, goodness in memory_list: + prompt += f"You have taken action {memory} in the past. This action was {goodness}.\n" + return prompt + + @retry(stop=stop_after_attempt(3)) + def openai_query(self, msg_list: list, max_tokens: int = 60, model: str = None, fmt=None, temperature: float = 0.0): + llm_response = self.client.chat.completions.create( + model=model or self.model, + messages=msg_list, + max_tokens=max_tokens, + temperature=temperature, + response_format=fmt or {"type": "text"}, + ) + return llm_response.choices[0].message.content, llm_response.usage + + def parse_response(self, llm_response: str, state: Observation.state): + """EXACT training parsing logic - no fallbacks or cleaning""" + response_dict = {"action": None, "parameters": None} + valid = False + action = None + + try: + # Direct JSON parsing + response = json.loads(llm_response) + action_str = response.get("action", None) + action_params = response.get("parameters", None) + + if action_str and action_params: + valid, action = create_action_from_response(response, state) + response_dict["action"] = action_str + response_dict["parameters"] = action_params + # Ensure we only report valid if we produced a concrete Action object + if valid and action is None: + self.logger.warning("Validated response did not produce an actionable object; marking invalid.") + valid = False + else: + self.logger.warning("Missing action or parameters in response.") + + except json.JSONDecodeError: + self.logger.error("Failed to parse response as JSON.") + self.logger.error(f"Raw response: {llm_response}") + # No fallback - just mark as invalid + response_dict["action"] = "InvalidJSON" + response_dict["parameters"] = {"raw_response": llm_response} + except KeyError: + self.logger.error("Missing keys in response.") + + return valid, response_dict, action + + def remove_reasoning(self, text): + match = re.search(r'(.*)', text, re.DOTALL) + if match: + return match.group(1).strip() + return text + + def check_repetition(self, memory_list): + repetitions = 0 + past_memories = [] + for memory, goodness in memory_list: + if memory in past_memories: + repetitions += 1 + past_memories.append(memory) + return repetitions + + def get_self_consistent_response(self, messages, temp=0.4, max_tokens=1024, n=3): + candidates = [] + total_tokens_used = 0 + for _ in range(n): + response, usage = self.openai_query(messages, temperature=temp, max_tokens=max_tokens) + total_tokens_used += usage.total_tokens + candidates.append(response.strip()) + + counts = Counter(candidates) + most_common = counts.most_common(1) + if most_common: + self.logger.info(f"Self-consistency candidates: {counts}") + return most_common[0][0], total_tokens_used + return candidates[0], total_tokens_used + + def build_complete_training_prompt(self, stage1_response: str, status_prompt: str, memory_prompt: str) -> str: + """ + Build the complete prompt that EXACTLY matches the training data format. + """ + # Get the prompt components from config + cot_prompt = self.config['prompts']['COT_PROMPT'] + q4 = self.config['questions'][3]['text'] + + complete_prompt = f"{self.instructions}\n\n" + complete_prompt += f"Current status:\n{status_prompt}\n\n" + complete_prompt += f"{memory_prompt}\n\n" if memory_prompt.strip() else "" + complete_prompt += f"{cot_prompt}\n\n" + complete_prompt += f"{stage1_response}\n\n" + complete_prompt += f"{q4}" + + return complete_prompt + + def get_action_with_modernbert(self, stage1_response: str, observation: Observation, status_prompt: str, memory_prompt: str) -> tuple: + """ + Replace Stage 2 with ModernBERT calls using EXACT training implementation. + """ + if not self.action_classifier or not self.masked_lm: + self.logger.error("ModernBERT models not initialized.") + return False, {"action": "ModelError", "parameters": {"error": "ModernBERT models not loaded"}}, None, 0 + + try: + complete_prompt = self.build_complete_training_prompt(stage1_response, status_prompt, memory_prompt) + + self.modernbert_prompts.append(complete_prompt) + + + self.logger.info(f"Complete prompt for ModernBERT:\n{complete_prompt[:500]}...") + + self.logger.info(f"Classifying action type with complete prompt...") + predicted_action, confidence = self.action_classifier.predict_action_type(complete_prompt) + self.logger.info(f"Predicted action: {predicted_action} (confidence: {confidence:.4f})") + + if predicted_action not in self.action_templates: + self.logger.error(f"Unknown action type: {predicted_action}") + return False, {"action": "UnknownAction", "parameters": {"predicted": predicted_action}}, None, 0 + + masked_template = self.action_templates[predicted_action] + masked_action_str = json.dumps(masked_template) + + self.logger.info(f"Using masked template: {masked_action_str}") + + filled_action, predictions = self.masked_lm.predict_action_full( + prompt=complete_prompt, + masked_action=masked_action_str, + max_length=8192, # Match training max_length + iterations=50 + ) + + self.logger.info(f"MLM predictions: {predictions}") + self.logger.info(f"Filled action: {filled_action}") + + is_valid, response_dict, action = self.parse_response(filled_action, observation.state) + + self.logger.info(f"Final parsed action - Valid: {is_valid}, Response: {response_dict}") + + return is_valid, response_dict, action, 0 # No LLM tokens used in this stage + + except Exception as e: + self.logger.error(f"Error in ModernBERT action generation: {str(e)}") + import traceback + traceback.print_exc() + return False, {"action": "BERTError", "parameters": {"error": str(e)}}, None, 0 + + def get_action_from_obs_react(self, observation: Observation, memory_buf: list) -> tuple: + self.states.append(observation.state.as_json()) + status_prompt = create_status_from_state(observation.state) + q1 = self.config['questions'][0]['text'] + memory_prompt = self.create_mem_prompt(memory_buf) + total_tokens_used = 0 + + repetitions = self.check_repetition(memory_buf) + messages = [ + {"role": "user", "content": self.instructions}, + {"role": "user", "content": status_prompt}, + {"role": "user", "content": memory_prompt}, + {"role": "user", "content": q1}, + ] + self.logger.info(f"Text sent to the LLM: {messages}") + + # Stage 1: Get reasoning from LLM + if self.use_self_consistency: + response, tokens = self.get_self_consistent_response(messages, temp=repetitions/9, max_tokens=1024) + total_tokens_used += tokens + else: + response, usage = self.openai_query(messages, max_tokens=1024) + total_tokens_used += usage.total_tokens + + if self.use_reflection: + reflection_prompt = [ + { + "role": "user", + "content": f""" + Instructions: {self.instructions} + Task: {q1} + + Status: {status_prompt} + Memory: {memory_prompt} + + Reasoning: + {response} + + Is this reasoning valid given the Instructions, Status, and Memory? + - If YES, repeat it exactly. + - If NO, output the corrected reasoning only (no commentary). + """ + } + ] + response, usage = self.openai_query(reflection_prompt, max_tokens=1024) + total_tokens_used += usage.total_tokens + + if self.use_reasoning: + response = self.remove_reasoning(response) + self.logger.info(f"(Stage 1) Response from LLM: {response}") + + # Stage 2: Use ModernBERT with EXACT training implementation + self.logger.info("Using ModernBERT for Stage 2 action generation with EXACT training implementation...") + is_valid, response_dict, action, bert_tokens = self.get_action_with_modernbert( + response, observation, status_prompt, memory_prompt + ) + total_tokens_used += bert_tokens # Should be 0 for BERT models + + # Format response for consistency with original code + final_response = json.dumps(response_dict, indent=2) + self.responses.append(final_response) + self.logger.info(f"(Stage 2) ModernBERT Response: {final_response}") + print(f"(Stage 2) ModernBERT Response: {final_response}") + + return is_valid, response_dict, action, total_tokens_used diff --git a/agents/attackers/llm_bert/llm_agent_qa.py b/agents/attackers/llm_bert/llm_agent_qa.py new file mode 100644 index 00000000..ff56b239 --- /dev/null +++ b/agents/attackers/llm_bert/llm_agent_qa.py @@ -0,0 +1,450 @@ +import logging +import argparse +import numpy as np +import pandas as pd +import mlflow +import sys +import json +from os import path +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +sys.path.append( + path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) +) + +# Import from the same directory - works when run as module or script +try: + from .llm_action_planner import LLMActionPlanner +except ImportError: + from llm_action_planner import LLMActionPlanner + +from AIDojoCoordinator.game_components import AgentStatus, Action, ActionType +from NetSecGameAgents.agents.base_agent import BaseAgent + +#mlflow.set_tracking_uri("http://147.32.83.60") +#mlflow.set_experiment("LLM_QA_netsecgame_dec2024") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--llm", + type=str, + # choices=[ + # "gpt-4", + # "gpt-4-turbo-preview", + # "gpt-3.5-turbo", + # "gpt-3.5-turbo-16k", + # "HuggingFaceH4/zephyr-7b-beta", + # ], + default="gpt-3.5-turbo", + help="LLM used with OpenAI API", + ) + parser.add_argument( + "--test_episodes", + help="Number of test episodes to run", + default=30, + action="store", + required=False, + type=int, + ) + parser.add_argument( + "--memory_buffer", + help="Number of actions to remember and pass to the LLM", + default=5, + action="store", + required=False, + type=int, + ) + parser.add_argument( + "--host", + help="Host where the game server is", + default="127.0.0.1", + action="store", + required=False, + ) + parser.add_argument( + "--port", + help="Port where the game server is", + default=9001, + type=int, + action="store", + required=False, + ) + + parser.add_argument( + "--api_url", + type=str, + default="http://127.0.0.1:11434/v1/" + ) + + parser.add_argument( + "--use_reasoning", + action="store_true", + help="Required for models that output reasoning using ...." + ) + + parser.add_argument( + "--use_reflection", + action="store_true", + help="To use reflection prompting technique in the LLM calls." + ) + + parser.add_argument( + "--use_self_consistency", + action="store_true", + help="To use self-consistency prompting technique in the LLM calls." + ) + + parser.add_argument( + "--mlflow_tracking_uri", + type=str, + default="http://147.32.83.60", + help="MLflow tracking server URI (default: %(default)s)", + ) + + parser.add_argument( + "--mlflow_experiment", + type=str, + default="LLM_QA_netsecgame_dec2024", + help="MLflow experiment name (default: %(default)s)", + ) + + parser.add_argument( + "--mlflow_description", + type=str, + default=None, + help="Optional description for MLflow run (default is generated)", + ) + + + + parser.add_argument( + "--disable_mlflow", + action="store_true", + help="Disable mlflow logging", + ) + + # CHANGED: Help text updated to be more specific. + parser.add_argument( + "--max_tokens_limit", + type=int, + default=0, + help="If the cumulative total tokens used across all episodes exceeds this limit, the entire run is stopped. 0 means no limit. (default: %(default)s)", + ) + + parser.add_argument( + "--classifier_model_path", + type=str, + default="./bertClassifier", + help="Path to the ModernBERT action classifier model (default: %(default)s)" + ) + + parser.add_argument( + "--mlm_model_path", + type=str, + default="./bertMLM", + help="Path to the ModernBERT masked language model (default: %(default)s)" + ) + + args = parser.parse_args() + + logging.basicConfig( + filename="llm_react.log", + filemode="w", + format="%(asctime)s %(name)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + level=logging.INFO, + ) + + logger = logging.getLogger("llm_react") + logger.info("Start") + agent = BaseAgent(args.host, args.port, "Attacker") + + if not args.disable_mlflow: + mlflow.set_tracking_uri(args.mlflow_tracking_uri) + mlflow.set_experiment(args.mlflow_experiment) + + # Use custom description if given, otherwise build a default + experiment_description = args.mlflow_description or ( + f"{args.mlflow_experiment} | Model: {args.llm}" + ) + + mlflow.start_run(description=experiment_description) + + params = { + "model": args.llm, + "memory_len": args.memory_buffer, + "episodes": args.test_episodes, + "host": args.host, + "port": args.port, + "api_url": args.api_url, + "max_tokens_limit": args.max_tokens_limit, + } + mlflow.log_params(params) + mlflow.set_tag("agent_role", "Attacker") + + # Run multiple episodes to compute statistics + wins = 0 + detected = 0 + reach_max_steps = 0 + returns = [] + num_steps = [] + num_win_steps = [] + num_detected_steps = [] + num_actions_repeated = [] + reward_memory = "" + + # ADDED: Counter for total tokens used across the entire run. + total_tokens_used = 0 + + # Create an empty DataFrame for storing prompts and responses, and evaluations + prompt_table = [] + + # We are still not using this, but we keep track + is_detected = False + + # Initialize the game + print("Registering") + agent.register() + print("Done") + + for episode in range(1, args.test_episodes + 1): + actions_took_in_episode = [] + evaluations = [] # used for prompt table storage. + logger.info(f"Running episode {episode}") + print(f"Running episode {episode}") + + # Reset the game at every episode and store the goal that changes + observation = agent.request_game_reset() + num_iterations = observation.info["max_steps"] + current_state = observation.state + + taken_action = None + memories = [] + total_reward = 0 + num_actions = 0 + repeated_actions = 0 + + if args.llm is not None: + llm_query = LLMActionPlanner( + model_name=args.llm, + goal=observation.info["goal_description"], + memory_len=args.memory_buffer, + api_url=args.api_url, + use_reasoning=args.use_reasoning, + use_reflection=args.use_reflection, + use_self_consistency=args.use_self_consistency, + classifier_model_path=args.classifier_model_path, # Add this + mlm_model_path=args.mlm_model_path # Add this + ) + print(observation) + for i in range(num_iterations): + good_action = False + + is_valid, response_dict, action, tokens_used = llm_query.get_action_from_obs_react(observation, memories) + + # CHANGED: Increment the total token counter for the run + total_tokens_used += tokens_used + print(f"Total tokens used so far: {total_tokens_used}") + + # --- MODIFIED BEHAVIOR --- + # If total token limit is reached, save completed episode data and terminate the entire script. + # Corrected Line + if args.max_tokens_limit > 0 and total_tokens_used > args.max_tokens_limit: + termination_message = f"CRITICAL: Total token limit of {args.max_tokens_limit} reached (cumulative total: {total_tokens_used}). Terminating run." + print(termination_message) + logger.critical(termination_message) + + print("Saving data from completed episodes before exiting...") + with open("episode_data.json", "w") as json_file: + json.dump(prompt_table, json_file, indent=4) + print("Data saved to episode_data.json.") + + if not args.disable_mlflow: + mlflow.set_tag("termination_reason", "max_tokens_limit_exceeded") + mlflow.log_param("termination_episode", episode) + mlflow.log_metric("total_tokens_at_termination", total_tokens_used) + mlflow.end_run("FAILED") # Mark the run as failed + print("MLflow run terminated.") + + print("Exiting script.") + sys.exit(1) # Exit with a non-zero status code to indicate abnormal termination + + if is_valid and action is not None: + observation = agent.make_step(action) + logger.info(f"Observation received: {observation}") + taken_action = action + total_reward += observation.reward + + if observation.state != current_state: + good_action = True + current_state = observation.state + evaluations.append(8) + else: + evaluations.append(3) + else: + print("Invalid action: ") + evaluations.append(0) + + try: + if not is_valid: + memories.append( + ( + (response_dict["action"], + response_dict["parameters"]), + "not valid based on your status." + ) + ) + print("not valid based on your status.") + else: + if good_action: + memories.append( + ( + (response_dict["action"], + response_dict["parameters"]), + "helpful." + ) + ) + print("Helpful") + else: + memories.append( + ( + (response_dict["action"], + response_dict["parameters"]), + "not helpful." + ) + ) + print("Not Helpful") + if action in actions_took_in_episode: + repeated_actions += 1 + actions_took_in_episode.append(action) + except: + memories.append( + (response_dict["action"], + response_dict["parameters"]), + "badly formated." + ) + print("badly formated") + + if len(memories) > args.memory_buffer: + memories.pop(0) + + logger.info(f"Iteration: {i} Valid: {is_valid} Good: {good_action}") + + if observation.end or i == (num_iterations - 1): + if i < (num_iterations - 1): + reason = observation.info + else: + reason = {"end_reason": AgentStatus.TimeoutReached} + + steps = i + epi_last_reward = observation.reward + num_actions_repeated += [repeated_actions] + if AgentStatus.Success == reason["end_reason"]: + wins += 1 + num_win_steps += [steps] + evaluations[-1] = 10 + elif AgentStatus.Fail == reason["end_reason"]: + detected += 1 + num_detected_steps += [steps] + elif AgentStatus.TimeoutReached == reason["end_reason"]: + reach_max_steps += 1 + total_reward = -100 + steps = num_iterations + else: + reach_max_steps += 1 + + returns += [total_reward] + num_steps += [steps] + + if not args.disable_mlflow: + mlflow.log_metric("wins", wins, step=episode) + mlflow.log_metric("num_steps", steps, step=episode) + mlflow.log_metric("return", total_reward, step=episode) + mlflow.log_metric("reached_max_steps", reach_max_steps, step=episode) + mlflow.log_metric("detected", detected, step=episode) + mlflow.log_metric("total_tokens_used", total_tokens_used, step=episode) # Log total tokens each episode + mlflow.log_metric("win_rate", (wins / episode) * 100, step=episode) + mlflow.log_metric("avg_returns", np.mean(returns), step=episode) + mlflow.log_metric("avg_steps", np.mean(num_steps), step=episode) + + logger.info( + f"\tEpisode {episode} of game ended after {steps} steps. Reason: {reason}. Last reward: {epi_last_reward}" + ) + print( + f"\tEpisode {episode} of game ended after {steps} steps. Reason: {reason}. Last reward: {epi_last_reward}" + ) + break + + episode_prompt_table = { + "episode": episode, + "state": llm_query.get_states(), + "prompt": llm_query.get_prompts(), + "modernbert_prompts": llm_query.get_modernbert_prompts(), + "response": llm_query.get_responses(), + "evaluation": evaluations, + "end_reason": str(reason["end_reason"]) + } + prompt_table.append(episode_prompt_table) + + # This part of the code will only be reached if the run completes all episodes + # without hitting the token limit. + print("All episodes completed successfully.") + with open("episode_data.json", "w") as json_file: + json.dump(prompt_table, json_file, indent=4) + + # After all episodes are done. Compute statistics + completed_episodes = len(prompt_table) + test_win_rate = (wins / completed_episodes) * 100 if completed_episodes > 0 else 0 + test_detection_rate = (detected / completed_episodes) * 100 if completed_episodes > 0 else 0 + test_max_steps_rate = (reach_max_steps / completed_episodes) * 100 if completed_episodes > 0 else 0 + test_average_returns = np.mean(returns) if returns else 0 + test_std_returns = np.std(returns) if returns else 0 + test_average_episode_steps = np.mean(num_steps) if num_steps else 0 + test_std_episode_steps = np.std(num_steps) if num_steps else 0 + test_average_win_steps = np.mean(num_win_steps) if num_win_steps else 0 + test_std_win_steps = np.std(num_win_steps) if num_win_steps else 0 + test_average_detected_steps = np.mean(num_detected_steps) if num_detected_steps else 0 + test_std_detected_steps = np.std(num_detected_steps) if num_detected_steps else 0 + test_average_repeated_steps = np.mean(num_actions_repeated) if num_actions_repeated else 0 + test_std_repeated_steps = np.std(num_actions_repeated) if num_actions_repeated else 0 + + tensorboard_dict = { + "test_avg_win_rate": test_win_rate, + "test_avg_detection_rate": test_detection_rate, + "test_avg_max_steps_rate": test_max_steps_rate, + "test_avg_returns": test_average_returns, + "test_std_returns": test_std_returns, + "test_avg_episode_steps": test_average_episode_steps, + "test_std_episode_steps": test_std_episode_steps, + "test_avg_win_steps": test_average_win_steps, + "test_std_win_steps": test_std_win_steps, + "test_avg_detected_steps": test_average_detected_steps, + "test_std_detected_steps": test_std_detected_steps, + "test_avg_repeated_steps": test_average_repeated_steps, + "test_std_repeated_steps": test_std_repeated_steps, + "final_total_tokens_used": total_tokens_used, + } + + if not args.disable_mlflow: + mlflow.log_metrics(tensorboard_dict) + mlflow.end_run("FINISHED") + + text = f"""Final test after {completed_episodes} episodes + Wins={wins}, + Detections={detected}, + Max Steps Reached={reach_max_steps}, + winrate={test_win_rate:.3f}%, + detection_rate={test_detection_rate:.3f}%, + max_steps_rate={test_max_steps_rate:.3f}%, + average_returns={test_average_returns:.3f} +- {test_std_returns:.3f}, + average_episode_steps={test_average_episode_steps:.3f} +- {test_std_episode_steps:.3f}, + average_win_steps={test_average_win_steps:.3f} +- {test_std_win_steps:.3f}, + average_detected_steps={test_average_detected_steps:.3f} +- {test_std_detected_steps:.3f}, + average_repeated_steps={test_average_repeated_steps:.3f} += {test_std_repeated_steps:.3f}, + Total Tokens Used: {total_tokens_used} + """ + + print(text) + logger.info(text) diff --git a/agents/attackers/llm_bert/prompts.yaml b/agents/attackers/llm_bert/prompts.yaml new file mode 100644 index 00000000..c6b6327b --- /dev/null +++ b/agents/attackers/llm_bert/prompts.yaml @@ -0,0 +1,70 @@ +prompts: + INSTRUCTIONS_TEMPLATE: | + You are a pentester and your goal is to {{goal}}. + The actions you can take are ScanServices, FindData, ScanNetwork, ExploitService, and ExfiltrateData. + + The rules are: + 1. You can only scan networks in your list of known networks. + 2. You can only scan services in hosts you know. + 3. You can only exploit services when you know the service. + 4. You can find data in hosts you control. + 5. You can exfiltrate known data to and from controlled hosts. + + Service format: when specifying `target_service` for ExploitService, use the exact string shown under "Known services" in the status (including port/proto and name), e.g., "22/tcp, ssh" or "80/tcp, http"; do not use aliases like just "ssh" or only the port. + + Do not repeat actions that you took in the past. + Do not scan or exploit services in hosts you control. + Do not scan the same network twice. + Do not exfiltrate the same data twice. + Exploit services that you know and find data in all controlled hosts. + Once you have known data exfiltrate them to the target one by one. + + COT_PROMPT: | + Example status: + Known networks are 1.1.1.0/24 + Known hosts are 2.2.2.2, 1.1.1.2, 2.2.2.3 + Controlled hosts are 2.2.2.2, 1.1.1.2 + Known data for source host 1.1.1.2: are ('User1', 'SomeData') + Known services for host 1.1.1.1 are "22/tcp, ssh" + + Here are some examples of actions: + Action: {"action":"ScanNetwork", "parameters": {"target_network": "1.1.1.0/24", "source_host": "2.2.2.2"}} + Action: {"action":"ScanServices", "parameters":{"target_host":"2.2.2.3", "source_host': '2.2.2.2"}} + Action: {"action":"ExploitService", "parameters":{"target_host":"1.1.1.1", "target_service":"22/tcp, ssh", "source_host": "1.1.1.2"}} + Action: {"action":"FindData", "parameters":{"target_host":"1.1.1.1", "source_host": "1.1.1.2"}} + Action: {"action":"ExfiltrateData", "parameters": {"target_host": "2.2.2.2", "data": {"owner":"User1", "id":"WebData"}, "source_host": "1.1.1.2"}} + End of example. + + Example status 2: + Known networks are 10.0.0.0/24, 192.168.1.0/24 + Known hosts are 10.0.0.1, 10.0.0.2, 192.168.1.1, 192.168.1.2 + Controlled hosts are 10.0.0.2, 192.168.1.2 + Known data for source host 192.168.1.2: ('Admin', 'ConfidentialData') + Known services for host 192.168.1.1 are "80/tcp, http", "21/tcp, ftp" + + Here are some examples of actions: + Action: {"action":"ScanNetwork", "parameters": {"target_network": "192.168.1.0/24", "source_host": "10.0.0.2"}} + Action: {"action":"ScanServices", "parameters":{"target_host":"192.168.1.1", "source_host": "192.168.1.2"}} + Action: {"action":"ExploitService", "parameters":{"target_host":"192.168.1.1", "target_service":"80/tcp, http", "source_host": "10.0.0.2"}} + Action: {"action":"FindData", "parameters":{"target_host":"192.168.1.1", "source_host": "192.168.1.2"}} + Action: {"action":"ExfiltrateData", "parameters": {"target_host": "10.0.0.2", "data": {"owner":"Admin", "id":"ConfidentialData"}, "source_host": "192.168.1.2"}} + End of example 2 + + COT_PROMPT2: | + Here are some examples of actions: + Action: {"action":"ScanNetwork", "parameters": {"target_network": "1.1.1.0/24", "source_host": "2.2.2.2"}} + Action: {"action":"ScanServices", "parameters":{"target_host":"2.2.2.3", "source_host': '2.2.2.2"}} + Action: {"action":"ExploitService", "parameters":{"target_host":"1.1.1.1", "target_service":"22/tcp, ssh", "source_host": "1.1.1.2"}} + Action: {"action":"FindData", "parameters":{"target_host":"1.1.1.1", "source_host": "1.1.1.2"}} + Action: {"action":"ExfiltrateData", "parameters": {"target_host": "2.2.2.2", "data": {"owner":"User1", "id":"WebData"}, "source_host": "1.1.1.2"}} + End of examples. + +questions: + - id: Q1 + text: "List the objects in the current status and the actions they can be used. Be specific." + - id: Q2 + text: "List the top 3 sub-tasks you should follow with specific parameters. Indicate their priority out of 5." + - id: Q3 + text: "Provide the action with the highest priority and its parameters in the correct JSON format. Do not repeat past actions.\nAction: " + - id: Q4 + text: "Provide the best next action in the correct JSON format. An action should include the `source_host`. Action: " diff --git a/agents/attackers/llm_bert/validate_responses.py b/agents/attackers/llm_bert/validate_responses.py new file mode 100644 index 00000000..042b0eae --- /dev/null +++ b/agents/attackers/llm_bert/validate_responses.py @@ -0,0 +1,142 @@ +import json + +# Define schema for each action +ACTION_SCHEMA = { + "ExfiltrateData": { + "required": ["target_host", "source_host", "data"], + "schema": { + "target_host": str, + "source_host": str, + "data": { + "required": ["owner", "id"], + "schema": { + "owner": str, + "id": str + } + } + } + }, + "FindData": { + "required": ["target_host", "source_host"], + "schema": { + "target_host": str, + "source_host": str + } + }, + "ExploitService": { + "required": ["target_host", "target_service", "source_host"], + "schema": { + "target_host": str, + "target_service": str, + "source_host": str + } + }, + "ScanServices": { + "required": ["target_host", "source_host"], + "schema": { + "target_host": str, + "source_host": str + } + }, + "FindServices": { + "required": ["target_host", "source_host"], + "schema": { + "target_host": str, + "source_host": str + } + }, + "ScanNetwork": { + "required": ["target_network", "source_host"], + "schema": { + "target_network": str, + "source_host": str + } + } +} + + +def validate_schema(data: dict, schema: dict) -> tuple[bool, str | bool]: + """Recursively validate that data matches the schema.""" + for key, expected_type in schema.items(): + if key not in data: + return False, f"Error: Missing required key '{key}'" + + value = data[key] + + if isinstance(expected_type, dict): + # Handle nested schema with "schema" and "required" + if "schema" in expected_type: + if not isinstance(value, dict): + return False, f"Error: Field '{key}' must be a dictionary" + + # Validate required fields + required = expected_type.get("required", []) + for req_key in required: + if req_key not in value: + return False, f"Error: Missing required key '{req_key}' in '{key}'" + + # Recursively validate inner schema + inner_result = validate_schema(value, expected_type["schema"]) + if not inner_result[0]: + return inner_result + + else: + # Regular nested dictionary check + if not isinstance(value, dict): + return False, f"Error: Field '{key}' must be a dictionary" + inner_result = validate_schema(value, expected_type) + if not inner_result[0]: + return inner_result + + elif isinstance(expected_type, type): + if not isinstance(value, expected_type): + return False, f"Error: Field '{key}' must be of type {expected_type.__name__}" + else: + return False, f"Error: Invalid schema definition for key '{key}'" + + return True, True + + +def validate_agent_response(raw_response: str) -> tuple[dict | None, str | None]: + """Validate agent JSON response. Assumes raw_response is a single dict. + + Args: + raw_response (str): Raw JSON string from agent + + Returns: + Tuple (validated_dict, None) on success or (None, error_message) on failure. + """ + try: + response = json.loads(raw_response) + + if not isinstance(response, dict): + return None, "Error: Response must be a JSON object." + + action = response.get("action") + if not action: + return None, "Error: Missing 'action' field." + + schema_def = ACTION_SCHEMA.get(action) + if not schema_def: + return None, f"Error: Unknown action '{action}'." + + parameters = response.get("parameters") + if not isinstance(parameters, dict): + return None, "Error: 'parameters' must be a dictionary." + + # Check required fields + required_fields = schema_def.get("required", []) + for field in required_fields: + if field not in parameters: + return None, f"Error: Missing required parameter '{field}' for action '{action}'." + + # Validate schema + schema = schema_def.get("schema", {}) + is_valid, validation_error = validate_schema(parameters, schema) + if not is_valid: + return None, "Parameter validation failed: " + validation_error + + return response, None # ✅ always return a 2-tuple + + except json.JSONDecodeError as e: + return None, f"Error: Invalid JSON format. {e}" From 0bf17251f4a2ed46509e6b98b6e8b1ecdee345fb Mon Sep 17 00:00:00 2001 From: diegoforni <71357860+diegoforni@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:14:54 -0300 Subject: [PATCH 2/2] =?UTF-8?q?Restore=20random=20agent=20files=20to=20mat?= =?UTF-8?q?ch=20main=20=E2=80=94=20remove=20accidental=20edits=20from=20ll?= =?UTF-8?q?m=5Fbert?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/attackers/random/random_agent.py | 6 +++--- agents/attackers/random/whitebox_random_agent.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/agents/attackers/random/random_agent.py b/agents/attackers/random/random_agent.py index d44087d9..cd9cdcf4 100644 --- a/agents/attackers/random/random_agent.py +++ b/agents/attackers/random/random_agent.py @@ -42,7 +42,7 @@ def play_game(self, observation, num_episodes=1): observation = self.request_game_reset() self._logger.info(f"Final results for {self.__class__.__name__} after {num_episodes} episodes: {np.mean(returns)}±{np.std(returns)}") # This will be the last observation played before the reset - return (last_observation, np.mean(returns), num_steps) + return (last_observation, num_steps) def select_action(self, observation:Observation)->Action: valid_actions = generate_valid_actions(observation.state) @@ -115,10 +115,10 @@ def select_action(self, observation:Observation)->Action: print(f'Starting the testing for episode {episode}') # Play the game for one episode - observation, avg_return, num_steps = agent.play_game(observation, 1) + observation, num_steps = agent.play_game(observation, 1) state = observation.state - reward = avg_return + reward = observation.reward end = observation.end info = observation.info diff --git a/agents/attackers/random/whitebox_random_agent.py b/agents/attackers/random/whitebox_random_agent.py index 4a5bc3db..b1ebc443 100644 --- a/agents/attackers/random/whitebox_random_agent.py +++ b/agents/attackers/random/whitebox_random_agent.py @@ -45,7 +45,7 @@ def play_game(self, observation, num_episodes=1): observation = self.request_game_reset() self._logger.info(f"Final results for {self.__class__.__name__} after {num_episodes} episodes: {np.mean(returns)}±{np.std(returns)}") # This will be the last observation played before the reset - return (last_observation, np.mean(returns), num_steps) + return (last_observation, num_steps) def select_action(self, observation:Observation)->Action: # Get the valid action mask (boolean array) for the current state using the parent class method @@ -126,10 +126,10 @@ def select_action(self, observation:Observation)->Action: print(f'Starting the testing for episode {episode}') # Play the game for one episode - observation, avg_return, num_steps = agent.play_game(observation, 1) + observation, num_steps = agent.play_game(observation, 1) state = observation.state - reward = avg_return + reward = observation.reward end = observation.end info = observation.info