@@ -178,10 +178,10 @@ def get_action(self, state, last_action, hidden_in, deterministic=False):
178
178
@brief:
179
179
for each distributed agent, generate action for one step given input data
180
180
@params:
181
- state: [#batch, #feature*action_shape ]
182
- action : [#batch , action_shape]
181
+ state: [n_agents, n_feature ]
182
+ last_action : [n_agents , action_shape]
183
183
'''
184
- state = torch .FloatTensor (state ).unsqueeze (0 ).unsqueeze (0 ).to (device ) # add #sequence and #batch print(last_action.shape)
184
+ state = torch .FloatTensor (state ).unsqueeze (0 ).unsqueeze (0 ).to (device ) # add #sequence and #batch: [[#batch, #sequence, n_agents, n_feature]]
185
185
last_action = torch .LongTensor (
186
186
last_action ).unsqueeze (0 ).unsqueeze (0 ).to (device ) # add #sequence and #batch: [#batch, #sequence, n_agents, action_shape]
187
187
hidden_in = hidden_in .unsqueeze (1 ) # add #batch: [#batch, n_agents, hidden_dim]
@@ -192,7 +192,7 @@ def get_action(self, state, last_action, hidden_in, deterministic=False):
192
192
action = np .argmax (agent_outs .detach ().cpu ().numpy (), axis = - 1 )
193
193
else :
194
194
action = dist .sample ().squeeze (0 ).squeeze (0 ).detach ().cpu ().numpy () # squeeze the added #batch and #sequence dimension
195
- return action , hidden_out
195
+ return action , hidden_out # [n_agents, action_shape]
196
196
197
197
class QMix (nn .Module ):
198
198
def __init__ (self , state_dim , n_agents , action_shape , embed_dim = 64 , hypernet_embed = 128 , abs = True ):
@@ -279,12 +279,13 @@ def b(self, states):
279
279
280
280
281
281
class QMix_Trainer ():
282
- def __init__ (self , replay_buffer , n_agents , state_dim , action_shape , action_dim , hidden_dim , hypernet_dim , lr = 0.001 , logger = None ):
282
+ def __init__ (self , replay_buffer , n_agents , state_dim , action_shape , action_dim , hidden_dim , hypernet_dim , target_update_interval , lr = 0.001 , logger = None ):
283
283
self .replay_buffer = replay_buffer
284
284
285
285
self .action_dim = action_dim
286
286
self .action_shape = action_shape
287
287
self .n_agents = n_agents
288
+ self .target_update_interval = target_update_interval
288
289
self .agent = RNNAgent (state_dim , action_shape ,
289
290
action_dim , hidden_dim ).to (device )
290
291
self .target_agent = RNNAgent (
@@ -296,6 +297,7 @@ def __init__(self, replay_buffer, n_agents, state_dim, action_shape, action_dim,
296
297
hidden_dim , hypernet_dim ).to (device )
297
298
298
299
self ._update_targets ()
300
+ self .update_cnt = 0
299
301
300
302
self .criterion = nn .MSELoss ()
301
303
@@ -336,15 +338,14 @@ def update(self, batch_size):
336
338
next_state = torch .FloatTensor (next_state ).to (device )
337
339
action = torch .LongTensor (action ).to (device ) # [#batch, sequence, #agents, #action_shape]
338
340
last_action = torch .LongTensor (last_action ).to (device )
339
- # reward is scalar, add 1 dim to be [reward] at the same dim
340
- reward = torch .FloatTensor (reward ).unsqueeze (- 1 ).to (device )
341
+ reward = torch .FloatTensor (reward ).unsqueeze (- 1 ).to (device ) # reward is scalar, add 1 dim to be [reward] at the same dim
341
342
342
343
agent_outs , _ = self .agent (state , last_action , hidden_in ) # [#batch, #sequence, #agent, action_shape, num_actions]
343
344
344
345
chosen_action_qvals = torch .gather ( # [#batch, #sequence, #agent, action_shape]
345
346
agent_outs , dim = - 1 , index = action .unsqueeze (- 1 )).squeeze (- 1 )
346
347
347
- qtot = self .mixer (chosen_action_qvals , state )
348
+ qtot = self .mixer (chosen_action_qvals , state ) # [#batch, #sequence, 1]
348
349
349
350
# target q
350
351
target_agent_outs , _ = self .target_agent (next_state , action , hidden_out )
@@ -359,11 +360,16 @@ def update(self, batch_size):
359
360
loss .backward ()
360
361
self .optimizer .step ()
361
362
363
+ self .update_cnt += 1
364
+ if self .update_cnt % self .target_update_interval == 0 :
365
+ self ._update_targets ()
366
+
362
367
return loss .item ()
363
368
364
369
def _build_td_lambda_targets (self , rewards , target_qs , gamma = 0.99 , td_lambda = 0.6 ):
365
370
'''
366
371
@params:
372
+ rewards: [#batch, #sequence, 1]
367
373
target_qs: [#batch, #sequence, 1]
368
374
'''
369
375
ret = target_qs .new_zeros (* target_qs .shape )
@@ -400,6 +406,7 @@ def load_model(self, path):
400
406
update_iter = 1
401
407
batch_size = 2
402
408
save_interval = 10
409
+ target_update_interval = 10
403
410
model_path = 'model/qmix'
404
411
405
412
env = entombed_cooperative_v2 # this is not a valid env, reward seems to be zero-sum; for QMIX we need same reward for all agents
@@ -412,7 +419,7 @@ def load_model(self, path):
412
419
print (state_dim , action_dim , n_agents )
413
420
414
421
replay_buffer = ReplayBufferGRU (replay_buffer_size )
415
- learner = QMix_Trainer (replay_buffer , n_agents , state_dim , action_shape , action_dim , hidden_dim , hypernet_dim )
422
+ learner = QMix_Trainer (replay_buffer , n_agents , state_dim , action_shape , action_dim , hidden_dim , hypernet_dim , target_update_interval )
416
423
417
424
loss = None
418
425
@@ -448,10 +455,6 @@ def load_model(self, path):
448
455
449
456
state = next_state
450
457
last_action = action
451
- # print("episode_state shape {}".format(np.array(episode_state).shape))
452
- # print("episode_action shape {}".format(np.array(episode_action).shape))
453
- # print("episode_last_action {}".format(np.array(episode_last_action)))
454
- # print("episode_last_action shape {}".format(np.array(episode_last_action).shape))
455
458
456
459
# break the episode
457
460
if np .any (done ):
0 commit comments