Skip to content

Commit 9a54277

Browse files
committed
working phase with polar other goal coordinates
1 parent 8741354 commit 9a54277

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

robot_nav/models/CNNTD3/att.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def __init__(self, embedding_dim):
3232

3333
# Soft attention projections
3434
self.q = nn.Linear(embedding_dim, embedding_dim, bias=False)
35-
self.k = nn.Linear(9, embedding_dim, bias=False)
36-
self.v = nn.Linear(9, embedding_dim)
35+
self.k = nn.Linear(10, embedding_dim, bias=False)
36+
self.v = nn.Linear(10, embedding_dim)
3737

3838
# Soft attention score network (with distance)
3939
self.attn_score_layer = nn.Sequential(
@@ -65,7 +65,7 @@ def forward(self, embedding):
6565
goal = embedding[:, :, -2:].reshape(batch_size, n_agents, 2)
6666
goal_j = goal.unsqueeze(1).expand(-1, n_agents, -1, -1) # (B, N, N, 2)
6767
pos_i = position.unsqueeze(2) # (B, N, 1, 2)
68-
rel_goal = goal_j - pos_i
68+
goal_rel_vec = goal_j - pos_i
6969

7070
agent_embed = self.encode_agent_features(embed)
7171
agent_embed = agent_embed.view(batch_size, n_agents, self.embedding_dim)
@@ -80,7 +80,7 @@ def forward(self, embedding):
8080
# Compute relative vectors and distance
8181
rel_vec = pos_j - pos_i # (B, N, N, 2)
8282
dx, dy = rel_vec[..., 0], rel_vec[..., 1]
83-
rel_dist = torch.linalg.vector_norm(rel_vec, dim=-1, keepdim=True) # (B, N, N, 1)
83+
rel_dist = torch.linalg.vector_norm(rel_vec, dim=-1, keepdim=True)/12 # (B, N, N, 1)
8484

8585
# Relative angle in agent i's frame
8686
angle = torch.atan2(dy, dx) - torch.atan2(heading_i[..., 1], heading_i[..., 0])
@@ -119,7 +119,7 @@ def forward(self, embedding):
119119
hard_weights = F.gumbel_softmax(hard_logits, hard=False, tau=0.5, dim=-1)[..., 1].unsqueeze(2)
120120
hard_weights = hard_weights.view(batch_size, n_agents, n_agents - 1)
121121

122-
unnorm_rel_vec = rel_vec * 12
122+
unnorm_rel_vec = rel_vec
123123
unnorm_rel_dist = torch.linalg.vector_norm(unnorm_rel_vec, dim=-1, keepdim=True)
124124
unnorm_rel_dist = unnorm_rel_dist[:, mask].reshape(batch_size * n_agents, n_agents - 1, 1)
125125

@@ -129,7 +129,19 @@ def forward(self, embedding):
129129
attention_outputs = []
130130
entropy_list = []
131131
combined_w = []
132-
soft_edge_features = torch.cat([edge_features, rel_goal], dim=-1)
132+
133+
goal_rel_dist = torch.linalg.vector_norm(goal_rel_vec, dim=-1, keepdim=True)
134+
goal_angle_global = torch.atan2(goal_rel_vec[..., 1], goal_rel_vec[..., 0])
135+
heading_angle = torch.atan2(heading_i[..., 1], heading_i[..., 0])
136+
goal_rel_angle = goal_angle_global - heading_angle
137+
goal_rel_angle = (goal_rel_angle + np.pi) % (2 * np.pi) - np.pi
138+
goal_rel_angle_cos = torch.cos(goal_rel_angle).unsqueeze(-1)
139+
goal_rel_angle_sin = torch.sin(goal_rel_angle).unsqueeze(-1)
140+
goal_polar = torch.cat([goal_rel_dist, goal_rel_angle_cos, goal_rel_angle_sin], dim=-1)
141+
142+
143+
144+
soft_edge_features = torch.cat([edge_features, goal_polar], dim=-1)
133145
for i in range(n_agents):
134146
q_i = q[:, i:i + 1, :] # (B, 1, D)
135147
mask = torch.ones(n_agents, dtype=torch.bool, device=edge_features.device)
@@ -591,18 +603,6 @@ def prepare_state(self, poses, distance, cos, sin, collision, goal, action, posi
591603
px, py, theta = pose
592604
gx, gy = goal_pos
593605

594-
# Global position (keep for boundary awareness)
595-
x = px / 12
596-
y = py / 12
597-
598-
# Relative goal position in local frame
599-
dx = gx - px
600-
dy = gy - py
601-
rel_gx = dx * np.cos(theta) + dy * np.sin(theta)
602-
rel_gy = -dx * np.sin(theta) + dy * np.cos(theta)
603-
rel_gx /= 12
604-
rel_gy /= 12
605-
606606
# Heading as cos/sin
607607
heading_cos = np.cos(theta)
608608
heading_sin = np.sin(theta)
@@ -612,7 +612,7 @@ def prepare_state(self, poses, distance, cos, sin, collision, goal, action, posi
612612
ang_vel = (act[1] + 1) / 2 # Assuming original range [-1, 1]
613613

614614
# Final state vector
615-
state = [x, y, heading_cos, heading_sin, distance[i]/17, cos[i], sin[i], lin_vel, ang_vel, gx, gy]
615+
state = [px, py, heading_cos, heading_sin, distance[i]/17, cos[i], sin[i], lin_vel, ang_vel, gx, gy]
616616

617617
assert len(state) == self.state_dim, f"State length mismatch: expected {self.state_dim}, got {len(state)}"
618618
states.append(state)

robot_nav/multi_train2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def main(args=None):
5656
num_robots=sim.num_robots,
5757
device=device,
5858
save_every=save_every,
59-
load_model=True,
60-
model_name="phase2",
59+
load_model=False,
60+
model_name="phase1",
6161
load_model_name="phase1"
6262
) # instantiate a model
6363

robot_nav/sim2.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -249,45 +249,45 @@ def get_reward(goal, collision, action, closest_robots, distance):
249249
# return 2*action[0] - abs(action[1]) - cl_pen + r_dist
250250

251251
# phase1
252+
if goal:
253+
return 100.0
254+
elif collision:
255+
return -100.0 * 3 * action[0]
256+
else:
257+
r_dist = 1.5/distance
258+
cl_pen = 0
259+
for rob in closest_robots:
260+
add = 1.5 - rob if rob < 1.5 else 0
261+
cl_pen += add
262+
263+
return action[0] - 0.5 * abs(action[1])-cl_pen + r_dist
264+
265+
266+
# phase2
252267
# if goal:
253268
# return 100.0
254269
# elif collision:
255-
# return -100.0 * 3 * action[0]
270+
# return -100.0
256271
# else:
257272
# r_dist = 1.5/distance
258273
# cl_pen = 0
259274
# for rob in closest_robots:
260275
# add = 1.5 - rob if rob < 1.5 else 0
261276
# cl_pen += add
262277
#
263-
# return action[0] - 0.5 * abs(action[1])-cl_pen + r_dist
264-
278+
# return -0.5*abs(action[1])-cl_pen
265279

266-
# phase2
280+
# phase3
267281
# if goal:
268-
# return 100.0
282+
# return 70.0
269283
# elif collision:
270-
# return -100.0
284+
# return -100.0 * 3 * action[0]
271285
# else:
272-
# r_dist = 1.5/distance
286+
# r_dist = 1.5 / distance
273287
# cl_pen = 0
274288
# for rob in closest_robots:
275-
# add = 1.5 - rob if rob < 1.5 else 0
289+
# add = (3 - rob)**2 if rob < 3 else 0
276290
# cl_pen += add
277291
#
278-
# return -0.5*abs(action[1])-cl_pen
279-
280-
# phase3
281-
if goal:
282-
return 70.0
283-
elif collision:
284-
return -100.0 * 3 * action[0]
285-
else:
286-
r_dist = 1.5 / distance
287-
cl_pen = 0
288-
for rob in closest_robots:
289-
add = (3 - rob)**2 if rob < 3 else 0
290-
cl_pen += add
291-
292-
return -0.5 * abs(action[1]) - cl_pen
292+
# return -0.5 * abs(action[1]) - cl_pen
293293

0 commit comments

Comments
 (0)