Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5071568
Initial commit of working code. Will lint later and then submit PR.
MatthewCWeston Sep 5, 2025
1c842c3
Linted and cleaned up the code.
MatthewCWeston Sep 7, 2025
c002277
Minor formatting/debugging fixes to PR submission.
MatthewCWeston Sep 7, 2025
55e8989
Fixed an inheritance issue and added global observation handling.
MatthewCWeston Sep 7, 2025
5756bf9
Tabs-->spaces in BUILD to satisfy tests.
MatthewCWeston Sep 8, 2025
dc8fdc9
Merge remote-tracking branch 'upstream/master' into pettingzoo_shared_vf
MatthewCWeston Sep 29, 2025
2ec32d4
Cleaned up two lines of code, one of which had a bug in an edge case.
MatthewCWeston Sep 29, 2025
dc74d6f
Merge branch 'master' into pettingzoo_shared_vf
pseudo-rnd-thoughts Nov 10, 2025
bfa6c08
Re-linted and cleaned up a function signature.
MatthewCWeston Nov 11, 2025
060e2c3
Deleted an unused line of code.
MatthewCWeston Nov 11, 2025
323866a
Merged in the updates from the master branch.
MatthewCWeston Jan 7, 2026
37845b2
Re-added episode postproc and updated expected results.
MatthewCWeston Jan 7, 2026
e7b02c4
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston Jan 8, 2026
fd7fa91
Fixed actor module batching and loss masking.
MatthewCWeston Jan 11, 2026
5791729
Patched critic loss pooling.
MatthewCWeston Jan 11, 2026
8a69229
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston Jan 11, 2026
3f2f51b
Linted.
MatthewCWeston Jan 11, 2026
f645709
Merge branch 'pettingzoo_shared_vf' of https://github.com/MatthewCWes…
MatthewCWeston Jan 11, 2026
30f99db
Cleaned up code and improved inheritance.
MatthewCWeston Jan 19, 2026
fca0c28
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston Jan 19, 2026
14d48dd
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston Jan 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions rllib/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5004,15 +5004,24 @@ py_test(
],
)

# TODO (sven): Activate this test once this script is ready.
# py_test(
# name = "examples/multi_agent/pettingzoo_shared_value_function",
# main = "examples/multi_agent/pettingzoo_shared_value_function.py",
# tags = ["team:rllib", "exclusive", "examples"],
# size = "large",
# srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"],
# args = ["--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"],
# )
py_test(
name = "examples/multi_agent/pettingzoo_shared_value_function",
size = "large",
srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"],
args = [
"--num-agents=2",
"--as-test",
"--framework=torch",
"--stop-reward=-100.0",
"--num-cpus=4",
],
main = "examples/multi_agent/pettingzoo_shared_value_function.py",
tags = [
"examples",
"exclusive",
"team:rllib",
],
)

py_test(
name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Any, Dict, List

import numpy as np

from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets
from ray.rllib.utils.postprocessing.zero_padding import (
split_and_zero_pad_n_episodes,
unpad_data_if_necessary,
)
from ray.rllib.utils.typing import EpisodeType

torch, nn = try_import_torch()

SHARED_CRITIC_ID = "shared_critic"


class MAPPOGAEConnector(ConnectorV2):
def __init__(
self,
input_observation_space=None,
input_action_space=None,
*,
gamma,
lambda_,
):
super().__init__(input_observation_space, input_action_space)
self.gamma = gamma
self.lambda_ = lambda_
# Internal numpy-to-tensor connector to translate GAE results (advantages and
# vf targets) into tensors.
self._numpy_to_tensor_connector = None

@override(ConnectorV2)
def __call__(
self,
*,
rl_module: MultiRLModule,
episodes: List[EpisodeType],
batch: Dict[str, Any],
**kwargs,
):
# Device to place all GAE result tensors (advantages and value targets) on.
device = None
# Extract all single-agent episodes.
sa_episodes_list = list(
self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False)
)
# Perform the value net's forward pass.
critic_batch = {}
# Concatenate all agent observations in batch, using a fixed order
obs_mids = [
k
for k in sorted(batch.keys())
if (Columns.OBS in batch[k])
and (not isinstance(rl_module[k], SelfSupervisedLossAPI))
]
critic_batch[Columns.OBS] = torch.cat(
[batch[k][Columns.OBS] for k in obs_mids], dim=1
)
# Compute value predictions
vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch)
vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)}
# Loop through all modules and perform each one's GAE computation.
for module_id, module_vf_preds in vf_preds.items():
# Skip those outputs of RLModules that are not implementers of
# `ValueFunctionAPI`.
if module_vf_preds is None:
continue

module = rl_module[module_id]
device = module_vf_preds.device
# Convert to numpy for the upcoming GAE computations.
module_vf_preds = convert_to_numpy(module_vf_preds)

# Collect (single-agent) episode lengths for this particular module.
episode_lens = [
len(e) for e in sa_episodes_list if e.module_id in [None, module_id]
]

# Remove all zero-padding again, if applicable, for the upcoming
# GAE computations.
module_vf_preds = unpad_data_if_necessary(episode_lens, module_vf_preds)
# Compute value targets.
module_value_targets = compute_value_targets(
values=module_vf_preds,
rewards=unpad_data_if_necessary(
episode_lens,
convert_to_numpy(batch[module_id][Columns.REWARDS]),
),
terminateds=unpad_data_if_necessary(
episode_lens,
convert_to_numpy(batch[module_id][Columns.TERMINATEDS]),
),
truncateds=unpad_data_if_necessary(
episode_lens,
convert_to_numpy(batch[module_id][Columns.TRUNCATEDS]),
),
gamma=self.gamma,
lambda_=self.lambda_,
)
assert module_value_targets.shape[0] == sum(episode_lens)

module_advantages = module_value_targets - module_vf_preds
# Drop vf-preds, not needed in loss. Note that in the DefaultPPORLModule,
# vf-preds are recomputed with each `forward_train` call anyway to compute
# the vf loss.
# Standardize advantages (used for more stable and better weighted
# policy gradient computations).
module_advantages = (module_advantages - module_advantages.mean()) / max(
1e-4, module_advantages.std()
)

# Zero-pad the new computations, if necessary.
if module.is_stateful():
module_advantages = np.stack(
split_and_zero_pad_n_episodes(
module_advantages,
episode_lens=episode_lens,
max_seq_len=module.model_config["max_seq_len"],
),
axis=0,
)
module_value_targets = np.stack(
split_and_zero_pad_n_episodes(
module_value_targets,
episode_lens=episode_lens,
max_seq_len=module.model_config["max_seq_len"],
),
axis=0,
)
batch[module_id][Postprocessing.ADVANTAGES] = module_advantages
batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets
# Add GAE results to the critic batch
critic_batch[Postprocessing.VALUE_TARGETS] = np.stack(
[batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=1
)
critic_batch[Postprocessing.ADVANTAGES] = np.stack(
[batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=1
)
batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch
# Convert all GAE results to tensors.
if self._numpy_to_tensor_connector is None:
self._numpy_to_tensor_connector = NumpyToTensor(
as_learner_connector=True, device=device
)
tensor_results = self._numpy_to_tensor_connector(
rl_module=rl_module,
batch={
mid: {
Postprocessing.ADVANTAGES: module_batch[Postprocessing.ADVANTAGES],
Postprocessing.VALUE_TARGETS: (
module_batch[Postprocessing.VALUE_TARGETS]
),
}
for mid, module_batch in batch.items()
if (mid == SHARED_CRITIC_ID) or (mid in vf_preds)
},
episodes=episodes,
)
# Move converted tensors back to `batch`.
for mid, module_batch in tensor_results.items():
batch[mid].update(module_batch)

return batch
54 changes: 54 additions & 0 deletions rllib/examples/algorithms/mappo/default_mappo_rl_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import abc
from typing import List

from ray.rllib.core.models.configs import RecurrentEncoderConfig
from ray.rllib.core.rl_module.apis import InferenceOnlyAPI
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import (
OverrideToImplementCustomLogic_CallToSuperRecommended,
override,
)
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC):
"""Default RLModule used by MAPPO, if user does not specify a custom RLModule.

Users who want to train their RLModules with MAPPO may implement any RLModule (or TorchRLModule) subclass.
"""

@override(RLModule)
def setup(self):
# __sphinx_doc_begin__
# If we have a stateful model, states for the critic need to be collected
# during sampling and `inference-only` needs to be `False`. Note, at this
# point the encoder is not built, yet and therefore `is_stateful()` does
# not work.
is_stateful = isinstance(
self.catalog.encoder_config,
RecurrentEncoderConfig,
)
if is_stateful:
self.inference_only = False
# If this is an `inference_only` Module, we'll have to pass this information
# to the encoder config as well.
if self.inference_only and self.framework == "torch":
self.catalog.encoder_config.inference_only = True
# Build models from catalog.
self.encoder = self.catalog.build_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)
# __sphinx_doc_end__

@override(RLModule)
def get_initial_state(self) -> dict:
if hasattr(self.encoder, "get_initial_state"):
return self.encoder.get_initial_state()
else:
return {}

@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(InferenceOnlyAPI)
def get_non_inference_attributes(self) -> List[str]:
"""Return attributes, which are NOT inference-only (only used for training)."""
return []
Loading