Skip to content

Commit 8864d0a

Browse files
committed
update docs
1 parent fe977a2 commit 8864d0a

File tree

10 files changed

+259
-65
lines changed

10 files changed

+259
-65
lines changed

docs/api/IR-SIM/ir-marl-sim.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# MARL-IR-SIM
2+
3+
::: robot_nav.SIM_ENV.marl_sim
4+
options:
5+
show_root_heading: true
6+
show_source: true

docs/api/IR-SIM/ir-sim.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# IR-SIM
22

3-
::: robot_nav.sim
3+
::: robot_nav.SIM_ENV.sim
44
options:
55
show_root_heading: true
66
show_source: true

docs/api/models/MARL/Attention.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Hard-Soft Attention
2+
3+
::: robot_nav.models.MARL.hardsoftAttention
4+
options:
5+
show_root_heading: true
6+
show_source: true

docs/api/models/MARL/TD3.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# MARL TD3
2+
3+
::: robot_nav.models.MARL.marlTD3
4+
options:
5+
show_root_heading: true
6+
show_source: true

mkdocs.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ extra:
44
nav:
55
- Home: index.md
66
- API Reference:
7-
IR-SIM: api/IR-SIM/ir-sim
7+
IR-SIM:
8+
- SIM: api/IR-SIM/ir-sim
9+
- MARL SIM: api/IR-SIM/ir-marl-sim
810
Models:
911
- DDPG: api/models/DDPG.md
1012
- TD3: api/models/TD3.md
@@ -13,6 +15,9 @@ nav:
1315
- HCM: api/models/HCM.md
1416
- PPO: api/models/PPO.md
1517
- SAC: api/models/SAC.md
18+
- MARL:
19+
- HardSoft Attention: api/models/MARL/Attention
20+
- TD3: api/models/MARL/TD3
1621
Training:
1722
- Train: api/Training/train.md
1823
- Train RNN: api/Training/trainrnn.md

robot_nav/SIM_ENV/marl_sim.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,26 @@
88

99
class MARL_SIM(SIM_ENV):
1010
"""
11-
A simulation environment interface for robot navigation using IRSim in MARL setting.
11+
Simulation environment for multi-agent robot navigation using IRSim.
1212
13-
This class wraps around the IRSim environment and provides methods for stepping,
14-
resetting, and interacting with mobile robots, including reward computation.
13+
This class extends the SIM_ENV and provides a wrapper for multi-robot
14+
simulation and interaction, supporting reward computation and custom reset logic.
1515
1616
Attributes:
17-
env (object): The simulation environment instance from IRSim.
18-
robot_goal (np.ndarray): The goal position of the robot.
17+
env (object): IRSim simulation environment instance.
18+
robot_goal (np.ndarray): Current goal position(s) for the robots.
19+
num_robots (int): Number of robots in the environment.
20+
x_range (tuple): World x-range.
21+
y_range (tuple): World y-range.
1922
"""
2023

2124
def __init__(self, world_file="multi_robot_world.yaml", disable_plotting=False):
2225
"""
23-
Initialize the simulation environment.
26+
Initialize the MARL_SIM environment.
2427
2528
Args:
26-
world_file (str): Path to the world configuration YAML file.
27-
disable_plotting (bool): If True, disables rendering and plotting.
29+
world_file (str, optional): Path to the world configuration YAML file.
30+
disable_plotting (bool, optional): If True, disables IRSim rendering and plotting.
2831
"""
2932
display = False if disable_plotting else True
3033
self.env = irsim.make(
@@ -38,15 +41,26 @@ def __init__(self, world_file="multi_robot_world.yaml", disable_plotting=False):
3841

3942
def step(self, action, connection, combined_weights=None):
4043
"""
41-
Perform one step in the simulation using the given control commands.
44+
Perform a simulation step for all robots using the provided actions and connections.
4245
4346
Args:
44-
lin_velocity (float): Linear velocity to apply to the robot.
45-
ang_velocity (float): Angular velocity to apply to the robot.
47+
action (list): List of actions for each robot [[lin_vel, ang_vel], ...].
48+
connection (Tensor): Tensor of shape (num_robots, num_robots-1) containing logits indicating connections between robots.
49+
combined_weights (Tensor or None, optional): Optional weights for each connection, shape (num_robots, num_robots-1).
4650
4751
Returns:
48-
(tuple): Contains the latest LIDAR scan, distance to goal, cosine and sine of angle to goal,
49-
collision flag, goal reached flag, applied action, and computed reward.
52+
tuple: (
53+
poses (list): List of [x, y, theta] for each robot,
54+
distances (list): Distance to goal for each robot,
55+
coss (list): Cosine of angle to goal for each robot,
56+
sins (list): Sine of angle to goal for each robot,
57+
collisions (list): Collision status for each robot,
58+
goals (list): Goal reached status for each robot,
59+
action (list): Actions applied,
60+
rewards (list): Rewards computed,
61+
positions (list): Current [x, y] for each robot,
62+
goal_positions (list): Goal [x, y] for each robot,
63+
)
5064
"""
5165
self.env.step(action_id=[i for i in range(self.num_robots)], action=action)
5266
self.env.render()
@@ -166,17 +180,27 @@ def reset(
166180
random_obstacle_ids=None,
167181
):
168182
"""
169-
Reset the simulation environment, optionally setting robot and obstacle states.
183+
Reset the simulation environment and optionally set robot and obstacle positions.
170184
171185
Args:
172-
robot_state (list or None): Initial state of the robot as a list of [x, y, theta, speed].
173-
robot_goal (list or None): Goal state for the robot.
174-
random_obstacles (bool): Whether to randomly reposition obstacles.
175-
random_obstacle_ids (list or None): Specific obstacle IDs to randomize.
186+
robot_state (list or None, optional): Initial state for robots as [x, y, theta, speed].
187+
robot_goal (list or None, optional): Goal position(s) for the robots.
188+
random_obstacles (bool, optional): If True, randomly position obstacles.
189+
random_obstacle_ids (list or None, optional): IDs of obstacles to randomize.
176190
177191
Returns:
178-
(tuple): Initial observation after reset, including LIDAR scan, distance, cos/sin,
179-
and reward-related flags and values.
192+
tuple: (
193+
poses (list): List of [x, y, theta] for each robot,
194+
distances (list): Distance to goal for each robot,
195+
coss (list): Cosine of angle to goal for each robot,
196+
sins (list): Sine of angle to goal for each robot,
197+
collisions (list): All False after reset,
198+
goals (list): All False after reset,
199+
action (list): Initial action ([[0.0, 0.0], ...]),
200+
rewards (list): Rewards for initial state,
201+
positions (list): Initial [x, y] for each robot,
202+
goal_positions (list): Initial goal [x, y] for each robot,
203+
)
180204
"""
181205
if robot_state is None:
182206
robot_state = [[random.uniform(3, 9)], [random.uniform(3, 9)], [0]]
@@ -254,18 +278,18 @@ def reset(
254278
@staticmethod
255279
def get_reward(goal, collision, action, closest_robots, distance, phase=1):
256280
"""
257-
Calculate the reward for the current step.
281+
Calculate the reward for a robot given the current state and action.
258282
259283
Args:
260-
goal (bool): Whether the goal has been reached.
284+
goal (bool): Whether the robot reached its goal.
261285
collision (bool): Whether a collision occurred.
262-
action (list): The action taken [linear velocity, angular velocity].
263-
closest_robots (list): Distances to the closest robots.
264-
distance (float): Distance to goal.
265-
phase (int, optional): Reward function phase. Defaults to 1.
286+
action (list): [linear_velocity, angular_velocity] applied.
287+
closest_robots (list): Distances to the closest other robots.
288+
distance (float): Distance to the goal.
289+
phase (int, optional): Reward phase/function selector (default: 1).
266290
267291
Returns:
268-
(float): Computed reward for the current state.
292+
float: Computed reward.
269293
"""
270294

271295
match phase:

robot_nav/models/MARL/__init__.py

Whitespace-only changes.

robot_nav/models/MARL/hardsoftAttention.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,33 @@
55

66

77
class Attention(nn.Module):
8+
"""
9+
Multi-robot attention mechanism for learning hard and soft attentions.
10+
11+
This module provides both hard (binary) and soft (weighted) attention,
12+
combining feature encoding, relative pose and goal geometry, and
13+
message passing between agents.
14+
15+
Args:
16+
embedding_dim (int): Dimension of the agent embedding vector.
17+
18+
Attributes:
19+
embedding1 (nn.Linear): First layer for agent feature encoding.
20+
embedding2 (nn.Linear): Second layer for agent feature encoding.
21+
hard_mlp (nn.Sequential): MLP to process concatenated agent and edge features.
22+
hard_encoding (nn.Linear): Outputs logits for hard (binary) attention.
23+
q, k, v (nn.Linear): Layers for query, key, value projections for soft attention.
24+
attn_score_layer (nn.Sequential): Computes unnormalized attention scores for each pair.
25+
decode_1, decode_2 (nn.Linear): Decoding layers to produce the final attended embedding.
26+
"""
27+
828
def __init__(self, embedding_dim):
29+
"""
30+
Initialize attention mechanism for multi-agent communication.
31+
32+
Args:
33+
embedding_dim (int): Output embedding dimension per agent.
34+
"""
935
super(Attention, self).__init__()
1036
self.embedding_dim = embedding_dim
1137

@@ -41,30 +67,59 @@ def __init__(self, embedding_dim):
4167
nn.init.kaiming_uniform_(self.decode_2.weight, nonlinearity="leaky_relu")
4268

4369
def encode_agent_features(self, embed):
70+
"""
71+
Encode agent features using a small MLP.
72+
73+
Args:
74+
embed (Tensor): Input features (B*N, 5).
75+
76+
Returns:
77+
Tensor: Encoded embedding (B*N, embedding_dim).
78+
"""
4479
embed = F.leaky_relu(self.embedding1(embed))
4580
embed = F.leaky_relu(self.embedding2(embed))
4681
return embed
4782

4883
def forward(self, embedding):
84+
"""
85+
Forward pass: computes both hard and soft attentions among agents,
86+
produces the attended embedding for each agent, as well as diagnostic info.
87+
88+
Args:
89+
embedding (Tensor): Input tensor of shape (B, N, D), where D is at least 11.
90+
91+
Returns:
92+
tuple:
93+
att_embedding (Tensor): Final attended embedding, shape (B*N, 2*embedding_dim).
94+
hard_logits (Tensor): Logits for hard attention, (B*N, N-1).
95+
unnorm_rel_dist (Tensor): Pairwise distances between agents (not normalized), (B*N, N-1, 1).
96+
mean_entropy (Tensor): Mean entropy of soft attention distributions.
97+
hard_weights (Tensor): Binary hard attention mask, (B, N, N-1).
98+
comb_w (Tensor): Final combined attention weights, (N, N*(N-1)).
99+
"""
49100
if embedding.dim() == 2:
50101
embedding = embedding.unsqueeze(0)
51102
batch_size, n_agents, _ = embedding.shape
52103

104+
# Extract sub-features
53105
embed = embedding[:, :, 4:9].reshape(batch_size * n_agents, -1)
54106
position = embedding[:, :, :2].reshape(batch_size, n_agents, 2)
55107
heading = embedding[:, :, 2:4].reshape(
56108
batch_size, n_agents, 2
57109
) # assume (cos(θ), sin(θ))
58110
action = embedding[:, :, 7:9].reshape(batch_size, n_agents, 2)
59111
goal = embedding[:, :, -2:].reshape(batch_size, n_agents, 2)
112+
113+
# Compute pairwise relative goal vectors (for each i,j)
60114
goal_j = goal.unsqueeze(1).expand(-1, n_agents, -1, -1)
61115
pos_i = position.unsqueeze(2)
62116
goal_rel_vec = goal_j - pos_i
63117

118+
# Encode agent features
64119
agent_embed = self.encode_agent_features(embed)
65120
agent_embed = agent_embed.view(batch_size, n_agents, self.embedding_dim)
66121

67-
# For hard attention
122+
# Prep for hard attention: compute all relative geometry for each agent pair
68123
h_i = agent_embed.unsqueeze(2) # (B, N, 1, D)
69124
pos_i = position.unsqueeze(2) # (B, N, 1, 2)
70125
pos_j = position.unsqueeze(1) # (B, 1, N, 2)
@@ -88,7 +143,7 @@ def forward(self, embedding):
88143
heading_j_cos = heading_j[..., 0] # (B, 1, N)
89144
heading_j_sin = heading_j[..., 1] # (B, 1, N)
90145

91-
# Stack edge features
146+
# Edge features for hard attention
92147
edge_features = torch.cat(
93148
[
94149
rel_dist, # (B, N, N, 1)
@@ -101,7 +156,7 @@ def forward(self, embedding):
101156
dim=-1,
102157
)
103158

104-
# Broadcast h_i along N (for each pair)
159+
# Broadcast agent embedding for all pairs (except self-pairs)
105160
h_i_expanded = h_i.expand(-1, -1, n_agents, -1)
106161

107162
# Remove self-pairs using mask
@@ -129,13 +184,14 @@ def forward(self, embedding):
129184
batch_size * n_agents, n_agents - 1, 1
130185
)
131186

132-
# Soft attention
187+
# ---- Soft attention computation ----
133188
q = self.q(agent_embed)
134189

135190
attention_outputs = []
136191
entropy_list = []
137192
combined_w = []
138193

194+
# Goal-relative polar features for soft attention
139195
goal_rel_dist = torch.linalg.vector_norm(goal_rel_vec, dim=-1, keepdim=True)
140196
goal_angle_global = torch.atan2(goal_rel_vec[..., 1], goal_rel_vec[..., 0])
141197
heading_angle = torch.atan2(heading_i[..., 1], heading_i[..., 0])
@@ -147,6 +203,7 @@ def forward(self, embedding):
147203
[goal_rel_dist, goal_rel_angle_cos, goal_rel_angle_sin], dim=-1
148204
)
149205

206+
# Soft attention edge features (include goal polar)
150207
soft_edge_features = torch.cat([edge_features, goal_polar], dim=-1)
151208
for i in range(n_agents):
152209
q_i = q[:, i : i + 1, :]
@@ -159,10 +216,10 @@ def forward(self, embedding):
159216
q_i_expanded = q_i.expand(-1, n_agents - 1, -1)
160217
attention_input = torch.cat([q_i_expanded, k], dim=-1)
161218

162-
# Score computation
219+
# Score computation (per pair)
163220
scores = self.attn_score_layer(attention_input).transpose(1, 2)
164221

165-
# Mask using hard weights
222+
# Mask using hard attention
166223
h_weights = hard_weights[:, i].unsqueeze(1)
167224
mask = (h_weights > 0.5).float()
168225

@@ -183,12 +240,12 @@ def forward(self, embedding):
183240
combined_weights = soft_weights * mask # (B, 1, N-1)
184241
combined_w.append(combined_weights)
185242

186-
# Normalize combined_weights for entropy calculation
243+
# Normalize for entropy calculation
187244
combined_weights_norm = combined_weights / (
188245
combined_weights.sum(dim=-1, keepdim=True) + epsilon
189246
)
190247

191-
# Calculate entropy from combined_weights
248+
# Entropy for analysis/logging
192249
entropy = (
193250
-(combined_weights_norm * (combined_weights_norm + epsilon).log())
194251
.sum(dim=-1)

0 commit comments

Comments
 (0)