@@ -9,7 +9,6 @@ def __init__(self, embedding_dim):
9
9
super (Attention , self ).__init__ ()
10
10
self .embedding_dim = embedding_dim
11
11
12
- # CNN for laser scan
13
12
self .embedding1 = nn .Linear (5 , 128 )
14
13
nn .init .kaiming_uniform_ (self .embedding1 .weight , nonlinearity = "leaky_relu" )
15
14
self .embedding2 = nn .Linear (128 , embedding_dim )
@@ -28,7 +27,7 @@ def __init__(self, embedding_dim):
28
27
self .k = nn .Linear (10 , embedding_dim , bias = False )
29
28
self .v = nn .Linear (10 , embedding_dim )
30
29
31
- # Soft attention score network (with distance )
30
+ # Soft attention score network (with polar other robot goal position )
32
31
self .attn_score_layer = nn .Sequential (
33
32
nn .Linear (embedding_dim * 2 , embedding_dim ),
34
33
nn .ReLU (),
@@ -58,8 +57,8 @@ def forward(self, embedding):
58
57
) # assume (cos(θ), sin(θ))
59
58
action = embedding [:, :, 7 :9 ].reshape (batch_size , n_agents , 2 )
60
59
goal = embedding [:, :, - 2 :].reshape (batch_size , n_agents , 2 )
61
- goal_j = goal .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 ) # (B, N, N, 2)
62
- pos_i = position .unsqueeze (2 ) # (B, N, 1, 2)
60
+ goal_j = goal .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 )
61
+ pos_i = position .unsqueeze (2 )
63
62
goal_rel_vec = goal_j - pos_i
64
63
65
64
agent_embed = self .encode_agent_features (embed )
@@ -100,10 +99,10 @@ def forward(self, embedding):
100
99
action .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 ), # (B, N, N, 2)
101
100
],
102
101
dim = - 1 ,
103
- ) # (B, N, N, 7)
102
+ )
104
103
105
104
# Broadcast h_i along N (for each pair)
106
- h_i_expanded = h_i .expand (- 1 , - 1 , n_agents , - 1 ) # (B, N, N, D)
105
+ h_i_expanded = h_i .expand (- 1 , - 1 , n_agents , - 1 )
107
106
108
107
# Remove self-pairs using mask
109
108
mask = ~ torch .eye (n_agents , dtype = torch .bool , device = embedding .device )
@@ -115,7 +114,7 @@ def forward(self, embedding):
115
114
)
116
115
117
116
# Concatenate agent embedding and edge features
118
- hard_input = torch .cat ([h_i_flat , edge_flat ], dim = - 1 ) # (B*N, N-1, D+7)
117
+ hard_input = torch .cat ([h_i_flat , edge_flat ], dim = - 1 )
119
118
120
119
# Hard attention forward
121
120
h_hard = self .hard_mlp (hard_input )
@@ -125,8 +124,7 @@ def forward(self, embedding):
125
124
].unsqueeze (2 )
126
125
hard_weights = hard_weights .view (batch_size , n_agents , n_agents - 1 )
127
126
128
- unnorm_rel_vec = rel_vec
129
- unnorm_rel_dist = torch .linalg .vector_norm (unnorm_rel_vec , dim = - 1 , keepdim = True )
127
+ unnorm_rel_dist = torch .linalg .vector_norm (rel_vec , dim = - 1 , keepdim = True )
130
128
unnorm_rel_dist = unnorm_rel_dist [:, mask ].reshape (
131
129
batch_size * n_agents , n_agents - 1 , 1
132
130
)
@@ -151,23 +149,21 @@ def forward(self, embedding):
151
149
152
150
soft_edge_features = torch .cat ([edge_features , goal_polar ], dim = - 1 )
153
151
for i in range (n_agents ):
154
- q_i = q [:, i : i + 1 , :] # (B, 1, D)
152
+ q_i = q [:, i : i + 1 , :]
155
153
mask = torch .ones (n_agents , dtype = torch .bool , device = edge_features .device )
156
154
mask [i ] = False
157
155
edge_i_wo_self = soft_edge_features [:, i , mask , :]
158
- edge_i_wo_self = edge_i_wo_self .squeeze (1 ) # (B, N-1, 7)
156
+ edge_i_wo_self = edge_i_wo_self .squeeze (1 )
159
157
k = F .leaky_relu (self .k (edge_i_wo_self ))
160
158
161
- q_i_expanded = q_i .expand (- 1 , n_agents - 1 , - 1 ) # (B, N-1, D)
162
- attention_input = torch .cat ([q_i_expanded , k ], dim = - 1 ) # (B, N-1, D+7)
159
+ q_i_expanded = q_i .expand (- 1 , n_agents - 1 , - 1 )
160
+ attention_input = torch .cat ([q_i_expanded , k ], dim = - 1 )
163
161
164
162
# Score computation
165
- scores = self .attn_score_layer (attention_input ).transpose (
166
- 1 , 2
167
- ) # (B, 1, N-1)
163
+ scores = self .attn_score_layer (attention_input ).transpose (1 , 2 )
168
164
169
165
# Mask using hard weights
170
- h_weights = hard_weights [:, i ].unsqueeze (1 ) # (B, 1, N-1)
166
+ h_weights = hard_weights [:, i ].unsqueeze (1 )
171
167
mask = (h_weights > 0.5 ).float ()
172
168
173
169
# All-zero mask handling
@@ -200,11 +196,8 @@ def forward(self, embedding):
200
196
)
201
197
entropy_list .append (entropy )
202
198
203
- # Project each other agent's features to embedding dim *before* the attention-weighted sum
204
199
v_j = F .leaky_relu (self .v (edge_i_wo_self ))
205
- attn_output = torch .bmm (combined_weights , v_j ).squeeze (
206
- 1
207
- ) # (B, embedding_dim)
200
+ attn_output = torch .bmm (combined_weights , v_j ).squeeze (1 )
208
201
attention_outputs .append (attn_output )
209
202
210
203
comb_w = torch .stack (combined_w , dim = 1 ).reshape (n_agents , - 1 )
0 commit comments