Skip to content

Commit 7cbbb5c

Browse files
committed
Fix akro.Dict to support existing APIs
Several of the APIs were missing from `akro.Dict` (and untested). I've copied their implementations from the old `akro.tf.Dict`. I also rewrote `akro.Dict.unflatten_from_keys` to make more sense. This shouldn't be a problem, since it isn't used in garage (but is provided for completeness). Wrap sub-spaces in Dict and Tuple spaces Pin gym==0.12.4 We require given gym version, since otherwise gym.spaces.Space is not defined, which causes akro.Space to fail to import. Increase version to v0.0.6
1 parent d701827 commit 7cbbb5c

File tree

5 files changed

+140
-21
lines changed

5 files changed

+140
-21
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.0.5-dev
1+
v0.0.6-dev

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Required dependencies
55
required = [
66
# Please keep alphabetized
7-
'gym',
7+
'gym==0.12.4',
88
'numpy',
99
]
1010

src/akro/dict.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
This Space produces samples which are dicts, where the values of those dicts
44
are drawn from the values of this Space.
55
"""
6+
import collections
7+
68
import gym.spaces
9+
import numpy as np
710

11+
import akro
812
from akro.requires import requires_tf, requires_theano
913
from akro.space import Space
1014

@@ -18,12 +22,27 @@ class Dict(gym.spaces.Dict, Space):
1822
"velocity": spaces.Discrete(3)})
1923
"""
2024

21-
@property # pragma: no cover
25+
def __init__(self, spaces=None, **kwargs):
26+
super().__init__(spaces, **kwargs)
27+
self.spaces = (collections.OrderedDict(
28+
[(k, akro.from_gym(s)) for k, s in self.spaces.items()]))
29+
30+
@property
2231
def flat_dim(self):
2332
"""Return the length of the flattened vector of the space."""
24-
raise NotImplementedError
33+
return sum([space.flat_dim for _, space in self.spaces.items()])
34+
35+
def flat_dim_with_keys(self, keys):
36+
"""
37+
Return a flat dimension of the spaces specified by the keys.
38+
39+
Returns:
40+
sum (int)
2541
26-
def flatten(self, x): # pragma: no cover
42+
"""
43+
return sum([self.spaces[key].flat_dim for key in keys])
44+
45+
def flatten(self, x):
2746
"""Return an observation of x with collapsed values.
2847
2948
Args:
@@ -34,21 +53,31 @@ def flatten(self, x): # pragma: no cover
3453
Keys are unchanged.
3554
3655
"""
37-
raise NotImplementedError
38-
39-
def unflatten(self, x): # pragma: no cover
56+
return np.concatenate(
57+
[
58+
space.flatten(xi)
59+
for space, xi in zip(self.spaces.values(), x.values())
60+
],
61+
axis=-1,
62+
)
63+
64+
def unflatten(self, x):
4065
"""Return an unflattened observation x.
4166
4267
Args:
4368
x (:obj:`Iterable`): The object to unflatten.
4469
4570
Returns:
46-
np.ndarray: An array of x in the shape of self.shape.
71+
collections.OrderedDict
4772
4873
"""
49-
raise NotImplementedError
74+
dims = np.array([s.flat_dim for s in self.spaces.values()])
75+
flat_x = np.split(x, np.cumsum(dims)[:-1])
76+
return collections.OrderedDict(
77+
[(key, self.spaces[key].unflatten(xi))
78+
for key, xi in zip(self.spaces.keys(), flat_x)])
5079

51-
def flatten_n(self, xs): # pragma: no cover
80+
def flatten_n(self, xs):
5281
"""Return flattened observations xs.
5382
5483
Args:
@@ -59,20 +88,51 @@ def flatten_n(self, xs): # pragma: no cover
5988
its first element.
6089
6190
"""
62-
raise NotImplementedError
91+
return np.array([self.flatten(x) for x in xs])
6392

64-
def unflatten_n(self, xs): # pragma: no cover
93+
def unflatten_n(self, xs):
6594
"""Return unflattened observations xs.
6695
6796
Args:
6897
xs (:obj:`Iterable`): The object to reshape and unflatten
6998
7099
Returns:
71-
np.ndarray: An array of xs in a shape inferred by the size of
72-
its first element and self.shape.
100+
List[OrderedDict]
101+
102+
"""
103+
return [self.unflatten(x) for x in xs]
104+
105+
def flatten_with_keys(self, x, keys):
106+
"""
107+
Return flattened obs of spaces specified by the keys using x.
108+
109+
Returns:
110+
list
111+
112+
"""
113+
return np.concatenate(
114+
[
115+
self.spaces[key].flatten(xi)
116+
for key, xi in zip(self.spaces.keys(), x.values())
117+
if key in keys
118+
],
119+
axis=-1,
120+
)
121+
122+
def unflatten_with_keys(self, x, keys):
123+
"""
124+
Return an unflattened observation.
125+
126+
This is the inverse of `flatten_with_keys`.
127+
128+
Returns:
129+
collections.OrderedDict
73130
74131
"""
75-
raise NotImplementedError
132+
dims = np.array([self.spaces[key].flat_dim for key in keys])
133+
flat_x = np.split(x, np.cumsum(dims)[:-1])
134+
return collections.OrderedDict([(key, self.spaces[key].unflatten(xi))
135+
for key, xi in zip(keys, flat_x)])
76136

77137
@requires_tf
78138
def to_tf_placeholder(self, name, batch_dims):

src/akro/tuple.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
import gym.spaces
88
import numpy as np
99

10+
import akro
1011
from akro.requires import requires_tf, requires_theano
1112
from akro.space import Space
1213

1314

1415
class Tuple(gym.spaces.Tuple, Space):
1516
"""A Tuple of Spaces which produces samples which are Tuples of samples."""
1617

18+
def __init__(self, spaces):
19+
super().__init__([akro.from_gym(space) for space in spaces])
20+
1721
@property
1822
def flat_dim(self):
1923
"""Return the length of the flattened vector of the space."""

tests/akro/test_dict.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import collections
12
import pickle
23
import unittest
34

5+
import numpy as np
6+
47
from akro import Dict
58
from akro import Discrete
9+
from akro import Box
610
from akro import tf
711
from akro import theano
812
from akro.requires import requires_tf, requires_theano
@@ -23,19 +27,70 @@ def test_pickleable(self):
2327
assert round_trip.contains(sample)
2428

2529
def test_flat_dim(self):
26-
pass
30+
d = Dict(collections.OrderedDict(position=Box(0, 10, (2,)),
31+
velocity=Box(0, 10, (3,))))
32+
assert d.flat_dim == 5
33+
34+
def test_flat_dim_with_keys(self):
35+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
36+
('velocity', Box(0, 10, (3,)))]))
37+
assert d.flat_dim_with_keys(['position']) == 2
2738

2839
def test_flatten(self):
29-
pass
40+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
41+
('velocity', Box(0, 10, (3,)))]))
42+
f = np.array([1., 2., 3., 4., 5.])
43+
s = collections.OrderedDict(position=np.array([1., 2.]),
44+
velocity=np.array([3., 4., 5.]))
45+
assert (d.flatten(s) == f).all()
3046

3147
def test_unflatten(self):
32-
pass
48+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
49+
('velocity', Box(0, 10, (3,)))]))
50+
f = np.array([1., 2., 3., 4., 5.])
51+
s = collections.OrderedDict(position=np.array([1., 2.]),
52+
velocity=np.array([3., 4., 5.]))
53+
assert all((s[k] == v).all() for k, v in d.unflatten(f).items())
3354

3455
def test_flatten_n(self):
35-
pass
56+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
57+
('velocity', Box(0, 10, (3,)))]))
58+
f = np.array([[1., 2., 3., 4., 5.],
59+
[6., 7., 8., 9., 0.]])
60+
s = [collections.OrderedDict(position=np.array([1., 2.]),
61+
velocity=np.array([3., 4., 5.])),
62+
collections.OrderedDict(position=np.array([6., 7.]),
63+
velocity=np.array([8., 9., 0.]))]
64+
assert (d.flatten_n(s) == f).all()
3665

3766
def test_unflatten_n(self):
38-
pass
67+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
68+
('velocity', Box(0, 10, (3,)))]))
69+
f = np.array([[1., 2., 3., 4., 5.],
70+
[6., 7., 8., 9., 0.]])
71+
s = [collections.OrderedDict(position=np.array([1., 2.]),
72+
velocity=np.array([3., 4., 5.])),
73+
collections.OrderedDict(position=np.array([6., 7.]),
74+
velocity=np.array([8., 9., 0.]))]
75+
for i, fi in enumerate(d.unflatten_n(f)):
76+
assert all((s[i][k] == v).all() for k, v in fi.items())
77+
78+
def test_flatten_with_keys(self):
79+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
80+
('velocity', Box(0, 10, (3,)))]))
81+
f = np.array([3., 4., 5.])
82+
s = collections.OrderedDict(position=np.array([1., 2.]),
83+
velocity=np.array([3., 4., 5.]))
84+
assert (d.flatten_with_keys(s, ['velocity']) == f).all()
85+
86+
def test_unflatten_with_keys(self):
87+
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
88+
('velocity', Box(0, 10, (3,)))]))
89+
f = np.array([3., 4., 5.])
90+
s = collections.OrderedDict(position=np.array([1., 2.]),
91+
velocity=np.array([3., 4., 5.]))
92+
assert all((s[k] == v).all()
93+
for k, v in d.unflatten_with_keys(f, ['velocity']).items())
3994

4095
@requires_tf
4196
def test_convert_tf(self):

0 commit comments

Comments
 (0)