Skip to content

Commit fd9ead4

Browse files
authored
Merge pull request #254 from stratosphereips/trajectory_graph_object
Trajectory graph object
2 parents 79f7e26 + bf755dd commit fd9ead4

File tree

2 files changed

+209
-26
lines changed

2 files changed

+209
-26
lines changed

env/netsecenv_conf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ env:
9797
random_seed: 'random'
9898
# Or you can fix the seed
9999
# random_seed: 42
100-
scenario: 'three_nets'
100+
scenario: 'scenario1'
101101
use_global_defender: False
102102
max_steps: 50
103103
use_dynamic_addresses: False

utils/gamaplay_graphs.py

Lines changed: 208 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,169 @@
44
import os
55
import utils
66
import argparse
7+
import matplotlib.pyplot as plt
78

89
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__) )))
910
from env.game_components import GameState, Action
1011

12+
class TrajectoryGraph:
13+
def __init__(self)->None:
14+
self._checkpoints = {}
15+
self._checkpoint_edges = {}
16+
self._checkpoint_simple_edges = {}
17+
self._wins_per_checkpoint = {}
18+
self._state_to_id = {}
19+
self._id_to_state = {}
20+
self._action_to_id = {}
21+
self._id_to_action = {}
1122

23+
@property
24+
def num_checkpoints(self)->int:
25+
return len(self._checkpoints)
26+
27+
def get_state_id(self, state:GameState)->int:
28+
"""
29+
Returns state id or creates new one if the state was not registered before
30+
"""
31+
state_str = utils.state_as_ordered_string(state)
32+
if state_str not in self._state_to_id.keys():
33+
self._state_to_id[state_str] = len(self._state_to_id)
34+
self._id_to_state[self._state_to_id[state_str]] = state
35+
return self._state_to_id[state_str]
36+
37+
def get_state(self, id:int)->GameState:
38+
return self._id_to_state[id]
39+
40+
def get_action_id(self, action:Action)->int:
41+
"""
42+
Returns action id or creates new one if the state was not registered before
43+
"""
44+
if action not in self._action_to_id.keys():
45+
self._action_to_id[action] = len(self._action_to_id)
46+
self._id_to_action[self._action_to_id[action]] = action
47+
return self._action_to_id[action]
48+
49+
def get_action(self, id:int)-> Action:
50+
return self._id_to_action[id]
51+
52+
def add_checkpoint(self, trajectories:list, end_reason=None)->None:
53+
# Add complete trajectory list
54+
wins = []
55+
edges = {}
56+
simple_edges = {}
57+
for play in trajectories:
58+
if end_reason and play["end_reason"] not in end_reason:
59+
continue
60+
if len(play["trajectory"]["actions"]) == 0:
61+
continue
62+
if play["end_reason"] == "goal_reached":
63+
wins.append(1)
64+
else:
65+
wins.append(0)
66+
state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][0]))
67+
#print(f'Trajectory len: {len(play["trajectory"]["actions"])}')
68+
for i in range(1, len(play["trajectory"]["actions"])):
69+
next_state_id = self.get_state_id(GameState.from_dict(play["trajectory"]["states"][i]))
70+
action_id = self.get_action_id(Action.from_dict((play["trajectory"]["actions"][i])))
71+
# fullgraph
72+
if (state_id, next_state_id, action_id) not in edges:
73+
edges[state_id, next_state_id, action_id] = 0
74+
edges[state_id, next_state_id, action_id] += 1
75+
76+
#simplified graph
77+
if (state_id, next_state_id)not in simple_edges:
78+
simple_edges[state_id, next_state_id] = 0
79+
simple_edges[state_id, next_state_id] += 1
80+
state_id = next_state_id
81+
self._checkpoint_simple_edges[self.num_checkpoints] = simple_edges
82+
self._checkpoint_edges[self.num_checkpoints] = edges
83+
self._wins_per_checkpoint[self.num_checkpoints] = np.array(wins)
84+
self._checkpoints[self.num_checkpoints] = trajectories
85+
86+
def get_checkpoint_wr(self, checkpoint_id:int)->tuple:
87+
if checkpoint_id not in self._wins_per_checkpoint:
88+
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
89+
else:
90+
return np.mean(self._wins_per_checkpoint[checkpoint_id]), np.std(self._wins_per_checkpoint[checkpoint_id])
91+
92+
def get_wr_progress(self)->dict:
93+
ret = {}
94+
for i in self._wins_per_checkpoint.keys():
95+
wr, std = self.get_checkpoint_wr(i)
96+
ret[i] = {"wr":wr, "std":std}
97+
print(f"Checkpoint {i}: WR={wr}±{std}")
98+
return ret
99+
100+
def get_graph_stats_progress(self):
101+
ret = {}
102+
print("Checkpoint,\tWR,\tEdges,\tSimpleEdges,\tNodes,\tLoops,\tSimpleLoops")
103+
for i in self._wins_per_checkpoint.keys():
104+
data = self.get_checkpoint_stats(i)
105+
ret[i] = data
106+
print(f'{i},\t{data["winrate"]},\t{data["num_edges"]},\t{data["num_simplified_edges"]},\t{data["num_nodes"]},\t{data["num_loops"]},\t{data["num_simplified_loops"]}')
107+
return ret
108+
109+
def plot_graph_stats_progress(self, filedir="figures", filename="trajectory_graph_stats.png"):
110+
data = self.get_graph_stats_progress()
111+
wr = [data[i]["winrate"] for i in range(len(data))]
112+
num_nodes = [data[i]["num_nodes"] for i in range(len(data))]
113+
num_edges = [data[i]["num_edges"] for i in range(len(data))]
114+
num_simle_edges = [data[i]["num_simplified_edges"] for i in range(len(data))]
115+
num_loops = [data[i]["num_loops"] for i in range(len(data))]
116+
num_simplified_loops = [data[i]["num_simplified_loops"] for i in range(len(data))]
117+
checkpoints = range(len(wr))
118+
plt.plot(checkpoints, num_nodes, label='Number of nodes')
119+
plt.plot(checkpoints, num_edges, label='Number of edges')
120+
plt.plot(checkpoints, num_simle_edges, label='Number of simplified edges')
121+
plt.plot(checkpoints, num_loops, label='Number of loops')
122+
plt.plot(checkpoints, num_simplified_loops, label='Number of simplified loops')
123+
124+
plt.title("Graph statistics per checkpoint")
125+
plt.yscale('log')
126+
plt.xlabel("Checkpoints")
127+
# Show legend
128+
plt.legend()
129+
130+
# Save the figure as an image file
131+
plt.savefig(os.path.join(filedir, filename))
132+
133+
def get_checkpoint_stats(self, checkpoint_id:int)->dict:
134+
if checkpoint_id not in self._wins_per_checkpoint:
135+
raise IndexError(f"Checkpoint id '{checkpoint_id}' not found!")
136+
else:
137+
data = {}
138+
data["winrate"] = np.mean(self._wins_per_checkpoint[checkpoint_id])
139+
data["winrate_std"] = np.std(self._wins_per_checkpoint[checkpoint_id])
140+
data["num_edges"] = len(self._checkpoint_edges[checkpoint_id])
141+
data["num_simplified_edges"] = len(self._checkpoint_simple_edges[checkpoint_id])
142+
data["num_loops"] = len([edge for edge in self._checkpoint_edges[checkpoint_id].keys() if edge[0]==edge[1]])
143+
data["num_simplified_loops"] = len([edge for edge in self._checkpoint_simple_edges[checkpoint_id].keys() if edge[0]==edge[1]])
144+
node_set = set([src_node for src_node,_,_ in self._checkpoint_edges[checkpoint_id].keys()]) | set([dst_node for _,dst_node,_ in self._checkpoint_edges[checkpoint_id].keys()])
145+
data["num_nodes"] = len(node_set)
146+
return data
147+
148+
def get_graph_structure_progress(self)->dict:
149+
150+
all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
151+
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
152+
for i, edge_list in self._checkpoint_edges.items():
153+
for edge in edge_list:
154+
super_graph[edge][i] = 1
155+
return super_graph
156+
157+
def get_graph_structure_probabilistic_progress(self)->dict:
158+
159+
all_edges = set().union(*(inner_dict.keys() for inner_dict in self._checkpoint_edges.values()))
160+
super_graph = {key:np.zeros(self.num_checkpoints) for key in all_edges}
161+
for i, edge_list in self._checkpoint_edges.items():
162+
total_out_edges_use = {}
163+
for (src, _, _), frequency in edge_list.items():
164+
if src not in total_out_edges_use:
165+
total_out_edges_use[src] = 0
166+
total_out_edges_use[src] += frequency
167+
for (src,dst,edge), value in edge_list.items():
168+
super_graph[(src,dst,edge)][i] = value/total_out_edges_use[src]
169+
return super_graph
12170

13171
def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple:
14172
edges = {}
@@ -94,33 +252,58 @@ def get_graph_modificiation(edge_list1, edge_list2):
94252

95253

96254
parser = argparse.ArgumentParser()
97-
parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
98-
parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
255+
# parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
256+
# parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
99257
parser.add_argument("--end_reason", help="Filter options for trajectories", default=None, type=str, action='store', required=False)
100-
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=1000, required=False)
258+
parser.add_argument("--n_trajectories", help="Limit of how many trajectories to use", action='store', default=10000, required=False)
101259

102260
args = parser.parse_args()
103-
trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
104-
trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
105-
states = {}
106-
actions = {}
261+
# trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
262+
# trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
263+
# states = {}
264+
# actions = {}
107265

108-
graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
109-
graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
266+
# graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
267+
# graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
110268

111-
state_to_id = {v:k for k,v in states.items()}
112-
action_to_id = {v:k for k,v in states.items()}
113-
114-
print(f"Trajectory 1: {args.t1}")
115-
print(f"WR={t1_wr_mean}±{t1_wr_std}")
116-
get_graph_stats(graph_t1, state_to_id, action_to_id)
117-
print(f"Trajectory 2: {args.t2}")
118-
print(f"WR={t2_wr_mean}±{t2_wr_std}")
119-
get_graph_stats(graph_t2, state_to_id, action_to_id)
120-
121-
a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
122-
print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
123-
# print("positions of same states:")
124-
# for node in node_set(graph_t1).intersection(node_set(graph_t2)):
125-
# print(g1_timestaps[node], g2_timestaps[node])
126-
# print("-----------------------")
269+
# state_to_id = {v:k for k,v in states.items()}
270+
# action_to_id = {v:k for k,v in states.items()}
271+
272+
# print(f"Trajectory 1: {args.t1}")
273+
# print(f"WR={t1_wr_mean}±{t1_wr_std}")
274+
# get_graph_stats(graph_t1, state_to_id, action_to_id)
275+
# print(f"Trajectory 2: {args.t2}")
276+
# print(f"WR={t2_wr_mean}±{t2_wr_std}")
277+
# get_graph_stats(graph_t2, state_to_id, action_to_id)
278+
279+
# a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
280+
# print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
281+
# # print("positions of same states:")
282+
# # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
283+
# # print(g1_timestaps[node], g2_timestaps[node])
284+
# # print("-----------------------")
285+
# tg_no_blocks = TrajectoryGraph()
286+
287+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories))
288+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories))
289+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories))
290+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories))
291+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories))
292+
# # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories))
293+
294+
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories))
295+
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories))
296+
# tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories))
297+
# tg_no_blocks.plot_graph_stats_progress()
298+
299+
tg_blocks = TrajectoryGraph()
300+
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl",max_lines=args.n_trajectories))
301+
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl",max_lines=args.n_trajectories))
302+
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl",max_lines=args.n_trajectories))
303+
tg_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl",max_lines=args.n_trajectories))
304+
tg_blocks.plot_graph_stats_progress()
305+
306+
super_graph = tg_blocks.get_graph_structure_probabilistic_progress()
307+
print(len(super_graph))
308+
edges_present_everycheckpoint = [k for k,v in super_graph.items() if np.min(v) > 0]
309+
print(len(edges_present_everycheckpoint))

0 commit comments

Comments
 (0)