1
+
2
+ import numpy as np
3
+ import gym
4
+
5
+
6
+ class Dict2TupleWrapper ():
7
+ """ Wrap the PettingZoo envs to have a similar style as LaserFrame in NFSP """
8
+ def __init__ (self , env , keep_info = False ):
9
+ super (Dict2TupleWrapper , self ).__init__ ()
10
+ self .env = env
11
+ self .num_agents = env .num_agents
12
+ self .keep_info = keep_info # if True keep info as dict
13
+ if len (env .observation_space .shape ) > 1 : # image
14
+ old_shape = env .observation_space .shape
15
+ self .observation_space = gym .spaces .Box (low = 0.0 , high = 1.0 , shape = (old_shape [- 1 ], old_shape [0 ], old_shape [1 ]), dtype = np .uint8 )
16
+ self .obs_type = 'rgb_image'
17
+ else :
18
+ self .observation_space = env .observation_space
19
+ self .obs_type = 'ram'
20
+ self .action_space = env .action_space
21
+ self .observation_spaces = env .observation_spaces
22
+ self .action_spaces = env .action_spaces
23
+ try : # both pettingzoo and slimevolley can work with this
24
+ self .agents = env .agents
25
+ except :
26
+ self .agents = env .unwrapped .agents
27
+
28
+ @property
29
+ def unwrapped (self ,):
30
+ return self .env
31
+
32
+ @property
33
+ def spec (self ):
34
+ return self .env .spec
35
+
36
+ def observation_swapaxis (self , observation ):
37
+ return (np .swapaxes (observation [0 ], 2 , 0 ), np .swapaxes (observation [1 ], 2 , 0 ))
38
+
39
+ def reset (self ):
40
+ obs_dict = self .env .reset ()
41
+ if self .obs_type == 'ram' :
42
+ return tuple (obs_dict .values ())
43
+ else :
44
+ return self .observation_swapaxis (tuple (obs_dict .values ()))
45
+
46
+ def step (self , actions ):
47
+ actions = {agent_name : action for agent_name , action in zip (self .agents , actions )}
48
+ obs , rewards , dones , infos = self .env .step (actions )
49
+ if self .obs_type == 'ram' :
50
+ o = tuple (obs .values ())
51
+ else :
52
+ o = self .observation_swapaxis (tuple (obs .values ()))
53
+ r = list (rewards .values ())
54
+ d = list (dones .values ())
55
+ if self .keep_info : # a special case for VectorEnv
56
+ info = infos
57
+ else :
58
+ info = list (infos .values ())
59
+ del obs ,rewards , dones , infos
60
+ # r = self._zerosum_filter(r)
61
+
62
+ return o , r , d , info
63
+
64
+ def _zerosum_filter (self , r ):
65
+ ## zero-sum filter:
66
+ # added for making non-zero sum game to be zero-sum, e.g. tennis_v2
67
+ if np .sum (r ) != 0 :
68
+ nonzero_idx = np .nonzero (r )[0 ][0 ]
69
+ r [1 - nonzero_idx ] = - r [nonzero_idx ]
70
+ return r
71
+
72
+ def seed (self , seed ):
73
+ self .env .seed (seed )
74
+ np .random .seed (seed )
75
+
76
+ def render (self ,):
77
+ self .env .render ()
78
+
79
+ def close (self ):
80
+ self .env .close ()
0 commit comments