|
| 1 | +from __future__ import absolute_import |
| 2 | +from __future__ import division |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch.autograd import Variable |
| 7 | +import torch.nn.functional as F |
| 8 | + |
| 9 | +from ray.rllib.a3c.torchpolicy import TorchPolicy |
| 10 | +from ray.rllib.models.pytorch.misc import var_to_np, convert_batch |
| 11 | +from ray.rllib.models.catalog import ModelCatalog |
| 12 | + |
| 13 | + |
| 14 | +class SharedTorchPolicy(TorchPolicy): |
| 15 | + """Assumes nonrecurrent.""" |
| 16 | + |
| 17 | + def __init__(self, ob_space, ac_space, **kwargs): |
| 18 | + super(SharedTorchPolicy, self).__init__( |
| 19 | + ob_space, ac_space, **kwargs) |
| 20 | + |
| 21 | + def _setup_graph(self, ob_space, ac_space): |
| 22 | + _, self.logit_dim = ModelCatalog.get_action_dist(ac_space) |
| 23 | + self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim) |
| 24 | + self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001) |
| 25 | + |
| 26 | + def compute_action(self, ob, *args): |
| 27 | + """Should take in a SINGLE ob""" |
| 28 | + with self.lock: |
| 29 | + ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) |
| 30 | + logits, values = self._model(ob) |
| 31 | + samples = self._model.probs(logits).multinomial().squeeze() |
| 32 | + values = values.squeeze(0) |
| 33 | + return var_to_np(samples), var_to_np(values) |
| 34 | + |
| 35 | + def compute_logits(self, ob, *args): |
| 36 | + with self.lock: |
| 37 | + ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) |
| 38 | + res = self._model.hidden_layers(ob) |
| 39 | + return var_to_np(self._model.logits(res)) |
| 40 | + |
| 41 | + def value(self, ob, *args): |
| 42 | + with self.lock: |
| 43 | + ob = Variable(torch.from_numpy(ob).float().unsqueeze(0)) |
| 44 | + res = self._model.hidden_layers(ob) |
| 45 | + res = self._model.value_branch(res) |
| 46 | + res = res.squeeze(0) |
| 47 | + return var_to_np(res) |
| 48 | + |
| 49 | + def _evaluate(self, obs, actions): |
| 50 | + """Passes in multiple obs.""" |
| 51 | + logits, values = self._model(obs) |
| 52 | + log_probs = F.log_softmax(logits) |
| 53 | + probs = self._model.probs(logits) |
| 54 | + action_log_probs = log_probs.gather(1, actions.view(-1, 1)) |
| 55 | + entropy = -(log_probs * probs).sum(-1).sum() |
| 56 | + return values, action_log_probs, entropy |
| 57 | + |
| 58 | + def _backward(self, batch): |
| 59 | + """Loss is encoded in here. Defining a new loss function |
| 60 | + would start by rewriting this function""" |
| 61 | + |
| 62 | + states, acs, advs, rs, _ = convert_batch(batch) |
| 63 | + values, ac_logprobs, entropy = self._evaluate(states, acs) |
| 64 | + pi_err = -(advs * ac_logprobs).sum() |
| 65 | + value_err = 0.5 * (values - rs).pow(2).sum() |
| 66 | + |
| 67 | + self.optimizer.zero_grad() |
| 68 | + overall_err = 0.5 * value_err + pi_err - entropy * 0.01 |
| 69 | + overall_err.backward() |
| 70 | + torch.nn.utils.clip_grad_norm(self._model.parameters(), 40) |
| 71 | + |
| 72 | + def get_initial_features(self): |
| 73 | + return [None] |
0 commit comments