Skip to content

Commit 62ded8e

Browse files
committed
working setup with k and v
1 parent bc5e0ae commit 62ded8e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

robot_nav/models/CNNTD3/att.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def __init__(self, embedding_dim):
3232

3333
# Soft attention projections
3434
self.q = nn.Linear(embedding_dim, embedding_dim, bias=False)
35-
self.k = nn.Linear(embedding_dim, embedding_dim, bias=False)
36-
self.v = nn.Linear(embedding_dim, embedding_dim)
35+
self.k = nn.Linear(7, embedding_dim, bias=False)
36+
self.v = nn.Linear(7, embedding_dim)
3737

3838
# Soft attention score network (with distance)
3939
self.attn_score_layer = nn.Sequential(
40-
nn.Linear(embedding_dim + 7, embedding_dim),
40+
nn.Linear(embedding_dim *2, embedding_dim),
4141
nn.ReLU(),
4242
nn.Linear(embedding_dim, 1)
4343
)
@@ -132,9 +132,10 @@ def forward(self, embedding):
132132
mask[i] = False
133133
edge_i_wo_self = edge_features[:, i, mask, :]
134134
edge_i_wo_self = edge_i_wo_self.squeeze(1) # (B, N-1, 7)
135+
k = F.leaky_relu(self.k(edge_i_wo_self))
135136

136137
q_i_expanded = q_i.expand(-1, n_agents - 1, -1) # (B, N-1, D)
137-
attention_input = torch.cat([q_i_expanded, edge_i_wo_self], dim=-1) # (B, N-1, D+7)
138+
attention_input = torch.cat([q_i_expanded, k], dim=-1) # (B, N-1, D+7)
138139

139140
# Score computation
140141
scores = self.attn_score_layer(attention_input).transpose(1, 2) # (B, 1, N-1)
@@ -166,7 +167,8 @@ def forward(self, embedding):
166167
entropy_list.append(entropy)
167168

168169
# Project each other agent's features to embedding dim *before* the attention-weighted sum
169-
v_j = self.v_proj(edge_i_wo_self) # (B, N-1, embedding_dim)
170+
# v_j = self.v_proj(edge_i_wo_self) # (B, N-1, embedding_dim)
171+
v_j = F.leaky_relu(self.v(edge_i_wo_self))
170172
attn_output = torch.bmm(combined_weights, v_j).squeeze(1) # (B, embedding_dim)
171173
attention_outputs.append(attn_output)
172174

0 commit comments

Comments
 (0)