Skip to content

Commit afdc873

Browse files
authored
[rllib] PyTorch Models for A3C (#1187)
* fixing policy * Compute Action is singular, fixed weird issue with arrays * remove vestige * extraneous ipdb * Can Drop in Pytorch Model * lint * introducing models * fix base policy * Missed this from last time * lint * removedolds * getting vision working * LINT * trying to fix test dependencies * requiremnets * try * tryconda * yes * shutup * flake_passes * changes * removing weight initializer for lstm for now * unused * adam * clip * zero * properscaling * weight * try * fix up pytorch visionnet * bias correction * fix model * same visionnet * matching_bad_things * test * try locking * fixing_linear * naming * lint * FORJENKINS * clouds * lint * Lint + removed dependencies * removed dependencies * format
1 parent 9a6a056 commit afdc873

File tree

16 files changed

+462
-16
lines changed

16 files changed

+462
-16
lines changed

docker/examples/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ FROM ray-project/deploy
44
RUN conda install -y -c conda-forge tensorflow
55
RUN apt-get install -y zlib1g-dev
66
RUN pip install gym[atari] opencv-python==3.2.0.8 smart_open
7+
RUN conda install -y -q pytorch torchvision -c soumith

python/ray/rllib/a3c/a3c.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
"num_batches_per_iteration": 100,
2121
"batch_size": 10,
2222
"use_lstm": True,
23+
"use_pytorch": False,
2324
"model": {"grayscale": True,
2425
"zero_mean": False,
2526
"dim": 42,
26-
"channel_major": True}
27+
"channel_major": False}
2728
}
2829

2930

@@ -35,6 +36,9 @@ def _init(self):
3536
self.env = create_and_wrap(self.env_creator, self.config["model"])
3637
if self.config["use_lstm"]:
3738
policy_cls = SharedModelLSTM
39+
elif self.config["use_pytorch"]:
40+
from ray.rllib.a3c.shared_torch_policy import SharedTorchPolicy
41+
policy_cls = SharedTorchPolicy
3842
else:
3943
policy_cls = SharedModel
4044
self.policy = policy_cls(

python/ray/rllib/a3c/policy.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ def set_weights(self, weights):
2020
def compute_gradients(self, batch):
2121
raise NotImplementedError
2222

23-
def get_vf_loss(self):
24-
raise NotImplementedError
25-
2623
def compute_action(self, observations):
2724
"""Compute action for a _single_ observation"""
2825
raise NotImplementedError

python/ray/rllib/a3c/shared_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class SharedModel(TFPolicy):
1212
def __init__(self, ob_space, ac_space, **kwargs):
1313
super(SharedModel, self).__init__(ob_space, ac_space, **kwargs)
1414

15-
def setup_graph(self, ob_space, ac_space):
15+
def _setup_graph(self, ob_space, ac_space):
1616
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
1717
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
1818
self._model = ModelCatalog.get_model(self.x, self.logit_dim)

python/ray/rllib/a3c/shared_model_lstm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SharedModelLSTM(TFPolicy):
1414
def __init__(self, ob_space, ac_space, **kwargs):
1515
super(SharedModelLSTM, self).__init__(ob_space, ac_space, **kwargs)
1616

17-
def setup_graph(self, ob_space, ac_space):
17+
def _setup_graph(self, ob_space, ac_space):
1818
self.x = tf.placeholder(tf.float32, [None] + list(ob_space))
1919
dist_class, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
2020
self._model = LSTM(self.x, self.logit_dim, {})
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import torch
6+
from torch.autograd import Variable
7+
import torch.nn.functional as F
8+
9+
from ray.rllib.a3c.torchpolicy import TorchPolicy
10+
from ray.rllib.models.pytorch.misc import var_to_np, convert_batch
11+
from ray.rllib.models.catalog import ModelCatalog
12+
13+
14+
class SharedTorchPolicy(TorchPolicy):
15+
"""Assumes nonrecurrent."""
16+
17+
def __init__(self, ob_space, ac_space, **kwargs):
18+
super(SharedTorchPolicy, self).__init__(
19+
ob_space, ac_space, **kwargs)
20+
21+
def _setup_graph(self, ob_space, ac_space):
22+
_, self.logit_dim = ModelCatalog.get_action_dist(ac_space)
23+
self._model = ModelCatalog.get_torch_model(ob_space, self.logit_dim)
24+
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.0001)
25+
26+
def compute_action(self, ob, *args):
27+
"""Should take in a SINGLE ob"""
28+
with self.lock:
29+
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
30+
logits, values = self._model(ob)
31+
samples = self._model.probs(logits).multinomial().squeeze()
32+
values = values.squeeze(0)
33+
return var_to_np(samples), var_to_np(values)
34+
35+
def compute_logits(self, ob, *args):
36+
with self.lock:
37+
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
38+
res = self._model.hidden_layers(ob)
39+
return var_to_np(self._model.logits(res))
40+
41+
def value(self, ob, *args):
42+
with self.lock:
43+
ob = Variable(torch.from_numpy(ob).float().unsqueeze(0))
44+
res = self._model.hidden_layers(ob)
45+
res = self._model.value_branch(res)
46+
res = res.squeeze(0)
47+
return var_to_np(res)
48+
49+
def _evaluate(self, obs, actions):
50+
"""Passes in multiple obs."""
51+
logits, values = self._model(obs)
52+
log_probs = F.log_softmax(logits)
53+
probs = self._model.probs(logits)
54+
action_log_probs = log_probs.gather(1, actions.view(-1, 1))
55+
entropy = -(log_probs * probs).sum(-1).sum()
56+
return values, action_log_probs, entropy
57+
58+
def _backward(self, batch):
59+
"""Loss is encoded in here. Defining a new loss function
60+
would start by rewriting this function"""
61+
62+
states, acs, advs, rs, _ = convert_batch(batch)
63+
values, ac_logprobs, entropy = self._evaluate(states, acs)
64+
pi_err = -(advs * ac_logprobs).sum()
65+
value_err = 0.5 * (values - rs).pow(2).sum()
66+
67+
self.optimizer.zero_grad()
68+
overall_err = 0.5 * value_err + pi_err - entropy * 0.01
69+
overall_err.backward()
70+
torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
71+
72+
def get_initial_features(self):
73+
return [None]

python/ray/rllib/a3c/tfpolicy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ def __init__(self, ob_space, action_space, name="local", summarize=True):
1717
self.g = tf.Graph()
1818
with self.g.as_default(), tf.device(worker_device):
1919
with tf.variable_scope(name):
20-
self.setup_graph(ob_space, action_space)
20+
self._setup_graph(ob_space, action_space)
2121
assert all([hasattr(self, attr)
2222
for attr in ["vf", "logits", "x", "var_list"]])
2323
print("Setting up loss")
2424
self.setup_loss(action_space)
2525
self.setup_gradients()
2626
self.initialize()
2727

28-
def setup_graph(self):
28+
def _setup_graph(self):
2929
raise NotImplementedError
3030

3131
def setup_loss(self, action_space):
@@ -92,9 +92,6 @@ def set_weights(self, weights):
9292
def compute_gradients(self, batch):
9393
raise NotImplementedError
9494

95-
def get_vf_loss(self):
96-
raise NotImplementedError
97-
9895
def compute_action(self, observations):
9996
raise NotImplementedError
10097

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import torch
6+
from torch.autograd import Variable
7+
8+
from ray.rllib.a3c.policy import Policy
9+
from threading import Lock
10+
11+
12+
class TorchPolicy(Policy):
13+
"""The policy base class for Torch.
14+
15+
The model is a separate object than the policy. This could be changed
16+
in the future."""
17+
18+
def __init__(self, ob_space, action_space, name="local", summarize=True):
19+
self.local_steps = 0
20+
self.summarize = summarize
21+
self._setup_graph(ob_space, action_space)
22+
torch.set_num_threads(2)
23+
self.lock = Lock()
24+
25+
def apply_gradients(self, grads):
26+
self.optimizer.zero_grad()
27+
for g, p in zip(grads, self._model.parameters()):
28+
p.grad = Variable(torch.from_numpy(g))
29+
self.optimizer.step()
30+
31+
def get_weights(self):
32+
# !! This only returns references to the data.
33+
return self._model.state_dict()
34+
35+
def set_weights(self, weights):
36+
with self.lock:
37+
self._model.load_state_dict(weights)
38+
39+
def compute_gradients(self, batch):
40+
"""_backward generates the gradient in each model parameter.
41+
This is taken out.
42+
43+
Args:
44+
batch: Batch of data needed for gradient calculation.
45+
46+
Return:
47+
gradients (list of np arrays): List of gradients
48+
info (dict): Extra information (user-defined)"""
49+
with self.lock:
50+
self._backward(batch)
51+
# Note that return values are just references;
52+
# calling zero_grad will modify the values
53+
return [p.grad.data.numpy() for p in self._model.parameters()], {}
54+
55+
def model_update(self, batch):
56+
"""Implements compute + apply
57+
58+
TODO(rliaw): Pytorch has nice caching property that doesn't require
59+
full batch to be passed in. Can exploit that later"""
60+
with self.lock:
61+
self._backward(batch)
62+
self.optimizer.step()
63+
64+
def _setup_graph(ob_space, action_space):
65+
raise NotImplementedError
66+
67+
def _backward(self, batch):
68+
"""Implements the loss function and calculates the gradient.
69+
Pytorch automatically generates a backward trace for each variable.
70+
Assumption right now is that variables are moved, so the backward
71+
trace is lost.
72+
73+
This function regenerates the backward trace and
74+
caluclates the gradient."""
75+
raise NotImplementedError
76+
77+
def get_initial_features(self):
78+
return []

python/ray/rllib/models/catalog.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,30 @@ def get_model(inputs, num_outputs, options=dict()):
8585

8686
return FullyConnectedNetwork(inputs, num_outputs, options)
8787

88+
@staticmethod
89+
def get_torch_model(input_shape, num_outputs, options=dict()):
90+
"""Returns a PyTorch suitable model.
91+
92+
Args:
93+
input_shape (tup): The input shape to the model.
94+
num_outputs (int): The size of the output vector of the model.
95+
options (dict): Optional args to pass to the model constructor.
96+
97+
Returns:
98+
model (Model): Neural network model.
99+
"""
100+
from ray.rllib.models.pytorch.fcnet import (
101+
FullyConnectedNetwork as PyTorchFCNet)
102+
from ray.rllib.models.pytorch.visionnet import (
103+
VisionNetwork as PyTorchVisionNet)
104+
105+
obs_rank = len(input_shape) - 1
106+
107+
if obs_rank > 1:
108+
return PyTorchVisionNet(input_shape, num_outputs, options)
109+
110+
return PyTorchFCNet(input_shape[0], num_outputs, options)
111+
88112
@classmethod
89113
def get_preprocessor(cls, env, options=dict()):
90114
"""Returns a suitable processor for the given environment.

python/ray/rllib/models/preprocessors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ def _init(self):
3030
self._grayscale = self._options.get("grayscale", False)
3131
self._zero_mean = self._options.get("zero_mean", True)
3232
self._dim = self._options.get("dim", 80)
33-
self._pytorch = self._options.get("pytorch", False)
33+
self._channel_major = self._options.get("channel_major", False)
3434
if self._grayscale:
3535
self.shape = (self._dim, self._dim, 1)
3636
else:
3737
self.shape = (self._dim, self._dim, 3)
3838

39-
# pytorch requires (# in-channels, row dim, col dim)
40-
if self._pytorch:
41-
self.shape = self.shape[::-1]
39+
# channel_major requires (# in-channels, row dim, col dim)
40+
if self._channel_major:
41+
self.shape = self.shape[-1:] + self.shape[:-1]
4242

4343
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
4444
def transform(self, observation):
@@ -59,7 +59,7 @@ def transform(self, observation):
5959
scaled = (scaled - 128) / 128
6060
else:
6161
scaled *= 1.0 / 255.0
62-
if self._pytorch:
62+
if self._channel_major:
6363
scaled = np.reshape(scaled, self.shape)
6464
return scaled
6565

0 commit comments

Comments
 (0)