@@ -32,8 +32,8 @@ def __init__(self, embedding_dim):
32
32
33
33
# Soft attention projections
34
34
self .q = nn .Linear (embedding_dim , embedding_dim , bias = False )
35
- self .k = nn .Linear (7 , embedding_dim , bias = False )
36
- self .v = nn .Linear (7 , embedding_dim )
35
+ self .k = nn .Linear (9 , embedding_dim , bias = False )
36
+ self .v = nn .Linear (9 , embedding_dim )
37
37
38
38
# Soft attention score network (with distance)
39
39
self .attn_score_layer = nn .Sequential (
@@ -42,7 +42,6 @@ def __init__(self, embedding_dim):
42
42
nn .Linear (embedding_dim , 1 )
43
43
)
44
44
45
- self .v_proj = nn .Linear (7 , embedding_dim )
46
45
# Decoder
47
46
self .decode_1 = nn .Linear (embedding_dim * 2 , embedding_dim * 2 )
48
47
nn .init .kaiming_uniform_ (self .decode_1 .weight , nonlinearity = "leaky_relu" )
@@ -59,10 +58,14 @@ def forward(self, embedding):
59
58
embedding = embedding .unsqueeze (0 )
60
59
batch_size , n_agents , _ = embedding .shape
61
60
62
- embed = embedding [:, :, 4 :].reshape (batch_size * n_agents , - 1 )
61
+ embed = embedding [:, :, 4 :9 ].reshape (batch_size * n_agents , - 1 )
63
62
position = embedding [:, :, :2 ].reshape (batch_size , n_agents , 2 )
64
63
heading = embedding [:, :, 2 :4 ].reshape (batch_size , n_agents , 2 ) # assume (cos(θ), sin(θ))
65
- action = embedding [:, :, - 2 :].reshape (batch_size , n_agents , 2 )
64
+ action = embedding [:, :, 7 :9 ].reshape (batch_size , n_agents , 2 )
65
+ goal = embedding [:, :, - 2 :].reshape (batch_size , n_agents , 2 )
66
+ goal_j = goal .unsqueeze (1 ).expand (- 1 , n_agents , - 1 , - 1 ) # (B, N, N, 2)
67
+ pos_i = position .unsqueeze (2 ) # (B, N, 1, 2)
68
+ rel_goal = goal_j - pos_i
66
69
67
70
agent_embed = self .encode_agent_features (embed )
68
71
agent_embed = agent_embed .view (batch_size , n_agents , self .embedding_dim )
@@ -126,11 +129,12 @@ def forward(self, embedding):
126
129
attention_outputs = []
127
130
entropy_list = []
128
131
combined_w = []
132
+ soft_edge_features = torch .cat ([edge_features , rel_goal ], dim = - 1 )
129
133
for i in range (n_agents ):
130
134
q_i = q [:, i :i + 1 , :] # (B, 1, D)
131
135
mask = torch .ones (n_agents , dtype = torch .bool , device = edge_features .device )
132
136
mask [i ] = False
133
- edge_i_wo_self = edge_features [:, i , mask , :]
137
+ edge_i_wo_self = soft_edge_features [:, i , mask , :]
134
138
edge_i_wo_self = edge_i_wo_self .squeeze (1 ) # (B, N-1, 7)
135
139
k = F .leaky_relu (self .k (edge_i_wo_self ))
136
140
@@ -167,7 +171,6 @@ def forward(self, embedding):
167
171
entropy_list .append (entropy )
168
172
169
173
# Project each other agent's features to embedding dim *before* the attention-weighted sum
170
- # v_j = self.v_proj(edge_i_wo_self) # (B, N-1, embedding_dim)
171
174
v_j = F .leaky_relu (self .v (edge_i_wo_self ))
172
175
attn_output = torch .bmm (combined_weights , v_j ).squeeze (1 ) # (B, embedding_dim)
173
176
attention_outputs .append (attn_output )
@@ -346,7 +349,7 @@ def get_action(self, obs, add_noise):
346
349
"""
347
350
action , connection , combined_weights = self .act (obs )
348
351
if add_noise :
349
- noise = np .random .normal (0 , 0.4 , size = action .shape )
352
+ noise = np .random .normal (0 , 0.5 , size = action .shape )
350
353
noise = [n / 4 if i % 2 else n for i , n in enumerate (noise )]
351
354
action = (action + noise
352
355
).clip (- self .max_action , self .max_action )
@@ -609,7 +612,7 @@ def prepare_state(self, poses, distance, cos, sin, collision, goal, action, posi
609
612
ang_vel = (act [1 ] + 1 ) / 2 # Assuming original range [-1, 1]
610
613
611
614
# Final state vector
612
- state = [x , y , heading_cos , heading_sin , distance [i ]/ 17 , cos [i ], sin [i ], lin_vel , ang_vel ]
615
+ state = [x , y , heading_cos , heading_sin , distance [i ]/ 17 , cos [i ], sin [i ], lin_vel , ang_vel , gx , gy ]
613
616
614
617
assert len (state ) == self .state_dim , f"State length mismatch: expected { self .state_dim } , got { len (state )} "
615
618
states .append (state )
0 commit comments