Skip to content

Commit b936f52

Browse files
Added support for stateful agents.
Signed-off-by: Matthew <[email protected]>
1 parent 37845b2 commit b936f52

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def __call__(
6464
and (not isinstance(rl_module[k], SelfSupervisedLossAPI))
6565
]
6666
critic_batch[Columns.OBS] = torch.cat(
67-
[batch[k][Columns.OBS] for k in obs_mids], dim=1
67+
[batch[k][Columns.OBS] for k in obs_mids], dim=-1
6868
)
6969
# Compute value predictions
7070
vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch)
71-
vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)}
71+
vf_preds = {mid: vf_preds[..., i] for i, mid in enumerate(obs_mids)}
7272
# Loop through all modules and perform each one's GAE computation.
7373
for module_id, module_vf_preds in vf_preds.items():
7474
module = rl_module[module_id]
@@ -136,10 +136,10 @@ def __call__(
136136
batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets
137137
# Add GAE results to the critic batch
138138
critic_batch[Postprocessing.VALUE_TARGETS] = np.stack(
139-
[batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=1
139+
[batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=-1
140140
)
141141
critic_batch[Postprocessing.ADVANTAGES] = np.stack(
142-
[batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=1
142+
[batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=-1
143143
)
144144
batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch
145145
# Convert all GAE results to tensors.

rllib/examples/multi_agent/pettingzoo_shared_value_function.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
from pettingzoo.sisl import waterworld_v4
4949

50+
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
5051
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
5152
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
5253
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
@@ -68,6 +69,11 @@
6869
default_timesteps=1000000,
6970
default_reward=0.0,
7071
)
72+
parser.add_argument(
73+
"--use-lstm",
74+
action="store_true",
75+
help="Whether to use LSTM encoders for the agents' policies.",
76+
)
7177

7278

7379
if __name__ == "__main__":
@@ -87,7 +93,10 @@ def get_env(_):
8793

8894
# An agent for each of our policies, and a single shared critic
8995
env_instantiated = get_env({}) # neccessary for non-agent modules
90-
specs = {p: RLModuleSpec() for p in policies}
96+
model_config = DefaultModelConfig(
97+
use_lstm=args.use_lstm,
98+
)
99+
specs = {p: RLModuleSpec(model_config=model_config) for p in policies}
91100
specs[SHARED_CRITIC_ID] = RLModuleSpec(
92101
module_class=SharedCriticTorchRLModule,
93102
observation_space=env_instantiated.observation_space[policies[0]],

0 commit comments

Comments
 (0)