Skip to content

Commit 564294f

Browse files
committed
second refactor
1 parent d290d93 commit 564294f

File tree

8 files changed

+80
-90
lines changed

8 files changed

+80
-90
lines changed

robot_nav/SIM_ENV/marl_sim.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
class MARL_SIM(SIM_ENV):
1010
"""
11-
A simulation environment interface for robot navigation using IRSim.
11+
A simulation environment interface for robot navigation using IRSim in MARL setting.
1212
1313
This class wraps around the IRSim environment and provides methods for stepping,
14-
resetting, and interacting with a mobile robot, including reward computation.
14+
resetting, and interacting with mobile robots, including reward computation.
1515
1616
Attributes:
1717
env (object): The simulation environment instance from IRSim.
@@ -33,6 +33,8 @@ def __init__(self, world_file="multi_robot_world.yaml", disable_plotting=False):
3333
robot_info = self.env.get_robot_info(0)
3434
self.robot_goal = robot_info.goal
3535
self.num_robots = len(self.env.robot_list)
36+
self.x_range = self.env._world.x_range
37+
self.y_range = self.env._world.y_range
3638

3739
def step(self, action, connection, combined_weights=None):
3840
"""
@@ -46,7 +48,6 @@ def step(self, action, connection, combined_weights=None):
4648
(tuple): Contains the latest LIDAR scan, distance to goal, cosine and sine of angle to goal,
4749
collision flag, goal reached flag, applied action, and computed reward.
4850
"""
49-
# action = [[lin_velocity, ang_velocity], [lin_velocity, ang_velocity], [lin_velocity, ang_velocity], [lin_velocity, ang_velocity], [lin_velocity, ang_velocity]]
5051
self.env.step(action_id=[i for i in range(self.num_robots)], action=action)
5152
self.env.render()
5253

@@ -139,8 +140,8 @@ def step(self, action, connection, combined_weights=None):
139140
obstacle_list=self.env.obstacle_list,
140141
init=True,
141142
range_limits=[
142-
[1, 1, -3.141592653589793],
143-
[11, 11, 3.141592653589793],
143+
[self.x_range[0] + 1, self.y_range[0] + 1, -3.141592653589793],
144+
[self.x_range[1] - 1, self.y_range[1] - 1, 3.141592653589793],
144145
],
145146
)
146147

@@ -209,8 +210,8 @@ def reset(
209210
if random_obstacle_ids is None:
210211
random_obstacle_ids = [i + self.num_robots for i in range(7)]
211212
self.env.random_obstacle_position(
212-
range_low=[0, 0, -3.14],
213-
range_high=[12, 12, 3.14],
213+
range_low=[self.x_range[0], self.y_range[0], -3.14],
214+
range_high=[self.x_range[1], self.y_range[1], 3.14],
214215
ids=random_obstacle_ids,
215216
non_overlapping=True,
216217
)
@@ -221,8 +222,8 @@ def reset(
221222
obstacle_list=self.env.obstacle_list,
222223
init=True,
223224
range_limits=[
224-
[1, 1, -3.141592653589793],
225-
[11, 11, 3.141592653589793],
225+
[self.x_range[0] + 1, self.y_range[0] + 1, -3.141592653589793],
226+
[self.x_range[1] - 1, self.y_range[1] - 1, 3.141592653589793],
226227
],
227228
)
228229
else:
@@ -251,44 +252,49 @@ def reset(
251252
)
252253

253254
@staticmethod
254-
def get_reward(goal, collision, action, closest_robots, distance):
255+
def get_reward(goal, collision, action, closest_robots, distance, phase=1):
255256
"""
256257
Calculate the reward for the current step.
257258
258259
Args:
259260
goal (bool): Whether the goal has been reached.
260261
collision (bool): Whether a collision occurred.
261262
action (list): The action taken [linear velocity, angular velocity].
262-
laser_scan (list): The LIDAR scan readings.
263+
closest_robots (list): Distances to the closest robots.
264+
distance (float): Distance to goal.
265+
phase (int, optional): Reward function phase. Defaults to 1.
263266
264267
Returns:
265268
(float): Computed reward for the current state.
266269
"""
267270

268-
# phase1
269-
if goal:
270-
return 100.0
271-
elif collision:
272-
return -100.0 * 3 * action[0]
273-
else:
274-
r_dist = 1.5 / distance
275-
cl_pen = 0
276-
for rob in closest_robots:
277-
add = 1.5 - rob if rob < 1.5 else 0
278-
cl_pen += add
279-
280-
return action[0] - 0.5 * abs(action[1]) - cl_pen + r_dist
281-
282-
# phase2
283-
# if goal:
284-
# return 70.0
285-
# elif collision:
286-
# return -100.0 * 3 * action[0]
287-
# else:
288-
# r_dist = 1.5 / distance
289-
# cl_pen = 0
290-
# for rob in closest_robots:
291-
# add = (3 - rob)**2 if rob < 3 else 0
292-
# cl_pen += add
293-
#
294-
# return -0.5 * abs(action[1]) - cl_pen
271+
match phase:
272+
case 1:
273+
if goal:
274+
return 100.0
275+
elif collision:
276+
return -100.0 * 3 * action[0]
277+
else:
278+
r_dist = 1.5 / distance
279+
cl_pen = 0
280+
for rob in closest_robots:
281+
add = 1.5 - rob if rob < 1.5 else 0
282+
cl_pen += add
283+
284+
return action[0] - 0.5 * abs(action[1]) - cl_pen + r_dist
285+
286+
case 2:
287+
if goal:
288+
return 70.0
289+
elif collision:
290+
return -100.0 * 3 * action[0]
291+
else:
292+
cl_pen = 0
293+
for rob in closest_robots:
294+
add = (3 - rob) ** 2 if rob < 3 else 0
295+
cl_pen += add
296+
297+
return -0.5 * abs(action[1]) - cl_pen
298+
299+
case _:
300+
raise ValueError("Unknown reward phase")

robot_nav/models/CNNTD3/CNNTD3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def train(
325325
state = torch.Tensor(batch_states).to(self.device)
326326
next_state = torch.Tensor(batch_next_states).to(self.device)
327327
action = torch.Tensor(batch_actions).to(self.device)
328-
reward = torch.Tensor(batch_rewards).to(self.device)
329-
done = torch.Tensor(batch_dones).to(self.device)
328+
reward = torch.Tensor(batch_rewards).to(self.device).reshape(-1, 1)
329+
done = torch.Tensor(batch_dones).to(self.device).reshape(-1, 1)
330330

331331
# Obtain the estimated action from the next state by using the actor-target
332332
next_action = self.actor_target(next_state)

robot_nav/models/DDPG/DDPG.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def train(
254254
state = torch.Tensor(batch_states).to(self.device)
255255
next_state = torch.Tensor(batch_next_states).to(self.device)
256256
action = torch.Tensor(batch_actions).to(self.device)
257-
reward = torch.Tensor(batch_rewards).to(self.device)
258-
done = torch.Tensor(batch_dones).to(self.device)
257+
reward = torch.Tensor(batch_rewards).to(self.device).reshape(-1, 1)
258+
done = torch.Tensor(batch_dones).to(self.device).reshape(-1, 1)
259259

260260
# Obtain the estimated action from the next state by using the actor-target
261261
next_action = self.actor_target(next_state)

robot_nav/models/MARL/hsAttention.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def __init__(self, embedding_dim):
99
super(Attention, self).__init__()
1010
self.embedding_dim = embedding_dim
1111

12-
# CNN for laser scan
1312
self.embedding1 = nn.Linear(5, 128)
1413
nn.init.kaiming_uniform_(self.embedding1.weight, nonlinearity="leaky_relu")
1514
self.embedding2 = nn.Linear(128, embedding_dim)
@@ -28,7 +27,7 @@ def __init__(self, embedding_dim):
2827
self.k = nn.Linear(10, embedding_dim, bias=False)
2928
self.v = nn.Linear(10, embedding_dim)
3029

31-
# Soft attention score network (with distance)
30+
# Soft attention score network (with polar other robot goal position)
3231
self.attn_score_layer = nn.Sequential(
3332
nn.Linear(embedding_dim * 2, embedding_dim),
3433
nn.ReLU(),
@@ -58,8 +57,8 @@ def forward(self, embedding):
5857
) # assume (cos(θ), sin(θ))
5958
action = embedding[:, :, 7:9].reshape(batch_size, n_agents, 2)
6059
goal = embedding[:, :, -2:].reshape(batch_size, n_agents, 2)
61-
goal_j = goal.unsqueeze(1).expand(-1, n_agents, -1, -1) # (B, N, N, 2)
62-
pos_i = position.unsqueeze(2) # (B, N, 1, 2)
60+
goal_j = goal.unsqueeze(1).expand(-1, n_agents, -1, -1)
61+
pos_i = position.unsqueeze(2)
6362
goal_rel_vec = goal_j - pos_i
6463

6564
agent_embed = self.encode_agent_features(embed)
@@ -100,10 +99,10 @@ def forward(self, embedding):
10099
action.unsqueeze(1).expand(-1, n_agents, -1, -1), # (B, N, N, 2)
101100
],
102101
dim=-1,
103-
) # (B, N, N, 7)
102+
)
104103

105104
# Broadcast h_i along N (for each pair)
106-
h_i_expanded = h_i.expand(-1, -1, n_agents, -1) # (B, N, N, D)
105+
h_i_expanded = h_i.expand(-1, -1, n_agents, -1)
107106

108107
# Remove self-pairs using mask
109108
mask = ~torch.eye(n_agents, dtype=torch.bool, device=embedding.device)
@@ -115,7 +114,7 @@ def forward(self, embedding):
115114
)
116115

117116
# Concatenate agent embedding and edge features
118-
hard_input = torch.cat([h_i_flat, edge_flat], dim=-1) # (B*N, N-1, D+7)
117+
hard_input = torch.cat([h_i_flat, edge_flat], dim=-1)
119118

120119
# Hard attention forward
121120
h_hard = self.hard_mlp(hard_input)
@@ -125,8 +124,7 @@ def forward(self, embedding):
125124
].unsqueeze(2)
126125
hard_weights = hard_weights.view(batch_size, n_agents, n_agents - 1)
127126

128-
unnorm_rel_vec = rel_vec
129-
unnorm_rel_dist = torch.linalg.vector_norm(unnorm_rel_vec, dim=-1, keepdim=True)
127+
unnorm_rel_dist = torch.linalg.vector_norm(rel_vec, dim=-1, keepdim=True)
130128
unnorm_rel_dist = unnorm_rel_dist[:, mask].reshape(
131129
batch_size * n_agents, n_agents - 1, 1
132130
)
@@ -151,23 +149,21 @@ def forward(self, embedding):
151149

152150
soft_edge_features = torch.cat([edge_features, goal_polar], dim=-1)
153151
for i in range(n_agents):
154-
q_i = q[:, i : i + 1, :] # (B, 1, D)
152+
q_i = q[:, i : i + 1, :]
155153
mask = torch.ones(n_agents, dtype=torch.bool, device=edge_features.device)
156154
mask[i] = False
157155
edge_i_wo_self = soft_edge_features[:, i, mask, :]
158-
edge_i_wo_self = edge_i_wo_self.squeeze(1) # (B, N-1, 7)
156+
edge_i_wo_self = edge_i_wo_self.squeeze(1)
159157
k = F.leaky_relu(self.k(edge_i_wo_self))
160158

161-
q_i_expanded = q_i.expand(-1, n_agents - 1, -1) # (B, N-1, D)
162-
attention_input = torch.cat([q_i_expanded, k], dim=-1) # (B, N-1, D+7)
159+
q_i_expanded = q_i.expand(-1, n_agents - 1, -1)
160+
attention_input = torch.cat([q_i_expanded, k], dim=-1)
163161

164162
# Score computation
165-
scores = self.attn_score_layer(attention_input).transpose(
166-
1, 2
167-
) # (B, 1, N-1)
163+
scores = self.attn_score_layer(attention_input).transpose(1, 2)
168164

169165
# Mask using hard weights
170-
h_weights = hard_weights[:, i].unsqueeze(1) # (B, 1, N-1)
166+
h_weights = hard_weights[:, i].unsqueeze(1)
171167
mask = (h_weights > 0.5).float()
172168

173169
# All-zero mask handling
@@ -200,11 +196,8 @@ def forward(self, embedding):
200196
)
201197
entropy_list.append(entropy)
202198

203-
# Project each other agent's features to embedding dim *before* the attention-weighted sum
204199
v_j = F.leaky_relu(self.v(edge_i_wo_self))
205-
attn_output = torch.bmm(combined_weights, v_j).squeeze(
206-
1
207-
) # (B, embedding_dim)
200+
attn_output = torch.bmm(combined_weights, v_j).squeeze(1)
208201
attention_outputs.append(attn_output)
209202

210203
comb_w = torch.stack(combined_w, dim=1).reshape(n_agents, -1)

robot_nav/models/MARL/marlTD3.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,9 @@ def train(
215215
policy_noise=0.2,
216216
noise_clip=0.5,
217217
policy_freq=2,
218+
bce_weight=0.1,
219+
entropy_weight=1,
220+
connection_proximity_threshold=4,
218221
):
219222
av_Q = 0
220223
max_Q = -inf
@@ -298,19 +301,16 @@ def train(
298301
current_Q2, target_Q
299302
)
300303

301-
proximity_threshold = 4 # You may need to adjust this
302-
targets = (unnorm_rel_dist.flatten() < proximity_threshold).float()
304+
targets = (
305+
unnorm_rel_dist.flatten() < connection_proximity_threshold
306+
).float()
303307
flat_logits = hard_logits.flatten()
304308
bce_loss = F.binary_cross_entropy_with_logits(flat_logits, targets)
305309

306-
bce_weight = 0.1
307310
av_critic_bce_loss.append(bce_loss)
308311

309-
critic_entropy_weight = 1 # or tuneable
310312
total_loss = (
311-
critic_loss
312-
- critic_entropy_weight * mean_entropy
313-
+ bce_weight * bce_loss
313+
critic_loss - entropy_weight * mean_entropy + bce_weight * bce_loss
314314
)
315315
av_critic_entropy.append(mean_entropy)
316316

@@ -328,20 +328,18 @@ def train(
328328
action, hard_logits, unnorm_rel_dist, mean_entropy, hard_weights, _ = (
329329
self.actor(state, detach_attn=False)
330330
)
331-
targets = (unnorm_rel_dist.flatten() < proximity_threshold).float()
331+
targets = (
332+
unnorm_rel_dist.flatten() < connection_proximity_threshold
333+
).float()
332334
flat_logits = hard_logits.flatten()
333335
bce_loss = F.binary_cross_entropy_with_logits(flat_logits, targets)
334336

335-
bce_weight = 0.1
336337
av_actor_bce_loss.append(bce_loss)
337338

338339
actor_Q, _, _, _, _, _ = self.critic(state, action)
339340
actor_loss = -actor_Q.mean()
340-
actor_entropy_weight = 0.05
341341
total_loss = (
342-
actor_loss
343-
- actor_entropy_weight * mean_entropy
344-
+ bce_weight * bce_loss
342+
actor_loss - entropy_weight * mean_entropy + bce_weight * bce_loss
345343
)
346344
av_actor_entropy.append(mean_entropy)
347345

@@ -458,9 +456,7 @@ def prepare_state(
458456
poses (list): Each agent's global pose [x, y, theta].
459457
distance, cos, sin: Unused, can be removed or ignored.
460458
collision (list): Collision flags per agent.
461-
goal (list): Goal reached flags per agent.
462459
action (list): Last action taken [lin_vel, ang_vel].
463-
positions (list): Extra features (e.g., neighbors).
464460
goal_positions (list): Each agent's goal [x, y].
465461
466462
Returns:
@@ -483,7 +479,7 @@ def prepare_state(
483479
heading_sin = np.sin(theta)
484480

485481
# Last velocity
486-
lin_vel = act[0] * 2 # Assuming original range [-0.5, 0.5]
482+
lin_vel = act[0] * 2 # Assuming original range [0, 0.5]
487483
ang_vel = (act[1] + 1) / 2 # Assuming original range [-1, 1]
488484

489485
# Final state vector

robot_nav/models/SAC/SAC.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ def update(self, replay_buffer, step, batch_size):
346346
state = torch.Tensor(batch_states).to(self.device)
347347
next_state = torch.Tensor(batch_next_states).to(self.device)
348348
action = torch.Tensor(batch_actions).to(self.device)
349-
reward = torch.Tensor(batch_rewards).to(self.device)
350-
done = torch.Tensor(batch_dones).to(self.device)
349+
reward = torch.Tensor(batch_rewards).to(self.device).reshape(-1, 1)
350+
done = torch.Tensor(batch_dones).to(self.device).reshape(-1, 1)
351351
self.train_metrics_dict["train/batch_reward_av"].append(
352352
batch_rewards.mean().item()
353353
)

robot_nav/models/TD3/TD3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ def train(
268268
state = torch.Tensor(batch_states).to(self.device)
269269
next_state = torch.Tensor(batch_next_states).to(self.device)
270270
action = torch.Tensor(batch_actions).to(self.device)
271-
reward = torch.Tensor(batch_rewards).to(self.device)
272-
done = torch.Tensor(batch_dones).to(self.device)
271+
reward = torch.Tensor(batch_rewards).to(self.device).reshape(-1, 1)
272+
done = torch.Tensor(batch_dones).to(self.device).reshape(-1, 1)
273273

274274
# Obtain the estimated action from the next state by using the actor-target
275275
next_action = self.actor_target(next_state)

robot_nav/multi_robot_world.yaml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
world:
22
height: 12 # the height of the world
33
width: 12 # the height of the world
4-
step_time: 0.3 # 10Hz calculate each step
5-
sample_time: 0.3 # 10 Hz for render and data extraction
4+
step_time: 0.3 # Calculate each step
5+
sample_time: 0.3 # For render and data extraction
66
collision_mode: 'reactive'
77

88
robot:
@@ -20,8 +20,3 @@ robot:
2020

2121
plot:
2222
show_trajectory: False
23-
24-
#obstacle:
25-
# - shape: { name: 'linestring', vertices: [ [ 0, 0 ], [ 12, 0 ], [ 12, 12 ], [ 0, 12 ],[ 0, 0 ] ] } # vertices
26-
# kinematics: {name: 'static'}
27-
# state: [ 0, 0, 0 ]

0 commit comments

Comments
 (0)