1010from os import path
1111sys .path .append (path .dirname (path .dirname (path .dirname ( path .abspath (__file__ ) ))))
1212#with the path fixed, we can import now
13- from env .game_components import Action , ActionType , GameState , Observation , IP , Network
13+ from AIDojoCoordinator .game_components import Action , ActionType , GameState , Observation , IP , Network
1414import ipaddress
1515
1616def generate_valid_actions_concepts (state : GameState )-> list :
@@ -22,12 +22,12 @@ def generate_valid_actions_concepts(state: GameState)->list:
2222 # TODO ADD neighbouring networks
2323 # Only scan local networks from local hosts
2424 if network .is_private () and source_host .is_private ():
25- valid_actions .add (Action (ActionType .ScanNetwork , params = {"target_network" : network , "source_host" : source_host ,}))
25+ valid_actions .add (Action (ActionType .ScanNetwork , parameters = {"target_network" : network , "source_host" : source_host ,}))
2626 # Service Scans
2727 for host in state .known_hosts :
2828 # Do not try to scan a service from hosts outside local networks towards local networks
2929 if host .is_private () and source_host .is_private ():
30- valid_actions .add (Action (ActionType .FindServices , params = {"target_host" : host , "source_host" : source_host ,}))
30+ valid_actions .add (Action (ActionType .FindServices , parameters = {"target_host" : host , "source_host" : source_host ,}))
3131 # Service Exploits
3232 for host , service_list in state .known_services .items ():
3333 # Only exploit local services from local hosts
@@ -37,52 +37,65 @@ def generate_valid_actions_concepts(state: GameState)->list:
3737 for service in service_list :
3838 # Do not consider local services, which are internal to the host
3939 if not service .is_local :
40- valid_actions .add (Action (ActionType .ExploitService , params = {"target_host" : host ,"target_service" : service ,"source_host" : source_host ,}))
40+ valid_actions .add (Action (ActionType .ExploitService , parameters = {"target_host" : host ,"target_service" : service ,"source_host" : source_host ,}))
4141 # Find Data Scans
4242 for host in state .controlled_hosts :
43- valid_actions .add (Action (ActionType .FindData , params = {"target_host" : host , "source_host" : host }))
43+ valid_actions .add (Action (ActionType .FindData , parameters = {"target_host" : host , "source_host" : host }))
4444
4545 # Data Exfiltration
4646 for source_host , data_list in state .known_data .items ():
4747 for data in data_list :
4848 for target_host in state .controlled_hosts :
4949 if target_host != source_host :
50- valid_actions .add (Action (ActionType .ExfiltrateData , params = {"target_host" : target_host , "source_host" : source_host , "data" : data }))
50+ valid_actions .add (Action (ActionType .ExfiltrateData , parameters = {"target_host" : target_host , "source_host" : source_host , "data" : data }))
5151 return list (valid_actions )
5252
53- def generate_valid_actions (state : GameState )-> list :
53+ def generate_valid_actions (state : GameState , include_blocks = False )-> list :
5454 """Function that generates a list of all valid actions in a given state"""
5555 valid_actions = set ()
56+ def is_fw_blocked (state , src_ip , dst_ip )-> bool :
57+ blocked = False
58+ try :
59+ blocked = dst_ip in state .known_blocks [src_ip ]
60+ except KeyError :
61+ pass #this src ip has no known blocks
62+ return blocked
63+
5664 for src_host in state .controlled_hosts :
5765 #Network Scans
5866 for network in state .known_networks :
5967 # TODO ADD neighbouring networks
60- valid_actions .add (Action (ActionType .ScanNetwork , params = {"target_network" : network , "source_host" : src_host ,}))
68+ valid_actions .add (Action (ActionType .ScanNetwork , parameters = {"target_network" : network , "source_host" : src_host ,}))
6169 # Service Scans
6270 for host in state .known_hosts :
63- valid_actions .add (Action (ActionType .FindServices , params = {"target_host" : host , "source_host" : src_host ,}))
71+ if not is_fw_blocked (state , src_host ,host ):
72+ valid_actions .add (Action (ActionType .FindServices , parameters = {"target_host" : host , "source_host" : src_host ,}))
6473 # Service Exploits
6574 for host , service_list in state .known_services .items ():
66- for service in service_list :
67- valid_actions .add (Action (ActionType .ExploitService , params = {"target_host" : host ,"target_service" : service ,"source_host" : src_host ,}))
75+ if not is_fw_blocked (state , src_host ,host ):
76+ for service in service_list :
77+ valid_actions .add (Action (ActionType .ExploitService , parameters = {"target_host" : host ,"target_service" : service ,"source_host" : src_host ,}))
6878 # Data Scans
6979 for host in state .controlled_hosts :
70- valid_actions .add (Action (ActionType .FindData , params = {"target_host" : host , "source_host" : host }))
80+ if not is_fw_blocked (state , src_host ,host ):
81+ valid_actions .add (Action (ActionType .FindData , parameters = {"target_host" : host , "source_host" : host }))
7182
7283 # Data Exfiltration
7384 for src_host , data_list in state .known_data .items ():
7485 for data in data_list :
7586 for trg_host in state .controlled_hosts :
7687 if trg_host != src_host :
77- valid_actions .add (Action (ActionType .ExfiltrateData , params = {"target_host" : trg_host , "source_host" : src_host , "data" : data }))
88+ if not is_fw_blocked (state , src_host ,trg_host ):
89+ valid_actions .add (Action (ActionType .ExfiltrateData , parameters = {"target_host" : trg_host , "source_host" : src_host , "data" : data }))
7890
79- # BlockIP
80- for src_host in state .controlled_hosts :
81- for target_host in state .controlled_hosts :
82- for blocked_ip in state .known_hosts :
83- valid_actions .add (Action (ActionType .BlockIP , {"target_host" :target_host , "source_host" :src_host , "blocked_host" :blocked_ip }))
84-
85-
91+ if include_blocks :
92+ # BlockIP
93+ if include_blocks :
94+ for src_host in state .controlled_hosts :
95+ for target_host in state .controlled_hosts :
96+ if not is_fw_blocked (state , src_host ,target_host ):
97+ for blocked_ip in state .known_hosts :
98+ valid_actions .add (Action (ActionType .BlockIP , {"target_host" :target_host , "source_host" :src_host , "blocked_host" :blocked_ip }))
8699 return list (valid_actions )
87100
88101def state_as_ordered_string (state :GameState )-> str :
@@ -97,6 +110,9 @@ def state_as_ordered_string(state:GameState)->str:
97110 ret += "},data:{"
98111 for host in sorted (state .known_data .keys ()):
99112 ret += f"{ host } :[{ ',' .join ([str (x ) for x in sorted (state .known_data [host ])])} ]"
113+ ret += "},blocks:{"
114+ for host in sorted (state .known_blocks .keys ()):
115+ ret += f"{ host } :[{ ',' .join ([str (x ) for x in sorted (state .known_blocks [host ])])} ]"
100116 ret += "}"
101117 return ret
102118
@@ -480,7 +496,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
480496 if src_host_concept == host_concept :
481497 new_src_host = concept_mapping ['controlled_hosts' ][host_concept ]
482498
483- action = Action (ActionType .ExploitService , params = {"target_host" : new_target_host , "target_service" : new_target_service , "source_host" : new_src_host })
499+ action = Action (ActionType .ExploitService , parameters = {"target_host" : new_target_host , "target_service" : new_target_service , "source_host" : new_src_host })
484500
485501 elif action ._type == ActionType .ExfiltrateData :
486502 # parameters = {
@@ -508,7 +524,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
508524 data_concept = action .parameters ['data' ]
509525 new_data = data_concept
510526
511- action = Action (ActionType .ExfiltrateData , params = {"target_host" : new_target_host , "source_host" : new_src_host , "data" : new_data })
527+ action = Action (ActionType .ExfiltrateData , parameters = {"target_host" : new_target_host , "source_host" : new_src_host , "data" : new_data })
512528
513529 elif action ._type == ActionType .FindData :
514530 # parameters = {
@@ -529,7 +545,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
529545 if src_host_concept == host_concept :
530546 new_src_host = concept_mapping ['controlled_hosts' ][host_concept ]
531547
532- action = Action (ActionType .FindData , params = {"target_host" : new_target_host , "source_host" : new_src_host })
548+ action = Action (ActionType .FindData , parameters = {"target_host" : new_target_host , "source_host" : new_src_host })
533549
534550 elif action ._type == ActionType .ScanNetwork :
535551 target_net_concept = action .parameters ['target_network' ]
@@ -545,7 +561,7 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
545561 for host_concept in concept_mapping ['controlled_hosts' ]:
546562 if src_host_concept == host_concept :
547563 new_src_host = concept_mapping ['controlled_hosts' ][host_concept ]
548- action = Action (ActionType .ScanNetwork , params = {"source_host" : new_src_host , "target_network" : new_target_network } )
564+ action = Action (ActionType .ScanNetwork , parameters = {"source_host" : new_src_host , "target_network" : new_target_network } )
549565
550566 elif action ._type == ActionType .FindServices :
551567 # parameters = {
@@ -565,6 +581,6 @@ def convert_concepts_to_actions(action, concept_mapping, logger):
565581 for host_concept in concept_mapping ['controlled_hosts' ]:
566582 if src_host_concept == host_concept :
567583 new_src_host = concept_mapping ['controlled_hosts' ][host_concept ]
568- action = Action (ActionType .FindServices , params = {"source_host" : new_src_host , "target_host" : new_target_host } )
584+ action = Action (ActionType .FindServices , parameters = {"source_host" : new_src_host , "target_host" : new_target_host } )
569585
570586 return action
0 commit comments