Skip to content

Commit 2c23124

Browse files
committed
solve numerical instable case in sac discrete
1 parent 095a331 commit 2c23124

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

sac.png

47.3 KB
Loading

sac_discrete.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,17 @@ def forward(self, state, softmax_dim=-1):
106106

107107
return probs
108108

109-
def evaluate(self, state, epsilon=1e-6):
109+
def evaluate(self, state, epsilon=1e-8):
110110
'''
111111
generate sampled action with state as input wrt the policy network;
112112
'''
113113
probs = self.forward(state, softmax_dim=-1)
114114
log_probs = torch.log(probs)
115+
116+
# Avoid numerical instability. Ref: https://github.com/ku2482/sac-discrete.pytorch/blob/40c9d246621e658750e0a03001325006da57f2d4/sacd/model.py#L98
117+
z = (probs == 0.0).float() * epsilon
118+
log_probs = torch.log(probs + z)
119+
115120
return log_probs
116121

117122
def get_action(self, state, deterministic):
@@ -171,7 +176,8 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
171176
predicted_q_value2 = self.soft_q_net2(state)
172177
predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))
173178
log_prob = self.policy_net.evaluate(state)
174-
next_log_prob = self.policy_net.evaluate(next_state)
179+
with torch.no_grad():
180+
next_log_prob = self.policy_net.evaluate(next_state)
175181
# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
176182

177183
# Training Q Function
@@ -189,8 +195,12 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
189195
self.soft_q_optimizer2.step()
190196

191197
# Training Policy Function
192-
predicted_new_q_value = torch.min(self.soft_q_net1(state),self.soft_q_net2(state))
198+
with torch.no_grad():
199+
predicted_new_q_value = torch.min(self.soft_q_net1(state),self.soft_q_net2(state))
193200
policy_loss = (log_prob.exp()*(self.alpha * log_prob - predicted_new_q_value)).sum(dim=-1).mean()
201+
if torch.isnan(policy_loss):
202+
print(log_prob, predicted_new_q_value, state)
203+
print('q: ', q_value_loss1, q_value_loss2, target_q_value, target_q_min, next_log_prob, predicted_q_value1, predicted_q_value2)
194204

195205
self.policy_optimizer.zero_grad()
196206
policy_loss.backward()

sac_discrete_per.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,17 @@ def forward(self, state, softmax_dim=-1):
8080

8181
return probs
8282

83-
def evaluate(self, state, epsilon=1e-6):
83+
def evaluate(self, state, epsilon=1e-8):
8484
'''
8585
generate sampled action with state as input wrt the policy network;
8686
'''
8787
probs = self.forward(state, softmax_dim=-1)
8888
log_probs = torch.log(probs)
89+
90+
# Avoid numerical instability. Ref: https://github.com/ku2482/sac-discrete.pytorch/blob/40c9d246621e658750e0a03001325006da57f2d4/sacd/model.py#L98
91+
z = (probs == 0.0).float() * epsilon
92+
log_probs = torch.log(probs + z)
93+
8994
return log_probs
9095

9196
def get_action(self, state, deterministic):

sac_v2.png

13.5 KB
Loading

0 commit comments

Comments
 (0)