Skip to content

Commit bffc661

Browse files
committed
spellcheck
1 parent 9aa771d commit bffc661

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ and information about the goal point a robot learns to navigate to a specified p
4343
| CNNTD3 | TD3 model with 1D CNN encoding of laser state | - |
4444
| RCPG | Recurrent Convolution Policy Gradient - adding recurrence layers (lstm/gru/rnn) to CNNTD3 model | - |
4545
| MARL: TD3-G2ANet | G2ANet attention encoder for TD3 model in MARL setting | G2ANet adapted from https://github.com/starry-sky6688/MARL-Algorithms |
46-
| MARL: TD3-IGA | In-Graph Attention model for TD3 model in MARL setting | - |
46+
| MARL: TD3-IGS | In-Graph Softmax attention model for TD3 model in MARL setting | - |
4747

4848
**Max Upper Bound Models**
4949

robot_nav/marl_test_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def main(args=None):
6565
model_name="TDR-MARL-test",
6666
load_model_name="TDR-MARL-train",
6767
load_directory=Path("robot_nav/models/MARL/marlTD3/checkpoint"),
68-
attention="iga",
68+
attention="igs",
6969
) # instantiate a model
7070

7171
connections = torch.tensor(

robot_nav/marl_test_single.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def main(args=None):
329329
model_name="TDR-MARL-test",
330330
load_model_name="TDR-MARL-train",
331331
load_directory=Path("robot_nav/models/MARL/marlTD3/checkpoint"),
332-
attention="iga",
332+
attention="igs",
333333
) # instantiate a model
334334

335335
connections = torch.tensor(

robot_nav/models/MARL/marlTD3/marlTD3.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Actor(nn.Module):
2121
Args:
2222
action_dim (int): Number of action dimensions per agent.
2323
embedding_dim (int): Dimensionality of the attention embedding.
24-
attention (str): Attention backend, one of {"iga", "g2anet"}.
24+
attention (str): Attention backend, one of {"igs", "g2anet"}.
2525
2626
Attributes:
2727
attention (nn.Module): Attention encoder producing attended embeddings and
@@ -31,7 +31,7 @@ class Actor(nn.Module):
3131

3232
def __init__(self, action_dim, embedding_dim, attention):
3333
super().__init__()
34-
if attention == "iga":
34+
if attention == "igs":
3535
self.attention = Attention(embedding_dim)
3636
elif attention == "g2anet":
3737
self.attention = G2ANet(embedding_dim) # ➊ edge classifier
@@ -86,7 +86,7 @@ class Critic(nn.Module):
8686
Args:
8787
action_dim (int): Number of action dimensions per agent.
8888
embedding_dim (int): Dimensionality of the attention embedding.
89-
attention (str): Attention backend, one of {"iga", "g2anet"}.
89+
attention (str): Attention backend, one of {"igs", "g2anet"}.
9090
9191
Attributes:
9292
attention (nn.Module): Attention encoder producing attended embeddings and
@@ -97,7 +97,7 @@ class Critic(nn.Module):
9797
def __init__(self, action_dim, embedding_dim, attention):
9898
super(Critic, self).__init__()
9999
self.embedding_dim = embedding_dim
100-
if attention == "iga":
100+
if attention == "igs":
101101
self.attention = Attention(embedding_dim)
102102
elif attention == "g2anet":
103103
self.attention = G2ANet(embedding_dim) # ➊ edge classifier
@@ -189,7 +189,7 @@ class TD3(object):
189189
model_name (str, optional): Base filename for checkpoints. Defaults to "marlTD3".
190190
load_model_name (str or None, optional): Filename base to load. Defaults to None (uses model_name).
191191
load_directory (Path, optional): Directory to load checkpoints from.
192-
attention (str, optional): Attention backend, one of {"iga", "g2anet"}. Defaults to "iga".
192+
attention (str, optional): Attention backend, one of {"igs", "g2anet"}. Defaults to "igs".
193193
194194
Attributes:
195195
actor (Actor): Policy network.
@@ -220,10 +220,10 @@ def __init__(
220220
model_name="marlTD3",
221221
load_model_name=None,
222222
load_directory=Path("robot_nav/models/MARL/marlTD3/checkpoint"),
223-
attention="iga",
223+
attention="igs",
224224
):
225225
# Initialize the Actor network
226-
if attention not in ["iga", "g2anet"]:
226+
if attention not in ["igs", "g2anet"]:
227227
raise ValueError("unknown attention mechanism specified for TD3 model")
228228
self.num_robots = num_robots
229229
self.device = device

0 commit comments

Comments
 (0)