@@ -32,8 +32,8 @@ def __init__(self, embedding_dim):
32
32
33
33
# Soft attention projections
34
34
self .q = nn .Linear (embedding_dim , embedding_dim , bias = False )
35
- self .k = nn .Linear (9 , embedding_dim , bias = False )
36
- self .v = nn .Linear (9 , embedding_dim )
35
+ self .k = nn .Linear (10 , embedding_dim , bias = False )
36
+ self .v = nn .Linear (10 , embedding_dim )
37
37
38
38
# Soft attention score network (with distance)
39
39
self .attn_score_layer = nn .Sequential (
@@ -65,7 +65,7 @@ def forward(self, embedding):
65
65
goal = embedding [:, :, - 2 :].reshape (batch_size , n_agents , 2 )
66
66
goal_j = goal .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 ) # (B, N, N, 2)
67
67
pos_i = position .unsqueeze (2 ) # (B, N, 1, 2)
68
- rel_goal = goal_j - pos_i
68
+ goal_rel_vec = goal_j - pos_i
69
69
70
70
agent_embed = self .encode_agent_features (embed )
71
71
agent_embed = agent_embed .view (batch_size , n_agents , self .embedding_dim )
@@ -80,7 +80,7 @@ def forward(self, embedding):
80
80
# Compute relative vectors and distance
81
81
rel_vec = pos_j - pos_i # (B, N, N, 2)
82
82
dx , dy = rel_vec [..., 0 ], rel_vec [..., 1 ]
83
- rel_dist = torch .linalg .vector_norm (rel_vec , dim = - 1 , keepdim = True ) # (B, N, N, 1)
83
+ rel_dist = torch .linalg .vector_norm (rel_vec , dim = - 1 , keepdim = True )/ 12 # (B, N, N, 1)
84
84
85
85
# Relative angle in agent i's frame
86
86
angle = torch .atan2 (dy , dx ) - torch .atan2 (heading_i [..., 1 ], heading_i [..., 0 ])
@@ -119,7 +119,7 @@ def forward(self, embedding):
119
119
hard_weights = F .gumbel_softmax (hard_logits , hard = False , tau = 0.5 , dim = - 1 )[..., 1 ].unsqueeze (2 )
120
120
hard_weights = hard_weights .view (batch_size , n_agents , n_agents - 1 )
121
121
122
- unnorm_rel_vec = rel_vec * 12
122
+ unnorm_rel_vec = rel_vec
123
123
unnorm_rel_dist = torch .linalg .vector_norm (unnorm_rel_vec , dim = - 1 , keepdim = True )
124
124
unnorm_rel_dist = unnorm_rel_dist [:, mask ].reshape (batch_size * n_agents , n_agents - 1 , 1 )
125
125
@@ -129,7 +129,19 @@ def forward(self, embedding):
129
129
attention_outputs = []
130
130
entropy_list = []
131
131
combined_w = []
132
- soft_edge_features = torch .cat ([edge_features , rel_goal ], dim = - 1 )
132
+
133
+ goal_rel_dist = torch .linalg .vector_norm (goal_rel_vec , dim = - 1 , keepdim = True )
134
+ goal_angle_global = torch .atan2 (goal_rel_vec [..., 1 ], goal_rel_vec [..., 0 ])
135
+ heading_angle = torch .atan2 (heading_i [..., 1 ], heading_i [..., 0 ])
136
+ goal_rel_angle = goal_angle_global - heading_angle
137
+ goal_rel_angle = (goal_rel_angle + np .pi ) % (2 * np .pi ) - np .pi
138
+ goal_rel_angle_cos = torch .cos (goal_rel_angle ).unsqueeze (- 1 )
139
+ goal_rel_angle_sin = torch .sin (goal_rel_angle ).unsqueeze (- 1 )
140
+ goal_polar = torch .cat ([goal_rel_dist , goal_rel_angle_cos , goal_rel_angle_sin ], dim = - 1 )
141
+
142
+
143
+
144
+ soft_edge_features = torch .cat ([edge_features , goal_polar ], dim = - 1 )
133
145
for i in range (n_agents ):
134
146
q_i = q [:, i :i + 1 , :] # (B, 1, D)
135
147
mask = torch .ones (n_agents , dtype = torch .bool , device = edge_features .device )
@@ -591,18 +603,6 @@ def prepare_state(self, poses, distance, cos, sin, collision, goal, action, posi
591
603
px , py , theta = pose
592
604
gx , gy = goal_pos
593
605
594
- # Global position (keep for boundary awareness)
595
- x = px / 12
596
- y = py / 12
597
-
598
- # Relative goal position in local frame
599
- dx = gx - px
600
- dy = gy - py
601
- rel_gx = dx * np .cos (theta ) + dy * np .sin (theta )
602
- rel_gy = - dx * np .sin (theta ) + dy * np .cos (theta )
603
- rel_gx /= 12
604
- rel_gy /= 12
605
-
606
606
# Heading as cos/sin
607
607
heading_cos = np .cos (theta )
608
608
heading_sin = np .sin (theta )
@@ -612,7 +612,7 @@ def prepare_state(self, poses, distance, cos, sin, collision, goal, action, posi
612
612
ang_vel = (act [1 ] + 1 ) / 2 # Assuming original range [-1, 1]
613
613
614
614
# Final state vector
615
- state = [x , y , heading_cos , heading_sin , distance [i ]/ 17 , cos [i ], sin [i ], lin_vel , ang_vel , gx , gy ]
615
+ state = [px , py , heading_cos , heading_sin , distance [i ]/ 17 , cos [i ], sin [i ], lin_vel , ang_vel , gx , gy ]
616
616
617
617
assert len (state ) == self .state_dim , f"State length mismatch: expected { self .state_dim } , got { len (state )} "
618
618
states .append (state )
0 commit comments