@@ -32,12 +32,12 @@ def __init__(self, embedding_dim):
32
32
33
33
# Soft attention projections
34
34
self .q = nn .Linear (embedding_dim , embedding_dim , bias = False )
35
- self .k = nn .Linear (embedding_dim , embedding_dim , bias = False )
36
- self .v = nn .Linear (embedding_dim , embedding_dim )
35
+ self .k = nn .Linear (7 , embedding_dim , bias = False )
36
+ self .v = nn .Linear (7 , embedding_dim )
37
37
38
38
# Soft attention score network (with distance)
39
39
self .attn_score_layer = nn .Sequential (
40
- nn .Linear (embedding_dim + 7 , embedding_dim ),
40
+ nn .Linear (embedding_dim * 2 , embedding_dim ),
41
41
nn .ReLU (),
42
42
nn .Linear (embedding_dim , 1 )
43
43
)
@@ -132,9 +132,10 @@ def forward(self, embedding):
132
132
mask [i ] = False
133
133
edge_i_wo_self = edge_features [:, i , mask , :]
134
134
edge_i_wo_self = edge_i_wo_self .squeeze (1 ) # (B, N-1, 7)
135
+ k = F .leaky_relu (self .k (edge_i_wo_self ))
135
136
136
137
q_i_expanded = q_i .expand (- 1 , n_agents - 1 , - 1 ) # (B, N-1, D)
137
- attention_input = torch .cat ([q_i_expanded , edge_i_wo_self ], dim = - 1 ) # (B, N-1, D+7)
138
+ attention_input = torch .cat ([q_i_expanded , k ], dim = - 1 ) # (B, N-1, D+7)
138
139
139
140
# Score computation
140
141
scores = self .attn_score_layer (attention_input ).transpose (1 , 2 ) # (B, 1, N-1)
@@ -166,7 +167,8 @@ def forward(self, embedding):
166
167
entropy_list .append (entropy )
167
168
168
169
# Project each other agent's features to embedding dim *before* the attention-weighted sum
169
- v_j = self .v_proj (edge_i_wo_self ) # (B, N-1, embedding_dim)
170
+ # v_j = self.v_proj(edge_i_wo_self) # (B, N-1, embedding_dim)
171
+ v_j = F .leaky_relu (self .v (edge_i_wo_self ))
170
172
attn_output = torch .bmm (combined_weights , v_j ).squeeze (1 ) # (B, embedding_dim)
171
173
attention_outputs .append (attn_output )
172
174
0 commit comments