@@ -106,12 +106,17 @@ def forward(self, state, softmax_dim=-1):
106
106
107
107
return probs
108
108
109
- def evaluate (self , state , epsilon = 1e-6 ):
109
+ def evaluate (self , state , epsilon = 1e-8 ):
110
110
'''
111
111
generate sampled action with state as input wrt the policy network;
112
112
'''
113
113
probs = self .forward (state , softmax_dim = - 1 )
114
114
log_probs = torch .log (probs )
115
+
116
+ # Avoid numerical instability. Ref: https://github.com/ku2482/sac-discrete.pytorch/blob/40c9d246621e658750e0a03001325006da57f2d4/sacd/model.py#L98
117
+ z = (probs == 0.0 ).float () * epsilon
118
+ log_probs = torch .log (probs + z )
119
+
115
120
return log_probs
116
121
117
122
def get_action (self , state , deterministic ):
@@ -171,7 +176,8 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
171
176
predicted_q_value2 = self .soft_q_net2 (state )
172
177
predicted_q_value2 = predicted_q_value2 .gather (1 , action .unsqueeze (- 1 ))
173
178
log_prob = self .policy_net .evaluate (state )
174
- next_log_prob = self .policy_net .evaluate (next_state )
179
+ with torch .no_grad ():
180
+ next_log_prob = self .policy_net .evaluate (next_state )
175
181
# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
176
182
177
183
# Training Q Function
@@ -189,8 +195,12 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
189
195
self .soft_q_optimizer2 .step ()
190
196
191
197
# Training Policy Function
192
- predicted_new_q_value = torch .min (self .soft_q_net1 (state ),self .soft_q_net2 (state ))
198
+ with torch .no_grad ():
199
+ predicted_new_q_value = torch .min (self .soft_q_net1 (state ),self .soft_q_net2 (state ))
193
200
policy_loss = (log_prob .exp ()* (self .alpha * log_prob - predicted_new_q_value )).sum (dim = - 1 ).mean ()
201
+ if torch .isnan (policy_loss ):
202
+ print (log_prob , predicted_new_q_value , state )
203
+ print ('q: ' , q_value_loss1 , q_value_loss2 , target_q_value , target_q_min , next_log_prob , predicted_q_value1 , predicted_q_value2 )
194
204
195
205
self .policy_optimizer .zero_grad ()
196
206
policy_loss .backward ()
0 commit comments