Skip to content

Commit 3693e0e

Browse files
authored
Merge pull request #79 from stratosphereips/cyst-integration
Cyst integration
2 parents d15cf5c + 77e5edd commit 3693e0e

File tree

22 files changed

+230
-425
lines changed

22 files changed

+230
-425
lines changed

agents/agent_utils.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from os import path
1111
sys.path.append(path.dirname(path.dirname(path.dirname( path.abspath(__file__) ))))
1212
#with the path fixed, we can import now
13-
from env.game_components import Action, ActionType, GameState, Observation, IP, Network
13+
from AIDojoCoordinator.game_components import Action, ActionType, GameState, Observation, IP, Network
1414
import ipaddress
1515

1616
def generate_valid_actions_concepts(state: GameState)->list:
@@ -22,12 +22,12 @@ def generate_valid_actions_concepts(state: GameState)->list:
2222
# TODO ADD neighbouring networks
2323
# Only scan local networks from local hosts
2424
if network.is_private() and source_host.is_private():
25-
valid_actions.add(Action(ActionType.ScanNetwork, params={"target_network": network, "source_host": source_host,}))
25+
valid_actions.add(Action(ActionType.ScanNetwork, parameters={"target_network": network, "source_host": source_host,}))
2626
# Service Scans
2727
for host in state.known_hosts:
2828
# Do not try to scan a service from hosts outside local networks towards local networks
2929
if host.is_private() and source_host.is_private():
30-
valid_actions.add(Action(ActionType.FindServices, params={"target_host": host, "source_host": source_host,}))
30+
valid_actions.add(Action(ActionType.FindServices, parameters={"target_host": host, "source_host": source_host,}))
3131
# Service Exploits
3232
for host, service_list in state.known_services.items():
3333
# Only exploit local services from local hosts
@@ -37,52 +37,65 @@ def generate_valid_actions_concepts(state: GameState)->list:
3737
for service in service_list:
3838
# Do not consider local services, which are internal to the host
3939
if not service.is_local:
40-
valid_actions.add(Action(ActionType.ExploitService, params={"target_host": host,"target_service": service,"source_host": source_host,}))
40+
valid_actions.add(Action(ActionType.ExploitService, parameters={"target_host": host,"target_service": service,"source_host": source_host,}))
4141
# Find Data Scans
4242
for host in state.controlled_hosts:
43-
valid_actions.add(Action(ActionType.FindData, params={"target_host": host, "source_host": host}))
43+
valid_actions.add(Action(ActionType.FindData, parameters={"target_host": host, "source_host": host}))
4444

4545
# Data Exfiltration
4646
for source_host, data_list in state.known_data.items():
4747
for data in data_list:
4848
for target_host in state.controlled_hosts:
4949
if target_host != source_host:
50-
valid_actions.add(Action(ActionType.ExfiltrateData, params={"target_host": target_host, "source_host": source_host, "data": data}))
50+
valid_actions.add(Action(ActionType.ExfiltrateData, parameters={"target_host": target_host, "source_host": source_host, "data": data}))
5151
return list(valid_actions)
5252

53-
def generate_valid_actions(state: GameState)->list:
53+
def generate_valid_actions(state: GameState, include_blocks=False)->list:
5454
"""Function that generates a list of all valid actions in a given state"""
5555
valid_actions = set()
56+
def is_fw_blocked(state, src_ip, dst_ip)->bool:
57+
blocked = False
58+
try:
59+
blocked = dst_ip in state.known_blocks[src_ip]
60+
except KeyError:
61+
pass #this src ip has no known blocks
62+
return blocked
63+
5664
for src_host in state.controlled_hosts:
5765
#Network Scans
5866
for network in state.known_networks:
5967
# TODO ADD neighbouring networks
60-
valid_actions.add(Action(ActionType.ScanNetwork, params={"target_network": network, "source_host": src_host,}))
68+
valid_actions.add(Action(ActionType.ScanNetwork, parameters={"target_network": network, "source_host": src_host,}))
6169
# Service Scans
6270
for host in state.known_hosts:
63-
valid_actions.add(Action(ActionType.FindServices, params={"target_host": host, "source_host": src_host,}))
71+
if not is_fw_blocked(state, src_host,host):
72+
valid_actions.add(Action(ActionType.FindServices, parameters={"target_host": host, "source_host": src_host,}))
6473
# Service Exploits
6574
for host, service_list in state.known_services.items():
66-
for service in service_list:
67-
valid_actions.add(Action(ActionType.ExploitService, params={"target_host": host,"target_service": service,"source_host": src_host,}))
75+
if not is_fw_blocked(state, src_host,host):
76+
for service in service_list:
77+
valid_actions.add(Action(ActionType.ExploitService, parameters={"target_host": host,"target_service": service,"source_host": src_host,}))
6878
# Data Scans
6979
for host in state.controlled_hosts:
70-
valid_actions.add(Action(ActionType.FindData, params={"target_host": host, "source_host": host}))
80+
if not is_fw_blocked(state, src_host,host):
81+
valid_actions.add(Action(ActionType.FindData, parameters={"target_host": host, "source_host": host}))
7182

7283
# Data Exfiltration
7384
for src_host, data_list in state.known_data.items():
7485
for data in data_list:
7586
for trg_host in state.controlled_hosts:
7687
if trg_host != src_host:
77-
valid_actions.add(Action(ActionType.ExfiltrateData, params={"target_host": trg_host, "source_host": src_host, "data": data}))
88+
if not is_fw_blocked(state, src_host,trg_host):
89+
valid_actions.add(Action(ActionType.ExfiltrateData, parameters={"target_host": trg_host, "source_host": src_host, "data": data}))
7890

79-
# BlockIP
80-
for src_host in state.controlled_hosts:
81-
for target_host in state.controlled_hosts:
82-
for blocked_ip in state.known_hosts:
83-
valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":src_host, "blocked_host":blocked_ip}))
84-
85-
91+
if include_blocks:
92+
# BlockIP
93+
if include_blocks:
94+
for src_host in state.controlled_hosts:
95+
for target_host in state.controlled_hosts:
96+
if not is_fw_blocked(state, src_host,target_host):
97+
for blocked_ip in state.known_hosts:
98+
valid_actions.add(Action(ActionType.BlockIP, {"target_host":target_host, "source_host":src_host, "blocked_host":blocked_ip}))
8699
return list(valid_actions)
87100

88101
def state_as_ordered_string(state:GameState)->str:
@@ -97,6 +110,9 @@ def state_as_ordered_string(state:GameState)->str:
97110
ret += "},data:{"
98111
for host in sorted(state.known_data.keys()):
99112
ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_data[host])])}]"
113+
ret += "},blocks:{"
114+
for host in sorted(state.known_blocks.keys()):
115+
ret += f"{host}:[{','.join([str(x) for x in sorted(state.known_blocks[host])])}]"
100116
ret += "}"
101117
return ret
102118

@@ -480,7 +496,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
480496
if src_host_concept == host_concept:
481497
new_src_host = concept_mapping['controlled_hosts'][host_concept]
482498

483-
action = Action(ActionType.ExploitService, params={"target_host": new_target_host, "target_service": new_target_service, "source_host": new_src_host})
499+
action = Action(ActionType.ExploitService, parameters={"target_host": new_target_host, "target_service": new_target_service, "source_host": new_src_host})
484500

485501
elif action._type == ActionType.ExfiltrateData:
486502
# parameters = {
@@ -508,7 +524,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
508524
data_concept = action.parameters['data']
509525
new_data = data_concept
510526

511-
action = Action(ActionType.ExfiltrateData, params={"target_host": new_target_host, "source_host": new_src_host, "data": new_data})
527+
action = Action(ActionType.ExfiltrateData, parameters={"target_host": new_target_host, "source_host": new_src_host, "data": new_data})
512528

513529
elif action._type == ActionType.FindData:
514530
# parameters = {
@@ -529,7 +545,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
529545
if src_host_concept == host_concept:
530546
new_src_host = concept_mapping['controlled_hosts'][host_concept]
531547

532-
action = Action(ActionType.FindData, params={"target_host": new_target_host, "source_host": new_src_host})
548+
action = Action(ActionType.FindData, parameters={"target_host": new_target_host, "source_host": new_src_host})
533549

534550
elif action._type == ActionType.ScanNetwork:
535551
target_net_concept = action.parameters['target_network']
@@ -545,7 +561,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
545561
for host_concept in concept_mapping['controlled_hosts']:
546562
if src_host_concept == host_concept:
547563
new_src_host = concept_mapping['controlled_hosts'][host_concept]
548-
action = Action(ActionType.ScanNetwork, params={"source_host": new_src_host, "target_network": new_target_network} )
564+
action = Action(ActionType.ScanNetwork, parameters={"source_host": new_src_host, "target_network": new_target_network} )
549565

550566
elif action._type == ActionType.FindServices:
551567
# parameters = {
@@ -565,6 +581,6 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
565581
for host_concept in concept_mapping['controlled_hosts']:
566582
if src_host_concept == host_concept:
567583
new_src_host = concept_mapping['controlled_hosts'][host_concept]
568-
action = Action(ActionType.FindServices, params={"source_host": new_src_host, "target_host": new_target_host} )
584+
action = Action(ActionType.FindServices, parameters={"source_host": new_src_host, "target_host": new_target_host} )
569585

570586
return action

agents/attackers/concepts_q_learning/conceptual_q_agent.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,21 @@
22
# Arti
33
# Sebastian Garcia. sebastian.garcia@agents.fel.cvut.cz
44
import sys
5-
from os import path, makedirs
65
import numpy as np
76
import random
87
import pickle
98
import argparse
109
import logging
11-
# This is used so the agent can see the environment and game component
12-
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ) ) ))))
13-
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ))))
10+
import mlflow
11+
import subprocess
12+
from os import path, makedirs
13+
from AIDojoCoordinator.game_components import Action, Observation, GameState, AgentStatus
1414

1515
# This is used so the agent can see the environment and game component
1616
# with the path fixed, we can import now
17-
from env.game_components import Action, Observation, GameState
17+
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__) ))))
1818
from base_agent import BaseAgent
1919
from agent_utils import generate_valid_actions, state_as_ordered_string, convert_concepts_to_actions, convert_ips_to_concepts
20-
import mlflow
21-
import subprocess
22-
2320

2421
class QAgent(BaseAgent):
2522

@@ -383,15 +380,15 @@ def play_game(self, observation_ip, episode_num, testing=False):
383380
test_end = test_observation.end
384381
test_info = test_observation.info
385382

386-
if test_info and test_info['end_reason'] == 'blocked':
383+
if test_info and test_info['end_reason'] == AgentStatus.Fail:
387384
test_detected +=1
388385
test_num_detected_steps += [num_steps]
389386
test_num_detected_returns += [reward]
390-
elif test_info and test_info['end_reason'] == 'goal_reached':
387+
elif test_info and test_info['end_reason'] == AgentStatus.Success:
391388
test_wins += 1
392389
test_num_win_steps += [num_steps]
393390
test_num_win_returns += [reward]
394-
elif test_info and test_info['end_reason'] == 'max_steps':
391+
elif test_info and test_info['end_reason'] == AgentStatus.TimeoutReached:
395392
test_max_steps += 1
396393
test_num_max_steps_steps += [num_steps]
397394
test_num_max_steps_returns += [reward]

agents/attackers/double_q_learning/double_q_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import logging
1515
from torch.utils.tensorboard import SummaryWriter
1616
import time
17-
from env.worlds.network_security_game import NetworkSecurityEnvironment
18-
from env.game_components import Action, Observation, GameState
17+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
18+
from AIDojoCoordinator.game_components import Action, Observation, GameState
1919

2020
class DoubleQAgent:
2121

agents/attackers/gnn_reinforce/gnn_REINFORCE_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))
1919

2020
#with the path fixed, we can import now
21-
from env.worlds.network_security_game import NetworkSecurityEnvironment
21+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
2222

2323
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
2424
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

agents/attackers/interactive_tui/assistant.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,8 @@
1010
import jinja2
1111
from tenacity import retry, stop_after_attempt
1212

13-
sys.path.append(
14-
path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
15-
)
16-
from env.game_components import (
17-
ActionType,
18-
Observation,
19-
)
13+
14+
from AIDojoCoordinator.game_components import ActionType, Observation
2015

2116
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
2217

agents/attackers/interactive_tui/interactive_tui.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
11
#
22
# Author: Maria Rigaki - maria.rigaki@aic.fel.cvut.cz
3-
#
4-
from textual.app import App, ComposeResult, Widget
5-
from textual.widgets import Tree, Button, RichLog, Select, Input
6-
from textual.containers import Vertical, VerticalScroll, Horizontal
7-
from textual.validation import Function
8-
from textual import on
9-
from textual.reactive import reactive
103

114
import sys
12-
from os import path
135
import os
146
import logging
157
import ipaddress
168
import argparse
179
import asyncio
18-
10+
from textual.app import App, ComposeResult, Widget
11+
from textual.widgets import Tree, Button, RichLog, Select, Input
12+
from textual.containers import Vertical, VerticalScroll, Horizontal
13+
from textual.validation import Function
14+
from textual import on
15+
from textual.reactive import reactive
1916
from assistant import LLMAssistant
20-
21-
# This is used so the agent can see the environment and game components
22-
sys.path.append(
23-
path.dirname(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
24-
)
25-
from env.game_components import Network, IP
26-
from env.game_components import ActionType, Action, GameState, Observation
17+
from AIDojoCoordinator.game_components import Network, IP, ActionType, Action, GameState, Observation
2718

2819
# This is used so the agent can see the BaseAgent
2920
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -558,7 +549,7 @@ def generate_action(self, state: GameState) -> Action:
558549
self.network_input[:-3], mask=int(self.network_input[-2:])
559550
),
560551
}
561-
action = Action(action_type=self.next_action, params=parameters)
552+
action = Action(action_type=self.next_action, parameters=parameters)
562553
else:
563554
self.notify("Please provide valid inputs", severity="error")
564555
elif self.next_action in [ActionType.FindServices, ActionType.FindData]:
@@ -567,7 +558,7 @@ def generate_action(self, state: GameState) -> Action:
567558
"source_host": IP(self.src_host_input),
568559
"target_host": IP(self.target_host_input),
569560
}
570-
action = Action(action_type=self.next_action, params=parameters)
561+
action = Action(action_type=self.next_action, parameters=parameters)
571562
else:
572563
self.notify("Please provide valid inputs", severity="error")
573564
elif self.next_action == ActionType.ExploitService:
@@ -582,7 +573,7 @@ def generate_action(self, state: GameState) -> Action:
582573
"target_service": service,
583574
}
584575
action = Action(
585-
action_type=self.next_action, params=parameters
576+
action_type=self.next_action, parameters=parameters
586577
)
587578
break
588579
else:
@@ -600,7 +591,7 @@ def generate_action(self, state: GameState) -> Action:
600591
"data": datum,
601592
}
602593
action = Action(
603-
action_type=self.next_action, params=parameters
594+
action_type=self.next_action, parameters=parameters
604595
)
605596
else:
606597
parameters = self.data_input

agents/attackers/llm/llm_agent-2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))
99

1010

11-
from env.worlds.network_security_game import NetworkSecurityEnvironment
12-
from env.game_components import ActionType, Action, IP, Data, Network, Service
11+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
12+
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service
1313

1414
import openai
1515
from tenacity import retry, stop_after_attempt

agents/attackers/llm/llm_agent-3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))
88

99

10-
from env.worlds.network_security_game import NetworkSecurityEnvironment
11-
from env.game_components import ActionType, Action, IP, Data, Network, Service
10+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
11+
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service
1212

1313
import openai
1414
from tenacity import retry, stop_after_attempt

agents/attackers/llm/llm_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))
88

99

10-
from env.worlds.network_security_game import NetworkSecurityEnvironment
11-
from env.game_components import ActionType, Action, IP, Data, Network, Service
10+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
11+
from AIDojoCoordinator.game_components import ActionType, Action, IP, Data, Network, Service
1212

1313
import openai
1414
from tenacity import retry, stop_after_attempt

agents/attackers/llm_embed/llm_embed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from os import path
1212
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) ))))
1313

14-
from env.worlds.network_security_game import NetworkSecurityEnvironment
15-
from env.game_components import Action, ActionType
14+
from AIDojoCoordinator.worlds.network_security_game import NetworkSecurityEnvironment
15+
from AIDojoCoordinator.game_components import Action, ActionType
1616

1717
import numpy as np
1818
import torch

0 commit comments

Comments
 (0)