diff --git a/rllib/BUILD.bazel b/rllib/BUILD.bazel index fa4cebba23f8..946230bab75f 100644 --- a/rllib/BUILD.bazel +++ b/rllib/BUILD.bazel @@ -4694,15 +4694,24 @@ py_test( ], ) -# TODO (sven): Activate this test once this script is ready. -# py_test( -# name = "examples/multi_agent/pettingzoo_shared_value_function", -# main = "examples/multi_agent/pettingzoo_shared_value_function.py", -# tags = ["team:rllib", "exclusive", "examples"], -# size = "large", -# srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], -# args = ["--num-agents=2", "--as-test", "--framework=torch", "--stop-reward=-100.0", "--num-cpus=4"], -# ) +py_test( + name = "examples/multi_agent/pettingzoo_shared_value_function", + size = "large", + srcs = ["examples/multi_agent/pettingzoo_shared_value_function.py"], + args = [ + "--num-agents=2", + "--as-test", + "--framework=torch", + "--stop-reward=-100.0", + "--num-cpus=4", + ], + main = "examples/multi_agent/pettingzoo_shared_value_function.py", + tags = [ + "examples", + "exclusive", + "team:rllib", + ], +) py_test( name = "examples/checkpoints/restore_1_of_n_agents_from_checkpoint", diff --git a/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py new file mode 100644 index 000000000000..dcca1146302e --- /dev/null +++ b/rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py @@ -0,0 +1,170 @@ +from typing import Any, Dict, List + +import numpy as np + +from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets +from ray.rllib.utils.postprocessing.zero_padding import ( + split_and_zero_pad_n_episodes, + unpad_data_if_necessary, +) +from ray.rllib.utils.typing import EpisodeType + +torch, nn = try_import_torch() + +SHARED_CRITIC_ID = "shared_critic" + + +class MAPPOGAEConnector(ConnectorV2): + def __init__( + self, + input_observation_space=None, + input_action_space=None, + *, + gamma, + lambda_, + ): + super().__init__(input_observation_space, input_action_space) + self.gamma = gamma + self.lambda_ = lambda_ + # Internal numpy-to-tensor connector to translate GAE results (advantages and + # vf targets) into tensors. + self._numpy_to_tensor_connector = None + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: MultiRLModule, + episodes: List[EpisodeType], + batch: Dict[str, Any], + **kwargs, + ): + # Device to place all GAE result tensors (advantages and value targets) on. + device = None + # Extract all single-agent episodes. + sa_episodes_list = list( + self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False) + ) + # Perform the value net's forward pass. + critic_batch = {} + # Concatenate all agent observations in batch, using a fixed order + obs_mids = [ + k + for k in sorted(batch.keys()) + if (Columns.OBS in batch[k]) + and (not isinstance(rl_module[k], SelfSupervisedLossAPI)) + ] + critic_batch[Columns.OBS] = torch.cat( + [batch[k][Columns.OBS] for k in obs_mids], dim=-1 + ) + if Columns.LOSS_MASK in batch[obs_mids[0]]: + critic_batch[Columns.LOSS_MASK] = batch[obs_mids[0]][Columns.LOSS_MASK] + # Compute value predictions + vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch) + vf_preds = {mid: vf_preds[..., i] for i, mid in enumerate(obs_mids)} + # Loop through all modules and perform each one's GAE computation. + for module_id, module_vf_preds in vf_preds.items(): + module = rl_module[module_id] + device = module_vf_preds.device + # Convert to numpy for the upcoming GAE computations. + module_vf_preds = convert_to_numpy(module_vf_preds) + + # Collect (single-agent) episode lengths for this particular module. + episode_lens = [ + len(e) for e in sa_episodes_list if e.module_id in [None, module_id] + ] + + # Remove all zero-padding again, if applicable, for the upcoming + # GAE computations. + module_vf_preds = unpad_data_if_necessary(episode_lens, module_vf_preds) + # Compute value targets. + module_value_targets = compute_value_targets( + values=module_vf_preds, + rewards=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.REWARDS]), + ), + terminateds=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.TERMINATEDS]), + ), + truncateds=unpad_data_if_necessary( + episode_lens, + convert_to_numpy(batch[module_id][Columns.TRUNCATEDS]), + ), + gamma=self.gamma, + lambda_=self.lambda_, + ) + assert module_value_targets.shape[0] == sum(episode_lens) + + module_advantages = module_value_targets - module_vf_preds + # Drop vf-preds, not needed in loss. Note that in the DefaultPPORLModule, + # vf-preds are recomputed with each `forward_train` call anyway to compute + # the vf loss. + # Standardize advantages (used for more stable and better weighted + # policy gradient computations). + module_advantages = (module_advantages - module_advantages.mean()) / max( + 1e-4, module_advantages.std() + ) + + # Zero-pad the new computations, if necessary. + if module.is_stateful(): + module_advantages = np.stack( + split_and_zero_pad_n_episodes( + module_advantages, + episode_lens=episode_lens, + max_seq_len=module.model_config["max_seq_len"], + ), + axis=0, + ) + module_value_targets = np.stack( + split_and_zero_pad_n_episodes( + module_value_targets, + episode_lens=episode_lens, + max_seq_len=module.model_config["max_seq_len"], + ), + axis=0, + ) + batch[module_id][Postprocessing.ADVANTAGES] = module_advantages + batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets + # Add GAE results to the critic batch + critic_batch[Postprocessing.VALUE_TARGETS] = np.stack( + [batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=-1 + ) + critic_batch[Postprocessing.ADVANTAGES] = np.stack( + [batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=-1 + ) + batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch + # Convert all GAE results to tensors. + if self._numpy_to_tensor_connector is None: + self._numpy_to_tensor_connector = NumpyToTensor( + as_learner_connector=True, device=device + ) + tensor_results = self._numpy_to_tensor_connector( + rl_module=rl_module, + batch={ + mid: { + Postprocessing.ADVANTAGES: module_batch[Postprocessing.ADVANTAGES], + Postprocessing.VALUE_TARGETS: ( + module_batch[Postprocessing.VALUE_TARGETS] + ), + } + for mid, module_batch in batch.items() + if (mid == SHARED_CRITIC_ID) or (mid in vf_preds) + }, + episodes=episodes, + ) + # Move converted tensors back to `batch`. + for mid, module_batch in tensor_results.items(): + batch[mid].update(module_batch) + + return batch diff --git a/rllib/examples/algorithms/mappo/default_mappo_rl_module.py b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py new file mode 100644 index 000000000000..7b86ef73391c --- /dev/null +++ b/rllib/examples/algorithms/mappo/default_mappo_rl_module.py @@ -0,0 +1,54 @@ +import abc +from typing import List + +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + OverrideToImplementCustomLogic_CallToSuperRecommended, + override, +) +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC): + """Default RLModule used by MAPPO, if user does not specify a custom RLModule. + + Users who want to train their RLModules with MAPPO may implement any RLModule (or TorchRLModule) subclass. + """ + + @override(RLModule) + def setup(self): + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) + # __sphinx_doc_end__ + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(InferenceOnlyAPI) + def get_non_inference_attributes(self) -> List[str]: + """Return attributes, which are NOT inference-only (only used for training).""" + return [] diff --git a/rllib/examples/algorithms/mappo/mappo.py b/rllib/examples/algorithms/mappo/mappo.py new file mode 100644 index 000000000000..07d68973f7aa --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo.py @@ -0,0 +1,160 @@ +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.ppo.ppo import PPO +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.utils.annotations import override + +if TYPE_CHECKING: + from ray.rllib.core.learner.learner import Learner + +from ray.rllib.examples.algorithms.mappo.torch.default_mappo_torch_rl_module import ( + DefaultMAPPOTorchRLModule, +) +from ray.rllib.examples.algorithms.mappo.torch.mappo_torch_learner import ( + MAPPOTorchLearner, +) + +logger = logging.getLogger(__name__) + +LEARNER_RESULTS_KL_KEY = "mean_kl_loss" +LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff" +LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" + + +class MAPPO(PPO): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return MAPPOConfig() + + +class MAPPOConfig(AlgorithmConfig): + """Defines a configuration class from which a MAPPO Algorithm can be built.""" + + def __init__(self, algo_class=None): + """Initializes a MAPPOConfig instance.""" + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + + super().__init__(algo_class=algo_class or MAPPO) + + # fmt: off + # __sphinx_doc_begin__ + self.lr = 5e-5 + self.rollout_fragment_length = "auto" + self.train_batch_size = 4000 + + # MAPPO specific settings: + self.num_epochs = 30 + self.minibatch_size = 128 + self.shuffle_batch_per_epoch = True + self.lambda_ = 1.0 + self.use_kl_loss = True + self.kl_coeff = 0.2 + self.kl_target = 0.01 + self.entropy_coeff = 0.0 + self.clip_param = 0.3 + self.vf_clip_param = 10.0 + self.grad_clip = None + + # Override some of AlgorithmConfig's default values with MAPPO-specific values. + self.num_env_runners = 2 + # __sphinx_doc_end__ + # fmt: on + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> RLModuleSpec: + return RLModuleSpec(module_class=DefaultMAPPOTorchRLModule) + + @override(AlgorithmConfig) + def get_default_learner_class(self) -> Union[Type["Learner"], str]: + return MAPPOTorchLearner + + @override(AlgorithmConfig) + def training( + self, + *, + lambda_: Optional[float] = NotProvided, + use_kl_loss: Optional[bool] = NotProvided, + kl_coeff: Optional[float] = NotProvided, + kl_target: Optional[float] = NotProvided, + entropy_coeff: Optional[float] = NotProvided, + clip_param: Optional[float] = NotProvided, + vf_clip_param: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + **kwargs, + ) -> "MAPPOConfig": + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + if lambda_ is not NotProvided: + self.lambda_ = lambda_ + if use_kl_loss is not NotProvided: + self.use_kl_loss = use_kl_loss + if kl_coeff is not NotProvided: + self.kl_coeff = kl_coeff + if kl_target is not NotProvided: + self.kl_target = kl_target + if entropy_coeff is not NotProvided: + self.entropy_coeff = entropy_coeff + if clip_param is not NotProvided: + self.clip_param = clip_param + if vf_clip_param is not NotProvided: + self.vf_clip_param = vf_clip_param + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + # Synchronous sampling, on-policy/PPO algos -> Check mismatches between + # `rollout_fragment_length` and `train_batch_size_per_learner` to avoid user + # confusion. + # TODO (sven): Make rollout_fragment_length a property and create a private + # attribute to store (possibly) user provided value (or "auto") in. Deprecate + # `self.get_rollout_fragment_length()`. + self.validate_train_batch_size_vs_rollout_fragment_length() + + # SGD minibatch size must be smaller than train_batch_size (b/c + # we subsample a batch of `minibatch_size` from the train-batch for + # each `num_epochs`). + if ( + not self.enable_rl_module_and_learner + and self.minibatch_size > self.train_batch_size + ): + self._value_error( + f"`minibatch_size` ({self.minibatch_size}) must be <= " + f"`train_batch_size` ({self.train_batch_size}). In MAPPO, the train batch" + f" will be split into {self.minibatch_size} chunks, each of which " + f"is iterated over (used for updating the policy) {self.num_epochs} " + "times." + ) + elif self.enable_rl_module_and_learner: + mbs = self.minibatch_size + tbs = self.train_batch_size_per_learner or self.train_batch_size + if isinstance(mbs, int) and isinstance(tbs, int) and mbs > tbs: + self._value_error( + f"`minibatch_size` ({mbs}) must be <= " + f"`train_batch_size_per_learner` ({tbs}). In MAPPO, the train batch" + f" will be split into {mbs} chunks, each of which is iterated over " + f"(used for updating the policy) {self.num_epochs} times." + ) + if isinstance(self.entropy_coeff, float) and self.entropy_coeff < 0.0: + self._value_error("`entropy_coeff` must be >= 0.0") + + @property + @override(AlgorithmConfig) + def _model_config_auto_includes(self) -> Dict[str, Any]: + return super()._model_config_auto_includes | {} diff --git a/rllib/examples/algorithms/mappo/mappo_catalog.py b/rllib/examples/algorithms/mappo/mappo_catalog.py new file mode 100644 index 000000000000..ac422e60005f --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo_catalog.py @@ -0,0 +1,60 @@ +# __sphinx_doc_begin__ +import gymnasium as gym + +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.core.models.base import Encoder +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.utils import override + + +class MAPPOCatalog(PPOCatalog): + """The Catalog class used to build models for MAPPO. + + MAPPOCatalog provides the following models: + - Encoder: The encoder used to encode the observations. + - Pi Head: The head used to compute the policy logits. + + Any custom Encoder can be built by overriding the build_encoder() method. Alternatively, the EncoderConfig at MAPPOCatalog.encoder_config can be overridden to build a custom Encoder during RLModule runtime. + + Any custom head can be built by overriding the build_pi_head() method. Alternatively, the PiHeadConfig can be overridden to build a custom head during RLModule runtime. + + Any module built for exploration or inference is built with the flag `ìnference_only=True` and does not contain a value network. This flag can be set in the `SingleAgentModuleSpec` through the `inference_only` boolean flag. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + model_config_dict: dict, + ): + """Initializes the MAPPOCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + Catalog.__init__( # Skip PPOCatalog.__init__, since it overrides the encoder configs. + self, + observation_space=observation_space, + action_space=action_space, + model_config_dict=model_config_dict, + ) + self.encoder_config = self._encoder_config + # There is no vf head; the names below are held over from PPOCatalog.build_pi_head + self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.pi_and_vf_head_activation = self._model_config_dict[ + "head_fcnet_activation" + ] + # We don't have the exact (framework specific) action dist class yet and thus + # cannot determine the exact number of output nodes (action space) required. + # -> Build pi config only in the `self.build_pi_head` method. + self.pi_head_config = None + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the encoder.""" + return self.encoder_config.build(framework=framework) + + +# __sphinx_doc_end__ diff --git a/rllib/examples/algorithms/mappo/mappo_learner.py b/rllib/examples/algorithms/mappo/mappo_learner.py new file mode 100644 index 000000000000..7f6f8ac95ea4 --- /dev/null +++ b/rllib/examples/algorithms/mappo/mappo_learner.py @@ -0,0 +1,141 @@ +import abc +from typing import Any, Dict + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, + PPOConfig, +) +from ray.rllib.connectors.learner import ( + AddOneTsToEpisodesAndTruncate, +) +from ray.rllib.core.learner.learner import Learner +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, + MAPPOGAEConnector, +) +from ray.rllib.utils.annotations import ( + OverrideToImplementCustomLogic_CallToSuperRecommended, + override, +) +from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict +from ray.rllib.utils.metrics import ( + NUM_ENV_STEPS_SAMPLED_LIFETIME, + NUM_MODULE_STEPS_TRAINED, +) +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ModuleID, TensorType + + +class MAPPOLearner(Learner): + @override(Learner) + def build(self) -> None: + super().build() + + # Dict mapping module IDs to the respective entropy Scheduler instance. + self.entropy_coeff_schedulers_per_module: Dict[ + ModuleID, Scheduler + ] = LambdaDefaultDict( + lambda module_id: Scheduler( + fixed_value_or_schedule=( + self.config.get_config_for_module(module_id).entropy_coeff + ), + framework=self.framework, + device=self._device, + ) + ) + + # Set up KL coefficient variables (per module). + # Note that the KL coeff is not controlled by a Scheduler, but seeks + # to stay close to a given kl_target value. + self.curr_kl_coeffs_per_module: Dict[ModuleID, TensorType] = LambdaDefaultDict( + lambda module_id: self._get_tensor_variable( + self.config.get_config_for_module(module_id).kl_coeff + ) + ) + + # Extend all episodes by one artificial timestep to allow the value function net + # to compute the bootstrap values (and add a mask to the batch to know, which + # slots to mask out). + if ( + self._learner_connector is not None + and self.config.add_default_connectors_to_learner_pipeline + ): + self._learner_connector.prepend(AddOneTsToEpisodesAndTruncate()) + # At the end of the pipeline (when the batch is already completed), add the + # GAE connector, which performs a vf forward pass, then computes the GAE + # computations, and puts the results of this (advantages, value targets) + # directly back in the batch. This is then the batch used for + # `forward_train` and `compute_losses`. + self._learner_connector.append( + MAPPOGAEConnector( + gamma=self.config.gamma, + lambda_=self.config.lambda_, + ) + ) + + @override(Learner) + def remove_module(self, module_id: ModuleID, **kwargs): + marl_spec = super().remove_module(module_id, **kwargs) + self.entropy_coeff_schedulers_per_module.pop(module_id, None) + self.curr_kl_coeffs_per_module.pop(module_id, None) + return marl_spec + + @OverrideToImplementCustomLogic_CallToSuperRecommended + @override(Learner) + def after_gradient_based_update( + self, + *, + timesteps: Dict[str, Any], + ) -> None: + super().after_gradient_based_update(timesteps=timesteps) + + for module_id, module in self.module._rl_modules.items(): + if module_id == SHARED_CRITIC_ID: + continue # Policy terms irrelevant to shared critic. + config = self.config.get_config_for_module(module_id) + # Update entropy coefficient via our Scheduler. + new_entropy_coeff = self.entropy_coeff_schedulers_per_module[ + module_id + ].update(timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)) + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY), + new_entropy_coeff, + window=1, + ) + if ( + config.use_kl_loss + and self.metrics.peek((module_id, NUM_MODULE_STEPS_TRAINED), default=0) + > 0 + and (module_id, LEARNER_RESULTS_KL_KEY) in self.metrics + ): + kl_loss = convert_to_numpy( + self.metrics.peek((module_id, LEARNER_RESULTS_KL_KEY)) + ) + self._update_module_kl_coeff( + module_id=module_id, + config=config, + kl_loss=kl_loss, + ) + + @abc.abstractmethod + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + """Dynamically update the KL loss coefficients of each module. + + The update is completed using the mean KL divergence between the action + distributions current policy and old policy of each module. That action + distribution is computed during the most recent update/call to `compute_loss`. + + Args: + module_id: The module whose KL loss coefficient to update. + config: The AlgorithmConfig specific to the given `module_id`. + kl_loss: The mean KL loss of the module, computed inside + `compute_loss_for_module()`. + """ diff --git a/rllib/examples/algorithms/mappo/shared_critic_catalog.py b/rllib/examples/algorithms/mappo/shared_critic_catalog.py new file mode 100644 index 000000000000..e346b41e026f --- /dev/null +++ b/rllib/examples/algorithms/mappo/shared_critic_catalog.py @@ -0,0 +1,72 @@ +import gymnasium as gym + +from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.models.configs import ( + MLPHeadConfig, +) +from ray.rllib.utils import override +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic + + +class SharedCriticCatalog(Catalog): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, # TODO: Remove? + model_config_dict: dict, + ): + """Initializes the PPOCatalog. + + Args: + observation_space: The observation space of the Encoder. + action_space: The action space for the Pi Head. + model_config_dict: The model config to use. + """ + super().__init__( + observation_space=observation_space, + action_space=action_space, # Base Catalog class checks for this. + model_config_dict=model_config_dict, + ) + # We only want one encoder, so we use the base encoder config. + self.encoder_config = self._encoder_config + # Adjust the input and output dimensions of the encoder. + observation_spaces = self._model_config_dict["observation_spaces"] + obs_size = 0 + for agent, obs in observation_spaces.items(): + obs_size += obs.shape[0] # Assume 1D observations + self.encoder_config.input_dims = (obs_size,) + # Value head architecture + self.vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"] + self.vf_head_activation = self._model_config_dict["head_fcnet_activation"] + self.vf_head_config = MLPHeadConfig( + input_dims=self.latent_dims, + hidden_layer_dims=self.vf_head_hiddens, + hidden_layer_activation=self.vf_head_activation, + output_layer_activation="linear", + output_layer_dim=len(observation_spaces), # 1 value pred. per agent + ) + + @override(Catalog) + def build_encoder(self, framework: str) -> Encoder: + """Builds the encoder.""" + return self.encoder_config.build(framework=framework) + + @OverrideToImplementCustomLogic + def build_vf_head(self, framework: str) -> Model: + """Builds the value function head. + + The default behavior is to build the head from the vf_head_config. + This can be overridden to build a custom value function head as a means of + configuring the behavior of a MAPPORLModule implementation. + + Args: + framework: The framework to use. Either "torch" or "tf2". + + Returns: + The value function head. + """ + return self.vf_head_config.build(framework=framework) + + +# __sphinx_doc_end__ diff --git a/rllib/examples/algorithms/mappo/shared_critic_rl_module.py b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py new file mode 100644 index 000000000000..01a85bfd9e24 --- /dev/null +++ b/rllib/examples/algorithms/mappo/shared_critic_rl_module.py @@ -0,0 +1,46 @@ +import abc + +from ray.rllib.core.models.configs import RecurrentEncoderConfig +from ray.rllib.core.rl_module.apis import ValueFunctionAPI +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import ( + override, +) +from ray.util.annotations import DeveloperAPI + + +@DeveloperAPI +class SharedCriticRLModule(RLModule, ValueFunctionAPI, abc.ABC): + """Standard shared critic RLModule usable in MAPPO, if user does not intend to use a custom RLModule. + + Users who want to train custom shared critics with MAPPO may implement any RLModule (or TorchRLModule) subclass as long as the custom class also implements the `ValueFunctionAPI` (see ray.rllib.core.rl_module.apis.value_function_api.py) + """ + + @override(RLModule) + def setup(self): + # __sphinx_doc_begin__ + # If we have a stateful model, states for the critic need to be collected + # during sampling and `inference-only` needs to be `False`. Note, at this + # point the encoder is not built, yet and therefore `is_stateful()` does + # not work. + is_stateful = isinstance( + self.catalog.encoder_config, + RecurrentEncoderConfig, + ) + if is_stateful: + self.inference_only = False + # If this is an `inference_only` Module, we'll have to pass this information + # to the encoder config as well. + if self.inference_only and self.framework == "torch": + self.catalog.encoder_config.inference_only = True + # Build models from catalog. + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.vf = self.catalog.build_vf_head(framework=self.framework) + # __sphinx_doc_end__ + + @override(RLModule) + def get_initial_state(self) -> dict: + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() + else: + return {} diff --git a/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py new file mode 100644 index 000000000000..3207494b2084 --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/default_mappo_torch_rl_module.py @@ -0,0 +1,48 @@ +from typing import Any, Dict + +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.examples.algorithms.mappo.default_mappo_rl_module import ( + DefaultMAPPORLModule, +) +from ray.rllib.examples.algorithms.mappo.mappo_catalog import MAPPOCatalog +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + + +@DeveloperAPI +class DefaultMAPPOTorchRLModule(TorchRLModule, DefaultMAPPORLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = MAPPOCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(RLModule) + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Default forward pass (used for inference and exploration).""" + output = {} + # Encoder forward pass. + encoder_outs = self.encoder(batch) + # Stateful encoder? + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + # Pi head. + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT]) + return output + + @override(RLModule) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Train forward pass.""" + output = {} + encoder_outs = self.encoder(batch) + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT] + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT]) + return output diff --git a/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py new file mode 100644 index 000000000000..44b55f8e0539 --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/mappo_torch_learner.py @@ -0,0 +1,235 @@ +import logging +from collections.abc import Callable +from typing import Any, Dict + +import numpy as np + +from ray.rllib.algorithms.ppo.ppo import ( + LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, + PPOConfig, +) +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner.learner import ( + ENTROPY_KEY, + POLICY_LOSS_KEY, + VF_LOSS_KEY, + Learner, +) +from ray.rllib.core.learner.torch.torch_learner import TorchLearner +from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, +) +from ray.rllib.examples.algorithms.mappo.mappo_learner import MAPPOLearner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_utils import explained_variance +from ray.rllib.utils.typing import ModuleID, TensorType + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class MAPPOTorchLearner(MAPPOLearner, TorchLearner): + """ + Implements MAPPO in Torch, on top of a MAPPOLearner. + """ + + def get_pmm(self, batch: Dict[str, Any]) -> Callable: + """Gets the possibly_masked_mean function""" + if Columns.LOSS_MASK in batch: + mask = batch[Columns.LOSS_MASK] + num_valid = torch.sum(mask) + + def possibly_masked_mean(data_): + return torch.sum(data_[mask]) / num_valid + + else: + possibly_masked_mean = torch.mean + return possibly_masked_mean + + def compute_loss_for_critic(self, batch: Dict[str, Any]): + """Computes the loss for the shared critic module.""" + possibly_masked_mean = self.get_pmm(batch) + module = self.module[SHARED_CRITIC_ID].unwrapped() + vf_preds = module.compute_values(batch) + vf_targets = batch[Postprocessing.VALUE_TARGETS] + # Compute a value function loss. + vf_loss = torch.pow(vf_preds - vf_targets, 2.0) + vf_loss = vf_loss.mean(dim=-1) # Reduce for accurate masked mean + vf_loss_clipped = torch.clamp(vf_loss, 0, self.config.vf_clip_param) + mean_vf_loss = possibly_masked_mean(vf_loss_clipped) + mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) + # record metrics + self.metrics.log_dict( + { + VF_LOSS_KEY: mean_vf_loss, + LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY: mean_vf_unclipped_loss, + LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY: explained_variance( + vf_targets.reshape(-1), vf_preds.reshape(-1) + ), # Flatten multi-agent value predictions + }, + key=SHARED_CRITIC_ID, + window=1, + ) + return mean_vf_loss + + @override(Learner) + def compute_losses( + self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Args: + fwd_out: Output from a call to the `forward_train()` method of the + underlying MultiRLModule (`self.module`) during training + (`self.update()`). + batch: The train batch that was used to compute `fwd_out`. + + Returns: + A dictionary mapping module IDs to individual loss terms. + """ + loss_per_module = {} + # Optimize the critic + loss_per_module[SHARED_CRITIC_ID] = self.compute_loss_for_critic( + batch[SHARED_CRITIC_ID] + ) + # Calculate loss for agent policies + for module_id in fwd_out: + if module_id == SHARED_CRITIC_ID: # Computed above + continue + module_batch = batch[module_id] + module_fwd_out = fwd_out[module_id] + module = self.module[module_id].unwrapped() + if isinstance(module, SelfSupervisedLossAPI): + # For training e.g. intrinsic curiosity modules. + loss = module.compute_self_supervised_loss( + learner=self, + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + else: + # For every module we're going to touch, sans the critic + loss = self.compute_loss_for_module( + module_id=module_id, + config=self.config.get_config_for_module(module_id), + batch=module_batch, + fwd_out=module_fwd_out, + ) + loss_per_module[module_id] = loss + return loss_per_module + + # We strip out the value function optimization here. + @override(TorchLearner) + def compute_loss_for_module( + self, + *, + module_id: ModuleID, + config: PPOConfig, + batch: Dict[str, Any], + fwd_out: Dict[str, TensorType], + ) -> TensorType: + module = self.module[module_id].unwrapped() + possibly_masked_mean = self.get_pmm(batch) + # Possibly apply masking to some sub loss terms and to the total loss term + # at the end. Masking could be used for RNN-based model (zero padded `batch`) + # and for MAPPO's batched value function (and bootstrap value) computations, + # for which we add an (artificial) timestep to each episode to + # simplify the actual computation. + action_dist_class_train = module.get_train_action_dist_cls() + action_dist_class_exploration = module.get_exploration_action_dist_cls() + + curr_action_dist = action_dist_class_train.from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] + ) + prev_action_dist = action_dist_class_exploration.from_logits( + batch[Columns.ACTION_DIST_INPUTS] + ) + + logp_ratio = torch.exp( + curr_action_dist.logp(batch[Columns.ACTIONS]) - batch[Columns.ACTION_LOGP] + ) + + # Only calculate kl loss if necessary (kl-coeff > 0.0). + if config.use_kl_loss: + action_kl = prev_action_dist.kl(curr_action_dist) + mean_kl_loss = possibly_masked_mean(action_kl) + else: + mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) + + curr_entropy = curr_action_dist.entropy() + mean_entropy = possibly_masked_mean(curr_entropy) + + surrogate_loss = torch.min( + batch[Postprocessing.ADVANTAGES] * logp_ratio, + batch[Postprocessing.ADVANTAGES] + * torch.clamp(logp_ratio, 1 - config.clip_param, 1 + config.clip_param), + ) + # Removed critic loss from per-module computation + total_loss = possibly_masked_mean( + -surrogate_loss + - ( + self.entropy_coeff_schedulers_per_module[module_id].get_current_value() + * curr_entropy + ) + ) + + # Add mean_kl_loss (already processed through `possibly_masked_mean`), + # if necessary. + if config.use_kl_loss: + total_loss += self.curr_kl_coeffs_per_module[module_id] * mean_kl_loss + + # Log important loss stats. + self.metrics.log_dict( + { + POLICY_LOSS_KEY: -possibly_masked_mean(surrogate_loss), + ENTROPY_KEY: mean_entropy, + LEARNER_RESULTS_KL_KEY: mean_kl_loss, + }, + key=module_id, + window=1, # <- single items (should not be mean/ema-reduced over time). + ) + # Return the total loss. + return total_loss + + @override(MAPPOLearner) + def _update_module_kl_coeff( + self, + *, + module_id: ModuleID, + config: PPOConfig, + kl_loss: float, + ) -> None: + # Same as PPOTorchLearner + if np.isnan(kl_loss): + logger.warning( + f"KL divergence for Module {module_id} is non-finite, this " + "will likely destabilize your model and the training " + "process. Action(s) in a specific state have near-zero " + "probability. This can happen naturally in deterministic " + "environments where the optimal policy has zero mass for a " + "specific action. To fix this issue, consider setting " + "`kl_coeff` to 0.0 or increasing `entropy_coeff` in your " + "config." + ) + + # Update the KL coefficient. + curr_var = self.curr_kl_coeffs_per_module[module_id] + if kl_loss > 2.0 * config.kl_target: + # TODO (Kourosh) why not 2? + curr_var.data *= 1.5 + elif kl_loss < 0.5 * config.kl_target: + curr_var.data *= 0.5 + + # Log the updated KL-coeff value. + self.metrics.log_value( + (module_id, LEARNER_RESULTS_CURR_KL_COEFF_KEY), + curr_var.item(), + window=1, + ) diff --git a/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py new file mode 100644 index 000000000000..7489d29689f1 --- /dev/null +++ b/rllib/examples/algorithms/mappo/torch/shared_critic_torch_rl_module.py @@ -0,0 +1,39 @@ +import typing +from typing import Any, Optional + +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.examples.algorithms.mappo.shared_critic_catalog import ( + SharedCriticCatalog, +) +from ray.rllib.examples.algorithms.mappo.shared_critic_rl_module import ( + SharedCriticRLModule, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.util.annotations import DeveloperAPI + +torch, nn = try_import_torch() + + +@DeveloperAPI +class SharedCriticTorchRLModule(TorchRLModule, SharedCriticRLModule): + def __init__(self, *args, **kwargs): + catalog_class = kwargs.pop("catalog_class", None) + if catalog_class is None: + catalog_class = SharedCriticCatalog + super().__init__(*args, **kwargs, catalog_class=catalog_class) + + @override(ValueFunctionAPI) + def compute_values( + self, + batch: typing.Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + embeddings = self.encoder(batch)[ENCODER_OUT] + vf_out = self.vf(embeddings) + # Don't squeeze out last dimension (multi node value head). + return vf_out diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py new file mode 100644 index 000000000000..12cd2befb612 --- /dev/null +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -0,0 +1,114 @@ +"""Runs the PettingZoo Waterworld env in RLlib using a shared critic. + +See: https://pettingzoo.farama.org/environments/sisl/waterworld/ +for more details on the environment. + + +How to run this script +---------------------- +`python [script file name].py --num-agents=2` + +Control the number of agents and policies (RLModules) via --num-agents and +--num-policies. + +This works with hundreds of agents and policies, but note that initializing +many policies might take some time. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +The above options will typically reach a combined reward of 0 or more before 500k env timesteps. Keep in mind, though, that due to the separate learned policies in general, one agent's gain (in per-agent reward) might cause the other agent's reward to decrease at the same time. However, over time, both agents should simply improve, with the shared critic stabilizing this process significantly. + ++-----------------------+------------+--------------------+--------+------------------+ +| Trial name | status | loc | iter | total time (s) | +|-----------------------+------------+--------------------+--------+------------------+ +| MAPPO_env_aaaf6_00000 | TERMINATED | 172.29.87.208:6972 | 56 | 1386.86 | ++-----------------------+------------+--------------------+--------+------------------+ + ++--------+-------------------+--------------------+--------------------+ +| ts | combined return | return pursuer_0 | return pursuer_1 | ++--------+-------------------+--------------------+--------------------| +| 224000 | 29.5466 | 77.6161 | -48.0695 | ++--------+-------------------+--------------------+--------------------+ + +Note that the two agents (`pursuer_0` and `pursuer_1`) are optimized on the exact same +objective and thus differences in the rewards can be attributed to weight initialization +(and sampling randomness) only. +""" + +from pettingzoo.sisl import waterworld_v4 + +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +from ray.rllib.examples.algorithms.mappo.connectors.general_advantage_estimation import ( + SHARED_CRITIC_ID, +) +from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig +from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import ( + SharedCriticTorchRLModule, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import register_env + +parser = add_rllib_example_script_args( + default_iters=200, + default_timesteps=1000000, + default_reward=0.0, +) + + +if __name__ == "__main__": + args = parser.parse_args() + + assert args.num_agents > 0, "Must set --num-agents > 0 when running this script!" + + # Here, we use the "Parallel" PettingZoo environment type. + # This allows MAPPO's global observations to be constructed more neatly. + def get_env(_): + return ParallelPettingZooEnv(waterworld_v4.parallel_env()) + + register_env("env", get_env) + + # Policies are called just like the agents (exact 1:1 mapping). + policies = [f"pursuer_{i}" for i in range(args.num_agents)] + + # An agent for each of our policies, and a single shared critic + env_instantiated = get_env({}) # neccessary for non-agent modules + specs = {p: RLModuleSpec() for p in policies} + specs[SHARED_CRITIC_ID] = RLModuleSpec( + module_class=SharedCriticTorchRLModule, + observation_space=env_instantiated.observation_space[policies[0]], + action_space=env_instantiated.action_space[policies[0]], + learner_only=True, # Only build on learner + model_config={"observation_spaces": env_instantiated.observation_space}, + ) + + base_config = ( + MAPPOConfig() + .environment("env") + .multi_agent( + policies=policies + [SHARED_CRITIC_ID], + # Exact 1:1 mapping from AgentID to ModuleID. + policy_mapping_fn=(lambda aid, *args, **kwargs: aid), + ) + .rl_module( + rl_module_spec=MultiRLModuleSpec( + rl_module_specs=specs, + ), + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/multi_agent/shared_encoder_cartpole.py b/rllib/examples/multi_agent/shared_encoder_cartpole.py index d0cc66b18f5d..b61c03e5f119 100644 --- a/rllib/examples/multi_agent/shared_encoder_cartpole.py +++ b/rllib/examples/multi_agent/shared_encoder_cartpole.py @@ -20,7 +20,7 @@ Results to expect ----------------- -Under the shared encoder architecture, the target reward of 700 will typically be reached well before 100,000 iterations. A trial concludes as below: +Under the shared encoder architecture, the target reward of 600 will typically be reached well before 100,000 iterations. A trial concludes as below: +---------------------+------------+-----------------+--------+------------------+-------+-------------------+-------------+-------------+ | Trial name | status | loc | iter | total time (s) | ts | combined return | return p1 | return p0 |