From 5071568f6aac6230ffa01bbdf648656684778144 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 5 Sep 2025 17:27:02 -0500 Subject: [PATCH 01/13] Initial commit of working code. Will lint later and then submit PR. Signed-off-by: Matthew --- .../general_advantage_estimation.py | 168 +++++++++++++ .../mappo/default_mappo_rl_module.py | 61 +++++ rllib/examples/algorithms/mappo/mappo.py | 229 +++++++++++++++++ .../algorithms/mappo/mappo_catalog.py | 137 ++++++++++ .../algorithms/mappo/mappo_learner.py | 149 +++++++++++ .../algorithms/mappo/shared_critic_catalog.py | 78 ++++++ .../mappo/shared_critic_rl_module.py | 48 ++++ .../torch/default_mappo_torch_rl_module.py | 51 ++++ .../mappo/torch/mappo_torch_learner.py | 234 ++++++++++++++++++ .../torch/shared_critic_torch_rl_module.py | 36 +++ .../pettingzoo_shared_value_function.py | 118 ++++++++- .../multi_agent/shared_encoder_cartpole.py | 2 +- 12 files changed, 1305 insertions(+), 6 deletions(-) create mode 100644 rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py create mode 100644 rllib/examples/algorithms/mappo/default_mappo_rl_module.py create mode 100644 rllib/examples/algorithms/mappo/mappo.py create mode 100644 rllib/examples/algorithms/mappo/mappo_catalog.py create mode 100644 rllib/examples/algorithms/mappo/mappo_learner.py create mode 100644 rllib/examples/algorithms/mappo/shared_critic_catalog.py create mode 100644 rllib/examples/algorithms/mappo/shared_critic_rl_module.py create mode 100644 rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py create mode 100644 rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py create mode 100644 rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py new file mode 100644 index 000000000000..6f5353b12b8e --- /dev/null +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -0,0 +1,168 @@ +from typing import Any, List, Dict + +import numpy as np +import torch +from collections import defaultdict + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets +from ray.rllib.utils.postprocessing.zero_padding import ( + split_and_zero_pad_n_episodes, + unpad_data_if_necessary, +) +from ray.rllib.utils.typing import EpisodeType + + +SHARED_CRITIC_ID = "shared_critic" + + +class MAPPOGAEConnector(ConnectorV2): + def __init__( + self, + input_observation_space=None, + input_action_space=None, + *, + gamma, + lambda_, + ): + super().__init__(input_observation_space, input_action_space) + self.gamma = gamma + self.lambda_ = lambda_ + # Internal numpy-to-tensor connector to translate GAE results (advantages and + # vf targets) into tensors. + self._numpy_to_tensor_connector = None + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: MultiRLModule, + episodes: List[EpisodeType], + batch: Dict[str, Any], + **kwargs, + ): + # Device to place all GAE result tensors (advantages and value targets) on. + device = None + + # Extract all single-agent episodes. + sa_episodes_list = list( + self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False) + ) + # Perform the value nets' forward passes. + # TODO (sven): We need to check here in the pipeline already, whether a module + # should even be updated or not (which we usually do after(!) the Learner + # pipeline). This is an open TODO to move this filter into a connector as well. + # For now, we'll just check, whether `mid` is in batch and skip if it isn't. + # For MAPPO, we can't check ValueFunctionAPI, so we check for the presence of an observation and a lack of self-supervision instead. + vf_preds = rl_module.foreach_module( + func=lambda mid, module: ( + rl_module[SHARED_CRITIC_ID].compute_values(batch[mid]) + if (mid in batch) and (Columns.OBS in batch[mid]) + and (not isinstance(module, SelfSupervisedLossAPI)) + else None + ), + return_dict=True, + ) + # Loop through all modules and perform each one's GAE computation. + for module_id, module_vf_preds in vf_preds.items(): + # Skip those outputs of RLModules that are not implementers of + # `ValueFunctionAPI`. + if module_vf_preds is None: + continue + + module = rl_module[module_id] + device = module_vf_preds.device + # Convert to numpy for the upcoming GAE computations. + module_vf_preds = convert_to_numpy(module_vf_preds) + + # Collect (single-agent) episode lengths for this particular module. + episode_lens = [ + len(e) for e in sa_episodes_list if e.module_id in [None, module_id] + ] + + # Remove all zero-padding again, if applicable, for the upcoming + # GAE computations. + module_vf_preds = unpad_data_if_necessary(episode_lens, module_vf_preds) + # Compute value targets. + module_value_targets = compute_value_targets( + values=module_vf_preds, + rewards=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.REWARDS]), + ), + terminateds=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.TERMINATEDS]), + ), + truncateds=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.TRUNCATEDS]), + ), + gamma=self.gamma, + lambda_=self.lambda_, + ) + assert module_value_targets.shape[0] == sum(episode_lens) + + module_advantages = module_value_targets - module_vf_preds + # Drop vf-preds, not needed in loss. Note that in the DefaultPPORLModule, + # vf-preds are recomputed with each `forward_train` call anyway to compute + # the vf loss. + # Standardize advantages (used for more stable and better weighted + # policy gradient computations). + module_advantages = (module_advantages - module_advantages.mean()) / max( + 1e-4, module_advantages.std() + ) + + # Zero-pad the new computations, if necessary. + if module.is_stateful(): + module_advantages = np.stack( + split_and_zero_pad_n_episodes( + module_advantages, + episode_lens=episode_lens, + max_seq_len=module.model_config["max_seq_len"], + ), + axis=0, + ) + module_value_targets = np.stack( + split_and_zero_pad_n_episodes( + module_value_targets, + episode_lens=episode_lens, + max_seq_len=module.model_config["max_seq_len"], + ), + axis=0, + ) + batch[module_id][Postprocessing.ADVANTAGES] = module_advantages + batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets + + # Convert all GAE results to tensors. + if self._numpy_to_tensor_connector is None: + self._numpy_to_tensor_connector = NumpyToTensor( + as_learner_connector=True, device=device + ) + tensor_results = self._numpy_to_tensor_connector( + rl_module=rl_module, + batch={ + mid: { + Postprocessing.ADVANTAGES: module_batch[Postprocessing.ADVANTAGES], + Postprocessing.VALUE_TARGETS: ( + module_batch[Postprocessing.VALUE_TARGETS] + ), + } + for mid, module_batch in batch.items() + if vf_preds[mid] is not None + }, + episodes=episodes, + ) + # Move converted tensors back to `batch`. + for mid, module_batch in tensor_results.items(): + batch[mid].update(module_batch) + + return batch \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py new file mode 100644 index 000000000000..137a8284b1de --- /dev/null +++ b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py @@ -0,0 +1,61 @@ +import abc +from typing import List + +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC): + """Default RLModule used by PPO, if user does not specify a custom RLModule. + + Users who want to train their RLModules with PPO may implement any RLModule + (or TorchRLModule) subclass as long as the custom class also implements the + `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + """ + + @override(RLModule) + def setup(self): + try: + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config.base_encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) + except Exception as e: + print("Error in DefaultMAPPORLModule setup") + print(e) + raise e + # __sphinx_doc_end__ + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + """Return attributes, which are NOT inference-only (only used for training).""" + return [] \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py new file mode 100644 index 000000000000..0d4d9bd7ac73 --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -0,0 +1,229 @@ +import logging +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.execution.rollout_ops import ( + standardize_fields, + synchronous_parallel_sample, +) +from ray.rllib.execution.train_ops import ( + train_one_step, + multi_gpu_train_one_step, +) +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + ENV_RUNNER_SAMPLING_TIMER, + LEARNER_RESULTS, + LEARNER_UPDATE_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_LIFETIME, + SYNCH_WORKER_WEIGHTS_TIMER, + SAMPLE_TIMER, + TIMERS, + ALL_MODULES, +) +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ResultDict +from ray.util.debug import log_once + +from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO + +if TYPE_CHECKING: + from ray.rllib.core.learner.learner import Learner + +from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import MAPPOTorchLearner +from ray.rllib.examples.algorithms.mappo.torch.default_mappo_torch_rl_module import DefaultMAPPOTorchRLModule + +logger = logging.getLogger(__name__) + +LEARNER_RESULTS_KL_KEY = "mean_kl_loss" +LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff" +LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" + + +class MAPPOConfig(PPOConfig): # AlgorithmConfig -> PPOConfig -> MAPPO + """Defines a configuration class from which a MAPPO Algorithm can be built. + """ + + def __init__(self, algo_class=None): + """Initializes a MAPPOConfig instance.""" + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + + super().__init__(algo_class=algo_class or PPO) + + # fmt: off + # __sphinx_doc_begin__ + self.lr = 5e-5 + self.rollout_fragment_length = "auto" + self.train_batch_size = 4000 + + # PPO specific settings: + self.num_epochs = 30 + self.minibatch_size = 128 + self.shuffle_batch_per_epoch = True + self.lambda_ = 1.0 + self.use_kl_loss = True + self.kl_coeff = 0.2 + self.kl_target = 0.01 + self.entropy_coeff = 0.0 + self.clip_param = 0.3 + self.grad_clip = None + + # Override some of AlgorithmConfig's default values with PPO-specific values. + self.num_env_runners = 2 + # __sphinx_doc_end__ + # fmt: on + + self.entropy_coeff_schedule = None # OldAPIStack + self.lr_schedule = None # OldAPIStack + + # Deprecated keys. + self.sgd_minibatch_size = DEPRECATED_VALUE + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpec: + if self.framework_str == "torch": + return RLModuleSpec(module_class=DefaultMAPPOTorchRLModule) + raise NotImplementedError() + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + if self.framework_str == "torch": + return MAPPOTorchLearner + raise NotImplementedError() + + @override(AlgorithmConfig) + def training( + self, + *, + lambda_: Optional[float] = NotProvided, + use_kl_loss: Optional[bool] = NotProvided, + kl_coeff: Optional[float] = NotProvided, + kl_target: Optional[float] = NotProvided, + entropy_coeff: Optional[float] = NotProvided, + entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + clip_param: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + # OldAPIStack + lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + **kwargs, + ) -> "PPOConfig": + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + if lambda_ is not NotProvided: + self.lambda_ = lambda_ + if use_kl_loss is not NotProvided: + self.use_kl_loss = use_kl_loss + if kl_coeff is not NotProvided: + self.kl_coeff = kl_coeff + if kl_target is not NotProvided: + self.kl_target = kl_target + if entropy_coeff is not NotProvided: + self.entropy_coeff = entropy_coeff + if clip_param is not NotProvided: + self.clip_param = clip_param + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + + # TODO (sven): Remove these once new API stack is only option for PPO. + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if entropy_coeff_schedule is not NotProvided: + self.entropy_coeff_schedule = entropy_coeff_schedule + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + # Synchronous sampling, on-policy/PPO algos -> Check mismatches between + # `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user + # confusion. + # TODO (sven): Make rollout_fragment_length a property and create a private + # attribute to store (possibly) user provided value (or "auto") in. Deprecate + # `self.get_rollout_fragment_length()`. + self.validate_train_batch_size_vs_rollout_fragment_length() + + # SGD minibatch size must be smaller than train_batch_size (b/c + # we subsample a batch of `minibatch_size` from the train-batch for + # each `num_epochs`). + if ( + not self.enable_rl_module_and_learner + and self.minibatch_size > self.train_batch_size + ): + self._value_error( + f"`minibatch_size` ({self.minibatch_size}) must be <= " + f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch" + f" will be split into {self.minibatch_size} chunks, each of which " + f"is iterated over (used for updating the policy) {self.num_epochs} " + "times." + ) + elif self.enable_rl_module_and_learner: + mbs = self.minibatch_size + tbs = self.train_batch_size_per_learner or self.train_batch_size + if isinstance(mbs, int) and isinstance(tbs, int) and mbs > tbs: + self._value_error( + f"`minibatch_size` ({mbs}) must be <= " + f"`train_batch_size_per_learner` ({tbs}). In PPO, the train batch" + f" will be split into {mbs} chunks, each of which is iterated over " + f"(used for updating the policy) {self.num_epochs} times." + ) + + # Episodes may only be truncated (and passed into PPO's + # `postprocessing_fn`), iff generalized advantage estimation is used + # (value function estimate at end of truncated episode to estimate + # remaining value). + if ( + not self.in_evaluation + and self.batch_mode == "truncate_episodes" + and not self.use_gae + ): + self._value_error( + "Episode truncation is not supported without a value " + "function (to estimate the return at the end of the truncated" + " trajectory). Consider setting " + "batch_mode=complete_episodes." + ) + + # New API stack checks. + if self.enable_rl_module_and_learner: + # `lr_schedule` checking. + if self.lr_schedule is not None: + self._value_error( + "`lr_schedule` is deprecated and must be None! Use the " + "`lr` setting to setup a schedule." + ) + if self.entropy_coeff_schedule is not None: + self._value_error( + "`entropy_coeff_schedule` is deprecated and must be None! Use the " + "`entropy_coeff` setting to setup a schedule." + ) + Scheduler.validate( + fixed_value_or_schedule=self.entropy_coeff, + setting_name="entropy_coeff", + description="entropy coefficient", + ) + if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0: + self._value_error("`entropy_coeff` must be >= 0.0") + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self) -> Dict[str, Any]: + return super()._model_config_auto_includes | {} \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/mappo_catalog.py b/rllib/examples/algorithms/mappo/mappo_catalog.py new file mode 100644 index 000000000000..216b0fe04968 --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo_catalog.py @@ -0,0 +1,137 @@ +# @title MAPPOCatalog +# __sphinx_doc_begin__ +import gymnasium as gym + +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import ( + ActorCriticEncoderConfig, + MLPHeadConfig, + FreeLogStdMLPHeadConfig, +) +from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model +from ray.rllib.utils import override +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + + +from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian + + +class MAPPOCatalog(Catalog): + """The Catalog class used to build models for MAPPO. + + MAPPOCatalog provides the following models: + - ActorCriticEncoder: The encoder used to encode the observations. + - Pi Head: The head used to compute the policy logits. + + The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs + for the policy and value function. See implementations of DefaultPPORLModule for + more details. + + Any custom ActorCriticEncoder can be built by overriding the + build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig + at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom + ActorCriticEncoder during RLModule runtime. + + Any custom head can be built by overriding the build_pi_head() and build_vf_head() + methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to + build custom heads during RLModule runtime. + + Any module built for exploration or inference is built with the flag + `ìnference_only=True` and does not contain a value network. This flag can be set + in the `SingleAgentModuleSpec` through the `inference_only` boolean flag. + In case that the actor-critic-encoder is not shared between the policy and value + function, the inference-only module will contain only the actor encoder network. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + ): + """Initializes the PPOCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + # + self.encoder_config = ActorCriticEncoderConfig( + base_encoder_config=self._encoder_config, + shared=True # Since we don't want to instantiate an extra network + ) + + self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_and_vf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] + + # We don't have the exact (framework specific) action dist class yet and thus + # cannot determine the exact number of output nodes (action space) required. + # -> Build pi config only in the `self.build_pi_head` method. + self.pi_head_config = None + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the encoder. + + Since PPO uses an ActorCriticEncoder, this method should not be implemented. + """ + return self.encoder_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_pi_head(self, framework: str) -> Model: + """Builds the policy head. + + The default behavior is to build the head from the pi_head_config. + This can be overridden to build a custom policy head as a means of configuring + the behavior of a PPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The policy head. + """ + # Get action_distribution_cls to find out about the output dimension for pi_head + action_distribution_cls = self.get_action_dist_cls(framework=framework) + if self._model_config_dict["free_log_std"]: + _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, framework=framework + ) + is_diag_gaussian = True + else: + is_diag_gaussian = _check_if_diag_gaussian( + action_distribution_cls=action_distribution_cls, + framework=framework, + no_error=True, + ) + required_output_dim = action_distribution_cls.required_input_dim( + space=self.action_space, model_config=self._model_config_dict + ) + # Now that we have the action dist class and number of outputs, we can define + # our pi-config and build the pi head. + pi_head_config_class = ( + FreeLogStdMLPHeadConfig + if self._model_config_dict["free_log_std"] + else MLPHeadConfig + ) + self.pi_head_config = pi_head_config_class( + input_dims=self.latent_dims, + hidden_layer_dims=self.pi_and_vf_head_hiddens, + hidden_layer_activation=self.pi_and_vf_head_activation, + output_layer_dim=required_output_dim, + output_layer_activation="linear", + clip_log_std=is_diag_gaussian, + log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), + ) + + return self.pi_head_config.build(framework=framework) + +# __sphinx_doc_end__ \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py new file mode 100644 index 000000000000..547d2ca5ffaf --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -0,0 +1,149 @@ +# @title MAPPOLearner +import abc +from typing import Any, Dict + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, + PPOConfig, +) +from ray.rllib.connectors.learner import ( + AddOneTsToEpisodesAndTruncate, + GeneralAdvantageEstimation, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.metrics import ( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_MODULE_STEPS_TRAINED, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ModuleID, TensorType + +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import MAPPOGAEConnector, SHARED_CRITIC_ID + +class MAPPOLearner(Learner): + + # Deal with GAE somehow. Maybe skip it and move the logic over here, into the value loss calculation method. Saves us a VF pass, too. + @override(Learner) + def build(self) -> None: + super().build() # We call Learner's build function, not PPOLearner's + + # Dict mapping module IDs to the respective entropy Scheduler instance. + self.entropy_coeff_schedulers_per_module: Dict[ + ModuleID, Scheduler + ] = LambdaDefaultDict( + lambda module_id: Scheduler( + fixed_value_or_schedule=( + self.config.get_config_for_module(module_id).entropy_coeff + ), + framework=self.framework, + device=self._device, + ) + ) + + # Set up KL coefficient variables (per module). + # Note that the KL coeff is not controlled by a Scheduler, but seeks + # to stay close to a given kl_target value. + self.curr_kl_coeffs_per_module: Dict[ModuleID, TensorType] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable( + self.config.get_config_for_module(module_id).kl_coeff + ) + ) + + # Extend all episodes by one artificial timestep to allow the value function net + # to compute the bootstrap values (and add a mask to the batch to know, which + # slots to mask out). + if ( + self._learner_connector is not None + and self.config.add_default_connectors_to_learner_pipeline + ): + # At the end of the pipeline (when the batch is already completed), add the + # GAE connector, which performs a vf forward pass, then computes the GAE + # computations, and puts the results of this (advantages, value targets) + # directly back in the batch. This is then the batch used for + # `forward_train` and `compute_losses`. + self._learner_connector.append( + MAPPOGAEConnector( + gamma=self.config.gamma, + lambda_=self.config.lambda_, + ) + ) + + @override(Learner) + def remove_module(self, module_id: ModuleID, **kwargs): + marl_spec = super().remove_module(module_id, **kwargs) + if (module_id != SHARED_CRITIC_ID): + self.entropy_coeff_schedulers_per_module.pop(module_id, None) + self.curr_kl_coeffs_per_module.pop(module_id, None) + return marl_spec + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update( + self, + *, + timesteps: Dict[str, Any], + ) -> None: + super().after_gradient_based_update(timesteps=timesteps) + + for module_id, module in self.module._rl_modules.items(): + if (module_id == SHARED_CRITIC_ID): + continue # Policy terms irrelevant to shared critic. + config = self.config.get_config_for_module(module_id) + # Update entropy coefficient via our Scheduler. + new_entropy_coeff = self.entropy_coeff_schedulers_per_module[ + module_id + ].update(timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)) + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY), + new_entropy_coeff, + window=1, + ) + if ( + config.use_kl_loss + and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0) + > 0 + and (module_id, LEARNER_RESULTS_KL_KEY) in self.metrics + ): + kl_loss = convert_to_numpy( + self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)) + ) + self._update_module_kl_coeff( + module_id=module_id, + config=config, + kl_loss=kl_loss, + ) + + @classmethod + @override(Learner) + def rl_module_required_apis(cls) -> list[type]: + # We no longer require value functions for modules, since there's a central critic + return [] + + @abc.abstractmethod + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + """Dynamically update the KL loss coefficients of each module. + + The update is completed using the mean KL divergence between the action + distributions current policy and old policy of each module. That action + distribution is computed during the most recent update/call to `compute_loss`. + + Args: + module_id: The module whose KL loss coefficient to update. + config: The AlgorithmConfig specific to the given `module_id`. + kl_loss: The mean KL loss of the module, computed inside + `compute_loss_for_module()`. + """ \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py new file mode 100644 index 000000000000..48a9b1c9a46e --- /dev/null +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -0,0 +1,78 @@ +import gymnasium as gym + +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import ( + ActorCriticEncoderConfig, + MLPHeadConfig, + FreeLogStdMLPHeadConfig, +) +from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model +from ray.rllib.utils import override +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + +from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian + +class SharedCriticCatalog(Catalog): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, # TODO: Remove? + model_config_dict: dict, + ): + """Initializes the PPOCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, # We shouldn't need to provide this. I should reconfigure at some point. + model_config_dict=model_config_dict, + ) + # We only want one encoder, so we use the base encoder config. + self._encoder_config.shared = True + self.encoder_config = ActorCriticEncoderConfig( + base_encoder_config=self._encoder_config, + shared=True, # Because we only want one network + ) + self.vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.vf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] + + self.vf_head_config = MLPHeadConfig( + input_dims=self.latent_dims, + hidden_layer_dims=self.vf_head_hiddens, + hidden_layer_activation=self.vf_head_activation, + output_layer_activation="linear", + output_layer_dim=1, + ) + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the encoder. + + Since PPO uses an ActorCriticEncoder, this method should not be implemented. + """ + return self.encoder_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_vf_head(self, framework: str) -> Model: + """Builds the value function head. + + The default behavior is to build the head from the vf_head_config. + This can be overridden to build a custom value function head as a means of + configuring the behavior of a PPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The value function head. + """ + return self.vf_head_config.build(framework=framework) + + +# __sphinx_doc_end__ \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py new file mode 100644 index 000000000000..237af7073ade --- /dev/null +++ b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py @@ -0,0 +1,48 @@ +import abc +from typing import List + +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.core.rl_module.apis import ValueFunctionAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class SharedCriticRLModule(RLModule, ValueFunctionAPI, abc.ABC): + """Default RLModule used by PPO, if user does not specify a custom RLModule. + + Users who want to train their RLModules with PPO may implement any RLModule + (or TorchRLModule) subclass as long as the custom class also implements the + `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + """ + + @override(RLModule) + def setup(self): + try: + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config.base_encoder_config, + RecurrentEncoderConfig, + ) + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.vf = self.catalog.build_vf_head(framework=self.framework) + except Exception as e: + print(e) + raise e + # __sphinx_doc_end__ + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py new file mode 100644 index 000000000000..1317bcbf00cf --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional + +from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import DeveloperAPI + +from ray.rllib.examples.algorithms.mappo.mappo_catalog import MAPPOCatalog +from ray.rllib.examples.algorithms.mappo.default_mappo_rl_module import DefaultMAPPORLModule +from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import SharedCriticCatalog + +torch, nn = try_import_torch() + +@DeveloperAPI +class DefaultMAPPOTorchRLModule(TorchRLModule, DefaultMAPPORLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = MAPPOCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Default forward pass (used for inference and exploration).""" + output = {} + # Encoder forward pass. + encoder_outs = self.encoder(batch) + # Stateful encoder? + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + # Pi head. + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output + + @override(RLModule) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Train forward pass (keep embeddings for possible shared value func. call).""" + output = {} + encoder_outs = self.encoder(batch) + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + return output \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py new file mode 100644 index 000000000000..1a9e4fc85e5b --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -0,0 +1,234 @@ +# @title MAPPOTorchLearner +import logging +from typing import Any, Dict +from collections.abc import Callable + +import numpy as np + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_KL_KEY, + LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, + PPOConfig, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import Learner, POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import explained_variance +from ray.rllib.utils.typing import ModuleID, TensorType + +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI + +from ray.rllib.examples.algorithms.mappo.mappo_learner import MAPPOLearner +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import SHARED_CRITIC_ID + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class MAPPOTorchLearner(MAPPOLearner, TorchLearner): + def get_pmm( + self, + batch: Dict[str, Any] + ) -> Callable: + """ Gets the possibly_masked_mean function """ + if Columns.LOSS_MASK in batch: + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + else: + possibly_masked_mean = torch.mean + return possibly_masked_mean + + """ + Implements MAPPO in Torch, on top of a MAPPOLearner. + """ + def compute_loss_for_critic( + self, + batch: Dict[str, Any] + ): + """ + Computes loss for critic, and returns a list of advantages and rewards for the target batch. + """ + possibly_masked_mean = self.get_pmm(batch) + module = self.module[SHARED_CRITIC_ID].unwrapped() + vf_preds = module.compute_values(batch) + vf_targets = batch[Postprocessing.VALUE_TARGETS] + # Compute a value function loss. + vf_loss = torch.pow(vf_preds - vf_targets, 2.0) + vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) + mean_vf_loss = possibly_masked_mean(vf_loss_clipped) + mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) + # record metrics + self.metrics.log_dict( + { + VF_LOSS_KEY: mean_vf_loss, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( + vf_targets, vf_preds + ), + }, + key=SHARED_CRITIC_ID, + window=1, + ) + return mean_vf_loss + + @override(Learner) + def compute_losses( + self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Args: + fwd_out: Output from a call to the `forward_train()` method of the + underlying MultiRLModule (`self.module`) during training + (`self.update()`). + batch: The train batch that was used to compute `fwd_out`. + + Returns: + A dictionary mapping module IDs to individual loss terms. + """ + loss_per_module = {SHARED_CRITIC_ID: 0} + # Calculate loss for agent policies + for module_id in fwd_out: + if (module_id == SHARED_CRITIC_ID): # Computed for each module + continue + # + module = self.module[module_id].unwrapped() + if isinstance(module, SelfSupervisedLossAPI): + # For e.g. enabling intrinsic curiosity modules. + loss = module.compute_self_supervised_loss( + learner=self, + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + else: + module_batch = batch[module_id] + module_fwd_out = fwd_out[module_id] + # For every module we're going to touch, sans the critic + loss = self.compute_loss_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + # Optimize the critic + loss_per_module[SHARED_CRITIC_ID] += self.compute_loss_for_critic(module_batch) + loss_per_module[module_id] = loss + return loss_per_module + + # We strip out the value function optimization here. + @override(TorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: PPOConfig, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + module = self.module[module_id].unwrapped() + possibly_masked_mean = self.get_pmm(batch) + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for PPO's batched value function (and bootstrap value) computations, + # for which we add an (artificial) timestep to each episode to + # simplify the actual computation. + action_dist_class_train = module.get_train_action_dist_cls() + action_dist_class_exploration = module.get_exploration_action_dist_cls() + + curr_action_dist = action_dist_class_train.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + prev_action_dist = action_dist_class_exploration.from_logits( + batch[Columns.ACTION_DIST_INPUTS] + ) + + logp_ratio = torch.exp( + curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if config.use_kl_loss: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = possibly_masked_mean(action_kl) + else: + mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = possibly_masked_mean(curr_entropy) + + surrogate_loss = torch.min( + batch[Postprocessing.ADVANTAGES] * logp_ratio, + batch[Postprocessing.ADVANTAGES] + * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param), + ) + # Remove critic loss from per-module computation + total_loss = possibly_masked_mean( + -surrogate_loss + - ( + self.entropy_coeff_schedulers_per_module[module_id].get_current_value() + * curr_entropy + ) + ) + + # Add mean_kl_loss (already processed through `possibly_masked_mean`), + # if necessary. + if config.use_kl_loss: + total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss + + # Log important loss stats. + self.metrics.log_dict( + { + POLICY_LOSS_KEY: -possibly_masked_mean(surrogate_loss), + ENTROPY_KEY: mean_entropy, + LEARNER_RESULTS_KL_KEY: mean_kl_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # Return the total loss. + return total_loss + + # Same as PPOTorchLearner + @override(MAPPOLearner) + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + if np.isnan(kl_loss): + logger.warning( + f"KL divergence for Module {module_id} is non-finite, this " + "will likely destabilize your model and the training " + "process. Action(s) in a specific state have near-zero " + "probability. This can happen naturally in deterministic " + "environments where the optimal policy has zero mass for a " + "specific action. To fix this issue, consider setting " + "`kl_coeff` to 0.0 or increasing `entropy_coeff` in your " + "config." + ) + + # Update the KL coefficient. + curr_var = self.curr_kl_coeffs_per_module[module_id] + if kl_loss > 2.0 * config.kl_target: + # TODO (Kourosh) why not 2? + curr_var.data *= 1.5 + elif kl_loss < 0.5 * config.kl_target: + curr_var.data *= 0.5 + + # Log the updated KL-coeff value. + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY), + curr_var.item(), + window=1, + ) \ No newline at end of file diff --git a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py new file mode 100644 index 000000000000..e35c00b6f342 --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py @@ -0,0 +1,36 @@ +import typing +from typing import Any, Optional + +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + +from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import SharedCriticRLModule +from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import SharedCriticCatalog + +@DeveloperAPI +class SharedCriticTorchRLModule(TorchRLModule, SharedCriticRLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = SharedCriticCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(ValueFunctionAPI) + def compute_values( + self, + batch: typing.Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] + # Value head. + vf_out = self.vf(embeddings) + # Squeeze out last dimension (single node value head). + return vf_out.squeeze(-1) \ No newline at end of file diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index e2c8bb9a4ffb..2bebf4b05bdc 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -1,7 +1,115 @@ -msg = """ -This script is NOT yet ready, but will be available soon at this location. It will -feature a MultiRLModule with one shared value function and n policy heads for -cooperative multi-agent learning. +"""Runs the PettingZoo Waterworld env in RLlib using a shared critic. + +See: https://pettingzoo.farama.org/environments/sisl/waterworld/ +for more details on the environment. + + +How to run this script +---------------------- +`python [script file name].py --num-agents=2` + +Control the number of agents and policies (RLModules) via --num-agents and +--num-policies. + +This works with hundreds of agents and policies, but note that initializing +many policies might take some time. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +The above options can reach a combined reward of _ or more after about _ env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve. For reasons similar to those described in pettingzoo_parameter_sharing.py, learning may take slightly longer than in fully-independent settings, as agents are less inclined to specialize and thereby balance out one anothers' mistakes. + ++---------------------+------------+--------------------+--------+------------------+ +| Trial name | status | loc | iter | total time (s) | +|---------------------+------------+--------------------+--------+------------------+ +| PPO_env_c90f4_00000 | TERMINATED | 172.29.87.208:6322 | 101 | 1269.24 | ++---------------------+------------+--------------------+--------+------------------+ + +--------+-------------------+--------------------+--------------------+ + ts | combined return | return pursuer_0 | return pursuer_1 | +--------+-------------------+--------------------+--------------------| + 404000 | 1.31496 | 48.0908 | -46.7758 | +--------+-------------------+--------------------+--------------------+ + +Note that the two agents (`pursuer_0` and `pursuer_1`) are optimized on the exact same +objective and thus differences in the rewards can be attributed to weight initialization +(and sampling randomness) only. """ -raise NotImplementedError(msg) +from pettingzoo.sisl import waterworld_v4 + +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env + +from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import SHARED_CRITIC_ID +from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import SharedCriticTorchRLModule + + +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=1000000, + default_reward=0.0, +) + + +if __name__ == "__main__": + args = parser.parse_args() + + assert args.num_agents > 0, "Must set --num-agents > 0 when running this script!" + + # Here, we use the "Agent Environment Cycle" (AEC) PettingZoo environment type. + # For a "Parallel" environment example, see the rock paper scissors examples + # in this same repository folder. + get_env = lambda _: PettingZooEnv(waterworld_v4.env()) + register_env("env", get_env) + + # Policies are called just like the agents (exact 1:1 mapping). + policies = [f"pursuer_{i}" for i in range(args.num_agents)] + + # An agent for each of our policies, and a single shared critic + env_instantiated = get_env({}) # neccessary for non-agent modules + specs = {p: RLModuleSpec() for p in policies} + specs[SHARED_CRITIC_ID] = RLModuleSpec( + module_class=SharedCriticTorchRLModule, + observation_space=env_instantiated.observation_spaces[policies[0]], + action_space=env_instantiated.action_spaces[policies[0]], + learner_only=True, # Only build on learner + model_config={}, + ) + + base_config = ( + MAPPOConfig() + .environment("env") + .multi_agent( + policies=policies, + # Exact 1:1 mapping from AgentID to ModuleID. + policy_mapping_fn=(lambda aid, *args, **kwargs: aid), + ) + .training( + vf_loss_coeff=0.005, + ) + .rl_module( + rl_module_spec=MultiRLModuleSpec( + rl_module_specs=specs, + ), + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/multi_agent/shared_encoder_cartpole.py b/rllib/examples/multi_agent/shared_encoder_cartpole.py index caea04adef8c..75044a4c4ad6 100644 --- a/rllib/examples/multi_agent/shared_encoder_cartpole.py +++ b/rllib/examples/multi_agent/shared_encoder_cartpole.py @@ -20,7 +20,7 @@ Results to expect ----------------- -Under the shared encoder architecture, the target reward of 700 will typically be reached well before 100,000 iterations. A trial concludes as below: +Under the shared encoder architecture, the target reward of 600 will typically be reached well before 100,000 iterations. A trial concludes as below: +---------------------+------------+-----------------+--------+------------------+-------+-------------------+-------------+-------------+ | Trial name | status | loc | iter | total time (s) | ts | combined return | return p1 | return p0 | From 1c842c3a60018e1502aa9d321ed9c4aa9c3ae263 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sat, 6 Sep 2025 19:45:17 -0500 Subject: [PATCH 02/13] Linted and cleaned up the code. Signed-off-by: Matthew --- rllib/BUILD | 17 ++- .../general_advantage_estimation.py | 10 +- .../mappo/default_mappo_rl_module.py | 50 ++++----- rllib/examples/algorithms/mappo/mappo.py | 91 +++------------ .../algorithms/mappo/mappo_catalog.py | 54 +++------ .../algorithms/mappo/mappo_learner.py | 32 ++---- .../algorithms/mappo/shared_critic_catalog.py | 29 ++--- .../mappo/shared_critic_rl_module.py | 45 ++++---- .../torch/default_mappo_torch_rl_module.py | 24 ++-- .../mappo/torch/mappo_torch_learner.py | 105 +++++++++--------- .../torch/shared_critic_torch_rl_module.py | 13 ++- .../pettingzoo_shared_value_function.py | 26 +++-- 12 files changed, 199 insertions(+), 297 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 4dbfccb6c865..b6699b84b5e2 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -5004,15 +5004,14 @@ py_test( ], ) -# TODO (sven): Activate this test once this script is ready. -# py_test( -# name = "examples/multi_agent/pettingzoo_shared_value_function", -# main = "examples/multi_agent/pettingzoo_shared_value_function.py", -# tags = ["team:rllib", "exclusive", "examples"], -# size = "large", -# srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], -# args = ["--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], -# ) +py_test( + name = "examples/multi_agent/pettingzoo_shared_value_function", + main = "examples/multi_agent/pettingzoo_shared_value_function.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], + args = ["--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], +) py_test( name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index 6f5353b12b8e..771c342f5ec0 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -1,15 +1,12 @@ from typing import Any, List, Dict import numpy as np -import torch -from collections import defaultdict from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule -from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override from ray.rllib.utils.numpy import convert_to_numpy @@ -65,7 +62,8 @@ def __call__( vf_preds = rl_module.foreach_module( func=lambda mid, module: ( rl_module[SHARED_CRITIC_ID].compute_values(batch[mid]) - if (mid in batch) and (Columns.OBS in batch[mid]) + if (mid in batch) + and (Columns.OBS in batch[mid]) and (not isinstance(module, SelfSupervisedLossAPI)) else None ), @@ -165,4 +163,4 @@ def __call__( for mid, module_batch in tensor_results.items(): batch[mid].update(module_batch) - return batch \ No newline at end of file + return batch diff --git a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py index 137a8284b1de..8288a48a4674 100644 --- a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py +++ b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py @@ -13,38 +13,36 @@ @DeveloperAPI class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC): - """Default RLModule used by PPO, if user does not specify a custom RLModule. + """Default RLModule used by MAPPO, if user does not specify a custom RLModule. - Users who want to train their RLModules with PPO may implement any RLModule - (or TorchRLModule) subclass as long as the custom class also implements the - `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + Users who want to train their RLModules with MAPPO may implement any RLModule (or TorchRLModule) subclass. """ @override(RLModule) def setup(self): try: - # __sphinx_doc_begin__ - # If we have a stateful model, states for the critic need to be collected - # during sampling and `inference-only` needs to be `False`. Note, at this - # point the encoder is not built, yet and therefore `is_stateful()` does - # not work. - is_stateful = isinstance( - self.catalog.encoder_config.base_encoder_config, - RecurrentEncoderConfig, - ) - if is_stateful: - self.inference_only = False - # If this is an `inference_only` Module, we'll have to pass this information - # to the encoder config as well. - if self.inference_only and self.framework == "torch": - self.catalog.encoder_config.inference_only = True - # Build models from catalog. - self.encoder = self.catalog.build_encoder(framework=self.framework) - self.pi = self.catalog.build_pi_head(framework=self.framework) + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) except Exception as e: - print("Error in DefaultMAPPORLModule setup") - print(e) - raise e + print("Error in DefaultMAPPORLModule setup") + print(e) + raise e # __sphinx_doc_end__ @override(RLModule) @@ -58,4 +56,4 @@ def get_initial_state(self) -> dict: @override(InferenceOnlyAPI) def get_non_inference_attributes(self) -> List[str]: """Return attributes, which are NOT inference-only (only used for training).""" - return [] \ No newline at end of file + return [] diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py index 0d4d9bd7ac73..ec5ea1d2f11a 100644 --- a/rllib/examples/algorithms/mappo/mappo.py +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -1,45 +1,21 @@ import logging from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING -from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.execution.rollout_ops import ( - standardize_fields, - synchronous_parallel_sample, -) -from ray.rllib.execution.train_ops import ( - train_one_step, - multi_gpu_train_one_step, -) -from ray.rllib.policy.policy import Policy -from ray.rllib.utils.annotations import OldAPIStack, override -from ray.rllib.utils.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - ENV_RUNNER_SAMPLING_TIMER, - LEARNER_RESULTS, - LEARNER_UPDATE_TIMER, - NUM_AGENT_STEPS_SAMPLED, - NUM_ENV_STEPS_SAMPLED, - NUM_ENV_STEPS_SAMPLED_LIFETIME, - SYNCH_WORKER_WEIGHTS_TIMER, - SAMPLE_TIMER, - TIMERS, - ALL_MODULES, -) -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY -from ray.rllib.utils.schedules.scheduler import Scheduler -from ray.rllib.utils.typing import ResultDict -from ray.util.debug import log_once +from ray.rllib.utils.annotations import override -from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO +from ray.rllib.algorithms.ppo.ppo import PPO if TYPE_CHECKING: from ray.rllib.core.learner.learner import Learner -from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import MAPPOTorchLearner -from ray.rllib.examples.algorithms.mappo.torch.default_mappo_torch_rl_module import DefaultMAPPOTorchRLModule +from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import ( + MAPPOTorchLearner, +) +from ray.rllib.examples.algorithms.mappo.torch.default_mappo_torch_rl_module import ( + DefaultMAPPOTorchRLModule, +) logger = logging.getLogger(__name__) @@ -48,9 +24,8 @@ LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" -class MAPPOConfig(PPOConfig): # AlgorithmConfig -> PPOConfig -> MAPPO - """Defines a configuration class from which a MAPPO Algorithm can be built. - """ +class MAPPOConfig(AlgorithmConfig): # AlgorithmConfig -> PPOConfig -> MAPPO + """Defines a configuration class from which a MAPPO Algorithm can be built.""" def __init__(self, algo_class=None): """Initializes a MAPPOConfig instance.""" @@ -72,7 +47,7 @@ def __init__(self, algo_class=None): self.rollout_fragment_length = "auto" self.train_batch_size = 4000 - # PPO specific settings: + # MAPPO specific settings: self.num_epochs = 30 self.minibatch_size = 128 self.shuffle_batch_per_epoch = True @@ -84,17 +59,11 @@ def __init__(self, algo_class=None): self.clip_param = 0.3 self.grad_clip = None - # Override some of AlgorithmConfig's default values with PPO-specific values. + # Override some of AlgorithmConfig's default values with MAPPO-specific values. self.num_env_runners = 2 # __sphinx_doc_end__ # fmt: on - self.entropy_coeff_schedule = None # OldAPIStack - self.lr_schedule = None # OldAPIStack - - # Deprecated keys. - self.sgd_minibatch_size = DEPRECATED_VALUE - @override(AlgorithmConfig) def get_default_rl_module_spec(self) -> RLModuleSpec: if self.framework_str == "torch": @@ -119,10 +88,8 @@ def training( entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, clip_param: Optional[float] = NotProvided, grad_clip: Optional[float] = NotProvided, - # OldAPIStack - lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, **kwargs, - ) -> "PPOConfig": + ) -> "MAPPOConfig": # Pass kwargs onto super's `training()` method. super().training(**kwargs) if lambda_ is not NotProvided: @@ -139,13 +106,6 @@ def training( self.clip_param = clip_param if grad_clip is not NotProvided: self.grad_clip = grad_clip - - # TODO (sven): Remove these once new API stack is only option for PPO. - if lr_schedule is not NotProvided: - self.lr_schedule = lr_schedule - if entropy_coeff_schedule is not NotProvided: - self.entropy_coeff_schedule = entropy_coeff_schedule - return self @override(AlgorithmConfig) @@ -170,7 +130,7 @@ def validate(self) -> None: ): self._value_error( f"`minibatch_size` ({self.minibatch_size}) must be <= " - f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch" + f"`train_batch_size` ({self.train_batch_size}). In MAPPO, the train batch" f" will be split into {self.minibatch_size} chunks, each of which " f"is iterated over (used for updating the policy) {self.num_epochs} " "times." @@ -181,7 +141,7 @@ def validate(self) -> None: if isinstance(mbs, int) and isinstance(tbs, int) and mbs > tbs: self._value_error( f"`minibatch_size` ({mbs}) must be <= " - f"`train_batch_size_per_learner` ({tbs}). In PPO, the train batch" + f"`train_batch_size_per_learner` ({tbs}). In MAPPO, the train batch" f" will be split into {mbs} chunks, each of which is iterated over " f"(used for updating the policy) {self.num_epochs} times." ) @@ -201,29 +161,10 @@ def validate(self) -> None: " trajectory). Consider setting " "batch_mode=complete_episodes." ) - - # New API stack checks. - if self.enable_rl_module_and_learner: - # `lr_schedule` checking. - if self.lr_schedule is not None: - self._value_error( - "`lr_schedule` is deprecated and must be None! Use the " - "`lr` setting to setup a schedule." - ) - if self.entropy_coeff_schedule is not None: - self._value_error( - "`entropy_coeff_schedule` is deprecated and must be None! Use the " - "`entropy_coeff` setting to setup a schedule." - ) - Scheduler.validate( - fixed_value_or_schedule=self.entropy_coeff, - setting_name="entropy_coeff", - description="entropy coefficient", - ) if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0: self._value_error("`entropy_coeff` must be >= 0.0") @property @override(AlgorithmConfig) def _model_config_auto_includes(self) -> Dict[str, Any]: - return super()._model_config_auto_includes | {} \ No newline at end of file + return super()._model_config_auto_includes | {} diff --git a/rllib/examples/algorithms/mappo/mappo_catalog.py b/rllib/examples/algorithms/mappo/mappo_catalog.py index 216b0fe04968..d078c2cb9df7 100644 --- a/rllib/examples/algorithms/mappo/mappo_catalog.py +++ b/rllib/examples/algorithms/mappo/mappo_catalog.py @@ -4,11 +4,10 @@ from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( - ActorCriticEncoderConfig, MLPHeadConfig, FreeLogStdMLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic @@ -20,27 +19,14 @@ class MAPPOCatalog(Catalog): """The Catalog class used to build models for MAPPO. MAPPOCatalog provides the following models: - - ActorCriticEncoder: The encoder used to encode the observations. + - Encoder: The encoder used to encode the observations. - Pi Head: The head used to compute the policy logits. - The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs - for the policy and value function. See implementations of DefaultPPORLModule for - more details. + Any custom Encoder can be built by overriding the build_encoder() method. Alternatively, the EncoderConfig at MAPPOCatalog.encoder_config can be overridden to build a custom Encoder during RLModule runtime. - Any custom ActorCriticEncoder can be built by overriding the - build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig - at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom - ActorCriticEncoder during RLModule runtime. + Any custom head can be built by overriding the build_pi_head() method. Alternatively, the PiHeadConfig can be overridden to build a custom head during RLModule runtime. - Any custom head can be built by overriding the build_pi_head() and build_vf_head() - methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to - build custom heads during RLModule runtime. - - Any module built for exploration or inference is built with the flag - `ìnference_only=True` and does not contain a value network. This flag can be set - in the `SingleAgentModuleSpec` through the `inference_only` boolean flag. - In case that the actor-critic-encoder is not shared between the policy and value - function, the inference-only module will contain only the actor encoder network. + Any module built for exploration or inference is built with the flag `ìnference_only=True` and does not contain a value network. This flag can be set in the `SingleAgentModuleSpec` through the `inference_only` boolean flag. """ def __init__( @@ -49,7 +35,7 @@ def __init__( action_space: gym.Space, model_config_dict: dict, ): - """Initializes the PPOCatalog. + """Initializes the MAPPOCatalog. Args: observation_space: The observation space of the Encoder. @@ -62,16 +48,9 @@ def __init__( model_config_dict=model_config_dict, ) # - self.encoder_config = ActorCriticEncoderConfig( - base_encoder_config=self._encoder_config, - shared=True # Since we don't want to instantiate an extra network - ) - - self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] - self.pi_and_vf_head_activation = self._model_config_dict[ - "head_fcnet_activation" - ] - + self.encoder_config = self._encoder_config + self.pi_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_head_activation = self._model_config_dict["head_fcnet_activation"] # We don't have the exact (framework specific) action dist class yet and thus # cannot determine the exact number of output nodes (action space) required. # -> Build pi config only in the `self.build_pi_head` method. @@ -79,10 +58,7 @@ def __init__( @override(Catalog) def build_encoder(self, framework: str) -> Encoder: - """Builds the encoder. - - Since PPO uses an ActorCriticEncoder, this method should not be implemented. - """ + """Builds the encoder.""" return self.encoder_config.build(framework=framework) @OverrideToImplementCustomLogic @@ -90,8 +66,7 @@ def build_pi_head(self, framework: str) -> Model: """Builds the policy head. The default behavior is to build the head from the pi_head_config. - This can be overridden to build a custom policy head as a means of configuring - the behavior of a PPORLModule implementation. + This can be overridden to build a custom policy head as a means of configuring the behavior of a MAPPORLModule implementation. Args: framework: The framework to use. Either "torch" or "tf2". @@ -124,8 +99,8 @@ def build_pi_head(self, framework: str) -> Model: ) self.pi_head_config = pi_head_config_class( input_dims=self.latent_dims, - hidden_layer_dims=self.pi_and_vf_head_hiddens, - hidden_layer_activation=self.pi_and_vf_head_activation, + hidden_layer_dims=self.pi_head_hiddens, + hidden_layer_activation=self.pi_head_activation, output_layer_dim=required_output_dim, output_layer_activation="linear", clip_log_std=is_diag_gaussian, @@ -134,4 +109,5 @@ def build_pi_head(self, framework: str) -> Model: return self.pi_head_config.build(framework=framework) -# __sphinx_doc_end__ \ No newline at end of file + +# __sphinx_doc_end__ diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py index 547d2ca5ffaf..c5f416946a9f 100644 --- a/rllib/examples/algorithms/mappo/mappo_learner.py +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -7,12 +7,7 @@ LEARNER_RESULTS_KL_KEY, PPOConfig, ) -from ray.rllib.connectors.learner import ( - AddOneTsToEpisodesAndTruncate, - GeneralAdvantageEstimation, -) from ray.rllib.core.learner.learner import Learner -from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import ( override, OverrideToImplementCustomLogic_CallToSuperRecommended, @@ -26,14 +21,16 @@ from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ModuleID, TensorType -from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import MAPPOGAEConnector, SHARED_CRITIC_ID +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + MAPPOGAEConnector, + SHARED_CRITIC_ID, +) -class MAPPOLearner(Learner): - # Deal with GAE somehow. Maybe skip it and move the logic over here, into the value loss calculation method. Saves us a VF pass, too. +class MAPPOLearner(Learner): @override(Learner) def build(self) -> None: - super().build() # We call Learner's build function, not PPOLearner's + super().build() # Dict mapping module IDs to the respective entropy Scheduler instance. self.entropy_coeff_schedulers_per_module: Dict[ @@ -79,9 +76,8 @@ def build(self) -> None: @override(Learner) def remove_module(self, module_id: ModuleID, **kwargs): marl_spec = super().remove_module(module_id, **kwargs) - if (module_id != SHARED_CRITIC_ID): - self.entropy_coeff_schedulers_per_module.pop(module_id, None) - self.curr_kl_coeffs_per_module.pop(module_id, None) + self.entropy_coeff_schedulers_per_module.pop(module_id, None) + self.curr_kl_coeffs_per_module.pop(module_id, None) return marl_spec @OverrideToImplementCustomLogic_CallToSuperRecommended @@ -94,8 +90,8 @@ def after_gradient_based_update( super().after_gradient_based_update(timesteps=timesteps) for module_id, module in self.module._rl_modules.items(): - if (module_id == SHARED_CRITIC_ID): - continue # Policy terms irrelevant to shared critic. + if module_id == SHARED_CRITIC_ID: + continue # Policy terms irrelevant to shared critic. config = self.config.get_config_for_module(module_id) # Update entropy coefficient via our Scheduler. new_entropy_coeff = self.entropy_coeff_schedulers_per_module[ @@ -121,12 +117,6 @@ def after_gradient_based_update( kl_loss=kl_loss, ) - @classmethod - @override(Learner) - def rl_module_required_apis(cls) -> list[type]: - # We no longer require value functions for modules, since there's a central critic - return [] - @abc.abstractmethod def _update_module_kl_coeff( self, @@ -146,4 +136,4 @@ def _update_module_kl_coeff( config: The AlgorithmConfig specific to the given `module_id`. kl_loss: The mean KL loss of the module, computed inside `compute_loss_for_module()`. - """ \ No newline at end of file + """ diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py index 48a9b1c9a46e..6b79df601d16 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_catalog.py +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -2,21 +2,18 @@ from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( - ActorCriticEncoderConfig, MLPHeadConfig, - FreeLogStdMLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic -from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian class SharedCriticCatalog(Catalog): def __init__( self, observation_space: gym.Space, - action_space: gym.Space, # TODO: Remove? + action_space: gym.Space, # TODO: Remove? model_config_dict: dict, ): """Initializes the PPOCatalog. @@ -28,20 +25,13 @@ def __init__( """ super().__init__( observation_space=observation_space, - action_space=action_space, # We shouldn't need to provide this. I should reconfigure at some point. + action_space=action_space, # Base Catalog class checks for this. model_config_dict=model_config_dict, ) # We only want one encoder, so we use the base encoder config. - self._encoder_config.shared = True - self.encoder_config = ActorCriticEncoderConfig( - base_encoder_config=self._encoder_config, - shared=True, # Because we only want one network - ) + self.encoder_config = self._encoder_config self.vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] - self.vf_head_activation = self._model_config_dict[ - "head_fcnet_activation" - ] - + self.vf_head_activation = self._model_config_dict["head_fcnet_activation"] self.vf_head_config = MLPHeadConfig( input_dims=self.latent_dims, hidden_layer_dims=self.vf_head_hiddens, @@ -52,10 +42,7 @@ def __init__( @override(Catalog) def build_encoder(self, framework: str) -> Encoder: - """Builds the encoder. - - Since PPO uses an ActorCriticEncoder, this method should not be implemented. - """ + """Builds the encoder.""" return self.encoder_config.build(framework=framework) @OverrideToImplementCustomLogic @@ -64,7 +51,7 @@ def build_vf_head(self, framework: str) -> Model: The default behavior is to build the head from the vf_head_config. This can be overridden to build a custom value function head as a means of - configuring the behavior of a PPORLModule implementation. + configuring the behavior of a MAPPORLModule implementation. Args: framework: The framework to use. Either "torch" or "tf2". @@ -75,4 +62,4 @@ def build_vf_head(self, framework: str) -> Model: return self.vf_head_config.build(framework=framework) -# __sphinx_doc_end__ \ No newline at end of file +# __sphinx_doc_end__ diff --git a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py index 237af7073ade..e9c81aabccb1 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py +++ b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py @@ -1,43 +1,46 @@ import abc -from typing import List from ray.rllib.core.models.configs import RecurrentEncoderConfig from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( override, - OverrideToImplementCustomLogic_CallToSuperRecommended, ) from ray.util.annotations import DeveloperAPI @DeveloperAPI class SharedCriticRLModule(RLModule, ValueFunctionAPI, abc.ABC): - """Default RLModule used by PPO, if user does not specify a custom RLModule. + """Standard shared critic RLModule usable in MAPPO, if user does not intend to use a custom RLModule. - Users who want to train their RLModules with PPO may implement any RLModule - (or TorchRLModule) subclass as long as the custom class also implements the - `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + Users who want to train custom shared critics with MAPPO may implement any RLModule (or TorchRLModule) subclass as long as the custom class also implements the `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) """ @override(RLModule) def setup(self): try: - # __sphinx_doc_begin__ - # If we have a stateful model, states for the critic need to be collected - # during sampling and `inference-only` needs to be `False`. Note, at this - # point the encoder is not built, yet and therefore `is_stateful()` does - # not work. - is_stateful = isinstance( - self.catalog.encoder_config.base_encoder_config, - RecurrentEncoderConfig, - ) - # Build models from catalog. - self.encoder = self.catalog.build_encoder(framework=self.framework) - self.vf = self.catalog.build_vf_head(framework=self.framework) + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.vf = self.catalog.build_vf_head(framework=self.framework) except Exception as e: - print(e) - raise e + print("Error in SharedCriticRLModule.setup:") + print(e) + raise e # __sphinx_doc_end__ @override(RLModule) @@ -45,4 +48,4 @@ def get_initial_state(self) -> dict: if hasattr(self.encoder, "get_initial_state"): return self.encoder.get_initial_state() else: - return {} \ No newline at end of file + return {} diff --git a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py index 1317bcbf00cf..d78e979893e4 100644 --- a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py +++ b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py @@ -1,23 +1,21 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict -from ray.rllib.algorithms.ppo.default_ppo_rl_module import DefaultPPORLModule -from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.core.columns import Columns -from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT -from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType from ray.util.annotations import DeveloperAPI from ray.rllib.examples.algorithms.mappo.mappo_catalog import MAPPOCatalog -from ray.rllib.examples.algorithms.mappo.default_mappo_rl_module import DefaultMAPPORLModule -from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import SharedCriticCatalog +from ray.rllib.examples.algorithms.mappo.default_mappo_rl_module import ( + DefaultMAPPORLModule, +) torch, nn = try_import_torch() + @DeveloperAPI class DefaultMAPPOTorchRLModule(TorchRLModule, DefaultMAPPORLModule): def __init__(self, *args, **kwargs): @@ -36,16 +34,16 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] # Pi head. - output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT]) return output @override(RLModule) def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Train forward pass (keep embeddings for possible shared value func. call).""" + """Train forward pass.""" output = {} encoder_outs = self.encoder(batch) - output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT] if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] - output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - return output \ No newline at end of file + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT]) + return output diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index 1a9e4fc85e5b..127a6f6bd229 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -13,7 +13,12 @@ PPOConfig, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import Learner, POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.learner import ( + Learner, + POLICY_LOSS_KEY, + VF_LOSS_KEY, + ENTROPY_KEY, +) from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override @@ -24,7 +29,9 @@ from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.examples.algorithms.mappo.mappo_learner import MAPPOLearner -from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import SHARED_CRITIC_ID +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, +) torch, nn = try_import_torch() @@ -32,16 +39,15 @@ class MAPPOTorchLearner(MAPPOLearner, TorchLearner): - def get_pmm( - self, - batch: Dict[str, Any] - ) -> Callable: - """ Gets the possibly_masked_mean function """ + def get_pmm(self, batch: Dict[str, Any]) -> Callable: + """Gets the possibly_masked_mean function""" if Columns.LOSS_MASK in batch: - mask = batch[Columns.LOSS_MASK] - num_valid = torch.sum(mask) - def possibly_masked_mean(data_): - return torch.sum(data_[mask]) / num_valid + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) + + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + else: possibly_masked_mean = torch.mean return possibly_masked_mean @@ -49,35 +55,33 @@ def possibly_masked_mean(data_): """ Implements MAPPO in Torch, on top of a MAPPOLearner. """ - def compute_loss_for_critic( - self, - batch: Dict[str, Any] - ): - """ + + def compute_loss_for_critic(self, batch: Dict[str, Any]): + """ Computes loss for critic, and returns a list of advantages and rewards for the target batch. - """ - possibly_masked_mean = self.get_pmm(batch) - module = self.module[SHARED_CRITIC_ID].unwrapped() - vf_preds = module.compute_values(batch) - vf_targets = batch[Postprocessing.VALUE_TARGETS] - # Compute a value function loss. - vf_loss = torch.pow(vf_preds - vf_targets, 2.0) - vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) - mean_vf_loss = possibly_masked_mean(vf_loss_clipped) - mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) - # record metrics - self.metrics.log_dict( - { - VF_LOSS_KEY: mean_vf_loss, - LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, - LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( - vf_targets, vf_preds - ), - }, - key=SHARED_CRITIC_ID, - window=1, - ) - return mean_vf_loss + """ + possibly_masked_mean = self.get_pmm(batch) + module = self.module[SHARED_CRITIC_ID].unwrapped() + vf_preds = module.compute_values(batch) + vf_targets = batch[Postprocessing.VALUE_TARGETS] + # Compute a value function loss. + vf_loss = torch.pow(vf_preds - vf_targets, 2.0) + vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) + mean_vf_loss = possibly_masked_mean(vf_loss_clipped) + mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) + # record metrics + self.metrics.log_dict( + { + VF_LOSS_KEY: mean_vf_loss, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( + vf_targets, vf_preds + ), + }, + key=SHARED_CRITIC_ID, + window=1, + ) + return mean_vf_loss @override(Learner) def compute_losses( @@ -96,12 +100,13 @@ def compute_losses( loss_per_module = {SHARED_CRITIC_ID: 0} # Calculate loss for agent policies for module_id in fwd_out: - if (module_id == SHARED_CRITIC_ID): # Computed for each module - continue - # + if module_id == SHARED_CRITIC_ID: # Computed for each module + continue + module_batch = batch[module_id] + module_fwd_out = fwd_out[module_id] module = self.module[module_id].unwrapped() if isinstance(module, SelfSupervisedLossAPI): - # For e.g. enabling intrinsic curiosity modules. + # For training e.g. intrinsic curiosity modules. loss = module.compute_self_supervised_loss( learner=self, module_id=module_id, @@ -110,8 +115,6 @@ def compute_losses( fwd_out=module_fwd_out, ) else: - module_batch = batch[module_id] - module_fwd_out = fwd_out[module_id] # For every module we're going to touch, sans the critic loss = self.compute_loss_for_module( module_id=module_id, @@ -120,7 +123,9 @@ def compute_losses( fwd_out=module_fwd_out, ) # Optimize the critic - loss_per_module[SHARED_CRITIC_ID] += self.compute_loss_for_critic(module_batch) + loss_per_module[SHARED_CRITIC_ID] += self.compute_loss_for_critic( + module_batch + ) loss_per_module[module_id] = loss return loss_per_module @@ -138,7 +143,7 @@ def compute_loss_for_module( possibly_masked_mean = self.get_pmm(batch) # Possibly apply masking to some sub loss terms and to the total loss term # at the end. Masking could be used for RNN-based model (zero padded `batch`) - # and for PPO's batched value function (and bootstrap value) computations, + # and for MAPPO's batched value function (and bootstrap value) computations, # for which we add an (artificial) timestep to each episode to # simplify the actual computation. action_dist_class_train = module.get_train_action_dist_cls() @@ -170,7 +175,7 @@ def compute_loss_for_module( batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param), ) - # Remove critic loss from per-module computation + # Removed critic loss from per-module computation total_loss = possibly_masked_mean( -surrogate_loss - ( @@ -197,7 +202,6 @@ def compute_loss_for_module( # Return the total loss. return total_loss - # Same as PPOTorchLearner @override(MAPPOLearner) def _update_module_kl_coeff( self, @@ -206,6 +210,7 @@ def _update_module_kl_coeff( config: PPOConfig, kl_loss: float, ) -> None: + # Same as PPOTorchLearner if np.isnan(kl_loss): logger.warning( f"KL divergence for Module {module_id} is non-finite, this " @@ -231,4 +236,4 @@ def _update_module_kl_coeff( (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY), curr_var.item(), window=1, - ) \ No newline at end of file + ) diff --git a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py index e35c00b6f342..f77f364ef6ec 100644 --- a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py +++ b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py @@ -9,10 +9,15 @@ from ray.rllib.utils.typing import TensorType from ray.util.annotations import DeveloperAPI +from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import ( + SharedCriticRLModule, +) +from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import ( + SharedCriticCatalog, +) + torch, nn = try_import_torch() -from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import SharedCriticRLModule -from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import SharedCriticCatalog @DeveloperAPI class SharedCriticTorchRLModule(TorchRLModule, SharedCriticRLModule): @@ -29,8 +34,8 @@ def compute_values( embeddings: Optional[Any] = None, ) -> TensorType: if embeddings is None: - embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] + embeddings = self.encoder(batch)[ENCODER_OUT] # Value head. vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). - return vf_out.squeeze(-1) \ No newline at end of file + return vf_out.squeeze(-1) diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index 2bebf4b05bdc..6eddbcaf41e7 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -26,7 +26,7 @@ Results to expect ----------------- -The above options can reach a combined reward of _ or more after about _ env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve. For reasons similar to those described in pettingzoo_parameter_sharing.py, learning may take slightly longer than in fully-independent settings, as agents are less inclined to specialize and thereby balance out one anothers' mistakes. +The above options can reach a combined reward of 0 or more after about 500k-1M env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve. For reasons similar to those described in pettingzoo_parameter_sharing.py, learning may take slightly longer than in fully-independent settings, as agents are less inclined to specialize and thereby balance out one anothers' mistakes. +---------------------+------------+--------------------+--------+------------------+ | Trial name | status | loc | iter | total time (s) | @@ -47,7 +47,6 @@ from pettingzoo.sisl import waterworld_v4 -from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv @@ -55,11 +54,15 @@ add_rllib_example_script_args, run_rllib_example_script_experiment, ) -from ray.tune.registry import get_trainable_cls, register_env +from ray.tune.registry import register_env from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig -from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import SHARED_CRITIC_ID -from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import SharedCriticTorchRLModule +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, +) +from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import ( + SharedCriticTorchRLModule, +) parser = add_rllib_example_script_args( @@ -77,20 +80,22 @@ # Here, we use the "Agent Environment Cycle" (AEC) PettingZoo environment type. # For a "Parallel" environment example, see the rock paper scissors examples # in this same repository folder. - get_env = lambda _: PettingZooEnv(waterworld_v4.env()) + def get_env(_): + return PettingZooEnv(waterworld_v4.env()) + register_env("env", get_env) # Policies are called just like the agents (exact 1:1 mapping). policies = [f"pursuer_{i}" for i in range(args.num_agents)] - + # An agent for each of our policies, and a single shared critic - env_instantiated = get_env({}) # neccessary for non-agent modules + env_instantiated = get_env({}) # neccessary for non-agent modules specs = {p: RLModuleSpec() for p in policies} specs[SHARED_CRITIC_ID] = RLModuleSpec( module_class=SharedCriticTorchRLModule, observation_space=env_instantiated.observation_spaces[policies[0]], action_space=env_instantiated.action_spaces[policies[0]], - learner_only=True, # Only build on learner + learner_only=True, # Only build on learner model_config={}, ) @@ -102,9 +107,6 @@ # Exact 1:1 mapping from AgentID to ModuleID. policy_mapping_fn=(lambda aid, *args, **kwargs: aid), ) - .training( - vf_loss_coeff=0.005, - ) .rl_module( rl_module_spec=MultiRLModuleSpec( rl_module_specs=specs, From c002277ac5b8f36133e53ffc3ff280fc099ba536 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sat, 6 Sep 2025 20:44:42 -0500 Subject: [PATCH 03/13] Minor formatting/debugging fixes to PR submission. Signed-off-by: Matthew --- rllib/BUILD | 16 ++++++-- .../mappo/default_mappo_rl_module.py | 41 ++++++++----------- .../mappo/shared_critic_rl_module.py | 41 ++++++++----------- .../mappo/torch/mappo_torch_learner.py | 4 +- 4 files changed, 50 insertions(+), 52 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index b6699b84b5e2..3baf0ca361a2 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -5006,11 +5006,21 @@ py_test( py_test( name = "examples/multi_agent/pettingzoo_shared_value_function", - main = "examples/multi_agent/pettingzoo_shared_value_function.py", - tags = ["team:rllib", "exclusive", "examples"], size = "large", srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], - args = ["--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], + args = [ + "--num-agents=2", + "--as-test", + "--framework=torch", + "--stop-reward=-100.0", + "--num-cpus=4", + ], + main = "examples/multi_agent/pettingzoo_shared_value_function.py", + tags = [ + "team:rllib", + "exclusive", + "examples", + ], ) py_test( diff --git a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py index 8288a48a4674..2e260326d8b9 100644 --- a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py +++ b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py @@ -20,29 +20,24 @@ class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC): @override(RLModule) def setup(self): - try: - # __sphinx_doc_begin__ - # If we have a stateful model, states for the critic need to be collected - # during sampling and `inference-only` needs to be `False`. Note, at this - # point the encoder is not built, yet and therefore `is_stateful()` does - # not work. - is_stateful = isinstance( - self.catalog.encoder_config, - RecurrentEncoderConfig, - ) - if is_stateful: - self.inference_only = False - # If this is an `inference_only` Module, we'll have to pass this information - # to the encoder config as well. - if self.inference_only and self.framework == "torch": - self.catalog.encoder_config.inference_only = True - # Build models from catalog. - self.encoder = self.catalog.build_encoder(framework=self.framework) - self.pi = self.catalog.build_pi_head(framework=self.framework) - except Exception as e: - print("Error in DefaultMAPPORLModule setup") - print(e) - raise e + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) # __sphinx_doc_end__ @override(RLModule) diff --git a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py index e9c81aabccb1..01a85bfd9e24 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py +++ b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py @@ -18,29 +18,24 @@ class SharedCriticRLModule(RLModule, ValueFunctionAPI, abc.ABC): @override(RLModule) def setup(self): - try: - # __sphinx_doc_begin__ - # If we have a stateful model, states for the critic need to be collected - # during sampling and `inference-only` needs to be `False`. Note, at this - # point the encoder is not built, yet and therefore `is_stateful()` does - # not work. - is_stateful = isinstance( - self.catalog.encoder_config, - RecurrentEncoderConfig, - ) - if is_stateful: - self.inference_only = False - # If this is an `inference_only` Module, we'll have to pass this information - # to the encoder config as well. - if self.inference_only and self.framework == "torch": - self.catalog.encoder_config.inference_only = True - # Build models from catalog. - self.encoder = self.catalog.build_encoder(framework=self.framework) - self.vf = self.catalog.build_vf_head(framework=self.framework) - except Exception as e: - print("Error in SharedCriticRLModule.setup:") - print(e) - raise e + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.vf = self.catalog.build_vf_head(framework=self.framework) # __sphinx_doc_end__ @override(RLModule) diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index 127a6f6bd229..5e7425d36e9e 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -57,9 +57,7 @@ def possibly_masked_mean(data_): """ def compute_loss_for_critic(self, batch: Dict[str, Any]): - """ - Computes loss for critic, and returns a list of advantages and rewards for the target batch. - """ + """Computes the loss for the shared critic module.""" possibly_masked_mean = self.get_pmm(batch) module = self.module[SHARED_CRITIC_ID].unwrapped() vf_preds = module.compute_values(batch) From 55e8989e4c71d88efcb5cd4a2847992710207067 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 7 Sep 2025 16:47:49 -0500 Subject: [PATCH 04/13] Fixed an inheritance issue and added global observation handling. Signed-off-by: Matthew --- rllib/BUILD | 4 +- .../general_advantage_estimation.py | 43 +++++++++++-------- rllib/examples/algorithms/mappo/mappo.py | 30 ++++++------- .../algorithms/mappo/mappo_learner.py | 1 - .../algorithms/mappo/shared_critic_catalog.py | 10 ++++- .../mappo/torch/mappo_torch_learner.py | 25 ++++++----- .../torch/shared_critic_torch_rl_module.py | 5 +-- .../pettingzoo_shared_value_function.py | 39 ++++++++--------- 8 files changed, 82 insertions(+), 75 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 3baf0ca361a2..6b6a669f8c50 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -5017,9 +5017,9 @@ py_test( ], main = "examples/multi_agent/pettingzoo_shared_value_function.py", tags = [ - "team:rllib", - "exclusive", "examples", + "exclusive", + "team:rllib", ], ) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index 771c342f5ec0..e429958708ca 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -9,6 +9,7 @@ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets from ray.rllib.utils.postprocessing.zero_padding import ( @@ -17,6 +18,7 @@ ) from ray.rllib.utils.typing import EpisodeType +torch, nn = try_import_torch() SHARED_CRITIC_ID = "shared_critic" @@ -48,27 +50,25 @@ def __call__( ): # Device to place all GAE result tensors (advantages and value targets) on. device = None - # Extract all single-agent episodes. sa_episodes_list = list( self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False) ) - # Perform the value nets' forward passes. - # TODO (sven): We need to check here in the pipeline already, whether a module - # should even be updated or not (which we usually do after(!) the Learner - # pipeline). This is an open TODO to move this filter into a connector as well. - # For now, we'll just check, whether `mid` is in batch and skip if it isn't. - # For MAPPO, we can't check ValueFunctionAPI, so we check for the presence of an observation and a lack of self-supervision instead. - vf_preds = rl_module.foreach_module( - func=lambda mid, module: ( - rl_module[SHARED_CRITIC_ID].compute_values(batch[mid]) - if (mid in batch) - and (Columns.OBS in batch[mid]) - and (not isinstance(module, SelfSupervisedLossAPI)) - else None - ), - return_dict=True, + # Perform the value net's forward pass. + critic_batch = {} + # Concatenate all agent observations in batch, using a fixed order + obs_mids = [ + k + for k in sorted(batch.keys()) + if (Columns.OBS in batch[k]) + and (not isinstance(rl_module[k], SelfSupervisedLossAPI)) + ] + critic_batch[Columns.OBS] = torch.cat( + [batch[k][Columns.OBS] for k in obs_mids], dim=1 ) + # Compute value predictions + vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch) + vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)} # Loop through all modules and perform each one's GAE computation. for module_id, module_vf_preds in vf_preds.items(): # Skip those outputs of RLModules that are not implementers of @@ -139,7 +139,14 @@ def __call__( ) batch[module_id][Postprocessing.ADVANTAGES] = module_advantages batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets - + # Add GAE results to the critic batch + critic_batch[Postprocessing.VALUE_TARGETS] = np.stack( + [batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=1 + ) + critic_batch[Postprocessing.ADVANTAGES] = np.stack( + [batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=1 + ) + batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch # Convert all GAE results to tensors. if self._numpy_to_tensor_connector is None: self._numpy_to_tensor_connector = NumpyToTensor( @@ -155,7 +162,7 @@ def __call__( ), } for mid, module_batch in batch.items() - if vf_preds[mid] is not None + if (mid == SHARED_CRITIC_ID) or (vf_preds[mid] is not None) }, episodes=episodes, ) diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py index ec5ea1d2f11a..767dae645e3e 100644 --- a/rllib/examples/algorithms/mappo/mappo.py +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING +from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import override @@ -24,6 +25,13 @@ LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" +class MAPPO(PPO): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return MAPPOConfig() + + class MAPPOConfig(AlgorithmConfig): # AlgorithmConfig -> PPOConfig -> MAPPO """Defines a configuration class from which a MAPPO Algorithm can be built.""" @@ -39,7 +47,7 @@ def __init__(self, algo_class=None): # Add constructor kwargs here (if any). } - super().__init__(algo_class=algo_class or PPO) + super().__init__(algo_class=algo_class or MAPPO) # fmt: off # __sphinx_doc_begin__ @@ -57,6 +65,7 @@ def __init__(self, algo_class=None): self.kl_target = 0.01 self.entropy_coeff = 0.0 self.clip_param = 0.3 + self.vf_clip_param = 10.0 self.grad_clip = None # Override some of AlgorithmConfig's default values with MAPPO-specific values. @@ -87,6 +96,7 @@ def training( entropy_coeff: Optional[float] = NotProvided, entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, clip_param: Optional[float] = NotProvided, + vf_clip_param: Optional[float] = NotProvided, grad_clip: Optional[float] = NotProvided, **kwargs, ) -> "MAPPOConfig": @@ -104,6 +114,8 @@ def training( self.entropy_coeff = entropy_coeff if clip_param is not NotProvided: self.clip_param = clip_param + if vf_clip_param is not NotProvided: + self.vf_clip_param = vf_clip_param if grad_clip is not NotProvided: self.grad_clip = grad_clip return self @@ -145,22 +157,6 @@ def validate(self) -> None: f" will be split into {mbs} chunks, each of which is iterated over " f"(used for updating the policy) {self.num_epochs} times." ) - - # Episodes may only be truncated (and passed into PPO's - # `postprocessing_fn`), iff generalized advantage estimation is used - # (value function estimate at end of truncated episode to estimate - # remaining value). - if ( - not self.in_evaluation - and self.batch_mode == "truncate_episodes" - and not self.use_gae - ): - self._value_error( - "Episode truncation is not supported without a value " - "function (to estimate the return at the end of the truncated" - " trajectory). Consider setting " - "batch_mode=complete_episodes." - ) if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0: self._value_error("`entropy_coeff` must be >= 0.0") diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py index c5f416946a9f..ccde860acfa1 100644 --- a/rllib/examples/algorithms/mappo/mappo_learner.py +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -1,4 +1,3 @@ -# @title MAPPOLearner import abc from typing import Any, Dict diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py index 6b79df601d16..7be5460f7ceb 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_catalog.py +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -30,6 +30,14 @@ def __init__( ) # We only want one encoder, so we use the base encoder config. self.encoder_config = self._encoder_config + # Adjust the input and output dimensions of the encoder. + observation_spaces = self._model_config_dict["observation_spaces"] + obs_size = 0 + self.encoder_config.output_dim = len(observation_spaces) + for agent, obs in observation_spaces.items(): + obs_size += obs.shape[0] # Assume 1D observations + self.encoder_config.input_dims = (obs_size,) + # Value head architecture self.vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] self.vf_head_activation = self._model_config_dict["head_fcnet_activation"] self.vf_head_config = MLPHeadConfig( @@ -37,7 +45,7 @@ def __init__( hidden_layer_dims=self.vf_head_hiddens, hidden_layer_activation=self.vf_head_activation, output_layer_activation="linear", - output_layer_dim=1, + output_layer_dim=len(observation_spaces), # 1 value pred. per agent ) @override(Catalog) diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index 5e7425d36e9e..e525619d61a3 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -1,4 +1,3 @@ -# @title MAPPOTorchLearner import logging from typing import Any, Dict from collections.abc import Callable @@ -39,6 +38,10 @@ class MAPPOTorchLearner(MAPPOLearner, TorchLearner): + """ + Implements MAPPO in Torch, on top of a MAPPOLearner. + """ + def get_pmm(self, batch: Dict[str, Any]) -> Callable: """Gets the possibly_masked_mean function""" if Columns.LOSS_MASK in batch: @@ -52,10 +55,6 @@ def possibly_masked_mean(data_): possibly_masked_mean = torch.mean return possibly_masked_mean - """ - Implements MAPPO in Torch, on top of a MAPPOLearner. - """ - def compute_loss_for_critic(self, batch: Dict[str, Any]): """Computes the loss for the shared critic module.""" possibly_masked_mean = self.get_pmm(batch) @@ -73,8 +72,8 @@ def compute_loss_for_critic(self, batch: Dict[str, Any]): VF_LOSS_KEY: mean_vf_loss, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( - vf_targets, vf_preds - ), + vf_targets.reshape(-1), vf_preds.reshape(-1) + ), # Flatten multi-agent value predictions }, key=SHARED_CRITIC_ID, window=1, @@ -95,10 +94,14 @@ def compute_losses( Returns: A dictionary mapping module IDs to individual loss terms. """ - loss_per_module = {SHARED_CRITIC_ID: 0} + loss_per_module = {} + # Optimize the critic + loss_per_module[SHARED_CRITIC_ID] = self.compute_loss_for_critic( + batch[SHARED_CRITIC_ID] + ) # Calculate loss for agent policies for module_id in fwd_out: - if module_id == SHARED_CRITIC_ID: # Computed for each module + if module_id == SHARED_CRITIC_ID: # Computed above continue module_batch = batch[module_id] module_fwd_out = fwd_out[module_id] @@ -120,10 +123,6 @@ def compute_losses( batch=module_batch, fwd_out=module_fwd_out, ) - # Optimize the critic - loss_per_module[SHARED_CRITIC_ID] += self.compute_loss_for_critic( - module_batch - ) loss_per_module[module_id] = loss return loss_per_module diff --git a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py index f77f364ef6ec..bc7231596fc4 100644 --- a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py +++ b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py @@ -35,7 +35,6 @@ def compute_values( ) -> TensorType: if embeddings is None: embeddings = self.encoder(batch)[ENCODER_OUT] - # Value head. vf_out = self.vf(embeddings) - # Squeeze out last dimension (single node value head). - return vf_out.squeeze(-1) + # Don't squeeze out last dimension (multi node value head). + return vf_out diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index 6eddbcaf41e7..d654362903e7 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -26,19 +26,19 @@ Results to expect ----------------- -The above options can reach a combined reward of 0 or more after about 500k-1M env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve. For reasons similar to those described in pettingzoo_parameter_sharing.py, learning may take slightly longer than in fully-independent settings, as agents are less inclined to specialize and thereby balance out one anothers' mistakes. +The above options will typically reach a combined reward of 0 or more before 500k env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve, with the shared critic stabilizing this process significantly. -+---------------------+------------+--------------------+--------+------------------+ -| Trial name | status | loc | iter | total time (s) | -|---------------------+------------+--------------------+--------+------------------+ -| PPO_env_c90f4_00000 | TERMINATED | 172.29.87.208:6322 | 101 | 1269.24 | -+---------------------+------------+--------------------+--------+------------------+ ++-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------+ +| Trial name | status | loc | iter | total time (s) | ts | combined return | return pursuer_0 | return pursuer_1 | +|-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------| +| MAPPO_env_39b0c_00000 | TERMINATED | 172.29.87.208:9993 | 148 | 2690.21 | 592000 | 2.06999 | 38.2254 | -36.1554 | ++-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------+ ---------+-------------------+--------------------+--------------------+ - ts | combined return | return pursuer_0 | return pursuer_1 | ---------+-------------------+--------------------+--------------------| - 404000 | 1.31496 | 48.0908 | -46.7758 | ---------+-------------------+--------------------+--------------------+ ++--------+-------------------+--------------------+--------------------+ +| ts | combined return | return pursuer_0 | return pursuer_1 | ++--------+-------------------+--------------------+--------------------| +| 592000 | 2.06999 | 38.2254 | -36.1554 | ++--------+-------------------+--------------------+--------------------+ Note that the two agents (`pursuer_0` and `pursuer_1`) are optimized on the exact same objective and thus differences in the rewards can be attributed to weight initialization @@ -49,7 +49,7 @@ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, run_rllib_example_script_experiment, @@ -77,11 +77,10 @@ assert args.num_agents > 0, "Must set --num-agents > 0 when running this script!" - # Here, we use the "Agent Environment Cycle" (AEC) PettingZoo environment type. - # For a "Parallel" environment example, see the rock paper scissors examples - # in this same repository folder. + # Here, we use the "Parallel" PettingZoo environment type. + # This allows MAPPO's global observations to be constructed more neatly. def get_env(_): - return PettingZooEnv(waterworld_v4.env()) + return ParallelPettingZooEnv(waterworld_v4.parallel_env()) register_env("env", get_env) @@ -93,17 +92,17 @@ def get_env(_): specs = {p: RLModuleSpec() for p in policies} specs[SHARED_CRITIC_ID] = RLModuleSpec( module_class=SharedCriticTorchRLModule, - observation_space=env_instantiated.observation_spaces[policies[0]], - action_space=env_instantiated.action_spaces[policies[0]], + observation_space=env_instantiated.observation_space[policies[0]], + action_space=env_instantiated.action_space[policies[0]], learner_only=True, # Only build on learner - model_config={}, + model_config={"observation_spaces": env_instantiated.observation_space}, ) base_config = ( MAPPOConfig() .environment("env") .multi_agent( - policies=policies, + policies=policies + [SHARED_CRITIC_ID], # Exact 1:1 mapping from AgentID to ModuleID. policy_mapping_fn=(lambda aid, *args, **kwargs: aid), ) From 5756bf9f157814b0de74cad54be6c28f99aa1091 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 7 Sep 2025 19:31:14 -0500 Subject: [PATCH 05/13] Tabs-->spaces in BUILD to satisfy tests. Signed-off-by: Matthew --- rllib/BUILD | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 6b6a669f8c50..c1f54697fbd8 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -5005,22 +5005,22 @@ py_test( ) py_test( - name = "examples/multi_agent/pettingzoo_shared_value_function", - size = "large", - srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], - args = [ - "--num-agents=2", - "--as-test", - "--framework=torch", - "--stop-reward=-100.0", - "--num-cpus=4", - ], - main = "examples/multi_agent/pettingzoo_shared_value_function.py", - tags = [ - "examples", - "exclusive", - "team:rllib", - ], + name = "examples/multi_agent/pettingzoo_shared_value_function", + size = "large", + srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], + args = [ + "--num-agents=2", + "--as-test", + "--framework=torch", + "--stop-reward=-100.0", + "--num-cpus=4", + ], + main = "examples/multi_agent/pettingzoo_shared_value_function.py", + tags = [ + "examples", + "exclusive", + "team:rllib", + ], ) py_test( From 2ec32d4d9b3f856720785734faa36cdb88197415 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 29 Sep 2025 00:46:45 -0500 Subject: [PATCH 06/13] Cleaned up two lines of code, one of which had a bug in an edge case. Signed-off-by: Matthew --- .../algorithms/mappo/connectors/general_advantage_estimation.py | 2 +- rllib/examples/algorithms/mappo/shared_critic_catalog.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index e429958708ca..0c095f04a838 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -162,7 +162,7 @@ def __call__( ), } for mid, module_batch in batch.items() - if (mid == SHARED_CRITIC_ID) or (vf_preds[mid] is not None) + if (mid == SHARED_CRITIC_ID) or (mid in vf_preds) }, episodes=episodes, ) diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py index 7be5460f7ceb..98a2dcc59cd1 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_catalog.py +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -33,7 +33,6 @@ def __init__( # Adjust the input and output dimensions of the encoder. observation_spaces = self._model_config_dict["observation_spaces"] obs_size = 0 - self.encoder_config.output_dim = len(observation_spaces) for agent, obs in observation_spaces.items(): obs_size += obs.shape[0] # Assume 1D observations self.encoder_config.input_dims = (obs_size,) From bfa6c080bdc1f7acc3b9e1c4b576d690caf4cfe0 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 10 Nov 2025 21:17:28 -0600 Subject: [PATCH 07/13] Re-linted and cleaned up a function signature. Signed-off-by: Matthew --- .../general_advantage_estimation.py | 4 ++-- .../mappo/default_mappo_rl_module.py | 2 +- rllib/examples/algorithms/mappo/mappo.py | 12 +++++------ .../algorithms/mappo/mappo_catalog.py | 8 +++----- .../algorithms/mappo/mappo_learner.py | 11 +++++----- .../algorithms/mappo/shared_critic_catalog.py | 2 +- .../torch/default_mappo_torch_rl_module.py | 9 ++++----- .../mappo/torch/mappo_torch_learner.py | 20 +++++++++---------- .../torch/shared_critic_torch_rl_module.py | 13 ++++++------ .../pettingzoo_shared_value_function.py | 14 ++++++------- 10 files changed, 42 insertions(+), 53 deletions(-) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index 0c095f04a838..af99180e8d4e 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -1,9 +1,9 @@ -from typing import Any, List, Dict +from typing import Any, Dict, List import numpy as np -from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule diff --git a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py index 2e260326d8b9..7b86ef73391c 100644 --- a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py +++ b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py @@ -5,8 +5,8 @@ from ray.rllib.core.rl_module.apis import InferenceOnlyAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.util.annotations import DeveloperAPI diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py index 767dae645e3e..920b972b4a5b 100644 --- a/rllib/examples/algorithms/mappo/mappo.py +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -1,22 +1,21 @@ import logging -from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.ppo.ppo import PPO from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import override -from ray.rllib.algorithms.ppo.ppo import PPO - if TYPE_CHECKING: from ray.rllib.core.learner.learner import Learner -from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import ( - MAPPOTorchLearner, -) from ray.rllib.examples.algorithms.mappo.torch.default_mappo_torch_rl_module import ( DefaultMAPPOTorchRLModule, ) +from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import ( + MAPPOTorchLearner, +) logger = logging.getLogger(__name__) @@ -94,7 +93,6 @@ def training( kl_coeff: Optional[float] = NotProvided, kl_target: Optional[float] = NotProvided, entropy_coeff: Optional[float] = NotProvided, - entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, clip_param: Optional[float] = NotProvided, vf_clip_param: Optional[float] = NotProvided, grad_clip: Optional[float] = NotProvided, diff --git a/rllib/examples/algorithms/mappo/mappo_catalog.py b/rllib/examples/algorithms/mappo/mappo_catalog.py index d078c2cb9df7..08d2e8090fef 100644 --- a/rllib/examples/algorithms/mappo/mappo_catalog.py +++ b/rllib/examples/algorithms/mappo/mappo_catalog.py @@ -2,19 +2,17 @@ # __sphinx_doc_begin__ import gymnasium as gym +from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( - MLPHeadConfig, FreeLogStdMLPHeadConfig, + MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic -from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian - - class MAPPOCatalog(Catalog): """The Catalog class used to build models for MAPPO. diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py index ccde860acfa1..7081e99c3f3b 100644 --- a/rllib/examples/algorithms/mappo/mappo_learner.py +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -7,9 +7,13 @@ PPOConfig, ) from ray.rllib.core.learner.learner import Learner +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, + MAPPOGAEConnector, +) from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.metrics import ( @@ -20,11 +24,6 @@ from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ModuleID, TensorType -from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( - MAPPOGAEConnector, - SHARED_CRITIC_ID, -) - class MAPPOLearner(Learner): @override(Learner) diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py index 98a2dcc59cd1..e346b41e026f 100644 --- a/rllib/examples/algorithms/mappo/shared_critic_catalog.py +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -1,10 +1,10 @@ import gymnasium as gym +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic diff --git a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py index d78e979893e4..3207494b2084 100644 --- a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py +++ b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py @@ -4,14 +4,13 @@ from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.util.annotations import DeveloperAPI - -from ray.rllib.examples.algorithms.mappo.mappo_catalog import MAPPOCatalog from ray.rllib.examples.algorithms.mappo.default_mappo_rl_module import ( DefaultMAPPORLModule, ) +from ray.rllib.examples.algorithms.mappo.mappo_catalog import MAPPOCatalog +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.util.annotations import DeveloperAPI torch, nn = try_import_torch() diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index e525619d61a3..6fffeac8e09f 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -1,37 +1,35 @@ import logging -from typing import Any, Dict from collections.abc import Callable +from typing import Any, Dict import numpy as np from ray.rllib.algorithms.ppo.ppo import ( - LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, PPOConfig, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( - Learner, + ENTROPY_KEY, POLICY_LOSS_KEY, VF_LOSS_KEY, - ENTROPY_KEY, + Learner, ) from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, +) +from ray.rllib.examples.algorithms.mappo.mappo_learner import MAPPOLearner from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import explained_variance from ray.rllib.utils.typing import ModuleID, TensorType -from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI - -from ray.rllib.examples.algorithms.mappo.mappo_learner import MAPPOLearner -from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( - SHARED_CRITIC_ID, -) - torch, nn = try_import_torch() logger = logging.getLogger(__name__) diff --git a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py index bc7231596fc4..7489d29689f1 100644 --- a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py +++ b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py @@ -4,18 +4,17 @@ from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import ( + SharedCriticCatalog, +) +from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import ( + SharedCriticRLModule, +) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType from ray.util.annotations import DeveloperAPI -from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import ( - SharedCriticRLModule, -) -from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import ( - SharedCriticCatalog, -) - torch, nn = try_import_torch() diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index d654362903e7..d2a425b24297 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -50,20 +50,18 @@ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv -from ray.rllib.utils.test_utils import ( - add_rllib_example_script_args, - run_rllib_example_script_experiment, -) -from ray.tune.registry import register_env - -from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( SHARED_CRITIC_ID, ) +from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import ( SharedCriticTorchRLModule, ) - +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import register_env parser = add_rllib_example_script_args( default_iters=200, From 060e2c3806caf8c18fb494c7490a7ea5acc66f99 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 10 Nov 2025 21:27:00 -0600 Subject: [PATCH 08/13] Deleted an unused line of code. Signed-off-by: Matthew --- .../mappo/connectors/general_advantage_estimation.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index af99180e8d4e..6afd5f416f00 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -71,11 +71,6 @@ def __call__( vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)} # Loop through all modules and perform each one's GAE computation. for module_id, module_vf_preds in vf_preds.items(): - # Skip those outputs of RLModules that are not implementers of - # `ValueFunctionAPI`. - if module_vf_preds is None: - continue - module = rl_module[module_id] device = module_vf_preds.device # Convert to numpy for the upcoming GAE computations. From 37845b2e339188f21fc7a941a86035cd664481d2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Wed, 7 Jan 2026 15:02:43 -0600 Subject: [PATCH 09/13] Re-added episode postproc and updated expected results. Signed-off-by: Matthew --- rllib/examples/algorithms/mappo/mappo_learner.py | 4 ++++ .../multi_agent/pettingzoo_shared_value_function.py | 12 ++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py index 7081e99c3f3b..7f6f8ac95ea4 100644 --- a/rllib/examples/algorithms/mappo/mappo_learner.py +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -6,6 +6,9 @@ LEARNER_RESULTS_KL_KEY, PPOConfig, ) +from ray.rllib.connectors.learner import ( + AddOneTsToEpisodesAndTruncate, +) from ray.rllib.core.learner.learner import Learner from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( SHARED_CRITIC_ID, @@ -59,6 +62,7 @@ def build(self) -> None: self._learner_connector is not None and self.config.add_default_connectors_to_learner_pipeline ): + self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate()) # At the end of the pipeline (when the batch is already completed), add the # GAE connector, which performs a vf forward pass, then computes the GAE # computations, and puts the results of this (advantages, value targets) diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index d2a425b24297..12cd2befb612 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -28,16 +28,16 @@ ----------------- The above options will typically reach a combined reward of 0 or more before 500k env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve, with the shared critic stabilizing this process significantly. -+-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------+ -| Trial name | status | loc | iter | total time (s) | ts | combined return | return pursuer_0 | return pursuer_1 | -|-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------| -| MAPPO_env_39b0c_00000 | TERMINATED | 172.29.87.208:9993 | 148 | 2690.21 | 592000 | 2.06999 | 38.2254 | -36.1554 | -+-----------------------+------------+--------------------+--------+------------------+--------+-------------------+--------------------+--------------------+ ++-----------------------+------------+--------------------+--------+------------------+ +| Trial name | status | loc | iter | total time (s) | +|-----------------------+------------+--------------------+--------+------------------+ +| MAPPO_env_aaaf6_00000 | TERMINATED | 172.29.87.208:6972 | 56 | 1386.86 | ++-----------------------+------------+--------------------+--------+------------------+ +--------+-------------------+--------------------+--------------------+ | ts | combined return | return pursuer_0 | return pursuer_1 | +--------+-------------------+--------------------+--------------------| -| 592000 | 2.06999 | 38.2254 | -36.1554 | +| 224000 | 29.5466 | 77.6161 | -48.0695 | +--------+-------------------+--------------------+--------------------+ Note that the two agents (`pursuer_0` and `pursuer_1`) are optimized on the exact same From fd7fa91b3a13200346b9b7124e568ca32d02b481 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 11 Jan 2026 09:00:30 -0600 Subject: [PATCH 10/13] Fixed actor module batching and loss masking. Signed-off-by: Matthew --- .../mappo/connectors/general_advantage_estimation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index 6afd5f416f00..688d2d81621e 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -64,11 +64,13 @@ def __call__( and (not isinstance(rl_module[k], SelfSupervisedLossAPI)) ] critic_batch[Columns.OBS] = torch.cat( - [batch[k][Columns.OBS] for k in obs_mids], dim=1 + [batch[k][Columns.OBS] for k in obs_mids], dim=-1 ) + if (Columns.LOSS_MASK in batch[obs_mids[0]]): + critic_batch[Columns.LOSS_MASK] = batch[obs_mids[0]][Columns.LOSS_MASK] # Compute value predictions vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch) - vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)} + vf_preds = {mid: vf_preds[..., i] for i, mid in enumerate(obs_mids)} # Loop through all modules and perform each one's GAE computation. for module_id, module_vf_preds in vf_preds.items(): module = rl_module[module_id] @@ -136,10 +138,10 @@ def __call__( batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets # Add GAE results to the critic batch critic_batch[Postprocessing.VALUE_TARGETS] = np.stack( - [batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=1 + [batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=-1 ) critic_batch[Postprocessing.ADVANTAGES] = np.stack( - [batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=1 + [batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=-1 ) batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch # Convert all GAE results to tensors. From 57917298da8f12a70b3f68120522a922ae65312c Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 11 Jan 2026 09:24:26 -0600 Subject: [PATCH 11/13] Patched critic loss pooling. Signed-off-by: Matthew --- rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index 6fffeac8e09f..85c3179b0dab 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -61,6 +61,7 @@ def compute_loss_for_critic(self, batch: Dict[str, Any]): vf_targets = batch[Postprocessing.VALUE_TARGETS] # Compute a value function loss. vf_loss = torch.pow(vf_preds - vf_targets, 2.0) + vf_loss = vf_loss.mean(dim=-1) # Reduce for accurate masked mean vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) mean_vf_loss = possibly_masked_mean(vf_loss_clipped) mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) From 3f2f51b120dba405574d5111526ace74f5a8f950 Mon Sep 17 00:00:00 2001 From: Matthew Date: Sun, 11 Jan 2026 17:40:02 -0600 Subject: [PATCH 12/13] Linted. Signed-off-by: Matthew --- .../algorithms/mappo/connectors/general_advantage_estimation.py | 2 +- rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py index 688d2d81621e..dcca1146302e 100644 --- a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -66,7 +66,7 @@ def __call__( critic_batch[Columns.OBS] = torch.cat( [batch[k][Columns.OBS] for k in obs_mids], dim=-1 ) - if (Columns.LOSS_MASK in batch[obs_mids[0]]): + if Columns.LOSS_MASK in batch[obs_mids[0]]: critic_batch[Columns.LOSS_MASK] = batch[obs_mids[0]][Columns.LOSS_MASK] # Compute value predictions vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch) diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py index 85c3179b0dab..44b55f8e0539 100644 --- a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -61,7 +61,7 @@ def compute_loss_for_critic(self, batch: Dict[str, Any]): vf_targets = batch[Postprocessing.VALUE_TARGETS] # Compute a value function loss. vf_loss = torch.pow(vf_preds - vf_targets, 2.0) - vf_loss = vf_loss.mean(dim=-1) # Reduce for accurate masked mean + vf_loss = vf_loss.mean(dim=-1) # Reduce for accurate masked mean vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) mean_vf_loss = possibly_masked_mean(vf_loss_clipped) mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) From 30f99db2bf56af596ab842c9b399c04325a620d6 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 19 Jan 2026 10:11:35 -0600 Subject: [PATCH 13/13] Cleaned up code and improved inheritance. Signed-off-by: Matthew --- rllib/examples/algorithms/mappo/mappo.py | 10 +-- .../algorithms/mappo/mappo_catalog.py | 71 +++---------------- 2 files changed, 13 insertions(+), 68 deletions(-) diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py index 920b972b4a5b..07d68973f7aa 100644 --- a/rllib/examples/algorithms/mappo/mappo.py +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -31,7 +31,7 @@ def get_default_config(cls) -> AlgorithmConfig: return MAPPOConfig() -class MAPPOConfig(AlgorithmConfig): # AlgorithmConfig -> PPOConfig -> MAPPO +class MAPPOConfig(AlgorithmConfig): """Defines a configuration class from which a MAPPO Algorithm can be built.""" def __init__(self, algo_class=None): @@ -74,15 +74,11 @@ def __init__(self, algo_class=None): @override(AlgorithmConfig) def get_default_rl_module_spec(self) -> RLModuleSpec: - if self.framework_str == "torch": - return RLModuleSpec(module_class=DefaultMAPPOTorchRLModule) - raise NotImplementedError() + return RLModuleSpec(module_class=DefaultMAPPOTorchRLModule) @override(AlgorithmConfig) def get_default_learner_class(self) -> Union[Type["Learner"], str]: - if self.framework_str == "torch": - return MAPPOTorchLearner - raise NotImplementedError() + return MAPPOTorchLearner @override(AlgorithmConfig) def training( diff --git a/rllib/examples/algorithms/mappo/mappo_catalog.py b/rllib/examples/algorithms/mappo/mappo_catalog.py index 08d2e8090fef..ac422e60005f 100644 --- a/rllib/examples/algorithms/mappo/mappo_catalog.py +++ b/rllib/examples/algorithms/mappo/mappo_catalog.py @@ -1,19 +1,13 @@ -# @title MAPPOCatalog # __sphinx_doc_begin__ import gymnasium as gym -from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian -from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.core.models.base import Encoder from ray.rllib.core.models.catalog import Catalog -from ray.rllib.core.models.configs import ( - FreeLogStdMLPHeadConfig, - MLPHeadConfig, -) from ray.rllib.utils import override -from ray.rllib.utils.annotations import OverrideToImplementCustomLogic -class MAPPOCatalog(Catalog): +class MAPPOCatalog(PPOCatalog): """The Catalog class used to build models for MAPPO. MAPPOCatalog provides the following models: @@ -40,15 +34,18 @@ def __init__( action_space: The action space for the Pi Head. model_config_dict: The model config to use. """ - super().__init__( + Catalog.__init__( # Skip PPOCatalog.__init__, since it overrides the encoder configs. + self, observation_space=observation_space, action_space=action_space, model_config_dict=model_config_dict, ) - # self.encoder_config = self._encoder_config - self.pi_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] - self.pi_head_activation = self._model_config_dict["head_fcnet_activation"] + # There is no vf head; the names below are held over from PPOCatalog.build_pi_head + self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_and_vf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] # We don't have the exact (framework specific) action dist class yet and thus # cannot determine the exact number of output nodes (action space) required. # -> Build pi config only in the `self.build_pi_head` method. @@ -59,53 +56,5 @@ def build_encoder(self, framework: str) -> Encoder: """Builds the encoder.""" return self.encoder_config.build(framework=framework) - @OverrideToImplementCustomLogic - def build_pi_head(self, framework: str) -> Model: - """Builds the policy head. - - The default behavior is to build the head from the pi_head_config. - This can be overridden to build a custom policy head as a means of configuring the behavior of a MAPPORLModule implementation. - - Args: - framework: The framework to use. Either "torch" or "tf2". - - Returns: - The policy head. - """ - # Get action_distribution_cls to find out about the output dimension for pi_head - action_distribution_cls = self.get_action_dist_cls(framework=framework) - if self._model_config_dict["free_log_std"]: - _check_if_diag_gaussian( - action_distribution_cls=action_distribution_cls, framework=framework - ) - is_diag_gaussian = True - else: - is_diag_gaussian = _check_if_diag_gaussian( - action_distribution_cls=action_distribution_cls, - framework=framework, - no_error=True, - ) - required_output_dim = action_distribution_cls.required_input_dim( - space=self.action_space, model_config=self._model_config_dict - ) - # Now that we have the action dist class and number of outputs, we can define - # our pi-config and build the pi head. - pi_head_config_class = ( - FreeLogStdMLPHeadConfig - if self._model_config_dict["free_log_std"] - else MLPHeadConfig - ) - self.pi_head_config = pi_head_config_class( - input_dims=self.latent_dims, - hidden_layer_dims=self.pi_head_hiddens, - hidden_layer_activation=self.pi_head_activation, - output_layer_dim=required_output_dim, - output_layer_activation="linear", - clip_log_std=is_diag_gaussian, - log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20), - ) - - return self.pi_head_config.build(framework=framework) - # __sphinx_doc_end__