Skip to content

Commit 69551e1

Browse files
authored
Fix handling of wrong-order keys
Fix wrong-order keys in unflattened observations This should just be gracefully handled, but was broken before. The tests explicitly check that everything works by constructing `OrderedDict`s with the "wrong" order. Also, run flake8 and copy the CI's yapf output (which differs from my local yapf output).
1 parent 7cbbb5c commit 69551e1

File tree

2 files changed

+69
-47
lines changed

2 files changed

+69
-47
lines changed

src/akro/dict.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,7 @@ def flatten(self, x):
5454
5555
"""
5656
return np.concatenate(
57-
[
58-
space.flatten(xi)
59-
for space, xi in zip(self.spaces.values(), x.values())
60-
],
57+
[space.flatten(x[key]) for key, space in self.spaces.items()],
6158
axis=-1,
6259
)
6360

@@ -112,9 +109,8 @@ def flatten_with_keys(self, x, keys):
112109
"""
113110
return np.concatenate(
114111
[
115-
self.spaces[key].flatten(xi)
116-
for key, xi in zip(self.spaces.keys(), x.values())
117-
if key in keys
112+
space.flatten(x[key])
113+
for key, space in self.spaces.items() if key in keys
118114
],
119115
axis=-1,
120116
)
@@ -129,10 +125,14 @@ def unflatten_with_keys(self, x, keys):
129125
collections.OrderedDict
130126
131127
"""
132-
dims = np.array([self.spaces[key].flat_dim for key in keys])
128+
dims = np.array([
129+
space.flat_dim for key, space in self.spaces.items() if key in keys
130+
])
133131
flat_x = np.split(x, np.cumsum(dims)[:-1])
134-
return collections.OrderedDict([(key, self.spaces[key].unflatten(xi))
135-
for key, xi in zip(keys, flat_x)])
132+
return collections.OrderedDict(
133+
[(key, space.unflatten(xi))
134+
for (key, space), xi in zip(self.spaces.items(), flat_x)
135+
if key in keys])
136136

137137
@requires_tf
138138
def to_tf_placeholder(self, name, batch_dims):

tests/akro/test_dict.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import numpy as np
66

7+
from akro import Box
78
from akro import Dict
89
from akro import Discrete
9-
from akro import Box
1010
from akro import tf
1111
from akro import theano
1212
from akro.requires import requires_tf, requires_theano
@@ -27,70 +27,92 @@ def test_pickleable(self):
2727
assert round_trip.contains(sample)
2828

2929
def test_flat_dim(self):
30-
d = Dict(collections.OrderedDict(position=Box(0, 10, (2,)),
31-
velocity=Box(0, 10, (3,))))
30+
d = Dict(
31+
collections.OrderedDict(
32+
position=Box(0, 10, (2, )), velocity=Box(0, 10, (3, ))))
3233
assert d.flat_dim == 5
3334

3435
def test_flat_dim_with_keys(self):
35-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
36-
('velocity', Box(0, 10, (3,)))]))
36+
d = Dict(
37+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
38+
('velocity', Box(0, 10, (3, )))]))
3739
assert d.flat_dim_with_keys(['position']) == 2
3840

3941
def test_flatten(self):
40-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
41-
('velocity', Box(0, 10, (3,)))]))
42+
d = Dict(
43+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
44+
('velocity', Box(0, 10, (3, )))]))
4245
f = np.array([1., 2., 3., 4., 5.])
43-
s = collections.OrderedDict(position=np.array([1., 2.]),
44-
velocity=np.array([3., 4., 5.]))
46+
# Keys are intentionally in the "wrong" order.
47+
s = collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
48+
('position', np.array([1., 2.]))])
4549
assert (d.flatten(s) == f).all()
4650

4751
def test_unflatten(self):
48-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
49-
('velocity', Box(0, 10, (3,)))]))
52+
d = Dict(
53+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
54+
('velocity', Box(0, 10, (3, )))]))
5055
f = np.array([1., 2., 3., 4., 5.])
51-
s = collections.OrderedDict(position=np.array([1., 2.]),
52-
velocity=np.array([3., 4., 5.]))
56+
# Keys are intentionally in the "wrong" order.
57+
s = collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
58+
('position', np.array([1., 2.]))])
5359
assert all((s[k] == v).all() for k, v in d.unflatten(f).items())
5460

5561
def test_flatten_n(self):
56-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
57-
('velocity', Box(0, 10, (3,)))]))
58-
f = np.array([[1., 2., 3., 4., 5.],
59-
[6., 7., 8., 9., 0.]])
60-
s = [collections.OrderedDict(position=np.array([1., 2.]),
61-
velocity=np.array([3., 4., 5.])),
62-
collections.OrderedDict(position=np.array([6., 7.]),
63-
velocity=np.array([8., 9., 0.]))]
62+
d = Dict(
63+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
64+
('velocity', Box(0, 10, (3, )))]))
65+
f = np.array([[1., 2., 3., 4., 5.], [6., 7., 8., 9., 0.]])
66+
# Keys are intentionally in the "wrong" order.
67+
s = [
68+
collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
69+
('position', np.array([1., 2.]))]),
70+
collections.OrderedDict([('velocity', np.array([8., 9., 0.])),
71+
('position', np.array([6., 7.]))])
72+
]
6473
assert (d.flatten_n(s) == f).all()
6574

6675
def test_unflatten_n(self):
67-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
68-
('velocity', Box(0, 10, (3,)))]))
69-
f = np.array([[1., 2., 3., 4., 5.],
70-
[6., 7., 8., 9., 0.]])
71-
s = [collections.OrderedDict(position=np.array([1., 2.]),
72-
velocity=np.array([3., 4., 5.])),
73-
collections.OrderedDict(position=np.array([6., 7.]),
74-
velocity=np.array([8., 9., 0.]))]
76+
d = Dict(
77+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
78+
('velocity', Box(0, 10, (3, )))]))
79+
f = np.array([[1., 2., 3., 4., 5.], [6., 7., 8., 9., 0.]])
80+
# Keys are intentionally in the "wrong" order.
81+
s = [
82+
collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
83+
('position', np.array([1., 2.]))]),
84+
collections.OrderedDict([('velocity', np.array([8., 9., 0.])),
85+
('position', np.array([6., 7.]))])
86+
]
7587
for i, fi in enumerate(d.unflatten_n(f)):
7688
assert all((s[i][k] == v).all() for k, v in fi.items())
7789

7890
def test_flatten_with_keys(self):
79-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
80-
('velocity', Box(0, 10, (3,)))]))
91+
d = Dict(
92+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
93+
('velocity', Box(0, 10, (3, )))]))
8194
f = np.array([3., 4., 5.])
82-
s = collections.OrderedDict(position=np.array([1., 2.]),
83-
velocity=np.array([3., 4., 5.]))
95+
f_full = np.array([1., 2., 3., 4., 5.])
96+
# Keys are intentionally in the "wrong" order.
97+
s = collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
98+
('position', np.array([1., 2.]))])
8499
assert (d.flatten_with_keys(s, ['velocity']) == f).all()
100+
assert (d.flatten_with_keys(s,
101+
['velocity', 'position']) == f_full).all()
85102

86103
def test_unflatten_with_keys(self):
87-
d = Dict(collections.OrderedDict([('position', Box(0, 10, (2,))),
88-
('velocity', Box(0, 10, (3,)))]))
104+
d = Dict(
105+
collections.OrderedDict([('position', Box(0, 10, (2, ))),
106+
('velocity', Box(0, 10, (3, )))]))
89107
f = np.array([3., 4., 5.])
90-
s = collections.OrderedDict(position=np.array([1., 2.]),
91-
velocity=np.array([3., 4., 5.]))
108+
f_full = np.array([1., 2., 3., 4., 5.])
109+
# Keys are intentionally in the "wrong" order.
110+
s = collections.OrderedDict([('velocity', np.array([3., 4., 5.])),
111+
('position', np.array([1., 2.]))])
92112
assert all((s[k] == v).all()
93113
for k, v in d.unflatten_with_keys(f, ['velocity']).items())
114+
assert all((s[k] == v).all() for k, v in d.unflatten_with_keys(
115+
f_full, ['velocity', 'position']).items())
94116

95117
@requires_tf
96118
def test_convert_tf(self):

0 commit comments

Comments
 (0)