Skip to content

Commit 120a540

Browse files
committed
add qmix
1 parent d1af180 commit 120a540

File tree

2 files changed

+565
-0
lines changed

2 files changed

+565
-0
lines changed

common/wrappers.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
import numpy as np
3+
import gym
4+
5+
6+
class Dict2TupleWrapper():
7+
""" Wrap the PettingZoo envs to have a similar style as LaserFrame in NFSP """
8+
def __init__(self, env, keep_info=False):
9+
super(Dict2TupleWrapper, self).__init__()
10+
self.env = env
11+
self.num_agents = env.num_agents
12+
self.keep_info = keep_info # if True keep info as dict
13+
if len(env.observation_space.shape) > 1: # image
14+
old_shape = env.observation_space.shape
15+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.uint8)
16+
self.obs_type = 'rgb_image'
17+
else:
18+
self.observation_space = env.observation_space
19+
self.obs_type = 'ram'
20+
self.action_space = env.action_space
21+
self.observation_spaces = env.observation_spaces
22+
self.action_spaces = env.action_spaces
23+
try: # both pettingzoo and slimevolley can work with this
24+
self.agents = env.agents
25+
except:
26+
self.agents = env.unwrapped.agents
27+
28+
@property
29+
def unwrapped(self,):
30+
return self.env
31+
32+
@property
33+
def spec(self):
34+
return self.env.spec
35+
36+
def observation_swapaxis(self, observation):
37+
return (np.swapaxes(observation[0], 2, 0), np.swapaxes(observation[1], 2, 0))
38+
39+
def reset(self):
40+
obs_dict = self.env.reset()
41+
if self.obs_type == 'ram':
42+
return tuple(obs_dict.values())
43+
else:
44+
return self.observation_swapaxis(tuple(obs_dict.values()))
45+
46+
def step(self, actions):
47+
actions = {agent_name: action for agent_name, action in zip(self.agents, actions)}
48+
obs, rewards, dones, infos = self.env.step(actions)
49+
if self.obs_type == 'ram':
50+
o = tuple(obs.values())
51+
else:
52+
o = self.observation_swapaxis(tuple(obs.values()))
53+
r = list(rewards.values())
54+
d = list(dones.values())
55+
if self.keep_info: # a special case for VectorEnv
56+
info = infos
57+
else:
58+
info = list(infos.values())
59+
del obs,rewards, dones, infos
60+
# r = self._zerosum_filter(r)
61+
62+
return o, r, d, info
63+
64+
def _zerosum_filter(self, r):
65+
## zero-sum filter:
66+
# added for making non-zero sum game to be zero-sum, e.g. tennis_v2
67+
if np.sum(r) != 0:
68+
nonzero_idx = np.nonzero(r)[0][0]
69+
r[1-nonzero_idx] = -r[nonzero_idx]
70+
return r
71+
72+
def seed(self, seed):
73+
self.env.seed(seed)
74+
np.random.seed(seed)
75+
76+
def render(self,):
77+
self.env.render()
78+
79+
def close(self):
80+
self.env.close()

0 commit comments

Comments
 (0)