44
55import numpy as np
66
7+ from akro import Box
78from akro import Dict
89from akro import Discrete
9- from akro import Box
1010from akro import tf
1111from akro import theano
1212from akro .requires import requires_tf , requires_theano
@@ -27,70 +27,92 @@ def test_pickleable(self):
2727 assert round_trip .contains (sample )
2828
2929 def test_flat_dim (self ):
30- d = Dict (collections .OrderedDict (position = Box (0 , 10 , (2 ,)),
31- velocity = Box (0 , 10 , (3 ,))))
30+ d = Dict (
31+ collections .OrderedDict (
32+ position = Box (0 , 10 , (2 , )), velocity = Box (0 , 10 , (3 , ))))
3233 assert d .flat_dim == 5
3334
3435 def test_flat_dim_with_keys (self ):
35- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
36- ('velocity' , Box (0 , 10 , (3 ,)))]))
36+ d = Dict (
37+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
38+ ('velocity' , Box (0 , 10 , (3 , )))]))
3739 assert d .flat_dim_with_keys (['position' ]) == 2
3840
3941 def test_flatten (self ):
40- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
41- ('velocity' , Box (0 , 10 , (3 ,)))]))
42+ d = Dict (
43+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
44+ ('velocity' , Box (0 , 10 , (3 , )))]))
4245 f = np .array ([1. , 2. , 3. , 4. , 5. ])
43- s = collections .OrderedDict (position = np .array ([1. , 2. ]),
44- velocity = np .array ([3. , 4. , 5. ]))
46+ # Keys are intentionally in the "wrong" order.
47+ s = collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
48+ ('position' , np .array ([1. , 2. ]))])
4549 assert (d .flatten (s ) == f ).all ()
4650
4751 def test_unflatten (self ):
48- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
49- ('velocity' , Box (0 , 10 , (3 ,)))]))
52+ d = Dict (
53+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
54+ ('velocity' , Box (0 , 10 , (3 , )))]))
5055 f = np .array ([1. , 2. , 3. , 4. , 5. ])
51- s = collections .OrderedDict (position = np .array ([1. , 2. ]),
52- velocity = np .array ([3. , 4. , 5. ]))
56+ # Keys are intentionally in the "wrong" order.
57+ s = collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
58+ ('position' , np .array ([1. , 2. ]))])
5359 assert all ((s [k ] == v ).all () for k , v in d .unflatten (f ).items ())
5460
5561 def test_flatten_n (self ):
56- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
57- ('velocity' , Box (0 , 10 , (3 ,)))]))
58- f = np .array ([[1. , 2. , 3. , 4. , 5. ],
59- [6. , 7. , 8. , 9. , 0. ]])
60- s = [collections .OrderedDict (position = np .array ([1. , 2. ]),
61- velocity = np .array ([3. , 4. , 5. ])),
62- collections .OrderedDict (position = np .array ([6. , 7. ]),
63- velocity = np .array ([8. , 9. , 0. ]))]
62+ d = Dict (
63+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
64+ ('velocity' , Box (0 , 10 , (3 , )))]))
65+ f = np .array ([[1. , 2. , 3. , 4. , 5. ], [6. , 7. , 8. , 9. , 0. ]])
66+ # Keys are intentionally in the "wrong" order.
67+ s = [
68+ collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
69+ ('position' , np .array ([1. , 2. ]))]),
70+ collections .OrderedDict ([('velocity' , np .array ([8. , 9. , 0. ])),
71+ ('position' , np .array ([6. , 7. ]))])
72+ ]
6473 assert (d .flatten_n (s ) == f ).all ()
6574
6675 def test_unflatten_n (self ):
67- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
68- ('velocity' , Box (0 , 10 , (3 ,)))]))
69- f = np .array ([[1. , 2. , 3. , 4. , 5. ],
70- [6. , 7. , 8. , 9. , 0. ]])
71- s = [collections .OrderedDict (position = np .array ([1. , 2. ]),
72- velocity = np .array ([3. , 4. , 5. ])),
73- collections .OrderedDict (position = np .array ([6. , 7. ]),
74- velocity = np .array ([8. , 9. , 0. ]))]
76+ d = Dict (
77+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
78+ ('velocity' , Box (0 , 10 , (3 , )))]))
79+ f = np .array ([[1. , 2. , 3. , 4. , 5. ], [6. , 7. , 8. , 9. , 0. ]])
80+ # Keys are intentionally in the "wrong" order.
81+ s = [
82+ collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
83+ ('position' , np .array ([1. , 2. ]))]),
84+ collections .OrderedDict ([('velocity' , np .array ([8. , 9. , 0. ])),
85+ ('position' , np .array ([6. , 7. ]))])
86+ ]
7587 for i , fi in enumerate (d .unflatten_n (f )):
7688 assert all ((s [i ][k ] == v ).all () for k , v in fi .items ())
7789
7890 def test_flatten_with_keys (self ):
79- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
80- ('velocity' , Box (0 , 10 , (3 ,)))]))
91+ d = Dict (
92+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
93+ ('velocity' , Box (0 , 10 , (3 , )))]))
8194 f = np .array ([3. , 4. , 5. ])
82- s = collections .OrderedDict (position = np .array ([1. , 2. ]),
83- velocity = np .array ([3. , 4. , 5. ]))
95+ f_full = np .array ([1. , 2. , 3. , 4. , 5. ])
96+ # Keys are intentionally in the "wrong" order.
97+ s = collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
98+ ('position' , np .array ([1. , 2. ]))])
8499 assert (d .flatten_with_keys (s , ['velocity' ]) == f ).all ()
100+ assert (d .flatten_with_keys (s ,
101+ ['velocity' , 'position' ]) == f_full ).all ()
85102
86103 def test_unflatten_with_keys (self ):
87- d = Dict (collections .OrderedDict ([('position' , Box (0 , 10 , (2 ,))),
88- ('velocity' , Box (0 , 10 , (3 ,)))]))
104+ d = Dict (
105+ collections .OrderedDict ([('position' , Box (0 , 10 , (2 , ))),
106+ ('velocity' , Box (0 , 10 , (3 , )))]))
89107 f = np .array ([3. , 4. , 5. ])
90- s = collections .OrderedDict (position = np .array ([1. , 2. ]),
91- velocity = np .array ([3. , 4. , 5. ]))
108+ f_full = np .array ([1. , 2. , 3. , 4. , 5. ])
109+ # Keys are intentionally in the "wrong" order.
110+ s = collections .OrderedDict ([('velocity' , np .array ([3. , 4. , 5. ])),
111+ ('position' , np .array ([1. , 2. ]))])
92112 assert all ((s [k ] == v ).all ()
93113 for k , v in d .unflatten_with_keys (f , ['velocity' ]).items ())
114+ assert all ((s [k ] == v ).all () for k , v in d .unflatten_with_keys (
115+ f_full , ['velocity' , 'position' ]).items ())
94116
95117 @requires_tf
96118 def test_convert_tf (self ):
0 commit comments