Skip to content

Commit 42e3a73

Browse files
committed
fix sac discrete
1 parent 6790257 commit 42e3a73

File tree

2 files changed

+16
-26
lines changed

2 files changed

+16
-26
lines changed

sac_discrete.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,17 @@
99
'''
1010

1111

12-
import math
1312
import random
14-
1513
import gym
1614
import numpy as np
17-
1815
import torch
1916
import torch.nn as nn
2017
import torch.optim as optim
2118
import torch.nn.functional as F
2219
from torch.distributions import Categorical
23-
2420
from IPython.display import clear_output
2521
import matplotlib.pyplot as plt
26-
from matplotlib import animation
27-
from IPython.display import display
28-
2922
import argparse
30-
import time
3123

3224
GPU = True
3325
device_idx = 0
@@ -104,7 +96,7 @@ def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_mi
10496

10597
self.num_actions = num_actions
10698

107-
def forward(self, state, softmax_dim=0):
99+
def forward(self, state, softmax_dim=-1):
108100
x = F.tanh(self.linear1(state))
109101
x = F.tanh(self.linear2(x))
110102
# x = F.tanh(self.linear3(x))
@@ -183,7 +175,7 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
183175
# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
184176

185177
# Training Q Function
186-
# print((next_log_prob.exp()*self.target_soft_q_net2(next_state)).shape, next_log_prob.shape)
178+
self.alpha = self.log_alpha.exp()
187179
target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state),self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)
188180
target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward
189181
q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) # detach: no gradients for the variable
@@ -203,19 +195,6 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
203195
self.policy_optimizer.zero_grad()
204196
policy_loss.backward()
205197
self.policy_optimizer.step()
206-
207-
# print('q loss: ', q_value_loss1, q_value_loss2)
208-
# print('policy loss: ', policy_loss )
209-
210-
# Soft update the target value net
211-
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
212-
target_param.data.copy_( # copy data value into target parameters
213-
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
214-
)
215-
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
216-
target_param.data.copy_( # copy data value into target parameters
217-
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
218-
)
219198

220199
# Updating alpha wrt entropy
221200
# alpha = 0.0 # trade-off between exploration (max entropy) and exploitation (max Q)
@@ -225,10 +204,22 @@ def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy
225204
self.alpha_optimizer.zero_grad()
226205
alpha_loss.backward()
227206
self.alpha_optimizer.step()
228-
self.alpha = self.log_alpha.exp()
229207
else:
230208
self.alpha = 1.
231209
alpha_loss = 0
210+
211+
# print('q loss: ', q_value_loss1.item(), q_value_loss2.item())
212+
# print('policy loss: ', policy_loss.item() )
213+
214+
# Soft update the target value net
215+
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
216+
target_param.data.copy_( # copy data value into target parameters
217+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
218+
)
219+
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
220+
target_param.data.copy_( # copy data value into target parameters
221+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
222+
)
232223

233224
return predicted_new_q_value.mean()
234225

@@ -266,7 +257,7 @@ def plot(rewards):
266257

267258
# hyper-parameters for RL training
268259
max_episodes = 10000
269-
max_steps = 100
260+
max_steps = 200
270261
frame_idx = 0
271262
batch_size = 256
272263
update_itr = 1
@@ -287,7 +278,6 @@ def plot(rewards):
287278
state = env.reset()
288279
episode_reward = 0
289280

290-
291281
for step in range(max_steps):
292282
action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC)
293283
next_state, reward, done, _ = env.step(action)

sac_v2.png

-64.9 KB
Loading

0 commit comments

Comments
 (0)