Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env/netsecenv_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ env:
random_seed: 'random'
# Or you can fix the seed
# random_seed: 42
scenario: 'three_nets'
scenario: 'scenario1'
use_global_defender: False
max_steps: 50
use_dynamic_addresses: False
Expand Down
233 changes: 208 additions & 25 deletions utils/gamaplay_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,169 @@
import os
import utils
import argparse
import matplotlib.pyplot as plt

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
from env.game_components import GameState, Action

class TrajectoryGraph:
def __init__(self)->None:
self._checkpoints = {}
self._checkpoint_edges = {}
self._checkpoint_simple_edges = {}
self._wins_per_checkpoint = {}
self._state_to_id = {}
self._id_to_state = {}
self._action_to_id = {}
self._id_to_action = {}

@property
def num_checkpoints(self)->int:
return len(self._checkpoints)

def get_state_id(self, state:GameState)->int:
"""
Returns state id or creates new one if the state was not registered before
"""
state_str = utils.state_as_ordered_string(state)
if state_str not in self._state_to_id.keys():
self._state_to_id[state_str] = len(self._state_to_id)
self._id_to_state[self._state_to_id[state_str]] = state
return self._state_to_id[state_str]

def get_state(self, id:int)->GameState:
return self._id_to_state[id]

def get_action_id(self, action:Action)->int:
"""
Returns action id or creates new one if the state was not registered before
"""
if action not in self._action_to_id.keys():
self._action_to_id[action] = len(self._action_to_id)
self._id_to_action[self._action_to_id[action]] = action
return self._action_to_id[action]

def get_action(self, id:int)-> Action:
return self._id_to_action[id]

def add_checkpoint(self, trajectories:list, end_reason=None)->None:
# Add complete trajectory list
wins = []
edges = {}
simple_edges = {}
for play in trajectories:
if end_reason and play["end_reason"] not in end_reason:
continue
if len(play["trajectory"]["actions"]) == 0:
continue
if play["end_reason"] == "goal_reached":
wins.append(1)
else:
wins.append(0)
state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0]))
#print(f'Trajectory len: {len(play["trajectory"]["actions"])}')
for i in range(1, len(play["trajectory"]["actions"])):
next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i]))
action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i])))
# fullgraph
if (state_id, next_state_id, action_id) not in edges:
edges[state_id, next_state_id, action_id] = 0
edges[state_id, next_state_id, action_id] += 1

#simplified graph
if (state_id, next_state_id)not in simple_edges:
simple_edges[state_id, next_state_id] = 0
simple_edges[state_id, next_state_id] += 1
state_id = next_state_id
self._checkpoint_simple_edges[self.num_checkpoints] = simple_edges
self._checkpoint_edges[self.num_checkpoints] = edges
self._wins_per_checkpoint[self.num_checkpoints] = np.array(wins)
self._checkpoints[self.num_checkpoints] = trajectories

def get_checkpoint_wr(self, checkpoint_id:int)->tuple:
if checkpoint_id not in self._wins_per_checkpoint:
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
else:
return np.mean(self._wins_per_checkpoint[checkpoint_id]), np.std(self._wins_per_checkpoint[checkpoint_id])

def get_wr_progress(self)->dict:
ret = {}
for i in self._wins_per_checkpoint.keys():
wr, std = self.get_checkpoint_wr(i)
ret[i] = {"wr":wr, "std":std}
print(f"Checkpoint {i}: WR={wr}±{std}")
return ret

def get_graph_stats_progress(self):
ret = {}
print("Checkpoint,\tWR,\tEdges,\tSimpleEdges,\tNodes,\tLoops,\tSimpleLoops")
for i in self._wins_per_checkpoint.keys():
data = self.get_checkpoint_stats(i)
ret[i] = data
print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simplified_loops"]}')
return ret

def plot_graph_stats_progress(self, filedir="figures", filename="trajectory_graph_stats.png"):
data = self.get_graph_stats_progress()
wr = [data[i]["winrate"] for i in range(len(data))]
num_nodes = [data[i]["num_nodes"] for i in range(len(data))]
num_edges = [data[i]["num_edges"] for i in range(len(data))]
num_simle_edges = [data[i]["num_simplified_edges"] for i in range(len(data))]
num_loops = [data[i]["num_loops"] for i in range(len(data))]
num_simplified_loops = [data[i]["num_simplified_loops"] for i in range(len(data))]
checkpoints = range(len(wr))
plt.plot(checkpoints, num_nodes, label='Number of nodes')
plt.plot(checkpoints, num_edges, label='Number of edges')
plt.plot(checkpoints, num_simle_edges, label='Number of simplified edges')
plt.plot(checkpoints, num_loops, label='Number of loops')
plt.plot(checkpoints, num_simplified_loops, label='Number of simplified loops')

plt.title("Graph statistics per checkpoint")
plt.yscale('log')
plt.xlabel("Checkpoints")
# Show legend
plt.legend()

# Save the figure as an image file
plt.savefig(os.path.join(filedir, filename))

def get_checkpoint_stats(self, checkpoint_id:int)->dict:
if checkpoint_id not in self._wins_per_checkpoint:
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
else:
data = {}
data["winrate"] = np.mean(self._wins_per_checkpoint[checkpoint_id])
data["winrate_std"] = np.std(self._wins_per_checkpoint[checkpoint_id])
data["num_edges"] = len(self._checkpoint_edges[checkpoint_id])
data["num_simplified_edges"] = len(self._checkpoint_simple_edges[checkpoint_id])
data["num_loops"] = len([edge for edge in self._checkpoint_edges[checkpoint_id].keys() if edge[0]==edge[1]])
data["num_simplified_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]])
node_set = set([src_node for src_node,_,_ in self._checkpoint_edges[checkpoint_id].keys()]) | set([dst_node for _,dst_node,_ in self._checkpoint_edges[checkpoint_id].keys()])
data["num_nodes"] = len(node_set)
return data

def get_graph_structure_progress(self)->dict:

all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
for i, edge_list in self._checkpoint_edges.items():
for edge in edge_list:
super_graph[edge][i] = 1
return super_graph

def get_graph_structure_probabilistic_progress(self)->dict:

all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
for i, edge_list in self._checkpoint_edges.items():
total_out_edges_use = {}
for (src, _, _), frequency in edge_list.items():
if src not in total_out_edges_use:
total_out_edges_use[src] = 0
total_out_edges_use[src] += frequency
for (src,dst,edge), value in edge_list.items():
super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src]
return super_graph

def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple:
edges = {}
Expand Down Expand Up @@ -94,33 +252,58 @@ def get_graph_modificiation(edge_list1, edge_list2):


parser = argparse.ArgumentParser()
parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
# parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
# parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False)
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=1000, required=False)
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False)

args = parser.parse_args()
trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
states = {}
actions = {}
# trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
# trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
# states = {}
# actions = {}

graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
# graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
# graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)

state_to_id = {v:k for k,v in states.items()}
action_to_id = {v:k for k,v in states.items()}

print(f"Trajectory 1: {args.t1}")
print(f"WR={t1_wr_mean}±{t1_wr_std}")
get_graph_stats(graph_t1, state_to_id, action_to_id)
print(f"Trajectory 2: {args.t2}")
print(f"WR={t2_wr_mean}±{t2_wr_std}")
get_graph_stats(graph_t2, state_to_id, action_to_id)

a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
# print("positions of same states:")
# for node in node_set(graph_t1).intersection(node_set(graph_t2)):
# print(g1_timestaps[node], g2_timestaps[node])
# print("-----------------------")
# state_to_id = {v:k for k,v in states.items()}
# action_to_id = {v:k for k,v in states.items()}

# print(f"Trajectory 1: {args.t1}")
# print(f"WR={t1_wr_mean}±{t1_wr_std}")
# get_graph_stats(graph_t1, state_to_id, action_to_id)
# print(f"Trajectory 2: {args.t2}")
# print(f"WR={t2_wr_mean}±{t2_wr_std}")
# get_graph_stats(graph_t2, state_to_id, action_to_id)

# a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
# print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
# # print("positions of same states:")
# # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
# # print(g1_timestaps[node], g2_timestaps[node])
# # print("-----------------------")
# tg_no_blocks = TrajectoryGraph()

# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories))
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories))

# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories))
# tg_no_blocks.plot_graph_stats_progress()

tg_blocks = TrajectoryGraph()
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories))
tg_blocks.plot_graph_stats_progress()

super_graph = tg_blocks.get_graph_structure_probabilistic_progress()
print(len(super_graph))
edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0]
print(len(edges_present_everycheckpoint))