5
5
6
6
7
7
class Attention (nn .Module ):
8
+ """
9
+ Multi-robot attention mechanism for learning hard and soft attentions.
10
+
11
+ This module provides both hard (binary) and soft (weighted) attention,
12
+ combining feature encoding, relative pose and goal geometry, and
13
+ message passing between agents.
14
+
15
+ Args:
16
+ embedding_dim (int): Dimension of the agent embedding vector.
17
+
18
+ Attributes:
19
+ embedding1 (nn.Linear): First layer for agent feature encoding.
20
+ embedding2 (nn.Linear): Second layer for agent feature encoding.
21
+ hard_mlp (nn.Sequential): MLP to process concatenated agent and edge features.
22
+ hard_encoding (nn.Linear): Outputs logits for hard (binary) attention.
23
+ q, k, v (nn.Linear): Layers for query, key, value projections for soft attention.
24
+ attn_score_layer (nn.Sequential): Computes unnormalized attention scores for each pair.
25
+ decode_1, decode_2 (nn.Linear): Decoding layers to produce the final attended embedding.
26
+ """
27
+
8
28
def __init__ (self , embedding_dim ):
29
+ """
30
+ Initialize attention mechanism for multi-agent communication.
31
+
32
+ Args:
33
+ embedding_dim (int): Output embedding dimension per agent.
34
+ """
9
35
super (Attention , self ).__init__ ()
10
36
self .embedding_dim = embedding_dim
11
37
@@ -41,30 +67,59 @@ def __init__(self, embedding_dim):
41
67
nn .init .kaiming_uniform_ (self .decode_2 .weight , nonlinearity = "leaky_relu" )
42
68
43
69
def encode_agent_features (self , embed ):
70
+ """
71
+ Encode agent features using a small MLP.
72
+
73
+ Args:
74
+ embed (Tensor): Input features (B*N, 5).
75
+
76
+ Returns:
77
+ Tensor: Encoded embedding (B*N, embedding_dim).
78
+ """
44
79
embed = F .leaky_relu (self .embedding1 (embed ))
45
80
embed = F .leaky_relu (self .embedding2 (embed ))
46
81
return embed
47
82
48
83
def forward (self , embedding ):
84
+ """
85
+ Forward pass: computes both hard and soft attentions among agents,
86
+ produces the attended embedding for each agent, as well as diagnostic info.
87
+
88
+ Args:
89
+ embedding (Tensor): Input tensor of shape (B, N, D), where D is at least 11.
90
+
91
+ Returns:
92
+ tuple:
93
+ att_embedding (Tensor): Final attended embedding, shape (B*N, 2*embedding_dim).
94
+ hard_logits (Tensor): Logits for hard attention, (B*N, N-1).
95
+ unnorm_rel_dist (Tensor): Pairwise distances between agents (not normalized), (B*N, N-1, 1).
96
+ mean_entropy (Tensor): Mean entropy of soft attention distributions.
97
+ hard_weights (Tensor): Binary hard attention mask, (B, N, N-1).
98
+ comb_w (Tensor): Final combined attention weights, (N, N*(N-1)).
99
+ """
49
100
if embedding .dim () == 2 :
50
101
embedding = embedding .unsqueeze (0 )
51
102
batch_size , n_agents , _ = embedding .shape
52
103
104
+ # Extract sub-features
53
105
embed = embedding [:, :, 4 :9 ].reshape (batch_size * n_agents , - 1 )
54
106
position = embedding [:, :, :2 ].reshape (batch_size , n_agents , 2 )
55
107
heading = embedding [:, :, 2 :4 ].reshape (
56
108
batch_size , n_agents , 2
57
109
) # assume (cos(θ), sin(θ))
58
110
action = embedding [:, :, 7 :9 ].reshape (batch_size , n_agents , 2 )
59
111
goal = embedding [:, :, - 2 :].reshape (batch_size , n_agents , 2 )
112
+
113
+ # Compute pairwise relative goal vectors (for each i,j)
60
114
goal_j = goal .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 )
61
115
pos_i = position .unsqueeze (2 )
62
116
goal_rel_vec = goal_j - pos_i
63
117
118
+ # Encode agent features
64
119
agent_embed = self .encode_agent_features (embed )
65
120
agent_embed = agent_embed .view (batch_size , n_agents , self .embedding_dim )
66
121
67
- # For hard attention
122
+ # Prep for hard attention: compute all relative geometry for each agent pair
68
123
h_i = agent_embed .unsqueeze (2 ) # (B, N, 1, D)
69
124
pos_i = position .unsqueeze (2 ) # (B, N, 1, 2)
70
125
pos_j = position .unsqueeze (1 ) # (B, 1, N, 2)
@@ -88,7 +143,7 @@ def forward(self, embedding):
88
143
heading_j_cos = heading_j [..., 0 ] # (B, 1, N)
89
144
heading_j_sin = heading_j [..., 1 ] # (B, 1, N)
90
145
91
- # Stack edge features
146
+ # Edge features for hard attention
92
147
edge_features = torch .cat (
93
148
[
94
149
rel_dist , # (B, N, N, 1)
@@ -101,7 +156,7 @@ def forward(self, embedding):
101
156
dim = - 1 ,
102
157
)
103
158
104
- # Broadcast h_i along N ( for each pair )
159
+ # Broadcast agent embedding for all pairs (except self-pairs )
105
160
h_i_expanded = h_i .expand (- 1 , - 1 , n_agents , - 1 )
106
161
107
162
# Remove self-pairs using mask
@@ -129,13 +184,14 @@ def forward(self, embedding):
129
184
batch_size * n_agents , n_agents - 1 , 1
130
185
)
131
186
132
- # Soft attention
187
+ # ---- Soft attention computation ----
133
188
q = self .q (agent_embed )
134
189
135
190
attention_outputs = []
136
191
entropy_list = []
137
192
combined_w = []
138
193
194
+ # Goal-relative polar features for soft attention
139
195
goal_rel_dist = torch .linalg .vector_norm (goal_rel_vec , dim = - 1 , keepdim = True )
140
196
goal_angle_global = torch .atan2 (goal_rel_vec [..., 1 ], goal_rel_vec [..., 0 ])
141
197
heading_angle = torch .atan2 (heading_i [..., 1 ], heading_i [..., 0 ])
@@ -147,6 +203,7 @@ def forward(self, embedding):
147
203
[goal_rel_dist , goal_rel_angle_cos , goal_rel_angle_sin ], dim = - 1
148
204
)
149
205
206
+ # Soft attention edge features (include goal polar)
150
207
soft_edge_features = torch .cat ([edge_features , goal_polar ], dim = - 1 )
151
208
for i in range (n_agents ):
152
209
q_i = q [:, i : i + 1 , :]
@@ -159,10 +216,10 @@ def forward(self, embedding):
159
216
q_i_expanded = q_i .expand (- 1 , n_agents - 1 , - 1 )
160
217
attention_input = torch .cat ([q_i_expanded , k ], dim = - 1 )
161
218
162
- # Score computation
219
+ # Score computation (per pair)
163
220
scores = self .attn_score_layer (attention_input ).transpose (1 , 2 )
164
221
165
- # Mask using hard weights
222
+ # Mask using hard attention
166
223
h_weights = hard_weights [:, i ].unsqueeze (1 )
167
224
mask = (h_weights > 0.5 ).float ()
168
225
@@ -183,12 +240,12 @@ def forward(self, embedding):
183
240
combined_weights = soft_weights * mask # (B, 1, N-1)
184
241
combined_w .append (combined_weights )
185
242
186
- # Normalize combined_weights for entropy calculation
243
+ # Normalize for entropy calculation
187
244
combined_weights_norm = combined_weights / (
188
245
combined_weights .sum (dim = - 1 , keepdim = True ) + epsilon
189
246
)
190
247
191
- # Calculate entropy from combined_weights
248
+ # Entropy for analysis/logging
192
249
entropy = (
193
250
- (combined_weights_norm * (combined_weights_norm + epsilon ).log ())
194
251
.sum (dim = - 1 )
0 commit comments