From c06fb6c7ed7f0e91a8855d88387f37c408b6adf2 Mon Sep 17 00:00:00 2001 From: Aidandos Date: Wed, 23 Nov 2022 16:26:47 +0000 Subject: [PATCH 1/6] wip: there's a bug --- .pre-commit-config.yaml | 2 +- pax/agents/ppo/networks.py | 54 +- pax/agents/ppo/ppo_lstm.py | 579 ++++++++++++++++++ pax/conf/experiment/cg/earl_fixed.yaml | 1 + pax/conf/experiment/cg/earl_ppo_memory.yaml | 1 + pax/conf/experiment/cg/earl_v_ppo_memory.yaml | 1 + pax/conf/experiment/cg/earl_v_tabular.yaml | 1 + pax/conf/experiment/cg/greedy.yaml | 1 + pax/conf/experiment/cg/gs_v_ppo_memory.yaml | 1 + pax/conf/experiment/cg/gs_v_tabular.yaml | 1 + pax/conf/experiment/cg/mfos.yaml | 1 + pax/conf/experiment/cg/mfos_v_ppo_memory.yaml | 1 + pax/conf/experiment/cg/mfos_v_tabular.yaml | 1 + pax/conf/experiment/cg/pre_train.yaml | 1 + pax/conf/experiment/cg/sanity.yaml | 1 + pax/conf/experiment/cg/tabular.yaml | 1 + pax/conf/experiment/ipd/earl_fixed.yaml | 3 +- pax/conf/experiment/ipd/earl_infinite.yaml | 2 +- pax/conf/experiment/ipd/earl_nl_cma.yaml | 3 +- pax/conf/experiment/ipd/earl_nl_open.yaml | 3 +- pax/conf/experiment/ipd/earl_nl_pgpe.yaml | 3 +- pax/conf/experiment/ipd/earl_v_ppo.yaml | 2 +- pax/conf/experiment/ipd/earl_v_ppo_mem.yaml | 3 +- pax/conf/experiment/ipd/earl_v_tabular.yaml | 2 +- pax/conf/experiment/ipd/gs_v_ppo.yaml | 2 +- pax/conf/experiment/ipd/gs_v_ppo_mem.yaml | 2 +- pax/conf/experiment/ipd/gs_v_tabular.yaml | 3 +- pax/conf/experiment/ipd/inf_mfos_v_nl.yaml | 2 +- pax/conf/experiment/ipd/inf_mfos_v_tft.yaml | 3 +- pax/conf/experiment/ipd/marl2_v_nl.yaml | 3 +- pax/conf/experiment/ipd/marl2_v_tft.yaml | 2 +- pax/conf/experiment/ipd/mfos_v_ppo.yaml | 2 +- pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml | 2 +- pax/conf/experiment/ipd/mfos_v_tabular.yaml | 2 +- pax/conf/experiment/ipd/ppo_mem_v_tft.yaml | 3 +- pax/conf/experiment/ipd/ppo_v_tft.yaml | 1 + pax/conf/experiment/ipd/tabular_v_all_c.yaml | 3 +- pax/conf/experiment/ipd/tabular_v_all_d.yaml | 3 +- .../experiment/ipd/tabular_v_tabular.yaml | 3 +- pax/conf/experiment/ipd/tabular_v_tft.yaml | 3 +- pax/conf/experiment/mp/earl_v_ppo.yaml | 1 + pax/conf/experiment/mp/earl_v_ppo_mem.yaml | 1 + pax/conf/experiment/mp/earl_v_tabular.yaml | 1 + pax/conf/experiment/mp/gs_v_ppo.yaml | 1 + pax/conf/experiment/mp/gs_v_ppo_mem.yaml | 1 + pax/conf/experiment/mp/gs_v_tabular.yaml | 1 + pax/conf/experiment/mp/mfos_v_ppo.yaml | 1 + pax/conf/experiment/mp/mfos_v_ppo_mem.yaml | 1 + pax/conf/experiment/mp/mfos_v_tabular.yaml | 1 + pax/conf/experiment/mp/ppo_v_all_heads.yaml | 1 + pax/conf/experiment/mp/ppo_v_all_tails.yaml | 1 + pax/conf/experiment/mp/ppo_v_ppo.yaml | 1 + pax/conf/experiment/mp/ppo_v_tft.yaml | 1 + .../experiment/mp/tabular_v_all_heads.yaml | 1 + .../experiment/mp/tabular_v_all_tails.yaml | 1 + pax/conf/experiment/mp/tabular_v_tabular.yaml | 1 + pax/conf/experiment/mp/tabular_v_tft.yaml | 1 + pax/conf/experiment/sarl/acrobot.yaml | 1 + pax/conf/experiment/sarl/cartpole.yaml | 1 + pax/conf/experiment/sarl/pendulum.yaml | 69 +++ pax/experiment.py | 24 +- pax/runner_marl.py | 28 +- 62 files changed, 797 insertions(+), 51 deletions(-) create mode 100644 pax/agents/ppo/ppo_lstm.py create mode 100644 pax/conf/experiment/sarl/pendulum.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da853901..704ab3c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: black language_version: python3.9 -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/pycqa/flake8 rev: '3.9.2' hooks: - id: flake8 diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index 19f5bf7c..d6297a71 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, NamedTuple import distrax import haiku as hk @@ -44,24 +44,24 @@ def __init__( ): super().__init__(name=name) self._action_body = hk.nets.MLP( - [64, 64], - w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), - b_init=hk.initializers.Constant(0), - activate_final=True, - activation=jnp.tanh, - ) + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) self._logit_layer = hk.Linear( num_values, w_init=hk.initializers.Orthogonal(1.0), b_init=hk.initializers.Constant(0), ) self._value_body = hk.nets.MLP( - [64, 64], - w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), - b_init=hk.initializers.Constant(0), - activate_final=True, - activation=jnp.tanh, - ) + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) self._value_layer = hk.Linear( 1, w_init=hk.initializers.Orthogonal(0.01), @@ -316,16 +316,13 @@ def forward_fn(inputs): network = hk.without_apply_rng(hk.transform(forward_fn)) return network + def make_sarl_network(num_actions: int): """Creates a hk network using the baseline hyperparameters from OpenAI""" def forward_fn(inputs): layers = [] - layers.extend( - [ - CategoricalValueHeadSeparate(num_values=num_actions) - ] - ) + layers.extend([CategoricalValueHeadSeparate(num_values=num_actions)]) policy_value_network = hk.Sequential(layers) return policy_value_network(inputs) @@ -350,6 +347,27 @@ def forward_fn( return network, hidden_state +def make_LSTM_ipd_network(num_actions: int, args): + hidden = jnp.zeros((1, args.ppo.hidden_size)) + cell = jnp.zeros((1, args.ppo.hidden_size)) + hidden_state = hk.LSTMState(hidden=hidden, cell=cell) + # hidden_state = hidden_state_lstm.initial_state() + # hidden_state = jnp.zeros((1, args.ppo.hidden_size)) + + def forward_fn( + inputs: jnp.ndarray, state: NamedTuple + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], NamedTuple]: + """forward function""" + lstm = hk.LSTM(args.ppo.hidden_size) + embedding, state = lstm(inputs, state) + logits, values = CategoricalValueHead(num_actions)(embedding) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + + return network, hidden_state + + def make_GRU_cartpole_network(num_actions: int): hidden_size = 256 hidden_state = jnp.zeros((1, hidden_size)) diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py new file mode 100644 index 00000000..b2d6244e --- /dev/null +++ b/pax/agents/ppo/ppo_lstm.py @@ -0,0 +1,579 @@ +# Adapted from https://github.com/deepmind/acme/blob/master/acme/agents/jax/ppo/learning.py + +from typing import Any, Dict, NamedTuple, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp +import optax + +from pax import utils +from pax.agents.agent import AgentInterface +from pax.agents.ppo.networks import ( + make_GRU_cartpole_network, + make_GRU_coingame_network, + make_GRU_ipd_network, + make_LSTM_ipd_network, +) +from pax.utils import MemoryState, TrainingState, get_advantages + +# from dm_env import TimeStep + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + + observations: jnp.ndarray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray + + # GRU specific + hiddens: jnp.ndarray + + +class Logger: + metrics: dict + + +class PPO(AgentInterface): + """A simple PPO agent with memory using JAX""" + + def __init__( + self, + network: NamedTuple, + initial_hidden_state: jnp.ndarray, + optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + gru_dim: int, + obs_spec: Tuple, + batch_size: int = 2000, + num_envs: int = 4, + num_steps: int = 500, + num_minibatches: int = 16, + num_epochs: int = 4, + clip_value: bool = True, + value_coeff: float = 0.5, + anneal_entropy: bool = False, + entropy_coeff_start: float = 0.1, + entropy_coeff_end: float = 0.01, + entropy_coeff_horizon: int = 3_000_000, + ppo_clipping_epsilon: float = 0.2, + gamma: float = 0.99, + gae_lambda: float = 0.95, + player_id: int = 0, + ): + @jax.jit + def policy( + state: TrainingState, observation: jnp.ndarray, mem: MemoryState + ): + """Agent policy to select actions and calculate agent specific information""" + key, subkey = jax.random.split(state.random_key) + (dist, values), hidden_state = network.apply( + state.params, observation, mem.hidden + ) + + actions = dist.sample(seed=subkey) + mem.extras["values"] = values + mem.extras["log_probs"] = dist.log_prob(actions) + mem = mem._replace(hidden=hidden_state, extras=mem.extras) + state = state._replace(random_key=key) + return ( + actions, + state, + mem, + ) + + @jax.jit + def gae_advantages( + rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the gae advantages from a sequence. Note that the + arguments are of length = rollout length + 1""" + # 'Zero out' the terminated states + discounts = gamma * jnp.logical_not(dones) + reverse_batch = ( + jnp.flip(values[:-1], axis=0), + jnp.flip(rewards, axis=0), + jnp.flip(discounts, axis=0), + ) + + _, advantages = jax.lax.scan( + get_advantages, + ( + jnp.zeros_like(values[-1]), + values[-1], + jnp.ones_like(values[-1]) * gae_lambda, + ), + reverse_batch, + ) + + advantages = jnp.flip(advantages, axis=0) + target_values = values[:-1] + advantages # Q-value estimates + target_values = jax.lax.stop_gradient(target_values) + return advantages, target_values + + def loss( + params: hk.Params, + timesteps: int, + observations: jnp.ndarray, + actions: jnp.array, + behavior_log_probs: jnp.array, + target_values: jnp.array, + advantages: jnp.array, + behavior_values: jnp.array, + hiddens: jnp.ndarray, + ): + """Surrogate loss using clipped probability ratios.""" + (distribution, values), _ = network.apply( + params, observations, hiddens + ) + + log_prob = distribution.log_prob(actions) + entropy = distribution.entropy() + + # Compute importance sampling weights: current policy / behavior policy. + rhos = jnp.exp(log_prob - behavior_log_probs) + + # Policy loss: Clipping + clipped_ratios_t = jnp.clip( + rhos, 1.0 - ppo_clipping_epsilon, 1.0 + ppo_clipping_epsilon + ) + clipped_objective = jnp.fmin( + rhos * advantages, clipped_ratios_t * advantages + ) + policy_loss = -jnp.mean(clipped_objective) + + # Value loss: MSE + value_cost = value_coeff + unclipped_value_error = target_values - values + unclipped_value_loss = unclipped_value_error**2 + + # Value clipping + if clip_value: + # Clip values to reduce variablility during critic training. + clipped_values = behavior_values + jnp.clip( + values - behavior_values, + -ppo_clipping_epsilon, + ppo_clipping_epsilon, + ) + clipped_value_error = target_values - clipped_values + clipped_value_loss = clipped_value_error**2 + value_loss = jnp.mean( + jnp.fmax(unclipped_value_loss, clipped_value_loss) + ) + else: + value_loss = jnp.mean(unclipped_value_loss) + + # Entropy loss: Standard entropy term + # Calculate the new value based on linear annealing formula + if anneal_entropy: + fraction = jnp.fmax(1 - timesteps / entropy_coeff_horizon, 0) + entropy_cost = ( + fraction * entropy_coeff_start + + (1 - fraction) * entropy_coeff_end + ) + # Constant Entropy term + else: + entropy_cost = entropy_coeff_start + entropy_loss = -jnp.mean(entropy) + + # Total loss: Minimize policy and value loss; maximize entropy + total_loss = ( + policy_loss + + entropy_cost * entropy_loss + + value_loss * value_cost + ) + + return total_loss, { + "loss_total": total_loss, + "loss_policy": policy_loss, + "loss_value": value_loss, + "loss_entropy": entropy_loss, + "entropy_cost": entropy_cost, + } + + @jax.jit + def sgd_step( + state: TrainingState, sample: NamedTuple + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Performs a minibatch SGD step, returning new state and metrics.""" + # Extract data + ( + observations, + actions, + rewards, + behavior_log_probs, + behavior_values, + dones, + hiddens, + ) = ( + sample.observations, + sample.actions, + sample.rewards, + sample.behavior_log_probs, + sample.behavior_values, + sample.dones, + sample.hiddens, + ) + + # batch_gae_advantages = jax.vmap(gae_advantages, 1, (0, 0)) + advantages, target_values = gae_advantages( + rewards=rewards, values=behavior_values, dones=dones + ) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_steps, num_envs, ..] + behavior_values = behavior_values[:-1, :] + trajectories = Batch( + observations=observations, + actions=actions, + advantages=advantages, + behavior_log_probs=behavior_log_probs, + target_values=target_values, + behavior_values=behavior_values, + hiddens=hiddens, + ) + # Concatenate all trajectories. Reshape from [num_envs, num_steps, ..] + # to [num_envs * num_steps,..] + assert len(target_values.shape) > 1 + num_envs = target_values.shape[1] + num_steps = target_values.shape[0] + batch_size = num_envs * num_steps + assert batch_size % num_minibatches == 0, ( + "Num minibatches must divide batch size. Got batch_size={}" + " num_minibatches={}." + ).format(batch_size, num_minibatches) + + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), trajectories + ) + # Compute gradients. + grad_fn = jax.jit(jax.grad(loss, has_aux=True)) + + def model_update_minibatch( + carry: Tuple[hk.Params, optax.OptState, int], + minibatch: Batch, + ) -> Tuple[ + Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] + ]: + """Performs model update for a single minibatch.""" + params, opt_state, timesteps = carry + # Normalize advantages at the minibatch level before using them. + advantages = ( + minibatch.advantages + - jnp.mean(minibatch.advantages, axis=0) + ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) + gradients, metrics = grad_fn( + params, + timesteps, + minibatch.observations, + minibatch.actions, + minibatch.behavior_log_probs, + minibatch.target_values, + advantages, + minibatch.behavior_values, + minibatch.hiddens, + ) + + # Apply updates + updates, opt_state = optimizer.update(gradients, opt_state) + params = optax.apply_updates(params, updates) + + metrics["norm_grad"] = optax.global_norm(gradients) + metrics["norm_updates"] = optax.global_norm(updates) + return (params, opt_state, timesteps), metrics + + def model_update_epoch( + carry: Tuple[ + jnp.ndarray, hk.Params, optax.OptState, int, Batch + ], + unused_t: Tuple[()], + ) -> Tuple[ + Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], + Dict[str, jnp.ndarray], + ]: + """Performs model updates based on one epoch of data.""" + key, params, opt_state, timesteps, batch = carry + key, subkey = jax.random.split(key) + permutation = jax.random.permutation(subkey, batch_size) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [num_minibatches, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + + (params, opt_state, timesteps), metrics = jax.lax.scan( + model_update_minibatch, + (params, opt_state, timesteps), + minibatches, + length=num_minibatches, + ) + return (key, params, opt_state, timesteps, batch), metrics + + params = state.params + opt_state = state.opt_state + timesteps = state.timesteps + + # Repeat training for the given number of epoch, taking a random + # permutation for every epoch. + # signature is scan(function, carry, tuple to iterate over, length) + (key, params, opt_state, timesteps, _), metrics = jax.lax.scan( + model_update_epoch, + (state.random_key, params, opt_state, timesteps, batch), + (), + length=num_epochs, + ) + + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + metrics["rewards_mean"] = jnp.mean( + jnp.abs(jnp.mean(rewards, axis=(0, 1))) + ) + metrics["rewards_std"] = jnp.std(rewards, axis=(0, 1)) + + # Reset the memory + new_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=key, + timesteps=timesteps + batch_size, + ) + + new_memory = MemoryState( + hidden=jnp.zeros(shape=(self._num_envs,) + (gru_dim,)), + extras={ + "log_probs": jnp.zeros(self._num_envs), + "values": jnp.zeros(self._num_envs), + }, + ) + + return new_state, new_memory, metrics + + def make_initial_state( + key: Any, initial_hidden_state: NamedTuple + ) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + + # We pass through initial_hidden_state so its easy to batch memory + key, subkey = jax.random.split(key) + dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = utils.add_batch_dim(dummy_obs) + # import pdb; pdb.set_trace() + dummy_obs = dummy_obs.repeat( + initial_hidden_state[0].shape[-2], axis=0 + ) + # if len(initial_hidden_state[0].shape) > 2: + # dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init( + subkey, dummy_obs, initial_hidden_state + ) + initial_opt_state = optimizer.init(initial_params) + return TrainingState( + random_key=key, + params=initial_params, + opt_state=initial_opt_state, + timesteps=0, + ), MemoryState( + hidden=initial_hidden_state, + # initial_hidden_state, + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + ) + + # @jax.jit + def prepare_batch( + traj_batch: NamedTuple, + done: Any, + action_extras: dict, + ): + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + _value = jax.lax.select( + done, + jnp.zeros_like(action_extras["values"]), + action_extras["values"], + ) + + _value = jax.lax.expand_dims(_value, [0]) + + # need to add final value here + traj_batch = traj_batch._replace( + behavior_values=jnp.concatenate( + [traj_batch.behavior_values, _value], axis=0 + ) + ) + return traj_batch + + # Initialise training state (parameters, optimiser state, extras). + self._state, self._mem = make_initial_state( + random_key, initial_hidden_state + ) + + self.make_initial_state = make_initial_state + + self._prepare_batch = prepare_batch + self._sgd_step = jax.jit(sgd_step) + + # Set up counters and logger + self._logger = Logger() + self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + "loss_entropy": 0, + "entropy_cost": entropy_coeff_start, + } + + # Initialize functions + self._policy = policy + self.forward = network.apply + self.player_id = player_id + + # Other useful hyperparameters + self._num_envs = num_envs # number of environments + self._num_steps = num_steps # number of steps per environment + self._batch_size = int(num_envs * num_steps) # number in one batch + self._num_minibatches = num_minibatches # number of minibatches + self._num_epochs = num_epochs # number of epochs to use sample + self._gru_dim = gru_dim + + def reset_memory(self, memory, eval=False) -> TrainingState: + num_envs = 1 if eval else self._num_envs + memory = memory._replace( + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + hidden=hk.LSTMState( + hidden=jnp.zeros((num_envs, self._gru_dim)), + cell=jnp.zeros((num_envs, self._gru_dim)), + ), + ) + return memory + + def update( + self, + traj_batch: NamedTuple, + obs: jnp.ndarray, + state: TrainingState, + mem: MemoryState, + ): + + """Update the agent -> only called at the end of a trajectory""" + + _, _, mem = self._policy(state, obs, mem) + traj_batch = self._prepare_batch( + traj_batch, traj_batch.dones[-1, ...], mem.extras + ) + state, mem, metrics = self._sgd_step(state, traj_batch) + + # update logging + + self._logger.metrics["sgd_steps"] += ( + self._num_minibatches * self._num_epochs + ) + self._logger.metrics["loss_total"] = metrics["loss_total"] + self._logger.metrics["loss_policy"] = metrics["loss_policy"] + self._logger.metrics["loss_value"] = metrics["loss_value"] + self._logger.metrics["loss_entropy"] = metrics["loss_entropy"] + self._logger.metrics["entropy_cost"] = metrics["entropy_cost"] + return state, mem, metrics + + +# TODO: seed, and player_id not used in CartPole +def make_lstm_agent(args, obs_spec, action_spec, seed: int, player_id: int): + """Make PPO agent""" + # Network + if args.env_id == "CartPole-v1": + network, initial_hidden_state = make_GRU_cartpole_network(action_spec) + elif args.env_id == "coin_game": + if args.ppo.with_cnn: + print(f"Making network for {args.env_id} with CNN") + else: + print(f"Making network for {args.env_id} without CNN") + network, initial_hidden_state = make_GRU_coingame_network( + action_spec, args + ) + else: + network, _ = make_LSTM_ipd_network(action_spec, args) + + hidden = jnp.zeros((args.num_envs, args.ppo.hidden_size)) + cell = jnp.zeros((args.num_envs, args.ppo.hidden_size)) + initial_hidden_state = hk.LSTMState(hidden=hidden, cell=cell) + + # Optimizer + batch_size = int(args.num_envs * args.num_steps * args.num_opps) + transition_steps = ( + args.total_timesteps + / batch_size + * args.ppo.num_epochs + * args.ppo.num_minibatches + ) + + if args.ppo.lr_scheduling: + scheduler = optax.linear_schedule( + init_value=args.ppo.learning_rate, + end_value=0, + transition_steps=transition_steps, + ) + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale_by_schedule(scheduler), + optax.scale(-1), + ) + + else: + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale(-args.ppo.learning_rate), + ) + + # Random key + random_key = jax.random.PRNGKey(seed=seed) + + agent = PPO( + network=network, + initial_hidden_state=initial_hidden_state, + optimizer=optimizer, + random_key=random_key, + gru_dim=args.ppo.hidden_size, + obs_spec=obs_spec, + batch_size=args.num_envs * args.num_opps, + num_envs=args.num_envs, + num_steps=args.num_steps, + num_minibatches=args.ppo.num_minibatches, + num_epochs=args.ppo.num_epochs, + clip_value=args.ppo.clip_value, + value_coeff=args.ppo.value_coeff, + anneal_entropy=args.ppo.anneal_entropy, + entropy_coeff_start=args.ppo.entropy_coeff_start, + entropy_coeff_end=args.ppo.entropy_coeff_end, + entropy_coeff_horizon=args.ppo.entropy_coeff_horizon, + ppo_clipping_epsilon=args.ppo.ppo_clipping_epsilon, + gamma=args.ppo.gamma, + gae_lambda=args.ppo.gae_lambda, + player_id=player_id, + ) + return agent + + +if __name__ == "__main__": + pass diff --git a/pax/conf/experiment/cg/earl_fixed.yaml b/pax/conf/experiment/cg/earl_fixed.yaml index 75442f1a..495869b2 100644 --- a/pax/conf/experiment/cg/earl_fixed.yaml +++ b/pax/conf/experiment/cg/earl_fixed.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/earl_ppo_memory.yaml b/pax/conf/experiment/cg/earl_ppo_memory.yaml index 82aac4aa..d4bd3204 100644 --- a/pax/conf/experiment/cg/earl_ppo_memory.yaml +++ b/pax/conf/experiment/cg/earl_ppo_memory.yaml @@ -50,6 +50,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/earl_v_ppo_memory.yaml b/pax/conf/experiment/cg/earl_v_ppo_memory.yaml index 80f6e18d..a55fda39 100644 --- a/pax/conf/experiment/cg/earl_v_ppo_memory.yaml +++ b/pax/conf/experiment/cg/earl_v_ppo_memory.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/earl_v_tabular.yaml b/pax/conf/experiment/cg/earl_v_tabular.yaml index dffd4ece..825fed91 100644 --- a/pax/conf/experiment/cg/earl_v_tabular.yaml +++ b/pax/conf/experiment/cg/earl_v_tabular.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/greedy.yaml b/pax/conf/experiment/cg/greedy.yaml index 97fdd8b6..c5cd785c 100644 --- a/pax/conf/experiment/cg/greedy.yaml +++ b/pax/conf/experiment/cg/greedy.yaml @@ -50,6 +50,7 @@ ppo: separate: True # only works with CNN with_cnn: False hidden_size: 16 #50 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/cg/gs_v_ppo_memory.yaml b/pax/conf/experiment/cg/gs_v_ppo_memory.yaml index bb20f83a..42aabea4 100644 --- a/pax/conf/experiment/cg/gs_v_ppo_memory.yaml +++ b/pax/conf/experiment/cg/gs_v_ppo_memory.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/gs_v_tabular.yaml b/pax/conf/experiment/cg/gs_v_tabular.yaml index 1e6decf2..100ce30d 100644 --- a/pax/conf/experiment/cg/gs_v_tabular.yaml +++ b/pax/conf/experiment/cg/gs_v_tabular.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/mfos.yaml b/pax/conf/experiment/cg/mfos.yaml index 3c89cfb3..8c08e4e6 100644 --- a/pax/conf/experiment/cg/mfos.yaml +++ b/pax/conf/experiment/cg/mfos.yaml @@ -49,6 +49,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 #50 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/cg/mfos_v_ppo_memory.yaml b/pax/conf/experiment/cg/mfos_v_ppo_memory.yaml index 0868e389..3b2206d9 100644 --- a/pax/conf/experiment/cg/mfos_v_ppo_memory.yaml +++ b/pax/conf/experiment/cg/mfos_v_ppo_memory.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/mfos_v_tabular.yaml b/pax/conf/experiment/cg/mfos_v_tabular.yaml index a8e1ba34..2010c714 100644 --- a/pax/conf/experiment/cg/mfos_v_tabular.yaml +++ b/pax/conf/experiment/cg/mfos_v_tabular.yaml @@ -57,6 +57,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 + rnn_type: gru # ES parameters es: diff --git a/pax/conf/experiment/cg/pre_train.yaml b/pax/conf/experiment/cg/pre_train.yaml index 68a895a7..5caa3314 100644 --- a/pax/conf/experiment/cg/pre_train.yaml +++ b/pax/conf/experiment/cg/pre_train.yaml @@ -89,6 +89,7 @@ ppo: separate: True # only works with CNN with_cnn: False hidden_size: 16 #50 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/cg/sanity.yaml b/pax/conf/experiment/cg/sanity.yaml index 9febcde1..c28f8d19 100644 --- a/pax/conf/experiment/cg/sanity.yaml +++ b/pax/conf/experiment/cg/sanity.yaml @@ -49,6 +49,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 #50 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/cg/tabular.yaml b/pax/conf/experiment/cg/tabular.yaml index 65c0b65b..46de1a66 100644 --- a/pax/conf/experiment/cg/tabular.yaml +++ b/pax/conf/experiment/cg/tabular.yaml @@ -49,6 +49,7 @@ ppo: kernel_shape: [3, 3] separate: True # only works with CNN hidden_size: 16 #50 + rnn_type: gru naive: num_minibatches: 8 diff --git a/pax/conf/experiment/ipd/earl_fixed.yaml b/pax/conf/experiment/ipd/earl_fixed.yaml index 2350c275..78709d02 100644 --- a/pax/conf/experiment/ipd/earl_fixed.yaml +++ b/pax/conf/experiment/ipd/earl_fixed.yaml @@ -49,7 +49,8 @@ ppo: with_cnn: False separate: False hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_infinite.yaml b/pax/conf/experiment/ipd/earl_infinite.yaml index 124d6a44..22a5c466 100644 --- a/pax/conf/experiment/ipd/earl_infinite.yaml +++ b/pax/conf/experiment/ipd/earl_infinite.yaml @@ -50,7 +50,7 @@ ppo: with_cnn: False separate: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_nl_cma.yaml b/pax/conf/experiment/ipd/earl_nl_cma.yaml index c802da36..c29f1cc3 100644 --- a/pax/conf/experiment/ipd/earl_nl_cma.yaml +++ b/pax/conf/experiment/ipd/earl_nl_cma.yaml @@ -43,7 +43,8 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_nl_open.yaml b/pax/conf/experiment/ipd/earl_nl_open.yaml index 1bea503e..80be6203 100644 --- a/pax/conf/experiment/ipd/earl_nl_open.yaml +++ b/pax/conf/experiment/ipd/earl_nl_open.yaml @@ -61,7 +61,8 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_nl_pgpe.yaml b/pax/conf/experiment/ipd/earl_nl_pgpe.yaml index 0eb025c8..03214dae 100644 --- a/pax/conf/experiment/ipd/earl_nl_pgpe.yaml +++ b/pax/conf/experiment/ipd/earl_nl_pgpe.yaml @@ -41,7 +41,8 @@ ppo: adam_epsilon: 1e-5 with_memory: True hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_v_ppo.yaml b/pax/conf/experiment/ipd/earl_v_ppo.yaml index 8c0fe4d1..492fd389 100644 --- a/pax/conf/experiment/ipd/earl_v_ppo.yaml +++ b/pax/conf/experiment/ipd/earl_v_ppo.yaml @@ -60,7 +60,7 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_v_ppo_mem.yaml b/pax/conf/experiment/ipd/earl_v_ppo_mem.yaml index 1ad59579..e9e90a58 100644 --- a/pax/conf/experiment/ipd/earl_v_ppo_mem.yaml +++ b/pax/conf/experiment/ipd/earl_v_ppo_mem.yaml @@ -44,7 +44,8 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/earl_v_tabular.yaml b/pax/conf/experiment/ipd/earl_v_tabular.yaml index 1dda5b3e..1f907bba 100644 --- a/pax/conf/experiment/ipd/earl_v_tabular.yaml +++ b/pax/conf/experiment/ipd/earl_v_tabular.yaml @@ -62,7 +62,7 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/gs_v_ppo.yaml b/pax/conf/experiment/ipd/gs_v_ppo.yaml index a3af42bb..b23402ba 100644 --- a/pax/conf/experiment/ipd/gs_v_ppo.yaml +++ b/pax/conf/experiment/ipd/gs_v_ppo.yaml @@ -46,7 +46,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/gs_v_ppo_mem.yaml b/pax/conf/experiment/ipd/gs_v_ppo_mem.yaml index 5b7b5eac..0bad4acb 100644 --- a/pax/conf/experiment/ipd/gs_v_ppo_mem.yaml +++ b/pax/conf/experiment/ipd/gs_v_ppo_mem.yaml @@ -46,7 +46,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/gs_v_tabular.yaml b/pax/conf/experiment/ipd/gs_v_tabular.yaml index ed1e2b50..3128df50 100644 --- a/pax/conf/experiment/ipd/gs_v_tabular.yaml +++ b/pax/conf/experiment/ipd/gs_v_tabular.yaml @@ -83,7 +83,8 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml b/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml index 90a14f27..02804e0e 100644 --- a/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml +++ b/pax/conf/experiment/ipd/inf_mfos_v_nl.yaml @@ -41,7 +41,7 @@ ppo: adam_epsilon: 1e-5 with_memory: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: lr: 1.0 diff --git a/pax/conf/experiment/ipd/inf_mfos_v_tft.yaml b/pax/conf/experiment/ipd/inf_mfos_v_tft.yaml index 5b18fea0..f530b4d3 100644 --- a/pax/conf/experiment/ipd/inf_mfos_v_tft.yaml +++ b/pax/conf/experiment/ipd/inf_mfos_v_tft.yaml @@ -40,7 +40,8 @@ ppo: learning_rate: 2.5e-3 adam_epsilon: 1e-5 with_memory: False - + rnn_type: gru + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/ipd/marl2_v_nl.yaml b/pax/conf/experiment/ipd/marl2_v_nl.yaml index 795f755b..74591827 100644 --- a/pax/conf/experiment/ipd/marl2_v_nl.yaml +++ b/pax/conf/experiment/ipd/marl2_v_nl.yaml @@ -40,7 +40,8 @@ ppo: learning_rate: 3e-4 adam_epsilon: 1e-5 with_memory: True - + rnn_type: gru + # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/marl2_v_tft.yaml b/pax/conf/experiment/ipd/marl2_v_tft.yaml index bc6754bd..abc95e2d 100644 --- a/pax/conf/experiment/ipd/marl2_v_tft.yaml +++ b/pax/conf/experiment/ipd/marl2_v_tft.yaml @@ -40,7 +40,7 @@ ppo: learning_rate: 3e-4 adam_epsilon: 1e-5 with_memory: True - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/mfos_v_ppo.yaml b/pax/conf/experiment/ipd/mfos_v_ppo.yaml index a38edae7..635e07bf 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo.yaml @@ -44,7 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml index 37174388..09eae4ea 100644 --- a/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml +++ b/pax/conf/experiment/ipd/mfos_v_ppo_mem.yaml @@ -44,7 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/mfos_v_tabular.yaml b/pax/conf/experiment/ipd/mfos_v_tabular.yaml index 9e212373..8ef74d55 100644 --- a/pax/conf/experiment/ipd/mfos_v_tabular.yaml +++ b/pax/conf/experiment/ipd/mfos_v_tabular.yaml @@ -60,7 +60,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru # Naive Learner parameters naive: num_minibatches: 1 diff --git a/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml b/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml index b953beb9..67dc9df1 100644 --- a/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml +++ b/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml @@ -2,7 +2,7 @@ # Agents agent1: 'PPO_memory' -agent2: 'TitForTat' +agent2: 'PPO_memory' # Environment env_id: iterated_matrix_game @@ -46,6 +46,7 @@ ppo: with_memory: True with_cnn: False hidden_size: 4 + rnn_type: lstm # Logging setup wandb: diff --git a/pax/conf/experiment/ipd/ppo_v_tft.yaml b/pax/conf/experiment/ipd/ppo_v_tft.yaml index c53b1c16..8a4da0f7 100644 --- a/pax/conf/experiment/ipd/ppo_v_tft.yaml +++ b/pax/conf/experiment/ipd/ppo_v_tft.yaml @@ -43,6 +43,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/ipd/tabular_v_all_c.yaml b/pax/conf/experiment/ipd/tabular_v_all_c.yaml index 648b8876..620fae44 100644 --- a/pax/conf/experiment/ipd/tabular_v_all_c.yaml +++ b/pax/conf/experiment/ipd/tabular_v_all_c.yaml @@ -42,7 +42,8 @@ ppo: adam_epsilon: 1e-5 with_memory: False with_cnn: False - + rnn_type: gru + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/ipd/tabular_v_all_d.yaml b/pax/conf/experiment/ipd/tabular_v_all_d.yaml index 3c1e66c5..4ac5572d 100644 --- a/pax/conf/experiment/ipd/tabular_v_all_d.yaml +++ b/pax/conf/experiment/ipd/tabular_v_all_d.yaml @@ -43,7 +43,8 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/ipd/tabular_v_tabular.yaml b/pax/conf/experiment/ipd/tabular_v_tabular.yaml index da4ad825..8020a9ca 100644 --- a/pax/conf/experiment/ipd/tabular_v_tabular.yaml +++ b/pax/conf/experiment/ipd/tabular_v_tabular.yaml @@ -44,7 +44,8 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/ipd/tabular_v_tft.yaml b/pax/conf/experiment/ipd/tabular_v_tft.yaml index a4e5afeb..f00eec62 100644 --- a/pax/conf/experiment/ipd/tabular_v_tft.yaml +++ b/pax/conf/experiment/ipd/tabular_v_tft.yaml @@ -43,7 +43,8 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 - + rnn_type: gru + # Logging setup wandb: entity: "ucl-dark" diff --git a/pax/conf/experiment/mp/earl_v_ppo.yaml b/pax/conf/experiment/mp/earl_v_ppo.yaml index 11a35f44..ec65c7cb 100644 --- a/pax/conf/experiment/mp/earl_v_ppo.yaml +++ b/pax/conf/experiment/mp/earl_v_ppo.yaml @@ -61,6 +61,7 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/earl_v_ppo_mem.yaml b/pax/conf/experiment/mp/earl_v_ppo_mem.yaml index ef81fc24..fe107fc3 100644 --- a/pax/conf/experiment/mp/earl_v_ppo_mem.yaml +++ b/pax/conf/experiment/mp/earl_v_ppo_mem.yaml @@ -61,6 +61,7 @@ ppo: with_memory: True with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/earl_v_tabular.yaml b/pax/conf/experiment/mp/earl_v_tabular.yaml index 8cdda137..a11f5a66 100644 --- a/pax/conf/experiment/mp/earl_v_tabular.yaml +++ b/pax/conf/experiment/mp/earl_v_tabular.yaml @@ -61,6 +61,7 @@ ppo: adam_epsilon: 1e-5 with_memory: True with_cnn: False + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/gs_v_ppo.yaml b/pax/conf/experiment/mp/gs_v_ppo.yaml index c94b6772..ccd0a920 100644 --- a/pax/conf/experiment/mp/gs_v_ppo.yaml +++ b/pax/conf/experiment/mp/gs_v_ppo.yaml @@ -62,6 +62,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/gs_v_ppo_mem.yaml b/pax/conf/experiment/mp/gs_v_ppo_mem.yaml index 388bcd16..8d4d7557 100644 --- a/pax/conf/experiment/mp/gs_v_ppo_mem.yaml +++ b/pax/conf/experiment/mp/gs_v_ppo_mem.yaml @@ -60,6 +60,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/gs_v_tabular.yaml b/pax/conf/experiment/mp/gs_v_tabular.yaml index c33b6a0e..c1b2d692 100644 --- a/pax/conf/experiment/mp/gs_v_tabular.yaml +++ b/pax/conf/experiment/mp/gs_v_tabular.yaml @@ -60,6 +60,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/mfos_v_ppo.yaml b/pax/conf/experiment/mp/mfos_v_ppo.yaml index ffbd26eb..42d091a0 100644 --- a/pax/conf/experiment/mp/mfos_v_ppo.yaml +++ b/pax/conf/experiment/mp/mfos_v_ppo.yaml @@ -62,6 +62,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/mfos_v_ppo_mem.yaml b/pax/conf/experiment/mp/mfos_v_ppo_mem.yaml index 322871cb..349e7578 100644 --- a/pax/conf/experiment/mp/mfos_v_ppo_mem.yaml +++ b/pax/conf/experiment/mp/mfos_v_ppo_mem.yaml @@ -62,6 +62,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/mfos_v_tabular.yaml b/pax/conf/experiment/mp/mfos_v_tabular.yaml index 11d0fb8e..8475ef59 100644 --- a/pax/conf/experiment/mp/mfos_v_tabular.yaml +++ b/pax/conf/experiment/mp/mfos_v_tabular.yaml @@ -61,6 +61,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Naive Learner parameters naive: diff --git a/pax/conf/experiment/mp/ppo_v_all_heads.yaml b/pax/conf/experiment/mp/ppo_v_all_heads.yaml index fde623c7..0af15d0a 100644 --- a/pax/conf/experiment/mp/ppo_v_all_heads.yaml +++ b/pax/conf/experiment/mp/ppo_v_all_heads.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/ppo_v_all_tails.yaml b/pax/conf/experiment/mp/ppo_v_all_tails.yaml index 1db3d8a2..a1e1c6d1 100644 --- a/pax/conf/experiment/mp/ppo_v_all_tails.yaml +++ b/pax/conf/experiment/mp/ppo_v_all_tails.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/ppo_v_ppo.yaml b/pax/conf/experiment/mp/ppo_v_ppo.yaml index c712f6f8..092d5f7f 100644 --- a/pax/conf/experiment/mp/ppo_v_ppo.yaml +++ b/pax/conf/experiment/mp/ppo_v_ppo.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/ppo_v_tft.yaml b/pax/conf/experiment/mp/ppo_v_tft.yaml index 6be62361..37b1a91b 100644 --- a/pax/conf/experiment/mp/ppo_v_tft.yaml +++ b/pax/conf/experiment/mp/ppo_v_tft.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/tabular_v_all_heads.yaml b/pax/conf/experiment/mp/tabular_v_all_heads.yaml index 1731dd83..c5be8a97 100644 --- a/pax/conf/experiment/mp/tabular_v_all_heads.yaml +++ b/pax/conf/experiment/mp/tabular_v_all_heads.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/tabular_v_all_tails.yaml b/pax/conf/experiment/mp/tabular_v_all_tails.yaml index 51da2208..5f0ad4cc 100644 --- a/pax/conf/experiment/mp/tabular_v_all_tails.yaml +++ b/pax/conf/experiment/mp/tabular_v_all_tails.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/tabular_v_tabular.yaml b/pax/conf/experiment/mp/tabular_v_tabular.yaml index 6047a3d1..1fc57dcd 100644 --- a/pax/conf/experiment/mp/tabular_v_tabular.yaml +++ b/pax/conf/experiment/mp/tabular_v_tabular.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/mp/tabular_v_tft.yaml b/pax/conf/experiment/mp/tabular_v_tft.yaml index 5ea38924..967a4e1a 100644 --- a/pax/conf/experiment/mp/tabular_v_tft.yaml +++ b/pax/conf/experiment/mp/tabular_v_tft.yaml @@ -44,6 +44,7 @@ ppo: with_memory: False with_cnn: False hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/sarl/acrobot.yaml b/pax/conf/experiment/sarl/acrobot.yaml index dfb010b8..c3f1171b 100644 --- a/pax/conf/experiment/sarl/acrobot.yaml +++ b/pax/conf/experiment/sarl/acrobot.yaml @@ -68,6 +68,7 @@ ppo: kernel_shape: [3, 3] separate: True hidden_size: 16 + rnn_type: gru # Logging setup wandb: diff --git a/pax/conf/experiment/sarl/cartpole.yaml b/pax/conf/experiment/sarl/cartpole.yaml index 9f233020..d73f0366 100644 --- a/pax/conf/experiment/sarl/cartpole.yaml +++ b/pax/conf/experiment/sarl/cartpole.yaml @@ -46,6 +46,7 @@ ppo: kernel_shape: [3, 3] separate: True hidden_size: 16 + rnn_type: gru # naive: # num_minibatches: 1 diff --git a/pax/conf/experiment/sarl/pendulum.yaml b/pax/conf/experiment/sarl/pendulum.yaml new file mode 100644 index 00000000..89feff1f --- /dev/null +++ b/pax/conf/experiment/sarl/pendulum.yaml @@ -0,0 +1,69 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: MountainCar-v0 +env_type: sequential +egocentric: True +env_discount: 0.96 +payoff: [[1, 1, -2], [1, 1, -2]] +runner: sarl + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 16 +num_steps: 400 # 500 Cartpole +total_timesteps: 5e6 +save_interval: 100 + +# Evaluation +run_path: ucl-dark/cg/3sp0y2cy +model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 + +# PPO agent parameters +ppo: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + rnn_type: gru + +# naive: +# num_minibatches: 1 +# num_epochs: 1 +# gamma: 0.96 +# gae_lambda: 0.95 +# max_gradient_norm: 1.0 +# learning_rate: 1.0 +# adam_epsilon: 1e-5 +# entropy_coeff: 0 + + + +# Logging setup +wandb: + entity: "ucl-dark" + project: sarl + group: 'sanity-${agent1}-vs-${agent2}-parity' + name: run-seed-${seed} + log: False \ No newline at end of file diff --git a/pax/experiment.py b/pax/experiment.py index a72fc289..a8e0e6dd 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -16,6 +16,7 @@ from pax.agents.naive_exact import NaiveExact from pax.agents.ppo.ppo import make_agent from pax.agents.ppo.ppo_gru import make_gru_agent +from pax.agents.ppo.ppo_lstm import make_lstm_agent from pax.agents.strategies import ( Altruistic, Defect, @@ -253,13 +254,22 @@ def agent_setup(args, env, env_params, logger): num_actions = env.num_actions def get_PPO_memory_agent(seed, player_id): - ppo_memory_agent = make_gru_agent( - args, - obs_spec=obs_shape, - action_spec=num_actions, - seed=seed, - player_id=player_id, - ) + if args.ppo.rnn_type == "lstm": + ppo_memory_agent = make_lstm_agent( + args, + obs_spec=obs_shape, + action_spec=num_actions, + seed=seed, + player_id=player_id, + ) + else: + ppo_memory_agent = make_gru_agent( + args, + obs_spec=obs_shape, + action_spec=num_actions, + seed=seed, + player_id=player_id, + ) return ppo_memory_agent def get_PPO_agent(seed, player_id): diff --git a/pax/runner_marl.py b/pax/runner_marl.py index df31cf26..4f55949f 100644 --- a/pax/runner_marl.py +++ b/pax/runner_marl.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +from haiku import LSTMState import wandb from pax.utils import MemoryState, TrainingState, save @@ -146,14 +147,35 @@ def _reshape_opp_dim(x): if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) + if args.ppo.rnn_type == "lstm" and args.agent1 == "PPO_memory": + print("hello") + hidden = jnp.tile(agent1._mem.hidden[0], (args.num_opps, 1, 1)) + cell = jnp.tile(agent1._mem.hidden[1], (args.num_opps, 1, 1)) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: + print("hello2") + init_hidden = jnp.tile( + agent1._mem.hidden, (args.num_opps, 1, 1) + ) + + # import pdb; pdb.set_trace() agent1._state, agent1._mem = agent1.batch_init( agent1._state.random_key, init_hidden ) if args.agent2 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory": + print("hello3") + hidden = jnp.tile(agent2._mem.hidden[0], (args.num_opps, 1, 1)) + cell = jnp.tile(agent2._mem.hidden[1], (args.num_opps, 1, 1)) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: + print("hello4") + init_hidden = jnp.tile( + agent2._mem.hidden, (args.num_opps, 1, 1) + ) + # import pdb; pdb.set_trace() agent2._state, agent2._mem = agent2.batch_init( jax.random.split(agent2._state.random_key, args.num_opps), init_hidden, @@ -279,6 +301,7 @@ def _outer_rollout(carry, unused): a2_state, a2_mem, ) + return ( rngs, obs1, @@ -326,6 +349,7 @@ def _rollout( _a2_state, _a2_mem = agent2.batch_init( jax.random.split(_rng_run, self.num_opps), _a2_mem.hidden ) + # run trials vals, stack = jax.lax.scan( _outer_rollout, From 4c64e127d79bfc8dee7b454294f70f2046be3c93 Mon Sep 17 00:00:00 2001 From: akbir Date: Wed, 23 Nov 2022 18:31:16 +0000 Subject: [PATCH 2/6] updated vmap + lstm runs --- pax/agents/ppo/ppo_lstm.py | 15 ++++++++------- pax/conf/experiment/ipd/ppo_mem_v_tft.yaml | 2 +- pax/runner_marl.py | 19 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py index b2d6244e..10320033 100644 --- a/pax/agents/ppo/ppo_lstm.py +++ b/pax/agents/ppo/ppo_lstm.py @@ -12,7 +12,6 @@ from pax.agents.ppo.networks import ( make_GRU_cartpole_network, make_GRU_coingame_network, - make_GRU_ipd_network, make_LSTM_ipd_network, ) from pax.utils import MemoryState, TrainingState, get_advantages @@ -350,7 +349,10 @@ def model_update_epoch( ) new_memory = MemoryState( - hidden=jnp.zeros(shape=(self._num_envs,) + (gru_dim,)), + hidden=hk.LSTMState( + hidden=jnp.zeros((self._num_envs, self._gru_dim)), + cell=jnp.zeros((self._num_envs, self._gru_dim)), + ), extras={ "log_probs": jnp.zeros(self._num_envs), "values": jnp.zeros(self._num_envs), @@ -363,17 +365,17 @@ def make_initial_state( key: Any, initial_hidden_state: NamedTuple ) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" - # We pass through initial_hidden_state so its easy to batch memory key, subkey = jax.random.split(key) + # import pdb; pdb.set_trace() dummy_obs = jnp.zeros(shape=obs_spec) + + # if len(initial_hidden_state[0].shape) > 2: dummy_obs = utils.add_batch_dim(dummy_obs) - # import pdb; pdb.set_trace() dummy_obs = dummy_obs.repeat( initial_hidden_state[0].shape[-2], axis=0 ) - # if len(initial_hidden_state[0].shape) > 2: - # dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init( subkey, dummy_obs, initial_hidden_state ) @@ -476,7 +478,6 @@ def update( ): """Update the agent -> only called at the end of a trajectory""" - _, _, mem = self._policy(state, obs, mem) traj_batch = self._prepare_batch( traj_batch, traj_batch.dones[-1, ...], mem.extras diff --git a/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml b/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml index 67dc9df1..a3117fde 100644 --- a/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml +++ b/pax/conf/experiment/ipd/ppo_mem_v_tft.yaml @@ -18,7 +18,7 @@ evo: False eval: False # Training hyperparameters num_envs: 100 -num_opps: 1 +num_opps: 5 num_steps: 150 # number of steps per episode total_timesteps: 2e7 diff --git a/pax/runner_marl.py b/pax/runner_marl.py index 4f55949f..84e197ea 100644 --- a/pax/runner_marl.py +++ b/pax/runner_marl.py @@ -136,9 +136,8 @@ def _reshape_opp_dim(x): # special case where NaiveEx has a different call signature agent2.batch_init = jax.jit(jax.vmap(agent2.make_initial_state)) else: - agent2.batch_init = jax.vmap( - agent2.make_initial_state, (0, None), 0 - ) + agent2.batch_init = jax.vmap(agent2.make_initial_state) + agent2.batch_policy = jax.jit(jax.vmap(agent2._policy)) agent2.batch_reset = jax.jit( jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 @@ -148,17 +147,14 @@ def _reshape_opp_dim(x): if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. if args.ppo.rnn_type == "lstm" and args.agent1 == "PPO_memory": - print("hello") hidden = jnp.tile(agent1._mem.hidden[0], (args.num_opps, 1, 1)) cell = jnp.tile(agent1._mem.hidden[1], (args.num_opps, 1, 1)) init_hidden = LSTMState(hidden=hidden, cell=cell) else: - print("hello2") init_hidden = jnp.tile( agent1._mem.hidden, (args.num_opps, 1, 1) ) - # import pdb; pdb.set_trace() agent1._state, agent1._mem = agent1.batch_init( agent1._state.random_key, init_hidden ) @@ -166,18 +162,21 @@ def _reshape_opp_dim(x): if args.agent2 != "NaiveEx": # NaiveEx requires env first step to init. if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory": - print("hello3") + rngs = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + ).reshape((args.num_opps, -1)) hidden = jnp.tile(agent2._mem.hidden[0], (args.num_opps, 1, 1)) cell = jnp.tile(agent2._mem.hidden[1], (args.num_opps, 1, 1)) init_hidden = LSTMState(hidden=hidden, cell=cell) else: - print("hello4") init_hidden = jnp.tile( agent2._mem.hidden, (args.num_opps, 1, 1) ) - # import pdb; pdb.set_trace() + rngs = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + ).reshape((args.num_opps, -1)) agent2._state, agent2._mem = agent2.batch_init( - jax.random.split(agent2._state.random_key, args.num_opps), + rngs, init_hidden, ) From ceee54716a1e2b909b98d50b756138a1189108b1 Mon Sep 17 00:00:00 2001 From: akbir Date: Wed, 23 Nov 2022 22:58:04 +0000 Subject: [PATCH 3/6] cleaning up --- pax/agents/ppo/ppo_lstm.py | 14 ++----------- pax/runner_eval.py | 15 ++++++++++---- pax/runner_evo.py | 42 +++++++++++++++++++++++++++----------- pax/runner_marl.py | 17 +++++++-------- 4 files changed, 50 insertions(+), 38 deletions(-) diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py index 10320033..9d23be54 100644 --- a/pax/agents/ppo/ppo_lstm.py +++ b/pax/agents/ppo/ppo_lstm.py @@ -10,14 +10,10 @@ from pax import utils from pax.agents.agent import AgentInterface from pax.agents.ppo.networks import ( - make_GRU_cartpole_network, - make_GRU_coingame_network, make_LSTM_ipd_network, ) from pax.utils import MemoryState, TrainingState, get_advantages -# from dm_env import TimeStep - class Batch(NamedTuple): """A batch of data; all shapes are expected to be [B, ...].""" @@ -502,15 +498,9 @@ def make_lstm_agent(args, obs_spec, action_spec, seed: int, player_id: int): """Make PPO agent""" # Network if args.env_id == "CartPole-v1": - network, initial_hidden_state = make_GRU_cartpole_network(action_spec) + raise ValueError("CartPole-v1 not supported") elif args.env_id == "coin_game": - if args.ppo.with_cnn: - print(f"Making network for {args.env_id} with CNN") - else: - print(f"Making network for {args.env_id} without CNN") - network, initial_hidden_state = make_GRU_coingame_network( - action_spec, args - ) + raise ValueError("CoinGame not supported") else: network, _ = make_LSTM_ipd_network(action_spec, args) diff --git a/pax/runner_eval.py b/pax/runner_eval.py index 5edc57b0..3475eb55 100644 --- a/pax/runner_eval.py +++ b/pax/runner_eval.py @@ -8,6 +8,7 @@ import wandb from pax.utils import load from pax.watchers import cg_visitation, ipd_visitation +from haiku import LSTMState MAX_WANDB_CALLS = 10000 @@ -91,9 +92,7 @@ def __init__(self, agents, env, args): # special case where NaiveEx has a different call signature agent2.batch_init = jax.jit(jax.vmap(agent2.make_initial_state)) else: - agent2.batch_init = jax.vmap( - agent2.make_initial_state, (0, None), 0 - ) + agent2.batch_init = jax.vmap(agent2.make_initial_state) agent2.batch_policy = jax.jit(jax.vmap(agent2._policy)) agent2.batch_reset = jax.jit( jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 @@ -102,7 +101,14 @@ def __init__(self, agents, env, args): if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) + if args.ppo.rnn_type == "lstm" and args.agent1 == "PPO_memory": + hidden = jnp.tile(agent1._mem.hidden[0], (args.num_opps, 1, 1)) + cell = jnp.tile(agent1._mem.hidden[1], (args.num_opps, 1, 1)) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: + init_hidden = jnp.tile( + agent1._mem.hidden, (args.num_opps, 1, 1) + ) agent1._state, agent1._mem = agent1.batch_init( agent1._state.random_key, init_hidden ) @@ -268,6 +274,7 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers): ).reshape((self.args.num_opps, self.args.num_envs, -1)) # run actual loop for i in range(num_episodes): + rngs = jax.random.split(rngs, 1) obs, env_state = env.reset(rngs, env_params) rewards = [ jnp.zeros((self.args.num_opps, self.args.num_envs)), diff --git a/pax/runner_evo.py b/pax/runner_evo.py index 3a9b7412..fb2716f2 100644 --- a/pax/runner_evo.py +++ b/pax/runner_evo.py @@ -9,6 +9,7 @@ import wandb from pax.utils import MemoryState, TrainingState, save +from haiku import LSTMState # TODO: import when evosax library is updated # from evosax.utils import ESLog @@ -141,13 +142,11 @@ def __init__( else: agent2.batch_init = jax.jit( jax.vmap( - jax.vmap(agent2.make_initial_state, (0, None), 0), - (0, None), - 0, + jax.vmap(agent2.make_initial_state), ) ) - agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy))) agent2.batch_reset = jax.jit( jax.vmap( jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 @@ -163,11 +162,21 @@ def __init__( ) if args.agent2 != "NaiveEx": # NaiveEx requires env first step to init. - init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) - key = jax.random.split( agent2._state.random_key, args.popsize * args.num_opps ).reshape(args.popsize, args.num_opps, -1) + if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory": + hidden = jnp.tile( + agent2._mem.hidden[0], (args.popsize, args.num_opps, 1, 1) + ) + cell = jnp.tile( + agent2._mem.hidden[1], (args.popsize, args.num_opps, 1, 1) + ) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: + init_hidden = jnp.tile( + agent2._mem.hidden, (args.popsize, args.num_opps, 1, 1) + ) agent2._state, agent2._mem = agent2.batch_init( key, @@ -379,7 +388,8 @@ def _rollout( rewards_2 = traj_2.rewards.sum(axis=1).mean() elif args.env_id in [ - "matrix_game", + "infinite_matrix_game", + "iterated_matrix_game", ]: env_stats = jax.tree_util.tree_map( lambda x: x.mean(), @@ -391,6 +401,10 @@ def _rollout( ) rewards_1 = traj_1.rewards.mean() rewards_2 = traj_2.rewards.mean() + else: + env_stats = {} + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() return ( fitness, other_fitness, @@ -446,15 +460,19 @@ def run_loop( num_devices = self.args.num_devices # Reshape a single agent's params before vmapping - init_hidden = jnp.tile( - agent1._mem.hidden, - (popsize, num_opps, 1, 1), - ) + if self.args.ppo.rnn_type == "lstm": + hidden = jnp.tile(agent1._mem.hidden[0], (popsize, num_opps, 1, 1)) + cell = jnp.tile(agent1._mem.hidden[1], (popsize, num_opps, 1, 1)) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) agent1._state, agent1._mem = agent1.batch_init( jax.random.split(agent1._state.random_key, popsize), init_hidden, ) - a1_state, a1_mem = agent1._state, agent1._mem for gen in range(num_gens): diff --git a/pax/runner_marl.py b/pax/runner_marl.py index 84e197ea..fc5a81e5 100644 --- a/pax/runner_marl.py +++ b/pax/runner_marl.py @@ -102,7 +102,6 @@ def _reshape_opp_dim(x): env.step, (0, 0, 0, None), 0 # rng, state, actions, params ) ) - self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) num_outer_steps = ( 1 @@ -142,7 +141,7 @@ def _reshape_opp_dim(x): agent2.batch_reset = jax.jit( jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 ) - agent2.batch_update = jax.jit(jax.vmap(agent2.update, (1, 0, 0, 0), 0)) + agent2.batch_update = jax.jit(jax.vmap(agent2.update, (1, 0, 0, 0))) if args.agent1 != "NaiveEx": # NaiveEx requires env first step to init. @@ -160,11 +159,12 @@ def _reshape_opp_dim(x): ) if args.agent2 != "NaiveEx": + keys = jnp.concatenate( + [jax.random.split(agent2._state.random_key, args.num_opps)] + ).reshape((args.num_opps, -1)) + # NaiveEx requires env first step to init. if args.ppo.rnn_type == "lstm" and args.agent2 == "PPO_memory": - rngs = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - ).reshape((args.num_opps, -1)) hidden = jnp.tile(agent2._mem.hidden[0], (args.num_opps, 1, 1)) cell = jnp.tile(agent2._mem.hidden[1], (args.num_opps, 1, 1)) init_hidden = LSTMState(hidden=hidden, cell=cell) @@ -172,11 +172,9 @@ def _reshape_opp_dim(x): init_hidden = jnp.tile( agent2._mem.hidden, (args.num_opps, 1, 1) ) - rngs = jnp.concatenate( - [jax.random.split(agent2._state.random_key, args.num_opps)] - ).reshape((args.num_opps, -1)) + agent2._state, agent2._mem = agent2.batch_init( - rngs, + keys, init_hidden, ) @@ -419,7 +417,6 @@ def _rollout( env_stats = {} rewards_1 = traj_1.rewards.mean() rewards_2 = traj_2.rewards.mean() - return ( env_stats, rewards_1, From b5a539d43f7a6d089640d5bd3b3908e39a94b15c Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 24 Nov 2022 12:01:22 +0000 Subject: [PATCH 4/6] added better cartpole networks --- pax/agents/ppo/networks.py | 30 ++++++++++++++++++++---------- pax/agents/ppo/ppo_lstm.py | 3 ++- pax/runner_sarl.py | 33 ++++++++++++++++----------------- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index d6297a71..cf920b84 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -351,8 +351,6 @@ def make_LSTM_ipd_network(num_actions: int, args): hidden = jnp.zeros((1, args.ppo.hidden_size)) cell = jnp.zeros((1, args.ppo.hidden_size)) hidden_state = hk.LSTMState(hidden=hidden, cell=cell) - # hidden_state = hidden_state_lstm.initial_state() - # hidden_state = jnp.zeros((1, args.ppo.hidden_size)) def forward_fn( inputs: jnp.ndarray, state: NamedTuple @@ -376,15 +374,27 @@ def forward_fn( inputs: jnp.ndarray, state: jnp.ndarray ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: """forward function""" - torso = hk.nets.MLP( - [hidden_size, hidden_size], - w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), - b_init=hk.initializers.Constant(0), - activate_final=True, - ) gru = hk.GRU(hidden_size) - embedding = torso(inputs) - embedding, state = gru(embedding, state) + embedding, state = gru(inputs, state) + logits, values = CategoricalValueHead(num_actions)(embedding) + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + + return network, hidden_state + + +def make_LSTM_cartpole_network(num_actions: int, args): + hidden = jnp.zeros((1, args.ppo.hidden_size)) + cell = jnp.zeros((1, args.ppo.hidden_size)) + hidden_state = hk.LSTMState(hidden=hidden, cell=cell) + + def forward_fn( + inputs: jnp.ndarray, state: jnp.ndarray + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + """forward function""" + lstm = hk.LSTM(args.ppo.hidden_size) + embedding, state = lstm(inputs, state) logits, values = CategoricalValueHead(num_actions)(embedding) return (logits, values), state diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py index 9d23be54..ed9c71d8 100644 --- a/pax/agents/ppo/ppo_lstm.py +++ b/pax/agents/ppo/ppo_lstm.py @@ -10,6 +10,7 @@ from pax import utils from pax.agents.agent import AgentInterface from pax.agents.ppo.networks import ( + make_LSTM_cartpole_network, make_LSTM_ipd_network, ) from pax.utils import MemoryState, TrainingState, get_advantages @@ -498,7 +499,7 @@ def make_lstm_agent(args, obs_spec, action_spec, seed: int, player_id: int): """Make PPO agent""" # Network if args.env_id == "CartPole-v1": - raise ValueError("CartPole-v1 not supported") + network, _ = make_LSTM_cartpole_network(action_spec, args) elif args.env_id == "coin_game": raise ValueError("CoinGame not supported") else: diff --git a/pax/runner_sarl.py b/pax/runner_sarl.py index a05869e0..c3c7331e 100644 --- a/pax/runner_sarl.py +++ b/pax/runner_sarl.py @@ -6,13 +6,13 @@ import jax.numpy as jnp import wandb -from pax.watchers import cg_visitation, ipd_visitation from pax.utils import MemoryState, TrainingState, save +from haiku import LSTMState # from jax.config import config # config.update('jax_disable_jit', True) -MAX_WANDB_CALLS = 1000000 +MAX_WANDB_CALLS = 10000 class Sample(NamedTuple): @@ -49,22 +49,21 @@ def __init__(self, agent, env, save_dir, args): self.split = jax.vmap(jax.random.split, (0, None)) # set up agent if args.agent1 == "NaiveEx": - # special case where NaiveEx has a different call signature - agent.batch_init = jax.jit(jax.vmap(agent.make_initial_state)) - else: - # batch MemoryState not TrainingState - agent.batch_init = jax.jit(agent.make_initial_state) + raise ValueError("NaiveEx not supported in SARL") + agent.batch_init = jax.jit(agent.make_initial_state) agent.batch_reset = jax.jit(agent.reset_memory, static_argnums=1) - agent.batch_policy = jax.jit(agent._policy) - if args.agent1 != "NaiveEx": - # NaiveEx requires env first step to init. + if args.ppo.rnn_type == "lstm" and args.agent1 == "PPO_memory": + hidden = jnp.tile(agent._mem.hidden[0], (1)) + cell = jnp.tile(agent._mem.hidden[1], (1)) + init_hidden = LSTMState(hidden=hidden, cell=cell) + else: init_hidden = jnp.tile(agent._mem.hidden, (1)) - agent._state, agent._mem = agent.batch_init( - agent._state.random_key, init_hidden - ) + agent._state, agent._mem = agent.batch_init( + agent._state.random_key, init_hidden + ) def _inner_rollout(carry, unused): """Runner for inner episode""" @@ -134,7 +133,7 @@ def _rollout( # run trials vals, traj = jax.lax.scan( _inner_rollout, - ( + ( rngs, obs, _a1_state, @@ -167,7 +166,7 @@ def _rollout( _a1_mem = agent.batch_reset(_a1_mem, False) # Stats - rewards = jnp.sum(traj.rewards)/(jnp.sum(traj.dones)+1e-8) + rewards = jnp.sum(traj.rewards) / (jnp.sum(traj.dones) + 1e-8) env_stats = {} return ( @@ -218,7 +217,7 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): # logging self.train_episodes += 1 - if num_iters % log_interval == 0: + if i % log_interval == 0: print(f"Episode {i}") print(f"Env Stats: {env_stats}") @@ -246,4 +245,4 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): ) agent._state = a1_state - return agent \ No newline at end of file + return agent From d17a3d801e994b46292b87b122315a1ecc7d0262 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 24 Nov 2022 12:07:23 +0000 Subject: [PATCH 5/6] added support for coingame network --- pax/agents/ppo/networks.py | 31 +++++++++++++++++++++++++++++++ pax/agents/ppo/ppo_lstm.py | 10 +++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/pax/agents/ppo/networks.py b/pax/agents/ppo/networks.py index cf920b84..45979e7a 100644 --- a/pax/agents/ppo/networks.py +++ b/pax/agents/ppo/networks.py @@ -432,6 +432,37 @@ def forward_fn( return network, hidden_state +def make_LSTM_coingame_network(num_actions: int, args): + hidden = jnp.zeros((1, args.ppo.hidden_size)) + cell = jnp.zeros((1, args.ppo.hidden_size)) + hidden_state = hk.LSTMState(hidden=hidden, cell=cell) + + def forward_fn( + inputs: jnp.ndarray, state: jnp.ndarray + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + + if args.ppo.with_cnn: + torso = CNN(args)(inputs) + + else: + torso = hk.nets.MLP( + [args.ppo.hidden_size], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) + lstm = hk.LSTM(args.ppo.hidden_size) + embedding = torso(inputs) + embedding, state = lstm(embedding, state) + logits, values = CategoricalValueHead(num_actions)(embedding) + + return (logits, values), state + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network, hidden_state + + def test_GRU(): key = jax.random.PRNGKey(seed=0) num_actions = 2 diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py index ed9c71d8..2b133bbf 100644 --- a/pax/agents/ppo/ppo_lstm.py +++ b/pax/agents/ppo/ppo_lstm.py @@ -11,6 +11,7 @@ from pax.agents.agent import AgentInterface from pax.agents.ppo.networks import ( make_LSTM_cartpole_network, + make_LSTM_coingame_network, make_LSTM_ipd_network, ) from pax.utils import MemoryState, TrainingState, get_advantages @@ -501,10 +502,13 @@ def make_lstm_agent(args, obs_spec, action_spec, seed: int, player_id: int): if args.env_id == "CartPole-v1": network, _ = make_LSTM_cartpole_network(action_spec, args) elif args.env_id == "coin_game": - raise ValueError("CoinGame not supported") - else: + network, _ = make_LSTM_coingame_network(action_spec, args) + elif args.env_id in ["iterated_matrix_game", "infinite_matrix_game"]: network, _ = make_LSTM_ipd_network(action_spec, args) - + else: + raise ValueError( + f"PPO LSTM Agent {player_id}: not implemented for {args.env_id} environment" + ) hidden = jnp.zeros((args.num_envs, args.ppo.hidden_size)) cell = jnp.zeros((args.num_envs, args.ppo.hidden_size)) initial_hidden_state = hk.LSTMState(hidden=hidden, cell=cell) From 526dd658e3da478fbae3a29d4b0d117f8298d997 Mon Sep 17 00:00:00 2001 From: akbir Date: Thu, 24 Nov 2022 12:08:30 +0000 Subject: [PATCH 6/6] added support for coingame network --- pax/agents/ppo/ppo_lstm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pax/agents/ppo/ppo_lstm.py b/pax/agents/ppo/ppo_lstm.py index 2b133bbf..79b34a5d 100644 --- a/pax/agents/ppo/ppo_lstm.py +++ b/pax/agents/ppo/ppo_lstm.py @@ -497,7 +497,7 @@ def update( # TODO: seed, and player_id not used in CartPole def make_lstm_agent(args, obs_spec, action_spec, seed: int, player_id: int): - """Make PPO agent""" + """Make PPO LSTM agent""" # Network if args.env_id == "CartPole-v1": network, _ = make_LSTM_cartpole_network(action_spec, args)