@@ -21,7 +21,7 @@ class Actor(nn.Module):
2121 Args:
2222 action_dim (int): Number of action dimensions per agent.
2323 embedding_dim (int): Dimensionality of the attention embedding.
24- attention (str): Attention backend, one of {"iga ", "g2anet"}.
24+ attention (str): Attention backend, one of {"igs ", "g2anet"}.
2525
2626 Attributes:
2727 attention (nn.Module): Attention encoder producing attended embeddings and
@@ -31,7 +31,7 @@ class Actor(nn.Module):
3131
3232 def __init__ (self , action_dim , embedding_dim , attention ):
3333 super ().__init__ ()
34- if attention == "iga " :
34+ if attention == "igs " :
3535 self .attention = Attention (embedding_dim )
3636 elif attention == "g2anet" :
3737 self .attention = G2ANet (embedding_dim ) # ➊ edge classifier
@@ -86,7 +86,7 @@ class Critic(nn.Module):
8686 Args:
8787 action_dim (int): Number of action dimensions per agent.
8888 embedding_dim (int): Dimensionality of the attention embedding.
89- attention (str): Attention backend, one of {"iga ", "g2anet"}.
89+ attention (str): Attention backend, one of {"igs ", "g2anet"}.
9090
9191 Attributes:
9292 attention (nn.Module): Attention encoder producing attended embeddings and
@@ -97,7 +97,7 @@ class Critic(nn.Module):
9797 def __init__ (self , action_dim , embedding_dim , attention ):
9898 super (Critic , self ).__init__ ()
9999 self .embedding_dim = embedding_dim
100- if attention == "iga " :
100+ if attention == "igs " :
101101 self .attention = Attention (embedding_dim )
102102 elif attention == "g2anet" :
103103 self .attention = G2ANet (embedding_dim ) # ➊ edge classifier
@@ -189,7 +189,7 @@ class TD3(object):
189189 model_name (str, optional): Base filename for checkpoints. Defaults to "marlTD3".
190190 load_model_name (str or None, optional): Filename base to load. Defaults to None (uses model_name).
191191 load_directory (Path, optional): Directory to load checkpoints from.
192- attention (str, optional): Attention backend, one of {"iga ", "g2anet"}. Defaults to "iga ".
192+ attention (str, optional): Attention backend, one of {"igs ", "g2anet"}. Defaults to "igs ".
193193
194194 Attributes:
195195 actor (Actor): Policy network.
@@ -220,10 +220,10 @@ def __init__(
220220 model_name = "marlTD3" ,
221221 load_model_name = None ,
222222 load_directory = Path ("robot_nav/models/MARL/marlTD3/checkpoint" ),
223- attention = "iga " ,
223+ attention = "igs " ,
224224 ):
225225 # Initialize the Actor network
226- if attention not in ["iga " , "g2anet" ]:
226+ if attention not in ["igs " , "g2anet" ]:
227227 raise ValueError ("unknown attention mechanism specified for TD3 model" )
228228 self .num_robots = num_robots
229229 self .device = device
0 commit comments