diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da853901..97d4625a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/ambv/black - rev: 22.6.0 + rev: 22.10.0 hooks: - id: black language_version: python3.9 -- repo: https://gitlab.com/pycqa/flake8 - rev: '3.9.2' +- repo: https://github.com/pycqa/flake8 + rev: '5.0.4' hooks: - id: flake8 additional_dependencies: [flake8-bugbear] diff --git a/pax/agents/act/act_agent.py b/pax/agents/act/act_agent.py new file mode 100644 index 00000000..49e9b471 --- /dev/null +++ b/pax/agents/act/act_agent.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, NamedTuple, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp +import optax + +from pax.agents.agent import AgentInterface +from pax import utils +from pax.utils import Logger, MemoryState, TrainingState, get_advantages +from pax.agents.act.networks import make_act_network + +class ActAgent(AgentInterface): + def __init__( + self, + network: NamedTuple, + optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + obs_spec: Tuple, + num_envs: int = 4, + entropy_coeff_start: float = 0.1, + player_id: int = 0, + ): + @jax.jit + def policy( + state: TrainingState, observation: jnp.ndarray, mem: MemoryState + ): + """Agent policy to select actions and calculate agent specific information""" + values = network.apply(state.params, observation) + mem.extras["values"] = values + mem = mem._replace(extras=mem.extras) + return values, state, mem + + def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key, subkey = jax.random.split(key) + dummy_obs = jnp.zeros(shape=obs_spec) + + dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init(subkey, dummy_obs) + initial_opt_state = optimizer.init(initial_params) + return TrainingState( + random_key=key, + params=initial_params, + opt_state=initial_opt_state, + timesteps=0, + ), MemoryState( + hidden=jnp.zeros((num_envs, 1)), + extras={ + "values": jnp.zeros((num_envs, 2)), + "log_probs": jnp.zeros(num_envs), + }, + ) + self.make_initial_state = make_initial_state + self._state, self._mem = make_initial_state(random_key, jnp.zeros(1)) + + # Set up counters and logger + self._logger = Logger() + self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + "loss_entropy": 0, + "entropy_cost": entropy_coeff_start, + } + + # Initialize functions + self._policy = policy + self.player_id = player_id + + # Other useful hyperparameters + self._num_envs = num_envs # number of environments + + def reset_memory(self, memory, eval=False) -> MemoryState: + num_envs = 1 if eval else self._num_envs + memory = memory._replace( + extras={ + "values": jnp.zeros((num_envs, 2)), + "log_probs": jnp.zeros(num_envs), + }, + ) + return memory + +def make_act_agent( + args, + obs_spec, + action_spec, + seed: int, + player_id: int, + tabular=False, +): + """Make PPO agent""" + if args.runner == "act_evo": + network = make_act_network(action_spec) + else: + raise NotImplementedError + + # Optimizer + batch_size = int(args.num_envs * args.num_steps) + transition_steps = ( + args.total_timesteps + / batch_size + * args.ppo.num_epochs + * args.ppo.num_minibatches + ) + + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale(-args.ppo.learning_rate), + ) + + random_key = jax.random.PRNGKey(seed=seed) + + agent = ActAgent( + network=network, + optimizer=optimizer, + random_key=random_key, + obs_spec=obs_spec, + num_envs=args.num_envs, + entropy_coeff_start=args.ppo.entropy_coeff_start, + player_id=player_id, + ) + return agent \ No newline at end of file diff --git a/pax/agents/act/networks.py b/pax/agents/act/networks.py new file mode 100644 index 00000000..c7a9aa9f --- /dev/null +++ b/pax/agents/act/networks.py @@ -0,0 +1,54 @@ +from typing import Optional, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp + +from pax import utils + +class DeterministicFunction(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None): + super().__init__(name=name) + self._act_body = hk.nets.MLP( + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jax.nn.relu, + ) + self._act_output = hk.nets.MLP( + [num_values], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) + + + def __call__(self, inputs: jnp.ndarray): + output = self._act_body(inputs) + output = self._act_output(output) + + return output + + +def make_act_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + DeterministicFunction(num_values=num_actions) + ] + ) + act_network = hk.Sequential(layers) + return act_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network \ No newline at end of file diff --git a/pax/agents/ppo/ppo.py b/pax/agents/ppo/ppo.py index 457c9b10..f9334ef5 100644 --- a/pax/agents/ppo/ppo.py +++ b/pax/agents/ppo/ppo.py @@ -459,7 +459,7 @@ def make_agent( tabular=False, ): """Make PPO agent""" - if args.runner == "sarl": + if args.runner in ["sarl", "sarl_eval"]: network = make_sarl_network(action_spec) elif args.env_id == "coin_game": print(f"Making network for {args.env_id}") diff --git a/pax/agents/strategies.py b/pax/agents/strategies.py index 89efebf0..679f7266 100644 --- a/pax/agents/strategies.py +++ b/pax/agents/strategies.py @@ -481,6 +481,41 @@ def reset_memory(self, mem, *args) -> MemoryState: def make_initial_state(self, _unused, *args) -> TrainingState: return self._state, self._mem +class RandomACT(AgentInterface): + def __init__(self, num_actions: int, num_envs: int): + self.make_initial_state = initial_state_fun(num_envs) + self._state, self._mem = self.make_initial_state(None, None) + self.reset_memory = reset_mem_fun(num_envs) + self._logger = Logger() + self._logger.metrics = {} + self._num_actions = num_actions + print('self num actions', self._num_actions) + + def _policy( + state: NamedTuple, + obs: jnp.array, + mem: NamedTuple, + ) -> jnp.ndarray: + # state is [batch x time_step x num_players] + # return [batch] + batch_size = obs.shape[0] + new_key, _ = jax.random.split(state.random_key) + action = jnp.zeros((batch_size, num_actions)) + # action = jax.random.uniform(new_key, (batch_size, num_actions), dtype=jnp.float32, minval=0.0, maxval=1.0) + state = state._replace(random_key=new_key) + return action, state, mem + + self._policy = jax.jit(_policy) + + def update(self, unused0, unused1, state, mem) -> None: + return state, mem, {} + + def reset_memory(self, mem, *args) -> MemoryState: + return self._mem + + def make_initial_state(self, _unused, *args) -> TrainingState: + return self._state, self._mem + class Stay(AgentInterface): def __init__(self, num_actions: int, num_envs: int): diff --git a/pax/agents/synq/networks.py b/pax/agents/synq/networks.py new file mode 100644 index 00000000..517e26b4 --- /dev/null +++ b/pax/agents/synq/networks.py @@ -0,0 +1,113 @@ +from typing import Optional, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp + +from pax import utils + +class CategoricalValueHeadSeparate(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._action_body = hk.nets.MLP( + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) + self._logit_layer = hk.Linear( + num_values, + w_init=hk.initializers.Orthogonal(1.0), + b_init=hk.initializers.Constant(0), + ) + self._value_body = hk.nets.MLP( + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) + self._value_layer = hk.Linear( + 1, + w_init=hk.initializers.Orthogonal(0.01), + b_init=hk.initializers.Constant(0), + ) + + def __call__(self, inputs: jnp.ndarray): + # action_output, value_output = inputs + logits = self._action_body(inputs) + logits = self._logit_layer(logits) + + value = self._value_body(inputs) + value = jnp.squeeze(self._value_layer(value), axis=-1) + return (distrax.Categorical(logits=logits), value) + +class DeterministicFunction(hk.Module): + """Network head that produces a categorical distribution and value.""" + + def __init__( + self, + num_values: int, + name: Optional[str] = None): + super().__init__(name=name) + self._act_body = hk.nets.MLP( + [64, 64], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jax.nn.relu, + ) + self._act_output = hk.nets.MLP( + [num_values], + w_init=hk.initializers.Orthogonal(jnp.sqrt(2)), + b_init=hk.initializers.Constant(0), + activate_final=True, + activation=jnp.tanh, + ) + + + def __call__(self, inputs: jnp.ndarray): + output = self._act_body(inputs) + output = self._act_output(output) + + return output + + +def make_synq_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + DeterministicFunction(num_values=num_actions) + ] + ) + act_network = hk.Sequential(layers) + return act_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network + +def make_policy_network(num_actions: int): + """Creates a hk network using the baseline hyperparameters from OpenAI""" + + def forward_fn(inputs): + layers = [] + layers.extend( + [ + CategoricalValueHeadSeparate(num_values=num_actions) + ] + ) + policy_value_network = hk.Sequential(layers) + return policy_value_network(inputs) + + network = hk.without_apply_rng(hk.transform(forward_fn)) + return network \ No newline at end of file diff --git a/pax/agents/synq/ppo.py b/pax/agents/synq/ppo.py new file mode 100644 index 00000000..caee3917 --- /dev/null +++ b/pax/agents/synq/ppo.py @@ -0,0 +1,547 @@ +# Adapted from https://github.com/deepmind/acme/blob/master/acme/agents/jax/ppo/learning.py + +from typing import Any, Dict, NamedTuple, Tuple + +import haiku as hk +import jax +import jax.numpy as jnp +import optax + +from pax import utils +from pax.agents.agent import AgentInterface +from pax.agents.ppo.networks import ( + make_sarl_network, + make_coingame_network, + make_ipd_network, +) +from pax.utils import Logger, MemoryState, TrainingState, get_advantages + + +class Batch(NamedTuple): + """A batch of data; all shapes are expected to be [B, ...].""" + + observations: jnp.ndarray + actions: jnp.ndarray + advantages: jnp.ndarray + + # Target value estimate used to bootstrap the value function. + target_values: jnp.ndarray + + # Value estimate and action log-prob at behavior time. + behavior_values: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_synq_values: jnp.ndarray + + +class PPO(AgentInterface): + """A simple PPO agent using JAX""" + + def __init__( + self, + network: NamedTuple, + network_synq: NamedTuple, + optimizer: optax.GradientTransformation, + random_key: jnp.ndarray, + obs_spec: Tuple, + num_envs: int = 4, + num_steps: int = 500, + num_minibatches: int = 16, + num_epochs: int = 4, + clip_value: bool = True, + value_coeff: float = 0.5, + anneal_entropy: bool = False, + entropy_coeff_start: float = 0.1, + entropy_coeff_end: float = 0.01, + entropy_coeff_horizon: int = 3_000_000, + ppo_clipping_epsilon: float = 0.2, + gamma: float = 0.99, + gae_lambda: float = 0.95, + tabular: bool = False, + player_id: int = 0, + ): + @jax.jit + def policy( + state: TrainingState, observation: jnp.ndarray, mem: MemoryState + ): + """Agent policy to select actions and calculate agent specific information""" + key, subkey = jax.random.split(state.random_key) + dist, values = network.apply(state.params, observation) + actions = dist.sample(seed=subkey) + mem.extras["values"] = values + mem.extras["log_probs"] = dist.log_prob(actions) + mem = mem._replace(extras=mem.extras) + state = state._replace(random_key=key) + return actions, state, mem + + @jax.jit + def synq_value( + state: TrainingState, observation: jnp.ndarray, mem: MemoryState + ): + """Agent policy to select actions and calculate agent specific information""" + key, subkey = jax.random.split(state.random_key) + values = network.apply(state.params_synq, observation) + mem.extras["synq_values"] = values + mem = mem._replace(extras=mem.extras) + state = state._replace(random_key=key) + return actions, state, mem + + @jax.jit + def gae_advantages( + rewards: jnp.ndarray, values: jnp.ndarray, dones: jnp.ndarray + ) -> jnp.ndarray: + """Calculates the gae advantages from a sequence. Note that the + arguments are of length = rollout length + 1""" + # 'Zero out' the terminated states + discounts = gamma * jnp.logical_not(dones) + + reverse_batch = ( + jnp.flip(values[:-1], axis=0), + jnp.flip(rewards, axis=0), + jnp.flip(discounts, axis=0), + ) + + _, advantages = jax.lax.scan( + get_advantages, + ( + jnp.zeros_like(values[-1]), + values[-1], + jnp.ones_like(values[-1]) * gae_lambda, + ), + reverse_batch, + ) + + advantages = jnp.flip(advantages, axis=0) + target_values = values[:-1] + advantages # Q-value estimates + target_values = jax.lax.stop_gradient(target_values) + return advantages, target_values + + def loss( + params: hk.Params, + timesteps: int, + observations: jnp.ndarray, + actions: jnp.array, + behavior_log_probs: jnp.array, + target_values: jnp.array, + advantages: jnp.array, + behavior_values: jnp.array, + behavior_synq_values: jnp.ndarray, + target_synq_values: jnp.ndarray, + ): + """Surrogate loss using clipped probability ratios.""" + distribution, values = network.apply(params, observations) + log_prob = distribution.log_prob(actions) + entropy = distribution.entropy() + + # Compute importance sampling weights: current policy / behavior policy. + rhos = jnp.exp(log_prob - behavior_log_probs) + + # Policy loss: Clipping + clipped_ratios_t = jnp.clip( + rhos, 1.0 - ppo_clipping_epsilon, 1.0 + ppo_clipping_epsilon + ) + clipped_objective = jnp.fmin( + rhos * advantages, clipped_ratios_t * advantages + ) + policy_loss = -jnp.mean(clipped_objective) + + # Value loss: MSE + value_cost = value_coeff + unclipped_value_error = target_values - values + unclipped_value_loss = unclipped_value_error**2 + + # Value clipping + if clip_value: + # Clip values to reduce variablility during critic training. + clipped_values = behavior_values + jnp.clip( + values - behavior_values, + -ppo_clipping_epsilon, + ppo_clipping_epsilon, + ) + clipped_value_error = target_values - clipped_values + clipped_value_loss = clipped_value_error**2 + value_loss = jnp.mean( + jnp.fmax(unclipped_value_loss, clipped_value_loss) + ) + else: + value_loss = jnp.mean(unclipped_value_loss) + + # Entropy loss: Standard entropy term + # Calculate the new value based on linear annealing formula + if anneal_entropy: + fraction = jnp.fmax(1 - timesteps / entropy_coeff_horizon, 0) + entropy_cost = ( + fraction * entropy_coeff_start + + (1 - fraction) * entropy_coeff_end + ) + # Constant Entropy term + else: + entropy_cost = entropy_coeff_start + entropy_loss = -jnp.mean(entropy) + + # Total loss: Minimize policy and value loss; maximize entropy + total_loss = ( + policy_loss + + entropy_cost * entropy_loss + + value_loss * value_cost + ) + + return total_loss, { + "loss_total": total_loss, + "loss_policy": policy_loss, + "loss_value": value_loss, + "loss_entropy": entropy_loss, + "entropy_cost": entropy_cost, + } + + @jax.jit + def sgd_step( + state: TrainingState, sample: NamedTuple + ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: + """Performs a minibatch SGD step, returning new state and metrics.""" + + # Extract data + ( + observations, + actions, + rewards, + behavior_log_probs, + behavior_values, + synq_values, + dones, + ) = ( + sample.observations, + sample.actions, + sample.rewards, + sample.behavior_log_probs, + sample.behavior_values, + sample.synq_values, + sample.dones, + ) + + advantages, target_values = gae_advantages( + rewards=rewards, values=behavior_values, dones=dones + ) + + # Exclude the last step - it was only used for bootstrapping. + # The shape is [num_steps, num_envs, ..] + behavior_values = behavior_values[:-1, :] + trajectories = Batch( + observations=observations, + actions=actions, + advantages=advantages, + behavior_log_probs=behavior_log_probs, + target_values=target_values, + behavior_values=behavior_values, + behavior_synq_values=synq_values, + ) + + # Concatenate all trajectories. Reshape from [num_envs, num_steps, ..] + # to [num_envs * num_steps,..] + assert len(target_values.shape) > 1 + num_envs = target_values.shape[1] + num_steps = target_values.shape[0] + batch_size = num_envs * num_steps + assert batch_size % num_minibatches == 0, ( + "Num minibatches must divide batch size. Got batch_size={}" + " num_minibatches={}." + ).format(batch_size, num_minibatches) + + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), trajectories + ) + + # Compute gradients. + grad_fn = jax.jit(jax.grad(loss, has_aux=True)) + + @jax.jit + def model_update_minibatch( + carry: Tuple[hk.Params, optax.OptState, int], + minibatch: Batch, + ) -> Tuple[ + Tuple[hk.Params, optax.OptState, int], Dict[str, jnp.ndarray] + ]: + """Performs model update for a single minibatch.""" + params, opt_state, timesteps = carry + # Normalize advantages at the minibatch level before using them. + advantages = ( + minibatch.advantages + - jnp.mean(minibatch.advantages, axis=0) + ) / (jnp.std(minibatch.advantages, axis=0) + 1e-8) + gradients, metrics = grad_fn( + params, + timesteps, + minibatch.observations, + minibatch.actions, + minibatch.behavior_log_probs, + minibatch.target_values, + advantages, + minibatch.behavior_values, + minibatch.behavior_synq_values, + ) + + # Apply updates + updates, opt_state = optimizer.update(gradients, opt_state) + params = optax.apply_updates(params, updates) + + metrics["norm_grad"] = optax.global_norm(gradients) + metrics["norm_updates"] = optax.global_norm(updates) + return (params, opt_state, timesteps), metrics + + @jax.jit + def model_update_epoch( + carry: Tuple[ + jnp.ndarray, hk.Params, optax.OptState, int, Batch + ], + unused_t: Tuple[()], + ) -> Tuple[ + Tuple[jnp.ndarray, hk.Params, optax.OptState, Batch], + Dict[str, jnp.ndarray], + ]: + """Performs model updates based on one epoch of data.""" + key, params, opt_state, timesteps, batch = carry + key, subkey = jax.random.split(key) + permutation = jax.random.permutation(subkey, batch_size) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [num_minibatches, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + + (params, opt_state, timesteps), metrics = jax.lax.scan( + model_update_minibatch, + (params, opt_state, timesteps), + minibatches, + length=num_minibatches, + ) + return (key, params, opt_state, timesteps, batch), metrics + + params = state.params + opt_state = state.opt_state + timesteps = state.timesteps + + # Repeat training for the given number of epoch, taking a random + # permutation for every epoch. + # signature is scan(function, carry, tuple to iterate over, length) + (key, params, opt_state, timesteps, _), metrics = jax.lax.scan( + model_update_epoch, + (state.random_key, params, opt_state, timesteps, batch), + (), + length=num_epochs, + ) + + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + metrics["rewards_mean"] = jnp.mean( + jnp.abs(jnp.mean(rewards, axis=(0, 1))) + ) + metrics["rewards_std"] = jnp.std(rewards, axis=(0, 1)) + + new_state = TrainingState( + params=params, + opt_state=opt_state, + random_key=key, + timesteps=timesteps + batch_size, + ) + + new_memory = MemoryState( + hidden=jnp.zeros((num_envs, 1)), + extras={ + "log_probs": jnp.zeros(num_envs), + "values": jnp.zeros(num_envs), + }, + ) + + return new_state, new_memory, metrics + + def make_initial_state(key: Any, hidden: jnp.ndarray) -> TrainingState: + """Initialises the training state (parameters and optimiser state).""" + key, subkey = jax.random.split(key) + if not tabular: + dummy_obs = jnp.zeros(shape=obs_spec) + else: + dummy_obs = jnp.zeros(shape=obs_spec) + dummy_obs = dummy_obs.at[0].set(1) + dummy_obs = dummy_obs.at[9].set(1) + dummy_obs = dummy_obs.at[18].set(1) + dummy_obs = dummy_obs.at[27].set(1) + dummy_obs = utils.add_batch_dim(dummy_obs) + initial_params = network.init(subkey, dummy_obs) + initial_opt_state = optimizer.init(initial_params) + return TrainingState( + random_key=key, + params=initial_params, + opt_state=initial_opt_state, + timesteps=0, + ), MemoryState( + hidden=jnp.zeros((num_envs, 1)), + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + ) + + def prepare_batch( + traj_batch: NamedTuple, done: Any, action_extras: dict + ): + # Rollouts complete -> Training begins + # Add an additional rollout step for advantage calculation + _value = jax.lax.select( + done, + jnp.zeros_like(action_extras["values"]), + action_extras["values"], + ) + + _value = jax.lax.expand_dims(_value, [0]) + # need to add final value here + traj_batch = traj_batch._replace( + behavior_values=jnp.concatenate( + [traj_batch.behavior_values, _value], axis=0 + ) + ) + return traj_batch + + # Initialise training state (parameters, optimiser state, extras). + self.make_initial_state = make_initial_state + self._state, self._mem = make_initial_state(random_key, jnp.zeros(1)) + self._prepare_batch = jax.jit(prepare_batch) + self._sgd_step = jax.jit(sgd_step) + + # Set up counters and logger + self._logger = Logger() + self._total_steps = 0 + self._until_sgd = 0 + self._logger.metrics = { + "total_steps": 0, + "sgd_steps": 0, + "loss_total": 0, + "loss_policy": 0, + "loss_value": 0, + "loss_entropy": 0, + "entropy_cost": entropy_coeff_start, + } + + # Initialize functions + self._policy = policy + self._synq_value = synq_value + self.player_id = player_id + + # Other useful hyperparameters + self._num_envs = num_envs # number of environments + self._num_steps = num_steps # number of steps per environment + self._batch_size = int(num_envs * num_steps) # number in one batch + self._num_minibatches = num_minibatches # number of minibatches + self._num_epochs = num_epochs # number of epochs to use sample + + def reset_memory(self, memory, eval=False) -> MemoryState: + num_envs = 1 if eval else self._num_envs + memory = memory._replace( + extras={ + "values": jnp.zeros(num_envs), + "log_probs": jnp.zeros(num_envs), + }, + ) + return memory + + def update( + self, + traj_batch, + obs: jnp.ndarray, + state: TrainingState, + mem: MemoryState, + ): + """Update the agent -> only called at the end of a trajectory""" + _, _, mem = self._policy(state, obs, mem) + + traj_batch = self._prepare_batch( + traj_batch, traj_batch.dones[-1, ...], mem.extras + ) + state, mem, metrics = self._sgd_step(state, traj_batch) + self._logger.metrics["sgd_steps"] += ( + self._num_minibatches * self._num_epochs + ) + self._logger.metrics["loss_total"] = metrics["loss_total"] + self._logger.metrics["loss_policy"] = metrics["loss_policy"] + self._logger.metrics["loss_value"] = metrics["loss_value"] + self._logger.metrics["loss_entropy"] = metrics["loss_entropy"] + self._logger.metrics["entropy_cost"] = metrics["entropy_cost"] + + return state, mem, metrics + + +def make_agent( + args, + obs_spec, + action_spec, + seed: int, + player_id: int, + tabular=False, +): + """Make PPO agent""" + if args.runner in ["synq"]: + network = make_sarl_network(action_spec) + network_synq = make_synq_network(action_spec) + else: + raise NotImplementedError + + # Optimizer + batch_size = int(args.num_envs * args.num_steps) + transition_steps = ( + args.total_timesteps + / batch_size + * args.ppo.num_epochs + * args.ppo.num_minibatches + ) + + if args.ppo.lr_scheduling: + scheduler = optax.linear_schedule( + init_value=args.ppo.learning_rate, + end_value=0, + transition_steps=transition_steps, + ) + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale_by_schedule(scheduler), + optax.scale(-1), + ) + + else: + optimizer = optax.chain( + optax.clip_by_global_norm(args.ppo.max_gradient_norm), + optax.scale_by_adam(eps=args.ppo.adam_epsilon), + optax.scale(-args.ppo.learning_rate), + ) + + # Random key + random_key = jax.random.PRNGKey(seed=seed) + + agent = PPO( + network=network, + network_synq=network_synq, + optimizer=optimizer, + random_key=random_key, + obs_spec=obs_spec, + num_envs=args.num_envs, + num_steps=args.num_steps, + num_minibatches=args.ppo.num_minibatches, + num_epochs=args.ppo.num_epochs, + clip_value=args.ppo.clip_value, + value_coeff=args.ppo.value_coeff, + anneal_entropy=args.ppo.anneal_entropy, + entropy_coeff_start=args.ppo.entropy_coeff_start, + entropy_coeff_end=args.ppo.entropy_coeff_end, + entropy_coeff_horizon=args.ppo.entropy_coeff_horizon, + ppo_clipping_epsilon=args.ppo.ppo_clipping_epsilon, + gamma=args.ppo.gamma, + gae_lambda=args.ppo.gae_lambda, + tabular=tabular, + player_id=player_id, + ) + return agent + + +if __name__ == "__main__": + pass diff --git a/pax/conf/experiment/act/act.yaml b/pax/conf/experiment/act/act.yaml new file mode 100644 index 00000000..8470ded9 --- /dev/null +++ b/pax/conf/experiment/act/act.yaml @@ -0,0 +1,87 @@ +# @package _global_ + +# Agents +agent1: 'ACT' +agent2: 'PPO' + +# Environment +env_id: CartPole-v1 +env_type: meta +egocentric: True +env_discount: 0.96 +payoff: [[1, 1, -2], [1, 1, -2]] + +# Save +save: True +save_interval: 100 +adversary_type: ally # adversary, random adversary, zeroes adversary + +# Runner +runner: act_evo +output_dim: 0 + +# Training +top_k: 5 +popsize: 256 #512 +# total popsize = popsize * num_devices +num_envs: 8 +num_opps: 1 +num_devices: 4 +num_steps: 100000 +num_inner_steps: 200 +num_generations: 3000 + +# Evaluation +run_path: ucl-dark/cg/3mpgbfm2 +model_path: exp/coin_game-EARL-PPO_memory-vs-Random/run-seed-0/2022-09-08_20.41.03.643377/generation_30 + +# num_generations = total_timesteps / (num_envs*num_opps*num_steps) +# PPO agent parameters +ppo: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + +# ES parameters +es: + algo: OpenES # [OpenES, CMA_ES, SimpleGA] + sigma_init: 0.04 # Initial scale of isotropic Gaussian noise + sigma_decay: 0.999 # Multiplicative decay factor + sigma_limit: 0.01 # Smallest possible scale + init_min: 0.0 # Range of parameter mean initialization - Min + init_max: 0.0 # Range of parameter mean initialization - Max + clip_min: -1e10 # Range of parameter proposals - Min + clip_max: 1e10 # Range of parameter proposals - Max + lrate_init: 0.1 # Initial learning rate + lrate_decay: 0.9999 # Multiplicative decay factor + lrate_limit: 0.001 # Smallest possible lrate + beta_1: 0.99 # Adam - beta_1 + beta_2: 0.999 # Adam - beta_2 + eps: 1e-8 # eps constant, + elite_ratio: 0.1 + +# Logging setup +wandb: + entity: "ucl-dark" + project: act + group: 'GS-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: False \ No newline at end of file diff --git a/pax/conf/experiment/act/random.yaml b/pax/conf/experiment/act/random.yaml new file mode 100644 index 00000000..e2d1172f --- /dev/null +++ b/pax/conf/experiment/act/random.yaml @@ -0,0 +1,63 @@ +# @package _global_ + +# Agents +agent1: 'PPO' +agent2: 'PPO' + +# Environment +env_id: CartPole-v1 +env_type: meta +egocentric: True +env_discount: 0.96 +payoff: [[1, 1, -2], [1, 1, -2]] + +# Save +save: True +save_interval: 100 +adversary_type: ally # adversary, random adversary, zeroes adversary + +# Runner +runner: act_rl +output_dim: 0 + +# total popsize = popsize * num_devices +num_envs: 8 +num_opps: 1 +num_steps: 200 +num_inner_steps: 200 +total_timesteps: 1.0e6 + + +# num_generations = total_timesteps / (num_envs*num_opps*num_steps) +# PPO agent parameters +ppo: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + + +# Logging setup +wandb: + entity: "ucl-dark" + project: act + group: 'GS-${agent1}-vs-${agent2}' + name: run-seed-${seed} + log: False \ No newline at end of file diff --git a/pax/conf/experiment/sarl/acrobot.yaml b/pax/conf/experiment/sarl/acrobot.yaml index dfb010b8..ac5ee0c8 100644 --- a/pax/conf/experiment/sarl/acrobot.yaml +++ b/pax/conf/experiment/sarl/acrobot.yaml @@ -16,7 +16,7 @@ runner: sarl # env_batch_size = num_envs * num_opponents num_envs: 50 num_steps: 128 # 500 Cartpole -total_timesteps: 1e9 +total_timesteps: 8e7 save_interval: 100 # Evaluation @@ -72,7 +72,7 @@ ppo: # Logging setup wandb: entity: "ucl-dark" - project: synq + project: sarl group: 'sanity-${agent1}-vs-${agent2}-parity' name: run-seed-${seed} log: False \ No newline at end of file diff --git a/pax/conf/experiment/sarl/acrobot_eval.yaml b/pax/conf/experiment/sarl/acrobot_eval.yaml new file mode 100644 index 00000000..d44dd2f0 --- /dev/null +++ b/pax/conf/experiment/sarl/acrobot_eval.yaml @@ -0,0 +1,78 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: Acrobot-v1 +env_type: sequential +egocentric: True +env_discount: 0.96 +payoff: [[1, 1, -2], [1, 1, -2]] +runner: sarl_eval + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 1 +num_steps: 128 # 500 Cartpole +total_timesteps: 128 +save_interval: 100 + +# Evaluation +run_path: ucl-dark/sarl/2wnfkuoa +model_path: exp/sanity-PPO-vs-PPO-parity/run-seed-0/2022-11-17_12.28.11.590592/iteration_7800 + +# PPO agent parameters +# ppo: +# num_minibatches: 4 +# num_epochs: 2 +# gamma: 0.96 +# gae_lambda: 0.95 +# ppo_clipping_epsilon: 0.2 +# value_coeff: 0.5 +# clip_value: True +# max_gradient_norm: 0.5 +# anneal_entropy: False +# entropy_coeff_start: 0.1 +# entropy_coeff_horizon: 0.6e8 +# entropy_coeff_end: 0.03 +# lr_scheduling: False +# learning_rate: 1e-5 #0.05 +# adam_epsilon: 1e-5 +# with_memory: True +# with_cnn: False +# output_channels: 16 +# kernel_shape: [3, 3] +# separate: True # only works with CNN +# hidden_size: 16 #50 +ppo: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 2.5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + +# Logging setup +wandb: + entity: "ucl-dark" + project: sarl + group: 'sanity-${agent1}-vs-${agent2}-parity' + name: run-seed-${seed} + log: False \ No newline at end of file diff --git a/pax/conf/experiment/sarl/cartpole.yaml b/pax/conf/experiment/sarl/cartpole.yaml index 9f233020..f68825ed 100644 --- a/pax/conf/experiment/sarl/cartpole.yaml +++ b/pax/conf/experiment/sarl/cartpole.yaml @@ -15,8 +15,8 @@ runner: sarl # env_batch_size = num_envs * num_opponents num_envs: 8 -num_steps: 500 # 500 Cartpole -total_timesteps: 1e9 +num_steps: 200 # 500 Cartpole +total_timesteps: 1e6 save_interval: 100 # Evaluation @@ -62,7 +62,7 @@ ppo: # Logging setup wandb: entity: "ucl-dark" - project: synq + project: sarl group: 'sanity-${agent1}-vs-${agent2}-parity' name: run-seed-${seed} log: False \ No newline at end of file diff --git a/pax/conf/experiment/sarl/cartpole_eval.yaml b/pax/conf/experiment/sarl/cartpole_eval.yaml new file mode 100644 index 00000000..afd4ae63 --- /dev/null +++ b/pax/conf/experiment/sarl/cartpole_eval.yaml @@ -0,0 +1,70 @@ +# @package _global_ + +# Agents +agent1: 'PPO' + +# Environment +env_id: CartPole-v1 +env_type: sequential +egocentric: True +env_discount: 0.96 +payoff: [[1, 1, -2], [1, 1, -2]] +runner: sarl_eval + +# Training hyperparameters + +# env_batch_size = num_envs * num_opponents +num_envs: 1 +num_steps: 500 # 500 Cartpole +total_timesteps: 500 +save_interval: 100 + +# Evaluation +# run_path: ucl-dark/cg/3sp0y2cy +# model_path: exp/coin_game-PPO_memory-vs-PPO_memory-parity/run-seed-0/2022-09-12_11.21.52.633382/iteration_74900 +run_path: ucl-dark/sarl/1nby54fs +model_path: exp/sanity-PPO-vs-PPO-parity/run-seed-0/2022-11-17_16.00.47.035029/iteration_900 + +# PPO agent parameters +ppo: + num_minibatches: 4 + num_epochs: 4 + gamma: 0.99 + gae_lambda: 0.95 + ppo_clipping_epsilon: 0.2 + value_coeff: 0.5 + clip_value: True + max_gradient_norm: 0.5 + anneal_entropy: True + entropy_coeff_start: 0.01 + entropy_coeff_horizon: 1000000 + entropy_coeff_end: 0.001 + lr_scheduling: True + learning_rate: 5e-4 #5e-4 + adam_epsilon: 1e-5 + with_memory: True + with_cnn: False + output_channels: 16 + kernel_shape: [3, 3] + separate: True + hidden_size: 16 + +# naive: +# num_minibatches: 1 +# num_epochs: 1 +# gamma: 0.96 +# gae_lambda: 0.95 +# max_gradient_norm: 1.0 +# learning_rate: 1.0 +# adam_epsilon: 1e-5 +# entropy_coeff: 0 + + + +# Logging setup +wandb: + entity: "ucl-dark" + project: sarl + group: 'sanity-${agent1}-vs-${agent2}-parity' + name: run-seed-${seed} + log: False \ No newline at end of file diff --git a/pax/experiment.py b/pax/experiment.py index a72fc289..04ab907a 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -16,6 +16,7 @@ from pax.agents.naive_exact import NaiveExact from pax.agents.ppo.ppo import make_agent from pax.agents.ppo.ppo_gru import make_gru_agent +from pax.agents.act.act_agent import make_act_agent from pax.agents.strategies import ( Altruistic, Defect, @@ -27,6 +28,7 @@ HyperTFT, Random, RandomGreedy, + RandomACT, Stay, TitForTat, ) @@ -40,6 +42,9 @@ from pax.runner_evo import EvoRunner from pax.runner_marl import RLRunner from pax.runner_sarl import SARLRunner +from pax.runner_sarl_eval import SARLEvalRunner +from pax.runner_act_evo import ActEvoRunner +from pax.runner_act_rl import ActRLRunner from pax.utils import Section from pax.watchers import ( logger_hyper, @@ -134,7 +139,9 @@ def env_setup(args, logger=None): logger.info( f"Env Type: CoinGame | Episode Length: {args.num_steps}" ) - elif args.runner == "sarl": + elif args.runner in ["sarl", "sarl_eval"]: + env, env_params = gymnax.make(args.env_id) + elif args.runner == "act_evo" or args.runner == "act_rl": env, env_params = gymnax.make(args.env_id) else: raise ValueError(f"Unknown env id {args.env_id}") @@ -146,7 +153,7 @@ def runner_setup(args, env, agents, save_dir, logger): logger.info("Evaluating with EvalRunner") return EvalRunner(agents, env, args) - if args.runner == "evo": + if args.runner == "evo" or args.runner == "act_evo": agent1, _ = agents algo = args.es.algo strategies = {"CMA_ES", "OpenES", "PGPE", "SimpleGA"} @@ -227,10 +234,26 @@ def get_pgpe_strategy(agent): strategy, es_params, param_reshaper = get_ga_strategy(agent1) logger.info(f"Evolution Strategy: {algo}") - - return EvoRunner( - agents, env, strategy, es_params, param_reshaper, save_dir, args - ) + if args.runner == "evo": + return EvoRunner( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) + elif args.runner == "act_evo": + return ActEvoRunner( + agents, + env, + strategy, + es_params, + param_reshaper, + save_dir, + args, + ) elif args.runner == "rl": logger.info("Training with RL Runner") @@ -238,6 +261,12 @@ def get_pgpe_strategy(agent): elif args.runner == "sarl": logger.info("Training with SARL Runner") return SARLRunner(agents, env, save_dir, args) + elif args.runner == "sarl_eval": + logger.info("Evaluating with SARLEval Runner") + return SARLEvalRunner(agents, env, save_dir, args) + elif args.runner == "act_rl": + logger.info("Training with RL Runner") + return ActRLRunner(agents, env, save_dir, args) else: raise ValueError(f"Unknown runner type {args.runner}") @@ -251,8 +280,9 @@ def agent_setup(args, env, env_params, logger): else: obs_shape = env.observation_space(env_params).shape num_actions = env.num_actions + print(num_actions) - def get_PPO_memory_agent(seed, player_id): + def get_PPO_memory_agent(seed, player_id, obs_shape=obs_shape): ppo_memory_agent = make_gru_agent( args, obs_spec=obs_shape, @@ -262,7 +292,7 @@ def get_PPO_memory_agent(seed, player_id): ) return ppo_memory_agent - def get_PPO_agent(seed, player_id): + def get_PPO_agent(seed, player_id, obs_shape=obs_shape): ppo_agent = make_agent( args, obs_spec=obs_shape, @@ -272,7 +302,7 @@ def get_PPO_agent(seed, player_id): ) return ppo_agent - def get_PPO_tabular_agent(seed, player_id): + def get_PPO_tabular_agent(seed, player_id, obs_shape=obs_shape): ppo_agent = make_agent( args, obs_spec=obs_shape, @@ -283,7 +313,7 @@ def get_PPO_tabular_agent(seed, player_id): ) return ppo_agent - def get_mfos_agent(seed, player_id): + def get_mfos_agent(seed, player_id, obs_shape=obs_shape): ppo_agent = make_mfos_agent( args, obs_spec=obs_shape, @@ -293,7 +323,7 @@ def get_mfos_agent(seed, player_id): ) return ppo_agent - def get_hyper_agent(seed, player_id): + def get_hyper_agent(seed, player_id, obs_shape=obs_shape): hyper_agent = make_hyper( args, obs_spec=obs_shape, @@ -303,7 +333,7 @@ def get_hyper_agent(seed, player_id): ) return hyper_agent - def get_naive_pg(seed, player_id): + def get_naive_pg(seed, player_id, obs_shape=obs_shape): naive_agent = make_naive_pg( args, obs_spec=obs_shape, @@ -328,17 +358,33 @@ def get_random_agent(seed, player_id): random_agent.player_id = player_id return random_agent + def get_random_act_agent(seed, player_id): + random_agent = RandomACT(args.output_dim, args.num_envs) + random_agent.player_id = player_id + return random_agent + # flake8: noqa: C901 def get_stay_agent(seed, player_id): agent = Stay(num_actions, args.num_envs) agent.player_id = player_id return agent + def get_ACT_agent(seed, player_id, obs_shape=obs_shape): + act_agent = make_act_agent( + args, + obs_spec=obs_shape, + action_spec=args.output_dim, + seed=seed, + player_id=player_id, + ) + return act_agent + strategies = { "TitForTat": partial(TitForTat, args.num_envs), "Defect": partial(Defect, args.num_envs), "Altruistic": partial(Altruistic, args.num_envs), "Random": get_random_agent, + "RandomACT": get_random_act_agent, "Stay": get_stay_agent, "Grim": partial(GrimTrigger, args.num_envs), "GoodGreedy": partial(GoodGreedy, args.num_envs), @@ -355,9 +401,10 @@ def get_stay_agent(seed, player_id): "HyperAltruistic": partial(HyperAltruistic, args.num_envs), "HyperDefect": partial(HyperDefect, args.num_envs), "HyperTFT": partial(HyperTFT, args.num_envs), + "ACT": get_ACT_agent, } - if args.runner == "sarl": + if args.runner in ["sarl", "sarl_eval"]: assert args.agent1 in strategies num_agents = 1 seeds = [args.seed] @@ -370,12 +417,15 @@ def get_stay_agent(seed, player_id): logger.info(f"Agent Pair: {args.agent1}") logger.info(f"Agent seeds: {seeds[0]}") - if args.runner in ["eval", "sarl"]: + if args.runner in ["eval", "sarl", "sarl_eval"]: logger.info("Using Independent Learners") return agent_1 else: assert args.agent1 in strategies assert args.agent2 in strategies + if args.runner == "act_evo": + if args.agent1 != "ACT": + raise NotImplementedError num_agents = 2 seeds = [seed for seed in range(args.seed, args.seed + num_agents)] @@ -384,18 +434,24 @@ def get_stay_agent(seed, player_id): seed % seed + i if seed != 0 else 1 for seed, i in zip(seeds, range(1, num_agents + 1)) ] - agent_0 = strategies[args.agent1](seeds[0], pids[0]) # player 1 - agent_1 = strategies[args.agent2](seeds[1], pids[1]) # player 2 + if args.runner == "act_evo" or args.runner=="act_rl": + agent_0 = strategies[args.agent1](seeds[0], pids[0]) # player 1 + agent_1 = strategies[args.agent2]( + seeds[1], pids[1], obs_shape=(obs_shape[0] + args.output_dim,) + ) # player 2 + else: + agent_0 = strategies[args.agent1](seeds[0], pids[0]) # player 1 + agent_1 = strategies[args.agent2](seeds[1], pids[1]) # player 2 if args.agent1 in ["PPO", "PPO_memory"] and args.ppo.with_cnn: logger.info(f"PPO with CNN: {args.ppo.with_cnn}") logger.info(f"Agent Pair: {args.agent1} | {args.agent2}") logger.info(f"Agent seeds: {seeds[0]} | {seeds[1]}") - if args.runner in ["eval", "rl"]: + if args.runner in ["eval", "rl", "act_rl"]: logger.info("Using Independent Learners") return (agent_0, agent_1) - if args.runner == "evo": + if args.runner == "evo" or args.runner == "act_evo": logger.info("Using EvolutionaryLearners") return (agent_0, agent_1) @@ -459,6 +515,7 @@ def naive_pg_log(agent): "Altruistic": dumb_log, "Human": dumb_log, "Random": dumb_log, + "RandomACT": dumb_log, "Stay": dumb_log, "Grim": dumb_log, "GoodGreedy": dumb_log, @@ -476,9 +533,10 @@ def naive_pg_log(agent): "Tabular": ppo_log, "PPO_memory_pretrained": ppo_memory_log, "MFOS_pretrained": dumb_log, + "ACT": dumb_log, } - if args.runner == "sarl": + if args.runner in ["sarl", "sarl_eval"]: assert args.agent1 in strategies agent_1_log = naive_pg_log # strategies[args.agent1] # @@ -535,7 +593,7 @@ def main(args): print(f"Number of Episodes: {num_iters}") runner.run_loop(env, env_params, agent_pair, num_iters, watchers) - elif args.runner == "eval": + elif args.runner == "sarl_eval": num_iters = int( args.total_timesteps / args.num_steps ) # number of episodes @@ -549,6 +607,18 @@ def main(args): print(f"Number of Episodes: {num_iters}") runner.run_loop(env, env_params, agent_pair, num_iters, watchers) + elif args.runner == "act_evo": + num_iters = args.num_generations # number of generations + print(f"Number of Generations: {num_iters}") + runner.run_loop(env_params, agent_pair, num_iters, watchers) + + elif args.runner == "act_rl": + num_iters = int( + args.total_timesteps / args.num_steps + ) # number of episodes + print(f"Number of Episodes: {num_iters}") + runner.run_loop(env_params, agent_pair, num_iters, watchers) + wandb.finish() diff --git a/pax/runner_act_evo.py b/pax/runner_act_evo.py new file mode 100644 index 00000000..bb42be1a --- /dev/null +++ b/pax/runner_act_evo.py @@ -0,0 +1,613 @@ +import os +import time +from datetime import datetime +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +from evosax import FitnessShaper + +import wandb +from pax.utils import MemoryState, TrainingState, save + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, cg_visitation, ipd_visitation + +MAX_WANDB_CALLS = 1000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class ActEvoRunner: + """ + Evoluationary Strategy runner provides a convenient example for quickly writing + a MARL runner for PAX. The EvoRunner class can be used to + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. + It composes together agents, watchers, and the environment. + Within the init, we declare vmaps and pmaps for training. + The environment provided must conform to a meta-environment. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + num_outer_steps = ( + 1 + if self.args.env_type == "sequential" + else self.args.num_steps // self.args.num_inner_steps + ) + + agent1, agent2 = agents + + # vmap agents accordingly + # agent 1 is batched over popsize and num_opps + agent1.batch_init = jax.vmap( + jax.vmap( + agent1.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + agent1.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent1.batch_policy = jax.jit( + jax.vmap( + jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), + ) + ) + + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit( + jax.vmap(jax.vmap(agent2.make_initial_state)) + ) + else: + agent2.batch_init = jax.jit( + jax.vmap( + jax.vmap(agent2.make_initial_state, (0, None), 0), + (0, None), + 0, + ) + ) + + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent2.batch_update = jax.jit( + jax.vmap( + jax.vmap(agent2.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + if args.agent2 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + key = jax.random.split( + agent2._state.random_key, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + key, + init_hidden, + ) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs1, + a1, + rewards * jnp.logical_not(done), + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = Sample( + obs2, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) + + obs1 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = agent1.meta_policy(a1_mem) + + a1, a1_state, a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + # update second agent + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) + return ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _rollout( + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs1, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] + + # Player 1 + _a1_state = _a1_state._replace(params=_params) + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs1) + + else: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split( + _rng_run, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1), + agent2._mem.hidden, + ) + + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + obs1, + *rewards, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ), + None, + length=num_outer_steps, + ) + + ( + rngs, + obs1, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Fitness + # fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) + # other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) + if args.adversary_type == "ally": + fitness = traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1) + elif args.adversary_type == "adversary": + fitness = -(traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1)) + other_fitness = traj_2.rewards.sum(axis=(0, 1, 3, 4))/(traj_2.dones.sum(axis=(0, 1, 3, 4))+1) + + # fitness = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1e-8) + # other_fitness = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1e-8) + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + + elif args.env_id in [ + "matrix_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + else: + env_stats = {} + rewards_1 = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1) + rewards_2 = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1) + return ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None), + ) + + def run_loop( + self, + env_params, + agents, + num_generations: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_generations / MAX_WANDB_CALLS, 5) + print(f"Number of Generations: {num_generations}") + print(f"Number of Meta Episodes: {num_generations}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + agent1, agent2 = agents + rng, _ = jax.random.split(self.random_key) + + # Initialize evolution + num_gens = num_generations + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_state = strategy.initialize(rng, es_params) + fit_shaper = FitnessShaper(maximize=True) + es_logging = ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + log = es_logging.initialize() + num_devices = self.args.num_devices + + # Reshape a single agent's params before vmapping + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) + agent1._state, agent1._mem = agent1.batch_init( + jax.random.split(agent1._state.random_key, popsize), + init_hidden, + ) + + a1_state, a1_mem = agent1._state, agent1._mem + + for gen in range(num_gens): + rng, rng_run, rng_gen, rng_key = jax.random.split(rng, 4) + + # Ask + x, evo_state = strategy.ask(rng_gen, evo_state, es_params) + params = param_reshaper.reshape(x) + if num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + # Evo Rollout + ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) + + # Reshape over devices + fitness = jnp.reshape(fitness, popsize * num_devices) + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Maximize fitness + fitness_re = fit_shaper.apply(x, fitness) + + # Tell + evo_state = strategy.tell( + x, fitness_re - fitness_re.mean(), evo_state, es_params + ) + # Logging + log = es_logging.update(log, x, fitness) + + # Saving + if self.args.save and gen % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"generation_{gen}") + if num_devices > 1: + top_params = param_reshaper.reshape( + log["top_gen_params"][0 : self.args.num_devices] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + log["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath) + if watchers: + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {gen} locally") + + if gen % log_interval == 0: + print(f"Generation: {gen}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" + ) + print( + f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print(f"Env Stats: {env_stats}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" + f" | Std: {log['log_top_gen_std'][gen]}" + ) + print( + "--------------------------------------------------------------------------" + ) + print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") + print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") + print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") + print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") + print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") + print() + + if watchers: + wandb_log = { + "generations": gen, + "train/fitness/player_1": float(fitness.mean()), + "train/fitness/player_2": float(other_fitness.mean()), + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/episode_reward/player_1": float(rewards_1.mean()), + "train/episode_reward/player_2": float(rewards_2.mean()), + } + wandb_log.update(env_stats) + # loop through population + for idx, (overall_fitness, gen_fitness) in enumerate( + zip(log["top_fitness"], log["top_gen_fitness"]) + ): + wandb_log[ + f"train/fitness/top_overall_agent_{idx+1}" + ] = overall_fitness + wandb_log[ + f"train/fitness/top_gen_agent_{idx+1}" + ] = gen_fitness + + # player 2 metrics + # metrics [outer_timesteps, num_opps] + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + + agent2._logger.metrics.update(flattened_metrics) + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb.log(wandb_log) + + return agents diff --git a/pax/runner_act_rl.py b/pax/runner_act_rl.py new file mode 100644 index 00000000..60bf79e1 --- /dev/null +++ b/pax/runner_act_rl.py @@ -0,0 +1,563 @@ +import os +import time +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp + +import wandb +from pax.utils import MemoryState, TrainingState, save +from pax.watchers import cg_visitation, ipd_visitation + +MAX_WANDB_CALLS = 1000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class MFOSSample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + meta_actions: jnp.ndarray + + +# @jax.jit +# def reduce_outer_traj(traj: Sample) -> Sample: +# """Used to collapse lax.scan outputs dims""" +# # x: [outer_loop, inner_loop, num_opps, num_envs ...] +# # x: [timestep, batch_size, ...] +# num_envs = traj.observations.shape[2] * traj.observations.shape[3] +# num_timesteps = traj.observations.shape[0] * traj.observations.shape[1] +# return jax.tree_util.tree_map( +# lambda x: x.reshape((num_timesteps, num_envs) + x.shape[6:]), +# traj, +# ) + +# @jax.jit +# def reduce_outer_traj_but_not_opp_dim_for_timon_debug(traj: Sample) -> Sample: +# """Used to collapse lax.scan outputs dims""" +# # x: [outer_loop, inner_loop, num_envs ...] +# # x: [timestep, batch_size, ...] +# num_envs = traj.observations.shape[2] +# num_timesteps = traj.observations.shape[0] * traj.observations.shape[1] +# return jax.tree_util.tree_map( +# lambda x: x.reshape((num_timesteps, num_envs) + x.shape[3:]), +# traj, +# ) + + +class ActRLRunner: + """ + Reinforcement Learning runner provides a convenient example for quickly writing + a MARL runner for PAX. The MARLRunner class can be used to + run any two RL agents together either in a meta-game or regular game, it composes together agents, + watchers, and the environment. Within the init, we declare vmaps and pmaps for training. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The environment that the agents will run in. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + def __init__(self, agents, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.num_opps = args.num_opps + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + + def _reshape_opp_dim(x): + # x: [num_opps, num_envs ...] + # x: [batch_size, ...] + batch_size = args.num_envs * args.num_opps + return jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,)+ x.shape[2:]), x + ) + + # self.reduce_opp_dim = jax.jit(_reshape_opp_dim) + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(cg_visitation) + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.jit(jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + )) + + # VMAP for num opps: we vmap over the rng but not params + # env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + # env.step = jax.jit( + # jax.vmap( + # env.step, (0, 0, 0, None), 0 # rng, state, actions, params + # ) + # ) + + # self.split = jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)) + self.split = jax.vmap(jax.random.split, (0, None)) + # self.split = jax.random.split + num_outer_steps = ( + 1 + if self.args.env_type == "sequential" + else self.args.num_steps // self.args.num_inner_steps + ) + + agent1, agent2 = agents + + # set up agents + if args.agent1 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent1.batch_init = jax.jit(jax.vmap(agent1.make_initial_state)) + else: + # batch MemoryState not TrainingState + # agent1.batch_init = jax.vmap( + # agent1.make_initial_state, + # (None, 0), + # (None, 0), + # ) + agent1.batch_init = jax.jit(agent1.make_initial_state) + # agent1.batch_reset = jax.jit( + # jax.vmap(agent1.reset_memory, (0, None), 0), static_argnums=1 + # ) + + # agent1.batch_policy = jax.jit( + # jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)) + # ) + + # removed opps dim + agent1.batch_reset = jax.jit( + agent1.reset_memory, static_argnums=1) + + agent1.batch_policy = jax.jit( + agent1._policy) + + + # batch all for Agent2 + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit(jax.vmap(agent2.make_initial_state)) + else: + agent2.batch_init = jax.vmap( + agent2.make_initial_state, (0, None), 0 + ) + # agent2.batch_init = jax.jit( + # agent2.make_initial_state + # ) + # agent2.batch_policy = jax.jit(jax.vmap(agent2._policy)) + agent2.batch_reset = jax.jit( + jax.vmap(agent2.reset_memory, (0, None), 0), static_argnums=1 + ) + agent2.batch_update = jax.jit(jax.vmap(agent2.update, (1, 0, 0, 0), 0)) + agent2.batch_policy = jax.jit(agent2._policy) + # agent2.batch_reset = jax.jit(agent2.reset_memory, static_argnums=1 + # ) + # agent2.batch_update = agent2.update + + if args.agent1 != "NaiveEx": + # NaiveEx requires env first step to init. + # init_hidden = jnp.tile(agent1._mem.hidden, (args.num_opps, 1, 1)) + init_hidden = jnp.tile(agent1._mem.hidden, (1)) + agent1._state, agent1._mem = agent1.batch_init( + agent1._state.random_key, init_hidden + ) + + if args.agent2 != "NaiveEx": + # NaiveEx requires env first step to init. + # init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + init_hidden = jnp.tile(agent1._mem.hidden, (1)) + agent2._state, agent2._mem = agent2.batch_init( + jax.random.split(agent2._state.random_key, args.num_opps), + init_hidden, + ) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 2) + env_rng = rngs[:, 0, :] + # a1_rng = rngs[:, :, 1, :] + # a2_rng = rngs[:, :, 2, :] + rngs = rngs[:, 1, :] + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs, + a2_mem, + ) + + # obs2 = jnp.concatenate([obs1, a1], axis=-1) + # obs2 = obs1 + a1 = a2 + new_a1_mem = a1_mem + # a2, a2_state, new_a2_mem = agent2.batch_policy( + # a2_state, + # obs2, + # a2_mem, + # ) + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs, + a1, + rewards * jnp.logical_not(done), + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = Sample( + obs, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) + + obs = next_obs + # obs2 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs, + obs, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=self.args.num_inner_steps, + ) + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + + # a1, a1_state, a1_mem = agent1.batch_policy( + # a1_state, + # obs1, + # a1_mem, + # ) + + # obs2 = jnp.concatenate([obs1, a1], axis=-1) + + # update second agent + # jax.debug.breakpoint() + # traj_2.rewards = traj_2.rewards.squeeze(1) + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) + # a2_metrics = {} + return ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _rollout( + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _a2_state: TrainingState, + _a2_mem: MemoryState, + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] * args.num_opps + ).reshape((args.num_envs, -1)) #args.num_opps, + + obs, env_state = env.reset(rngs, _env_params) + obs1 = obs + obs2 = obs + + rewards = [ + jnp.zeros((args.num_envs)), #args.num_opps, + jnp.zeros((args.num_envs)), #args.num_opps, + ] + # Player 1 + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Player 2 + if args.agent1 == "NaiveEx": + _a1_state, _a1_mem = agent1.batch_init(obs[0]) + + if args.agent2 == "NaiveEx": + _a2_state, _a2_mem = agent2.batch_init(obs[1]) + + elif self.args.env_type in ["meta"]: + # meta-experiments - init 2nd agent per trial + _a2_state, _a2_mem = agent2.batch_init( + jax.random.split(_rng_run), _a2_mem.hidden #, args.num_opps + ) + # run trials + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + obs1, + obs2, + *rewards, + _a1_state, + _a1_mem, + _a2_state, + _a2_mem, + env_state, + _env_params, + ), + None, + length=num_outer_steps, + ) + + ( + rngs, + obs1, + obs2, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # update outer agent + # print(reduce_outer_traj(traj_1).dones.shape) + # print(traj_1.dones.shape) + traj_1 = jax.tree_util.tree_map( + lambda x: x.squeeze(0), traj_1 + ) + # print(a1_mem.extras['values'].shape) + # print(self.reduce_opp_dim(a1_mem).extras['values'].shape) + # print(traj_1.dones.shape) + # a1_state, _, a1_metrics = agent1.update( + # reduce_outer_traj(traj_1), + # self.reduce_opp_dim(obs1), + # a1_state, + # self.reduce_opp_dim(a1_mem), + # ) + + a1_state, _, a1_metrics = agent1.update( + traj_1, + obs1, + a1_state, + a1_mem, + ) + + # reset memory + a1_mem = agent1.batch_reset(a1_mem, False) + a2_mem = agent2.batch_reset(a2_mem, False) + + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + + elif args.env_id == "iterated_matrix_game": + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + else: + env_stats = {} + # rewards_1 = traj_1.rewards.mean() + # rewards_2 = traj_2.rewards.mean() + rewards_1 = (traj_1.rewards.sum(axis=0)/(traj_1.dones.sum(axis=0)+1)).mean() + rewards_2 = (traj_2.rewards[-1].sum(axis=0)/(traj_2.dones[-1].sum(axis=0)+1)).mean() + traj_1.dones[-1].sum(axis=0).mean() + + return ( + env_stats, + rewards_1, + rewards_2, + a1_state, + a1_mem, + a1_metrics, + a2_state, + a2_mem, + a2_metrics, + ) + + self.rollout = _rollout + # self.rollout = jax.jit(_rollout) + + def run_loop(self, env_params, agents, num_iters, watchers): + """Run training of agents in environment""" + print("Training") + print("-----------------------") + agent1, agent2 = agents + rng, _ = jax.random.split(self.random_key) + + a1_state, a1_mem = agent1._state, agent1._mem + a2_state, a2_mem = agent2._state, agent2._mem + + num_iters = max( + int(num_iters / (self.args.num_envs * self.num_opps)), 1 + ) + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + + print(f"Log Interval {log_interval}") + + # run actual loop + for i in range(num_iters): + rng, rng_run = jax.random.split(rng, 2) + # RL Rollout + ( + env_stats, + rewards_1, + rewards_2, + a1_state, + a1_mem, + a1_metrics, + a2_state, + a2_mem, + a2_metrics, + ) = self.rollout( + rng_run, a1_state, a1_mem, a2_state, a2_mem, env_params + ) + + if self.args.save and i % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"iteration_{i}") + save(a1_state.params, log_savepath) + if watchers: + print(f"Saving iteration {i} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {i} locally") + + # logging + self.train_episodes += 1 + if i % log_interval == 0: + print(f"Episode {i}") + + print(f"Env Stats: {env_stats}") + print( + f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print() + + if watchers: + # metrics [outer_timesteps] + flattened_metrics_1 = jax.tree_util.tree_map( + lambda x: jnp.mean(x), a1_metrics + ) + agent1._logger.metrics = ( + agent1._logger.metrics | flattened_metrics_1 + ) + # metrics [outer_timesteps, num_opps] + flattened_metrics_2 = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + agent2._logger.metrics = ( + agent2._logger.metrics | flattened_metrics_2 + ) + + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb.log( + { + "episodes": self.train_episodes, + "train/episode_reward/player_1": float( + rewards_1.mean() + ), + "train/episode_reward/player_2": float( + rewards_2.mean() + ), + } + | env_stats, + ) + + agents[0]._state = a1_state + agents[1]._state = a2_state + return agents diff --git a/pax/runner_act_testtime.py b/pax/runner_act_testtime.py new file mode 100644 index 00000000..871fd035 --- /dev/null +++ b/pax/runner_act_testtime.py @@ -0,0 +1,806 @@ +import os +import time +from datetime import datetime +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +from evosax import FitnessShaper + +import wandb +from pax.utils import MemoryState, TrainingState, save + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, cg_visitation, ipd_visitation + +MAX_WANDB_CALLS = 1000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + behavior_synq: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class ActRunner: + """ + Evoluationary Strategy runner provides a convenient example for quickly writing + a MARL runner for PAX. The EvoRunner class can be used to + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. + It composes together agents, watchers, and the environment. + Within the init, we declare vmaps and pmaps for training. + The environment provided must conform to a meta-environment. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + num_outer_steps = ( + 1 + if self.args.env_type == "sequential" + else self.args.num_steps // self.args.num_inner_steps + ) + + agent1, agent2 = agents + + # vmap agents accordingly + # agent 1 is batched over popsize and num_opps + agent1.batch_init = jax.vmap( + jax.vmap( + agent1.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + agent1.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent1.batch_policy = jax.jit( + jax.vmap( + jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), + ) + ) + + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit( + jax.vmap(jax.vmap(agent2.make_initial_state)) + ) + else: + agent2.batch_init = jax.jit( + jax.vmap( + jax.vmap(agent2.make_initial_state, (0, None), 0), + (0, None), + 0, + ) + ) + + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_synq_value + agent2.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent2.batch_update = jax.jit( + jax.vmap( + jax.vmap(agent2.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + if args.agent2 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + key = jax.random.split( + agent2._state.random_key, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + key, + init_hidden, + ) + + def _inner_rollout_train(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs1, + a1, + rewards, + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = Sample( + obs2, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) + + obs1 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout_train(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = agent1.meta_policy(a1_mem) + + + a1, a1_state, a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + # update second agent + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) + return ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _inner_rollout_test(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs1, + a1, + rewards, + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = Sample( + obs2, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + done, + a2_mem.hidden, + ) + + obs1 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout_test(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = agent1.meta_policy(a1_mem) + + + a1, a1_state, a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + # update second agent + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) + return ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _rollout( + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs1, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] + + # Player 1 + _a1_state = _a1_state._replace(params=_params) + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs1) + + else: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split( + _rng_run, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1), + agent2._mem.hidden, + ) + + # Train Episodes + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + obs1, + *rewards, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ), + None, + length=num_outer_steps, + ) + + ( + rngs, + obs1, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Test Episodes + # Reset Environment and Returns for Test Episodes + obs1, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] + + # Reset Shaper's memory + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Reinitialize Player 2 for Test Episodes + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs1) + + else: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split( + _rng_run, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1), + agent2._mem.hidden, + ) + + vals, stack = jax.lax.scan( + _outer_rollout, + ( + rngs, + obs1, + *rewards, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ), + None, + length=1, + ) + + ( + rngs, + obs1, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Fitness + # fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) + # other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) + if args.adversary_type == "ally": + fitness = traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1e-8) + elif args.adversary_type == "adversary": + fitness = -(traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1e-8)) + other_fitness = traj_2.rewards.sum(axis=(0, 1, 3, 4))/(traj_2.dones.sum(axis=(0, 1, 3, 4))+1e-8) + # fitness = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1e-8) + # other_fitness = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1e-8) + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + + elif args.env_id in [ + "matrix_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + else: + env_stats = {} + rewards_1 = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1e-8) + rewards_2 = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1e-8) + return ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None), + ) + + def run_loop( + self, + env_params, + agents, + num_generations: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_generations / MAX_WANDB_CALLS, 5) + print(f"Number of Generations: {num_generations}") + print(f"Number of Meta Episodes: {num_generations}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + agent1, agent2 = agents + rng, _ = jax.random.split(self.random_key) + + # Initialize evolution + num_gens = num_generations + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_state = strategy.initialize(rng, es_params) + fit_shaper = FitnessShaper(maximize=True) + es_logging = ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + log = es_logging.initialize() + num_devices = self.args.num_devices + + # Reshape a single agent's params before vmapping + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) + agent1._state, agent1._mem = agent1.batch_init( + jax.random.split(agent1._state.random_key, popsize), + init_hidden, + ) + + a1_state, a1_mem = agent1._state, agent1._mem + + for gen in range(num_gens): + rng, rng_run, rng_gen, rng_key = jax.random.split(rng, 4) + + # Ask + x, evo_state = strategy.ask(rng_gen, evo_state, es_params) + params = param_reshaper.reshape(x) + if num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + # Evo Train and Test Rollout + ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) + + # Reshape over devices + fitness = jnp.reshape(fitness, popsize * num_devices) + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Maximize fitness + fitness_re = fit_shaper.apply(x, fitness) + + # Tell + evo_state = strategy.tell( + x, fitness_re - fitness_re.mean(), evo_state, es_params + ) + # Logging + log = es_logging.update(log, x, fitness) + + # Saving + if self.args.save and gen % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"generation_{gen}") + if num_devices > 1: + top_params = param_reshaper.reshape( + log["top_gen_params"][0 : self.args.num_devices] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + log["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath) + if watchers: + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {gen} locally") + + if gen % log_interval == 0: + print(f"Generation: {gen}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" + ) + print( + f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print(f"Env Stats: {env_stats}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" + f" | Std: {log['log_top_gen_std'][gen]}" + ) + print( + "--------------------------------------------------------------------------" + ) + print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") + print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") + print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") + print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") + print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") + print() + + if watchers: + wandb_log = { + "generations": gen, + "train/fitness/player_1": float(fitness.mean()), + "train/fitness/player_2": float(other_fitness.mean()), + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/episode_reward/player_1": float(rewards_1.mean()), + "train/episode_reward/player_2": float(rewards_2.mean()), + } + wandb_log.update(env_stats) + # loop through population + for idx, (overall_fitness, gen_fitness) in enumerate( + zip(log["top_fitness"], log["top_gen_fitness"]) + ): + wandb_log[ + f"train/fitness/top_overall_agent_{idx+1}" + ] = overall_fitness + wandb_log[ + f"train/fitness/top_gen_agent_{idx+1}" + ] = gen_fitness + + # player 2 metrics + # metrics [outer_timesteps, num_opps] + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + + agent2._logger.metrics.update(flattened_metrics) + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb.log(wandb_log) + + return agents diff --git a/pax/runner_sarl.py b/pax/runner_sarl.py index a05869e0..5152cba7 100644 --- a/pax/runner_sarl.py +++ b/pax/runner_sarl.py @@ -97,7 +97,6 @@ def _inner_rollout(carry, unused): a1, env_params, ) - traj1 = Sample( obs, a1, @@ -134,7 +133,7 @@ def _rollout( # run trials vals, traj = jax.lax.scan( _inner_rollout, - ( + ( rngs, obs, _a1_state, @@ -167,7 +166,7 @@ def _rollout( _a1_mem = agent.batch_reset(_a1_mem, False) # Stats - rewards = jnp.sum(traj.rewards)/(jnp.sum(traj.dones)+1e-8) + rewards = jnp.sum(traj.rewards,axis=0) / (jnp.sum(traj.dones,axis=0) + 1) env_stats = {} return ( @@ -218,9 +217,8 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): # logging self.train_episodes += 1 - if num_iters % log_interval == 0: + if i % log_interval == 0: print(f"Episode {i}") - print(f"Env Stats: {env_stats}") print(f"Total Episode Reward: {float(rewards_1.mean())}") print() @@ -246,4 +244,4 @@ def run_loop(self, env, env_params, agent, num_iters, watcher): ) agent._state = a1_state - return agent \ No newline at end of file + return agent diff --git a/pax/runner_sarl_eval.py b/pax/runner_sarl_eval.py new file mode 100644 index 00000000..2c7fb2de --- /dev/null +++ b/pax/runner_sarl_eval.py @@ -0,0 +1,306 @@ +import os +import time +from typing import Any, NamedTuple +import sys + +import jax +import jax.numpy as jnp +import wandb + +from gymnax.visualize import Visualizer +from gymnax.environments.classic_control.cartpole import ( + EnvState as EnvStateCartPole, +) +from gymnax.environments.classic_control.acrobot import ( + EnvState as EnvStateAcrobot, +) + +from pax.watchers import cg_visitation, ipd_visitation +from pax.utils import MemoryState, TrainingState, save, load + +from jax.config import config + +MAX_WANDB_CALLS = 1000000 + + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + env_state: jnp.ndarray + + +class SARLEvalRunner: + """Holds the runner's state.""" + + def __init__(self, agent, env, save_dir, args): + self.train_steps = 0 + self.train_episodes = 0 + self.start_time = time.time() + self.args = args + self.random_key = jax.random.PRNGKey(args.seed) + self.save_dir = save_dir + self.run_path = args.run_path + self.model_path = args.model_path + + # VMAP for num envs: we vmap over the rng but not params + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + + self.split = jax.vmap(jax.random.split, (0, None)) + # set up agent + if args.agent1 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent.batch_init = jax.jit(jax.vmap(agent.make_initial_state)) + else: + # batch MemoryState not TrainingState + agent.batch_init = jax.jit(agent.make_initial_state) + + agent.batch_reset = jax.jit(agent.reset_memory, static_argnums=1) + + agent.batch_policy = jax.jit(agent._policy) + + if args.agent1 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent._mem.hidden, (1)) + agent._state, agent._mem = agent.batch_init( + agent._state.random_key, init_hidden + ) + + def _inner_rollout(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs, + a1_state, + a1_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + # import pdb; pdb.set_trace() + rngs = self.split(rngs, 2) + env_rng = rngs[:, 0, :] + # a1_rng = rngs[:, 1, :] + # a2_rng = rngs[:, 2, :] + rngs = rngs[:, 1, :] + + a1, a1_state, new_a1_mem = agent.batch_policy( + a1_state, + obs, + a1_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a1, + env_params, + ) + + traj1 = Sample( + obs, + a1, + rewards * jnp.logical_not(done), + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + env_state, + ) + + return ( + rngs, + next_obs, # next_obs + a1_state, + new_a1_mem, + env_state, + env_params, + ), traj1 + + def _rollout( + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + ).reshape((args.num_envs, -1)) + + obs, env_state = env.reset(rngs, _env_params) + _a1_mem = agent.batch_reset(_a1_mem, False) + + # run trials + vals, traj = jax.lax.scan( + _inner_rollout, + ( + rngs, + obs, + _a1_state, + _a1_mem, + env_state, + _env_params, + ), + None, + length=args.num_steps, + ) + + ( + rngs, + obs, + _a1_state, + _a1_mem, + env_state, + env_params, + ) = vals + + # update outer agent + _a1_state, _, _a1_metrics = agent.update( + traj, + obs, + _a1_state, + _a1_mem, + ) + + # reset memory + _a1_mem = agent.batch_reset(_a1_mem, False) + + # Stats + rewards = jnp.sum(traj.rewards) / (jnp.sum(traj.dones) + 1e-8) + env_stats = {} + + return ( + env_stats, + rewards, + _a1_state, + _a1_mem, + _a1_metrics, + ), traj + + self.rollout = _rollout + # self.rollout = jax.jit(_rollout) + + def run_loop(self, env, env_params, agent, num_iters, watchers): + """Run training of agent in environment""" + print("Training") + print("-----------------------") + config.update("jax_disable_jit", True) + agent = agent + rng, _ = jax.random.split(self.random_key) + + a1_state, a1_mem = agent._state, agent._mem + + if watchers: + wandb.restore( + name=self.model_path, run_path=self.run_path, root=os.getcwd() + ) + pretrained_params = load(self.model_path) + a1_state = a1_state._replace(params=pretrained_params) + + num_iters = max(int(num_iters / (self.args.num_envs)), 1) + log_interval = max(num_iters / MAX_WANDB_CALLS, 5) + + print(f"Log Interval {log_interval}") + print(f"Running for total iterations: {num_iters}") + # run actual loop + for i in range(num_iters): + rng, rng_run = jax.random.split(rng, 2) + # RL Rollout + ( + env_stats, + rewards_1, + a1_state, + a1_mem, + a1_metrics, + ), traj = self.rollout(rng_run, a1_state, a1_mem, env_params) + if self.args.env_id == "CartPole-v1": + obs_traj = [ + EnvStateCartPole( + x=traj.env_state.x[i], + x_dot=traj.env_state.x_dot[i], + theta=traj.env_state.theta[i], + theta_dot=traj.env_state.theta_dot[i], + time=i, + ) + for i in range(self.args.num_steps) + ] + elif self.args.env_id == "Acrobot-v1": + obs_traj = [ + EnvStateAcrobot( + joint_angle1=traj.env_state.joint_angle1[i], + joint_angle2=traj.env_state.joint_angle2[i], + velocity_1=traj.env_state.velocity_1[i], + velocity_2=traj.env_state.velocity_2[i], + time=i, + ) + for i in range(self.args.num_steps) + ] + + vis = Visualizer( + env, env_params, obs_traj, jnp.cumsum(traj.rewards) + ) + vis.animate(f"pax/vis/{self.args.env_id}.gif") + + if self.args.save and i % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"iteration_{i}") + save(a1_state.params, log_savepath) + if watchers: + print(f"Saving iteration {i} locally and to WandB") + wandb.save(log_savepath) + + else: + print(f"Saving iteration {i} locally") + + # logging + self.train_episodes += 1 + if i % log_interval == 0: + print(f"Episode {i}") + + print(f"Env Stats: {env_stats}") + print(f"Total Episode Reward: {float(rewards_1.mean())}") + print() + + if watchers: + # metrics [outer_timesteps] + flattened_metrics_1 = jax.tree_util.tree_map( + lambda x: jnp.mean(x), a1_metrics + ) + agent._logger.metrics = ( + agent._logger.metrics | flattened_metrics_1 + ) + + watchers(agent) + wandb.log( + { + "episodes": self.train_episodes, + "train/episode_reward/player_1": float( + rewards_1.mean() + ), + } + | env_stats, + ) + wandb.log( + { + "video": wandb.Video( + f"pax/vis/{self.args.env_id}.gif", + fps=4, + format="gif", + ) + } + ) + + agent._state = a1_state + return agent diff --git a/pax/runner_synq.py b/pax/runner_synq.py new file mode 100644 index 00000000..c56d60d6 --- /dev/null +++ b/pax/runner_synq.py @@ -0,0 +1,814 @@ +import os +import time +from datetime import datetime +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +from evosax import FitnessShaper + +import wandb +from pax.utils import MemoryState, TrainingState, save + +# TODO: import when evosax library is updated +# from evosax.utils import ESLog +from pax.watchers import ESLog, cg_visitation, ipd_visitation + +MAX_WANDB_CALLS = 1000 + +class Sample(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + +class SampleSynq(NamedTuple): + """Object containing a batch of data""" + + observations: jnp.ndarray + actions: jnp.ndarray + rewards: jnp.ndarray + behavior_log_probs: jnp.ndarray + behavior_values: jnp.ndarray + behavior_synq: jnp.ndarray + target_synq: jnp.ndarray + dones: jnp.ndarray + hiddens: jnp.ndarray + + +class ActRunner: + """ + Evoluationary Strategy runner provides a convenient example for quickly writing + a MARL runner for PAX. The EvoRunner class can be used to + run an RL agent (optimised by an Evolutionary Strategy) against an Reinforcement Learner. + It composes together agents, watchers, and the environment. + Within the init, we declare vmaps and pmaps for training. + The environment provided must conform to a meta-environment. + Args: + agents (Tuple[agents]): + The set of agents that will run in the experiment. Note, ordering is + important for logic used in the class. + env (gymnax.envs.Environment): + The meta-environment that the agents will run in. + strategy (evosax.Strategy): + The evolutionary strategy that will be used to train the agents. + param_reshaper (evosax.param_reshaper.ParameterReshaper): + A function that reshapes the parameters of the agents into a format that can be + used by the strategy. + save_dir (string): + The directory to save the model to. + args (NamedTuple): + A tuple of experiment arguments used (usually provided by HydraConfig). + """ + + def __init__( + self, agents, env, strategy, es_params, param_reshaper, save_dir, args + ): + self.args = args + self.algo = args.es.algo + self.es_params = es_params + self.generations = 0 + self.num_opps = args.num_opps + self.param_reshaper = param_reshaper + self.popsize = args.popsize + self.random_key = jax.random.PRNGKey(args.seed) + self.start_datetime = datetime.now() + self.save_dir = save_dir + self.start_time = time.time() + self.strategy = strategy + self.top_k = args.top_k + self.train_steps = 0 + self.train_episodes = 0 + self.ipd_stats = jax.jit(ipd_visitation) + self.cg_stats = jax.jit(jax.vmap(cg_visitation)) + + # Evo Runner has 3 vmap dims (popsize, num_opps, num_envs) + # Evo Runner also has an additional pmap dim (num_devices, ...) + # For the env we vmap over the rng but not params + + # num envs + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + + # num opps + env.reset = jax.vmap(env.reset, (0, None), 0) + env.step = jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + # pop size + env.reset = jax.jit(jax.vmap(env.reset, (0, None), 0)) + env.step = jax.jit( + jax.vmap( + env.step, (0, 0, 0, None), 0 # rng, state, actions, params + ) + ) + self.split = jax.vmap( + jax.vmap(jax.vmap(jax.random.split, (0, None)), (0, None)), + (0, None), + ) + + num_outer_steps = ( + 1 + if self.args.env_type == "sequential" + else self.args.num_steps // self.args.num_inner_steps + ) + + agent1, agent2 = agents + + # vmap agents accordingly + # agent 1 is batched over popsize and num_opps + agent1.batch_init = jax.vmap( + jax.vmap( + agent1.make_initial_state, + (None, 0), # (params, rng) + (None, 0), # (TrainingState, MemoryState) + ), + # both for Population + ) + agent1.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent1.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent1.batch_policy = jax.jit( + jax.vmap( + jax.vmap(agent1._policy, (None, 0, 0), (0, None, 0)), + ) + ) + + if args.agent2 == "NaiveEx": + # special case where NaiveEx has a different call signature + agent2.batch_init = jax.jit( + jax.vmap(jax.vmap(agent2.make_initial_state)) + ) + else: + agent2.batch_init = jax.jit( + jax.vmap( + jax.vmap(agent2.make_initial_state, (0, None), 0), + (0, None), + 0, + ) + ) + + agent2.batch_policy = jax.jit(jax.vmap(jax.vmap(agent2._policy, 0, 0))) + agent2.batch_synq_value = jax.jit(jax.vmap(jax.vmap(agent2._synq, 0, 0))) + agent2.batch_reset = jax.jit( + jax.vmap( + jax.vmap(agent2.reset_memory, (0, None), 0), (0, None), 0 + ), + static_argnums=1, + ) + + agent2.batch_update = jax.jit( + jax.vmap( + jax.vmap(agent2.update, (1, 0, 0, 0)), + (1, 0, 0, 0), + ) + ) + if args.agent2 != "NaiveEx": + # NaiveEx requires env first step to init. + init_hidden = jnp.tile(agent2._mem.hidden, (args.num_opps, 1, 1)) + + key = jax.random.split( + agent2._state.random_key, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1) + + agent2._state, agent2._mem = agent2.batch_init( + key, + init_hidden, + ) + + def _inner_rollout_train(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + + a1, a1_state, new_a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + a2_mem, + ) + + synq2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + new_a2_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs1, + a1, + rewards, + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + + traj2 = SampleSynq( + obs2, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + synq2, + a1, + done, + a2_mem.hidden, + ) + + obs1 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout_train(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout_train, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = agent1.meta_policy(a1_mem) + + + a1, a1_state, a1_mem = agent1.batch_policy( + a1_state, + obs1, + a1_mem, + ) + + obs2 = jnp.concatenate([obs1, a1], axis=-1) + # update second agent + a2_state, a2_mem, a2_metrics = agent2.batch_update( + trajectories[1], + obs2, + a2_state, + a2_mem, + ) + return ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _inner_rollout_test(carry, unused): + """Runner for inner episode""" + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = carry + + # unpack rngs + rngs = self.split(rngs, 4) + env_rng = rngs[:, :, :, 0, :] + # a1_rng = rngs[:, :, :, 1, :] + # a2_rng = rngs[:, :, :, 2, :] + rngs = rngs[:, :, :, 3, :] + + synq2, a2_state, new_a2_mem = agent2.batch_synq_value( + a2_state, + obs2, + a2_mem, + ) + + obs2 = jnp.concatenate([obs1, synq2], axis=-1) + + a2, a2_state, new_a2_mem = agent2.batch_policy( + a2_state, + obs2, + new_a2_mem, + ) + + next_obs, env_state, rewards, done, info = env.step( + env_rng, + env_state, + a2, + env_params, + ) + + traj1 = Sample( + obs1, + a1, + rewards, + new_a1_mem.extras["log_probs"], + new_a1_mem.extras["values"], + done, + a1_mem.hidden, + ) + traj2 = SampleSynq( + obs2, + a2, + rewards * jnp.logical_not(done), + new_a2_mem.extras["log_probs"], + new_a2_mem.extras["values"], + synq2, + synq2, + done, + a2_mem.hidden, + ) + + obs1 = next_obs + r1 = rewards + r2 = rewards + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + new_a1_mem, + a2_state, + new_a2_mem, + env_state, + env_params, + ), ( + traj1, + traj2, + ) + + def _outer_rollout_test(carry, unused): + """Runner for trial""" + # play episode of the game + vals, trajectories = jax.lax.scan( + _inner_rollout_test, + carry, + None, + length=args.num_inner_steps, + ) + ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ) = vals + # MFOS has to take a meta-action for each episode + if args.agent1 == "MFOS": + a1_mem = agent1.meta_policy(a1_mem) + + + return ( + rngs, + obs1, + r1, + r2, + a1_state, + a1_mem, + a2_state, + a2_mem, + env_state, + env_params, + ), (*trajectories, a2_metrics) + + def _rollout( + _params: jnp.ndarray, + _rng_run: jnp.ndarray, + _a1_state: TrainingState, + _a1_mem: MemoryState, + _env_params: Any, + ): + # env reset + rngs = jnp.concatenate( + [jax.random.split(_rng_run, args.num_envs)] + * args.num_opps + * args.popsize + ).reshape((args.popsize, args.num_opps, args.num_envs, -1)) + + obs1, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] + + # Player 1 + _a1_state = _a1_state._replace(params=_params) + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs1) + + else: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split( + _rng_run, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1), + agent2._mem.hidden, + ) + + # Train Episodes + vals, stack = jax.lax.scan( + _outer_rollout_train, + ( + rngs, + obs1, + *rewards, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ), + None, + length=num_outer_steps, + ) + + ( + rngs, + obs1, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Test Episodes + # Reset Environment and Returns for Test Episodes + obs1, env_state = env.reset(rngs, _env_params) + rewards = [ + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + jnp.zeros((args.popsize, args.num_opps, args.num_envs)), + ] + + # Reset Shaper's memory + _a1_mem = agent1.batch_reset(_a1_mem, False) + + # Reinitialize Player 2 for Test Episodes + # Player 2 + if args.agent2 == "NaiveEx": + a2_state, a2_mem = agent2.batch_init(obs1) + + else: + # meta-experiments - init 2nd agent per trial + a2_state, a2_mem = agent2.batch_init( + jax.random.split( + _rng_run, args.popsize * args.num_opps + ).reshape(args.popsize, args.num_opps, -1), + agent2._mem.hidden, + ) + + vals, stack = jax.lax.scan( + _outer_rollout_test, + ( + rngs, + obs1, + *rewards, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ), + None, + length=1, + ) + + ( + rngs, + obs1, + r1, + r2, + _a1_state, + _a1_mem, + a2_state, + a2_mem, + env_state, + _env_params, + ) = vals + traj_1, traj_2, a2_metrics = stack + + # Fitness + # fitness = traj_1.rewards.mean(axis=(0, 1, 3, 4)) + # other_fitness = traj_2.rewards.mean(axis=(0, 1, 3, 4)) + if args.adversary_type == "ally": + fitness = traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1e-8) + elif args.adversary_type == "adversary": + fitness = -(traj_1.rewards.sum(axis=(0, 1, 3, 4))/(traj_1.dones.sum(axis=(0, 1, 3, 4))+1e-8)) + other_fitness = traj_2.rewards.sum(axis=(0, 1, 3, 4))/(traj_2.dones.sum(axis=(0, 1, 3, 4))+1e-8) + # fitness = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1e-8) + # other_fitness = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1e-8) + # Stats + if args.env_id == "coin_game": + env_stats = jax.tree_util.tree_map( + lambda x: x, + self.cg_stats(env_state), + ) + + rewards_1 = traj_1.rewards.sum(axis=1).mean() + rewards_2 = traj_2.rewards.sum(axis=1).mean() + + elif args.env_id in [ + "matrix_game", + ]: + env_stats = jax.tree_util.tree_map( + lambda x: x.mean(), + self.ipd_stats( + traj_1.observations, + traj_1.actions, + obs1, + ), + ) + rewards_1 = traj_1.rewards.mean() + rewards_2 = traj_2.rewards.mean() + else: + env_stats = {} + rewards_1 = jnp.sum(traj_1.rewards)/(jnp.sum(traj_1.dones)+1e-8) + rewards_2 = jnp.sum(traj_2.rewards)/(jnp.sum(traj_2.dones)+1e-8) + return ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) + + self.rollout = jax.pmap( + _rollout, + in_axes=(0, None, None, None, None), + ) + + def run_loop( + self, + env_params, + agents, + num_generations: int, + watchers: Callable, + ): + """Run training of agents in environment""" + print("Training") + print("------------------------------") + log_interval = max(num_generations / MAX_WANDB_CALLS, 5) + print(f"Number of Generations: {num_generations}") + print(f"Number of Meta Episodes: {num_generations}") + print(f"Population Size: {self.popsize}") + print(f"Number of Environments: {self.args.num_envs}") + print(f"Number of Opponent: {self.args.num_opps}") + print(f"Log Interval: {log_interval}") + print("------------------------------") + # Initialize agents and RNG + agent1, agent2 = agents + rng, _ = jax.random.split(self.random_key) + + # Initialize evolution + num_gens = num_generations + strategy = self.strategy + es_params = self.es_params + param_reshaper = self.param_reshaper + popsize = self.popsize + num_opps = self.num_opps + evo_state = strategy.initialize(rng, es_params) + fit_shaper = FitnessShaper(maximize=True) + es_logging = ESLog( + param_reshaper.total_params, + num_gens, + top_k=self.top_k, + maximize=True, + ) + log = es_logging.initialize() + num_devices = self.args.num_devices + + # Reshape a single agent's params before vmapping + init_hidden = jnp.tile( + agent1._mem.hidden, + (popsize, num_opps, 1, 1), + ) + agent1._state, agent1._mem = agent1.batch_init( + jax.random.split(agent1._state.random_key, popsize), + init_hidden, + ) + + a1_state, a1_mem = agent1._state, agent1._mem + + for gen in range(num_gens): + rng, rng_run, rng_gen, rng_key = jax.random.split(rng, 4) + + # Ask + x, evo_state = strategy.ask(rng_gen, evo_state, es_params) + params = param_reshaper.reshape(x) + if num_devices == 1: + params = jax.tree_util.tree_map( + lambda x: jax.lax.expand_dims(x, (0,)), params + ) + # Evo Train and Test Rollout + ( + fitness, + other_fitness, + env_stats, + rewards_1, + rewards_2, + a2_metrics, + ) = self.rollout(params, rng_run, a1_state, a1_mem, env_params) + + # Reshape over devices + fitness = jnp.reshape(fitness, popsize * num_devices) + env_stats = jax.tree_util.tree_map(lambda x: x.mean(), env_stats) + + # Maximize fitness + fitness_re = fit_shaper.apply(x, fitness) + + # Tell + evo_state = strategy.tell( + x, fitness_re - fitness_re.mean(), evo_state, es_params + ) + # Logging + log = es_logging.update(log, x, fitness) + + # Saving + if self.args.save and gen % self.args.save_interval == 0: + log_savepath = os.path.join(self.save_dir, f"generation_{gen}") + if num_devices > 1: + top_params = param_reshaper.reshape( + log["top_gen_params"][0 : self.args.num_devices] + ) + top_params = jax.tree_util.tree_map( + lambda x: x[0].reshape(x[0].shape[1:]), top_params + ) + else: + top_params = param_reshaper.reshape( + log["top_gen_params"][0:1] + ) + top_params = jax.tree_util.tree_map( + lambda x: x.reshape(x.shape[1:]), top_params + ) + save(top_params, log_savepath) + if watchers: + print(f"Saving generation {gen} locally and to WandB") + wandb.save(log_savepath) + else: + print(f"Saving iteration {gen} locally") + + if gen % log_interval == 0: + print(f"Generation: {gen}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Fitness: {fitness.mean()} | Other Fitness: {other_fitness.mean()}" + ) + print( + f"Total Episode Reward: {float(rewards_1.mean()), float(rewards_2.mean())}" + ) + print(f"Env Stats: {env_stats}") + print( + "--------------------------------------------------------------------------" + ) + print( + f"Top 5: Generation | Mean: {log['log_top_gen_mean'][gen]}" + f" | Std: {log['log_top_gen_std'][gen]}" + ) + print( + "--------------------------------------------------------------------------" + ) + print(f"Agent {1} | Fitness: {log['top_gen_fitness'][0]}") + print(f"Agent {2} | Fitness: {log['top_gen_fitness'][1]}") + print(f"Agent {3} | Fitness: {log['top_gen_fitness'][2]}") + print(f"Agent {4} | Fitness: {log['top_gen_fitness'][3]}") + print(f"Agent {5} | Fitness: {log['top_gen_fitness'][4]}") + print() + + if watchers: + wandb_log = { + "generations": gen, + "train/fitness/player_1": float(fitness.mean()), + "train/fitness/player_2": float(other_fitness.mean()), + "train/fitness/top_overall_mean": log["log_top_mean"][gen], + "train/fitness/top_overall_std": log["log_top_std"][gen], + "train/fitness/top_gen_mean": log["log_top_gen_mean"][gen], + "train/fitness/top_gen_std": log["log_top_gen_std"][gen], + "train/fitness/gen_std": log["log_gen_std"][gen], + "train/time/minutes": float( + (time.time() - self.start_time) / 60 + ), + "train/time/seconds": float( + (time.time() - self.start_time) + ), + "train/episode_reward/player_1": float(rewards_1.mean()), + "train/episode_reward/player_2": float(rewards_2.mean()), + } + wandb_log.update(env_stats) + # loop through population + for idx, (overall_fitness, gen_fitness) in enumerate( + zip(log["top_fitness"], log["top_gen_fitness"]) + ): + wandb_log[ + f"train/fitness/top_overall_agent_{idx+1}" + ] = overall_fitness + wandb_log[ + f"train/fitness/top_gen_agent_{idx+1}" + ] = gen_fitness + + # player 2 metrics + # metrics [outer_timesteps, num_opps] + flattened_metrics = jax.tree_util.tree_map( + lambda x: jnp.sum(jnp.mean(x, 1)), a2_metrics + ) + + agent2._logger.metrics.update(flattened_metrics) + for watcher, agent in zip(watchers, agents): + watcher(agent) + wandb.log(wandb_log) + + return agents diff --git a/pax/utils.py b/pax/utils.py index 68f7233b..0c9ed1ef 100644 --- a/pax/utils.py +++ b/pax/utils.py @@ -101,6 +101,15 @@ class TrainingState(NamedTuple): random_key: jnp.ndarray timesteps: int +class TrainingStateSynq(NamedTuple): + """Training state consists of network parameters, optimiser state, random key, timesteps""" + + params: hk.Params + params_synq: hk.Params + opt_state: optax.GradientTransformation + random_key: jnp.ndarray + timesteps: int + class MemoryState(NamedTuple): """State consists of network extras (to be batched)"""