Skip to content

Commit 6174295

Browse files
committed
fix qmix
1 parent e5d72ab commit 6174295

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

qmix.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ def get_action(self, state, last_action, hidden_in, deterministic=False):
178178
@brief:
179179
for each distributed agent, generate action for one step given input data
180180
@params:
181-
state: [#batch, #feature*action_shape]
182-
action: [#batch, action_shape]
181+
state: [n_agents, n_feature]
182+
last_action: [n_agents, action_shape]
183183
'''
184-
state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device) # add #sequence and #batch print(last_action.shape)
184+
state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device) # add #sequence and #batch: [[#batch, #sequence, n_agents, n_feature]]
185185
last_action = torch.LongTensor(
186186
last_action).unsqueeze(0).unsqueeze(0).to(device) # add #sequence and #batch: [#batch, #sequence, n_agents, action_shape]
187187
hidden_in = hidden_in.unsqueeze(1) # add #batch: [#batch, n_agents, hidden_dim]
@@ -192,7 +192,7 @@ def get_action(self, state, last_action, hidden_in, deterministic=False):
192192
action = np.argmax(agent_outs.detach().cpu().numpy(), axis=-1)
193193
else:
194194
action = dist.sample().squeeze(0).squeeze(0).detach().cpu().numpy() # squeeze the added #batch and #sequence dimension
195-
return action, hidden_out
195+
return action, hidden_out # [n_agents, action_shape]
196196

197197
class QMix(nn.Module):
198198
def __init__(self, state_dim, n_agents, action_shape, embed_dim=64, hypernet_embed=128, abs=True):
@@ -279,12 +279,13 @@ def b(self, states):
279279

280280

281281
class QMix_Trainer():
282-
def __init__(self, replay_buffer, n_agents, state_dim, action_shape, action_dim, hidden_dim, hypernet_dim, lr=0.001, logger=None):
282+
def __init__(self, replay_buffer, n_agents, state_dim, action_shape, action_dim, hidden_dim, hypernet_dim, target_update_interval, lr=0.001, logger=None):
283283
self.replay_buffer = replay_buffer
284284

285285
self.action_dim = action_dim
286286
self.action_shape = action_shape
287287
self.n_agents = n_agents
288+
self.target_update_interval = target_update_interval
288289
self.agent = RNNAgent(state_dim, action_shape,
289290
action_dim, hidden_dim).to(device)
290291
self.target_agent = RNNAgent(
@@ -296,6 +297,7 @@ def __init__(self, replay_buffer, n_agents, state_dim, action_shape, action_dim,
296297
hidden_dim, hypernet_dim).to(device)
297298

298299
self._update_targets()
300+
self.update_cnt = 0
299301

300302
self.criterion = nn.MSELoss()
301303

@@ -336,15 +338,14 @@ def update(self, batch_size):
336338
next_state = torch.FloatTensor(next_state).to(device)
337339
action = torch.LongTensor(action).to(device) # [#batch, sequence, #agents, #action_shape]
338340
last_action = torch.LongTensor(last_action).to(device)
339-
# reward is scalar, add 1 dim to be [reward] at the same dim
340-
reward = torch.FloatTensor(reward).unsqueeze(-1).to(device)
341+
reward = torch.FloatTensor(reward).unsqueeze(-1).to(device) # reward is scalar, add 1 dim to be [reward] at the same dim
341342

342343
agent_outs, _ = self.agent(state, last_action, hidden_in) # [#batch, #sequence, #agent, action_shape, num_actions]
343344

344345
chosen_action_qvals = torch.gather( # [#batch, #sequence, #agent, action_shape]
345346
agent_outs, dim=-1, index=action.unsqueeze(-1)).squeeze(-1)
346347

347-
qtot = self.mixer(chosen_action_qvals, state)
348+
qtot = self.mixer(chosen_action_qvals, state) # [#batch, #sequence, 1]
348349

349350
# target q
350351
target_agent_outs, _ = self.target_agent(next_state, action, hidden_out)
@@ -359,11 +360,16 @@ def update(self, batch_size):
359360
loss.backward()
360361
self.optimizer.step()
361362

363+
self.update_cnt += 1
364+
if self.update_cnt % self.target_update_interval == 0:
365+
self._update_targets()
366+
362367
return loss.item()
363368

364369
def _build_td_lambda_targets(self, rewards, target_qs, gamma=0.99, td_lambda=0.6):
365370
'''
366371
@params:
372+
rewards: [#batch, #sequence, 1]
367373
target_qs: [#batch, #sequence, 1]
368374
'''
369375
ret = target_qs.new_zeros(*target_qs.shape)
@@ -400,6 +406,7 @@ def load_model(self, path):
400406
update_iter = 1
401407
batch_size = 2
402408
save_interval = 10
409+
target_update_interval = 10
403410
model_path = 'model/qmix'
404411

405412
env = entombed_cooperative_v2 # this is not a valid env, reward seems to be zero-sum; for QMIX we need same reward for all agents
@@ -412,7 +419,7 @@ def load_model(self, path):
412419
print(state_dim, action_dim, n_agents)
413420

414421
replay_buffer = ReplayBufferGRU(replay_buffer_size)
415-
learner = QMix_Trainer(replay_buffer, n_agents, state_dim, action_shape, action_dim, hidden_dim, hypernet_dim)
422+
learner = QMix_Trainer(replay_buffer, n_agents, state_dim, action_shape, action_dim, hidden_dim, hypernet_dim, target_update_interval)
416423

417424
loss = None
418425

@@ -448,10 +455,6 @@ def load_model(self, path):
448455

449456
state = next_state
450457
last_action = action
451-
# print("episode_state shape {}".format(np.array(episode_state).shape))
452-
# print("episode_action shape {}".format(np.array(episode_action).shape))
453-
# print("episode_last_action {}".format(np.array(episode_last_action)))
454-
# print("episode_last_action shape {}".format(np.array(episode_last_action).shape))
455458

456459
# break the episode
457460
if np.any(done):

0 commit comments

Comments
 (0)