44import os
55import utils
66import argparse
7+ import matplotlib .pyplot as plt
78
89sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ) )))
910from env .game_components import GameState , Action
1011
12+ class TrajectoryGraph :
13+ def __init__ (self )-> None :
14+ self ._checkpoints = {}
15+ self ._checkpoint_edges = {}
16+ self ._checkpoint_simple_edges = {}
17+ self ._wins_per_checkpoint = {}
18+ self ._state_to_id = {}
19+ self ._id_to_state = {}
20+ self ._action_to_id = {}
21+ self ._id_to_action = {}
1122
23+ @property
24+ def num_checkpoints (self )-> int :
25+ return len (self ._checkpoints )
26+
27+ def get_state_id (self , state :GameState )-> int :
28+ """
29+ Returns state id or creates new one if the state was not registered before
30+ """
31+ state_str = utils .state_as_ordered_string (state )
32+ if state_str not in self ._state_to_id .keys ():
33+ self ._state_to_id [state_str ] = len (self ._state_to_id )
34+ self ._id_to_state [self ._state_to_id [state_str ]] = state
35+ return self ._state_to_id [state_str ]
36+
37+ def get_state (self , id :int )-> GameState :
38+ return self ._id_to_state [id ]
39+
40+ def get_action_id (self , action :Action )-> int :
41+ """
42+ Returns action id or creates new one if the state was not registered before
43+ """
44+ if action not in self ._action_to_id .keys ():
45+ self ._action_to_id [action ] = len (self ._action_to_id )
46+ self ._id_to_action [self ._action_to_id [action ]] = action
47+ return self ._action_to_id [action ]
48+
49+ def get_action (self , id :int )-> Action :
50+ return self ._id_to_action [id ]
51+
52+ def add_checkpoint (self , trajectories :list , end_reason = None )-> None :
53+ # Add complete trajectory list
54+ wins = []
55+ edges = {}
56+ simple_edges = {}
57+ for play in trajectories :
58+ if end_reason and play ["end_reason" ] not in end_reason :
59+ continue
60+ if len (play ["trajectory" ]["actions" ]) == 0 :
61+ continue
62+ if play ["end_reason" ] == "goal_reached" :
63+ wins .append (1 )
64+ else :
65+ wins .append (0 )
66+ state_id = self .get_state_id (GameState .from_dict (play ["trajectory" ]["states" ][0 ]))
67+ #print(f'Trajectory len: {len(play["trajectory"]["actions"])}')
68+ for i in range (1 , len (play ["trajectory" ]["actions" ])):
69+ next_state_id = self .get_state_id (GameState .from_dict (play ["trajectory" ]["states" ][i ]))
70+ action_id = self .get_action_id (Action .from_dict ((play ["trajectory" ]["actions" ][i ])))
71+ # fullgraph
72+ if (state_id , next_state_id , action_id ) not in edges :
73+ edges [state_id , next_state_id , action_id ] = 0
74+ edges [state_id , next_state_id , action_id ] += 1
75+
76+ #simplified graph
77+ if (state_id , next_state_id )not in simple_edges :
78+ simple_edges [state_id , next_state_id ] = 0
79+ simple_edges [state_id , next_state_id ] += 1
80+ state_id = next_state_id
81+ self ._checkpoint_simple_edges [self .num_checkpoints ] = simple_edges
82+ self ._checkpoint_edges [self .num_checkpoints ] = edges
83+ self ._wins_per_checkpoint [self .num_checkpoints ] = np .array (wins )
84+ self ._checkpoints [self .num_checkpoints ] = trajectories
85+
86+ def get_checkpoint_wr (self , checkpoint_id :int )-> tuple :
87+ if checkpoint_id not in self ._wins_per_checkpoint :
88+ raise IndexError (f"Checkpoint id '{ checkpoint_id } ' not found!" )
89+ else :
90+ return np .mean (self ._wins_per_checkpoint [checkpoint_id ]), np .std (self ._wins_per_checkpoint [checkpoint_id ])
91+
92+ def get_wr_progress (self )-> dict :
93+ ret = {}
94+ for i in self ._wins_per_checkpoint .keys ():
95+ wr , std = self .get_checkpoint_wr (i )
96+ ret [i ] = {"wr" :wr , "std" :std }
97+ print (f"Checkpoint { i } : WR={ wr } ±{ std } " )
98+ return ret
99+
100+ def get_graph_stats_progress (self ):
101+ ret = {}
102+ print ("Checkpoint,\t WR,\t Edges,\t SimpleEdges,\t Nodes,\t Loops,\t SimpleLoops" )
103+ for i in self ._wins_per_checkpoint .keys ():
104+ data = self .get_checkpoint_stats (i )
105+ ret [i ] = data
106+ print (f'{ i } ,\t { data ["winrate" ]} ,\t { data ["num_edges" ]} ,\t { data ["num_simplified_edges" ]} ,\t { data ["num_nodes" ]} ,\t { data ["num_loops" ]} ,\t { data ["num_simplified_loops" ]} ' )
107+ return ret
108+
109+ def plot_graph_stats_progress (self , filedir = "figures" , filename = "trajectory_graph_stats.png" ):
110+ data = self .get_graph_stats_progress ()
111+ wr = [data [i ]["winrate" ] for i in range (len (data ))]
112+ num_nodes = [data [i ]["num_nodes" ] for i in range (len (data ))]
113+ num_edges = [data [i ]["num_edges" ] for i in range (len (data ))]
114+ num_simle_edges = [data [i ]["num_simplified_edges" ] for i in range (len (data ))]
115+ num_loops = [data [i ]["num_loops" ] for i in range (len (data ))]
116+ num_simplified_loops = [data [i ]["num_simplified_loops" ] for i in range (len (data ))]
117+ checkpoints = range (len (wr ))
118+ plt .plot (checkpoints , num_nodes , label = 'Number of nodes' )
119+ plt .plot (checkpoints , num_edges , label = 'Number of edges' )
120+ plt .plot (checkpoints , num_simle_edges , label = 'Number of simplified edges' )
121+ plt .plot (checkpoints , num_loops , label = 'Number of loops' )
122+ plt .plot (checkpoints , num_simplified_loops , label = 'Number of simplified loops' )
123+
124+ plt .title ("Graph statistics per checkpoint" )
125+ plt .yscale ('log' )
126+ plt .xlabel ("Checkpoints" )
127+ # Show legend
128+ plt .legend ()
129+
130+ # Save the figure as an image file
131+ plt .savefig (os .path .join (filedir , filename ))
132+
133+ def get_checkpoint_stats (self , checkpoint_id :int )-> dict :
134+ if checkpoint_id not in self ._wins_per_checkpoint :
135+ raise IndexError (f"Checkpoint id '{ checkpoint_id } ' not found!" )
136+ else :
137+ data = {}
138+ data ["winrate" ] = np .mean (self ._wins_per_checkpoint [checkpoint_id ])
139+ data ["winrate_std" ] = np .std (self ._wins_per_checkpoint [checkpoint_id ])
140+ data ["num_edges" ] = len (self ._checkpoint_edges [checkpoint_id ])
141+ data ["num_simplified_edges" ] = len (self ._checkpoint_simple_edges [checkpoint_id ])
142+ data ["num_loops" ] = len ([edge for edge in self ._checkpoint_edges [checkpoint_id ].keys () if edge [0 ]== edge [1 ]])
143+ data ["num_simplified_loops" ] = len ([edge for edge in self ._checkpoint_simple_edges [checkpoint_id ].keys () if edge [0 ]== edge [1 ]])
144+ node_set = set ([src_node for src_node ,_ ,_ in self ._checkpoint_edges [checkpoint_id ].keys ()]) | set ([dst_node for _ ,dst_node ,_ in self ._checkpoint_edges [checkpoint_id ].keys ()])
145+ data ["num_nodes" ] = len (node_set )
146+ return data
147+
148+ def get_graph_structure_progress (self )-> dict :
149+
150+ all_edges = set ().union (* (inner_dict .keys () for inner_dict in self ._checkpoint_edges .values ()))
151+ super_graph = {key :np .zeros (self .num_checkpoints ) for key in all_edges }
152+ for i , edge_list in self ._checkpoint_edges .items ():
153+ for edge in edge_list :
154+ super_graph [edge ][i ] = 1
155+ return super_graph
156+
157+ def get_graph_structure_probabilistic_progress (self )-> dict :
158+
159+ all_edges = set ().union (* (inner_dict .keys () for inner_dict in self ._checkpoint_edges .values ()))
160+ super_graph = {key :np .zeros (self .num_checkpoints ) for key in all_edges }
161+ for i , edge_list in self ._checkpoint_edges .items ():
162+ total_out_edges_use = {}
163+ for (src , _ , _ ), frequency in edge_list .items ():
164+ if src not in total_out_edges_use :
165+ total_out_edges_use [src ] = 0
166+ total_out_edges_use [src ] += frequency
167+ for (src ,dst ,edge ), value in edge_list .items ():
168+ super_graph [(src ,dst ,edge )][i ] = value / total_out_edges_use [src ]
169+ return super_graph
12170
13171def gameplay_graph (game_plays :list , states , actions , end_reason = None )-> tuple :
14172 edges = {}
@@ -94,33 +252,58 @@ def get_graph_modificiation(edge_list1, edge_list2):
94252
95253
96254 parser = argparse .ArgumentParser ()
97- parser .add_argument ("--t1" , help = "Trajectory file #1" , action = 'store' , required = True )
98- parser .add_argument ("--t2" , help = "Trajectory file #2" , action = 'store' , required = True )
255+ # parser.add_argument("--t1", help="Trajectory file #1", action='store', required=True)
256+ # parser.add_argument("--t2", help="Trajectory file #2", action='store', required=True)
99257 parser .add_argument ("--end_reason" , help = "Filter options for trajectories" , default = None , type = str , action = 'store' , required = False )
100- parser .add_argument ("--n_trajectories" , help = "Limit of how many trajectories to use" , action = 'store' , default = 1000 , required = False )
258+ parser .add_argument ("--n_trajectories" , help = "Limit of how many trajectories to use" , action = 'store' , default = 10000 , required = False )
101259
102260 args = parser .parse_args ()
103- trajectories1 = read_json (args .t1 , max_lines = args .n_trajectories )
104- trajectories2 = read_json (args .t2 , max_lines = args .n_trajectories )
105- states = {}
106- actions = {}
261+ # trajectories1 = read_json(args.t1, max_lines=args.n_trajectories)
262+ # trajectories2 = read_json(args.t2, max_lines=args.n_trajectories)
263+ # states = {}
264+ # actions = {}
107265
108- graph_t1 , g1_timestaps , t1_wr_mean , t1_wr_std = gameplay_graph (trajectories1 , states , actions ,end_reason = args .end_reason )
109- graph_t2 , g2_timestaps , t2_wr_mean , t2_wr_std = gameplay_graph (trajectories2 , states , actions ,end_reason = args .end_reason )
266+ # graph_t1, g1_timestaps, t1_wr_mean, t1_wr_std = gameplay_graph(trajectories1, states, actions,end_reason=args.end_reason)
267+ # graph_t2, g2_timestaps, t2_wr_mean, t2_wr_std = gameplay_graph(trajectories2, states, actions,end_reason=args.end_reason)
110268
111- state_to_id = {v :k for k ,v in states .items ()}
112- action_to_id = {v :k for k ,v in states .items ()}
113-
114- print (f"Trajectory 1: { args .t1 } " )
115- print (f"WR={ t1_wr_mean } ±{ t1_wr_std } " )
116- get_graph_stats (graph_t1 , state_to_id , action_to_id )
117- print (f"Trajectory 2: { args .t2 } " )
118- print (f"WR={ t2_wr_mean } ±{ t2_wr_std } " )
119- get_graph_stats (graph_t2 , state_to_id , action_to_id )
120-
121- a_edges , d_edges , a_nodes , d_nodes = get_graph_modificiation (graph_t1 , graph_t2 )
122- print (f"AE:{ len (a_edges )} ,DE:{ len (d_edges )} , AN:{ len (a_nodes )} ,DN:{ len (d_nodes )} " )
123- # print("positions of same states:")
124- # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
125- # print(g1_timestaps[node], g2_timestaps[node])
126- # print("-----------------------")
269+ # state_to_id = {v:k for k,v in states.items()}
270+ # action_to_id = {v:k for k,v in states.items()}
271+
272+ # print(f"Trajectory 1: {args.t1}")
273+ # print(f"WR={t1_wr_mean}±{t1_wr_std}")
274+ # get_graph_stats(graph_t1, state_to_id, action_to_id)
275+ # print(f"Trajectory 2: {args.t2}")
276+ # print(f"WR={t2_wr_mean}±{t2_wr_std}")
277+ # get_graph_stats(graph_t2, state_to_id, action_to_id)
278+
279+ # a_edges, d_edges, a_nodes, d_nodes = get_graph_modificiation(graph_t1, graph_t2)
280+ # print(f"AE:{len(a_edges)},DE:{len(d_edges)}, AN:{len(a_nodes)},DN:{len(d_nodes)}")
281+ # # print("positions of same states:")
282+ # # for node in node_set(graph_t1).intersection(node_set(graph_t2)):
283+ # # print(g1_timestaps[node], g2_timestaps[node])
284+ # # print("-----------------------")
285+ # tg_no_blocks = TrajectoryGraph()
286+
287+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-2000.jsonl",max_lines=args.n_trajectories))
288+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-4000.jsonl",max_lines=args.n_trajectories))
289+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-6000.jsonl",max_lines=args.n_trajectories))
290+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-8000.jsonl",max_lines=args.n_trajectories))
291+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-10000.jsonl",max_lines=args.n_trajectories))
292+ # # tg.add_checkpoint(read_json("./trajectories/experiment0002/2024-08-02_QAgent_Attacker_experiment0002-episodes-12000.jsonl",max_lines=args.n_trajectories))
293+
294+ # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_no_blocks.jsonl",max_lines=args.n_trajectories))
295+ # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_no_blocks.jsonl",max_lines=args.n_trajectories))
296+ # tg_no_blocks.add_checkpoint(read_json("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_no_blocks.jsonl",max_lines=args.n_trajectories))
297+ # tg_no_blocks.plot_graph_stats_progress()
298+
299+ tg_blocks = TrajectoryGraph ()
300+ tg_blocks .add_checkpoint (read_json ("./trajectories/2024-11-12_QAgent_Attacker-episodes-5000_blocks.jsonl" ,max_lines = args .n_trajectories ))
301+ tg_blocks .add_checkpoint (read_json ("./trajectories/2024-11-12_QAgent_Attacker-episodes-10000_blocks.jsonl" ,max_lines = args .n_trajectories ))
302+ tg_blocks .add_checkpoint (read_json ("./trajectories/2024-11-12_QAgent_Attacker-episodes-15000_blocks.jsonl" ,max_lines = args .n_trajectories ))
303+ tg_blocks .add_checkpoint (read_json ("./trajectories/2024-11-12_QAgent_Attacker-episodes-20000_blocks.jsonl" ,max_lines = args .n_trajectories ))
304+ tg_blocks .plot_graph_stats_progress ()
305+
306+ super_graph = tg_blocks .get_graph_structure_probabilistic_progress ()
307+ print (len (super_graph ))
308+ edges_present_everycheckpoint = [k for k ,v in super_graph .items () if np .min (v ) > 0 ]
309+ print (len (edges_present_everycheckpoint ))
0 commit comments