-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[RLlib] Working implementation of pettingzoo_shared_value_function.py #56309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MatthewCWeston
wants to merge
21
commits into
ray-project:master
Choose a base branch
from
MatthewCWeston:pettingzoo_shared_vf
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 9 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
5071568
Initial commit of working code. Will lint later and then submit PR.
MatthewCWeston 1c842c3
Linted and cleaned up the code.
MatthewCWeston c002277
Minor formatting/debugging fixes to PR submission.
MatthewCWeston 55e8989
Fixed an inheritance issue and added global observation handling.
MatthewCWeston 5756bf9
Tabs-->spaces in BUILD to satisfy tests.
MatthewCWeston dc8fdc9
Merge remote-tracking branch 'upstream/master' into pettingzoo_shared_vf
MatthewCWeston 2ec32d4
Cleaned up two lines of code, one of which had a bug in an edge case.
MatthewCWeston dc74d6f
Merge branch 'master' into pettingzoo_shared_vf
pseudo-rnd-thoughts bfa6c08
Re-linted and cleaned up a function signature.
MatthewCWeston 060e2c3
Deleted an unused line of code.
MatthewCWeston 323866a
Merged in the updates from the master branch.
MatthewCWeston 37845b2
Re-added episode postproc and updated expected results.
MatthewCWeston e7b02c4
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston fd7fa91
Fixed actor module batching and loss masking.
MatthewCWeston 5791729
Patched critic loss pooling.
MatthewCWeston 8a69229
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston 3f2f51b
Linted.
MatthewCWeston f645709
Merge branch 'pettingzoo_shared_vf' of https://github.com/MatthewCWes…
MatthewCWeston 30f99db
Cleaned up code and improved inheritance.
MatthewCWeston fca0c28
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston 14d48dd
Merge branch 'master' into pettingzoo_shared_vf
MatthewCWeston File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
173 changes: 173 additions & 0 deletions
173
rllib/examples/algorithms/mappo/connectors/general_advantage_estimation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| from typing import Any, Dict, List | ||
|
|
||
| import numpy as np | ||
|
|
||
| from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor | ||
| from ray.rllib.connectors.connector_v2 import ConnectorV2 | ||
| from ray.rllib.core.columns import Columns | ||
| from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI | ||
| from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule | ||
| from ray.rllib.evaluation.postprocessing import Postprocessing | ||
| from ray.rllib.utils.annotations import override | ||
| from ray.rllib.utils.framework import try_import_torch | ||
| from ray.rllib.utils.numpy import convert_to_numpy | ||
| from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets | ||
| from ray.rllib.utils.postprocessing.zero_padding import ( | ||
| split_and_zero_pad_n_episodes, | ||
| unpad_data_if_necessary, | ||
| ) | ||
| from ray.rllib.utils.typing import EpisodeType | ||
|
|
||
| torch, nn = try_import_torch() | ||
|
|
||
| SHARED_CRITIC_ID = "shared_critic" | ||
|
|
||
|
|
||
| class MAPPOGAEConnector(ConnectorV2): | ||
| def __init__( | ||
| self, | ||
| input_observation_space=None, | ||
| input_action_space=None, | ||
| *, | ||
| gamma, | ||
| lambda_, | ||
| ): | ||
| super().__init__(input_observation_space, input_action_space) | ||
| self.gamma = gamma | ||
| self.lambda_ = lambda_ | ||
| # Internal numpy-to-tensor connector to translate GAE results (advantages and | ||
| # vf targets) into tensors. | ||
| self._numpy_to_tensor_connector = None | ||
|
|
||
| @override(ConnectorV2) | ||
| def __call__( | ||
| self, | ||
| *, | ||
| rl_module: MultiRLModule, | ||
| episodes: List[EpisodeType], | ||
| batch: Dict[str, Any], | ||
| **kwargs, | ||
| ): | ||
| # Device to place all GAE result tensors (advantages and value targets) on. | ||
| device = None | ||
| # Extract all single-agent episodes. | ||
| sa_episodes_list = list( | ||
| self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False) | ||
| ) | ||
| # Perform the value net's forward pass. | ||
| critic_batch = {} | ||
| # Concatenate all agent observations in batch, using a fixed order | ||
| obs_mids = [ | ||
| k | ||
| for k in sorted(batch.keys()) | ||
| if (Columns.OBS in batch[k]) | ||
| and (not isinstance(rl_module[k], SelfSupervisedLossAPI)) | ||
| ] | ||
| critic_batch[Columns.OBS] = torch.cat( | ||
| [batch[k][Columns.OBS] for k in obs_mids], dim=1 | ||
| ) | ||
| # Compute value predictions | ||
| vf_preds = rl_module[SHARED_CRITIC_ID].compute_values(critic_batch) | ||
| vf_preds = {mid: vf_preds[:, i] for i, mid in enumerate(obs_mids)} | ||
| # Loop through all modules and perform each one's GAE computation. | ||
| for module_id, module_vf_preds in vf_preds.items(): | ||
| # Skip those outputs of RLModules that are not implementers of | ||
| # `ValueFunctionAPI`. | ||
| if module_vf_preds is None: | ||
| continue | ||
MatthewCWeston marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| module = rl_module[module_id] | ||
| device = module_vf_preds.device | ||
| # Convert to numpy for the upcoming GAE computations. | ||
| module_vf_preds = convert_to_numpy(module_vf_preds) | ||
|
|
||
| # Collect (single-agent) episode lengths for this particular module. | ||
| episode_lens = [ | ||
| len(e) for e in sa_episodes_list if e.module_id in [None, module_id] | ||
| ] | ||
|
|
||
| # Remove all zero-padding again, if applicable, for the upcoming | ||
| # GAE computations. | ||
| module_vf_preds = unpad_data_if_necessary(episode_lens, module_vf_preds) | ||
| # Compute value targets. | ||
| module_value_targets = compute_value_targets( | ||
| values=module_vf_preds, | ||
| rewards=unpad_data_if_necessary( | ||
| episode_lens, | ||
| convert_to_numpy(batch[module_id][Columns.REWARDS]), | ||
| ), | ||
| terminateds=unpad_data_if_necessary( | ||
| episode_lens, | ||
| convert_to_numpy(batch[module_id][Columns.TERMINATEDS]), | ||
| ), | ||
| truncateds=unpad_data_if_necessary( | ||
| episode_lens, | ||
| convert_to_numpy(batch[module_id][Columns.TRUNCATEDS]), | ||
| ), | ||
| gamma=self.gamma, | ||
| lambda_=self.lambda_, | ||
| ) | ||
| assert module_value_targets.shape[0] == sum(episode_lens) | ||
|
|
||
| module_advantages = module_value_targets - module_vf_preds | ||
| # Drop vf-preds, not needed in loss. Note that in the DefaultPPORLModule, | ||
| # vf-preds are recomputed with each `forward_train` call anyway to compute | ||
| # the vf loss. | ||
| # Standardize advantages (used for more stable and better weighted | ||
| # policy gradient computations). | ||
| module_advantages = (module_advantages - module_advantages.mean()) / max( | ||
| 1e-4, module_advantages.std() | ||
| ) | ||
|
|
||
| # Zero-pad the new computations, if necessary. | ||
| if module.is_stateful(): | ||
| module_advantages = np.stack( | ||
| split_and_zero_pad_n_episodes( | ||
| module_advantages, | ||
| episode_lens=episode_lens, | ||
| max_seq_len=module.model_config["max_seq_len"], | ||
| ), | ||
| axis=0, | ||
| ) | ||
| module_value_targets = np.stack( | ||
| split_and_zero_pad_n_episodes( | ||
| module_value_targets, | ||
| episode_lens=episode_lens, | ||
| max_seq_len=module.model_config["max_seq_len"], | ||
| ), | ||
| axis=0, | ||
| ) | ||
| batch[module_id][Postprocessing.ADVANTAGES] = module_advantages | ||
| batch[module_id][Postprocessing.VALUE_TARGETS] = module_value_targets | ||
| # Add GAE results to the critic batch | ||
| critic_batch[Postprocessing.VALUE_TARGETS] = np.stack( | ||
| [batch[mid][Postprocessing.VALUE_TARGETS] for mid in obs_mids], axis=1 | ||
| ) | ||
| critic_batch[Postprocessing.ADVANTAGES] = np.stack( | ||
| [batch[mid][Postprocessing.ADVANTAGES] for mid in obs_mids], axis=1 | ||
| ) | ||
MatthewCWeston marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| batch[SHARED_CRITIC_ID] = critic_batch # Critic data -> training batch | ||
MatthewCWeston marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Convert all GAE results to tensors. | ||
| if self._numpy_to_tensor_connector is None: | ||
| self._numpy_to_tensor_connector = NumpyToTensor( | ||
| as_learner_connector=True, device=device | ||
| ) | ||
| tensor_results = self._numpy_to_tensor_connector( | ||
| rl_module=rl_module, | ||
| batch={ | ||
| mid: { | ||
| Postprocessing.ADVANTAGES: module_batch[Postprocessing.ADVANTAGES], | ||
| Postprocessing.VALUE_TARGETS: ( | ||
| module_batch[Postprocessing.VALUE_TARGETS] | ||
| ), | ||
| } | ||
| for mid, module_batch in batch.items() | ||
| if (mid == SHARED_CRITIC_ID) or (mid in vf_preds) | ||
| }, | ||
| episodes=episodes, | ||
MatthewCWeston marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| # Move converted tensors back to `batch`. | ||
| for mid, module_batch in tensor_results.items(): | ||
| batch[mid].update(module_batch) | ||
|
|
||
| return batch | ||
54 changes: 54 additions & 0 deletions
54
rllib/examples/algorithms/mappo/default_mappo_rl_module.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import abc | ||
| from typing import List | ||
|
|
||
| from ray.rllib.core.models.configs import RecurrentEncoderConfig | ||
| from ray.rllib.core.rl_module.apis import InferenceOnlyAPI | ||
| from ray.rllib.core.rl_module.rl_module import RLModule | ||
| from ray.rllib.utils.annotations import ( | ||
| OverrideToImplementCustomLogic_CallToSuperRecommended, | ||
| override, | ||
| ) | ||
| from ray.util.annotations import DeveloperAPI | ||
|
|
||
|
|
||
| @DeveloperAPI | ||
| class DefaultMAPPORLModule(RLModule, InferenceOnlyAPI, abc.ABC): | ||
| """Default RLModule used by MAPPO, if user does not specify a custom RLModule. | ||
|
|
||
| Users who want to train their RLModules with MAPPO may implement any RLModule (or TorchRLModule) subclass. | ||
| """ | ||
|
|
||
| @override(RLModule) | ||
| def setup(self): | ||
| # __sphinx_doc_begin__ | ||
| # If we have a stateful model, states for the critic need to be collected | ||
| # during sampling and `inference-only` needs to be `False`. Note, at this | ||
| # point the encoder is not built, yet and therefore `is_stateful()` does | ||
| # not work. | ||
| is_stateful = isinstance( | ||
| self.catalog.encoder_config, | ||
| RecurrentEncoderConfig, | ||
| ) | ||
| if is_stateful: | ||
| self.inference_only = False | ||
| # If this is an `inference_only` Module, we'll have to pass this information | ||
| # to the encoder config as well. | ||
| if self.inference_only and self.framework == "torch": | ||
| self.catalog.encoder_config.inference_only = True | ||
| # Build models from catalog. | ||
| self.encoder = self.catalog.build_encoder(framework=self.framework) | ||
| self.pi = self.catalog.build_pi_head(framework=self.framework) | ||
| # __sphinx_doc_end__ | ||
|
|
||
| @override(RLModule) | ||
| def get_initial_state(self) -> dict: | ||
| if hasattr(self.encoder, "get_initial_state"): | ||
| return self.encoder.get_initial_state() | ||
| else: | ||
| return {} | ||
|
|
||
| @OverrideToImplementCustomLogic_CallToSuperRecommended | ||
| @override(InferenceOnlyAPI) | ||
| def get_non_inference_attributes(self) -> List[str]: | ||
| """Return attributes, which are NOT inference-only (only used for training).""" | ||
| return [] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.