Skip to content

Commit 7879a34

Browse files
committed
add sac discrete
1 parent a7a6871 commit 7879a34

File tree

2 files changed

+333
-0
lines changed

2 files changed

+333
-0
lines changed

sac_discrete.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
'''
2+
Soft Actor-Critic version 2
3+
using target Q instead of V net: 2 Q net, 2 target Q net, 1 policy net
4+
add alpha loss compared with version 1
5+
paper: https://arxiv.org/pdf/1812.05905.pdf
6+
7+
Discrete version reference:
8+
https://towardsdatascience.com/adapting-soft-actor-critic-for-discrete-action-spaces-a20614d4a50a
9+
'''
10+
11+
12+
import math
13+
import random
14+
15+
import gym
16+
import numpy as np
17+
18+
import torch
19+
import torch.nn as nn
20+
import torch.optim as optim
21+
import torch.nn.functional as F
22+
from torch.distributions import Categorical
23+
24+
from IPython.display import clear_output
25+
import matplotlib.pyplot as plt
26+
from matplotlib import animation
27+
from IPython.display import display
28+
29+
import argparse
30+
import time
31+
32+
GPU = True
33+
device_idx = 0
34+
if GPU:
35+
device = torch.device("cuda:" + str(device_idx) if torch.cuda.is_available() else "cpu")
36+
else:
37+
device = torch.device("cpu")
38+
print(device)
39+
40+
41+
parser = argparse.ArgumentParser(description='Train or test neural net motor controller.')
42+
parser.add_argument('--train', dest='train', action='store_true', default=False)
43+
parser.add_argument('--test', dest='test', action='store_true', default=False)
44+
45+
args = parser.parse_args()
46+
47+
48+
class ReplayBuffer:
49+
def __init__(self, capacity):
50+
self.capacity = capacity
51+
self.buffer = []
52+
self.position = 0
53+
54+
def push(self, state, action, reward, next_state, done):
55+
if len(self.buffer) < self.capacity:
56+
self.buffer.append(None)
57+
self.buffer[self.position] = (state, action, reward, next_state, done)
58+
self.position = int((self.position + 1) % self.capacity) # as a ring buffer
59+
60+
def sample(self, batch_size):
61+
batch = random.sample(self.buffer, batch_size)
62+
state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element
63+
'''
64+
the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ;
65+
zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ;
66+
the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ;
67+
np.stack((1,2)) => array([1, 2])
68+
'''
69+
return state, action, reward, next_state, done
70+
71+
def __len__(self):
72+
return len(self.buffer)
73+
74+
class SoftQNetwork(nn.Module):
75+
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
76+
super(SoftQNetwork, self).__init__()
77+
78+
self.linear1 = nn.Linear(num_inputs, hidden_size)
79+
self.linear2 = nn.Linear(hidden_size, hidden_size)
80+
# self.linear3 = nn.Linear(hidden_size, hidden_size)
81+
self.linear4 = nn.Linear(hidden_size, num_actions)
82+
83+
self.linear4.weight.data.uniform_(-init_w, init_w)
84+
self.linear4.bias.data.uniform_(-init_w, init_w)
85+
86+
def forward(self, state):
87+
x = F.tanh(self.linear1(state))
88+
x = F.tanh(self.linear2(x))
89+
# x = F.tanh(self.linear3(x))
90+
x = self.linear4(x)
91+
return x
92+
93+
94+
class PolicyNetwork(nn.Module):
95+
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
96+
super(PolicyNetwork, self).__init__()
97+
98+
self.linear1 = nn.Linear(num_inputs, hidden_size)
99+
self.linear2 = nn.Linear(hidden_size, hidden_size)
100+
# self.linear3 = nn.Linear(hidden_size, hidden_size)
101+
# self.linear4 = nn.Linear(hidden_size, hidden_size)
102+
103+
self.output = nn.Linear(hidden_size, num_actions)
104+
105+
self.num_actions = num_actions
106+
107+
def forward(self, state, softmax_dim=0):
108+
x = F.tanh(self.linear1(state))
109+
x = F.tanh(self.linear2(x))
110+
# x = F.tanh(self.linear3(x))
111+
# x = F.tanh(self.linear4(x))
112+
113+
probs = F.softmax(self.output(x), dim=softmax_dim)
114+
115+
return probs
116+
117+
def evaluate(self, state, epsilon=1e-6):
118+
'''
119+
generate sampled action with state as input wrt the policy network;
120+
'''
121+
probs = self.forward(state, softmax_dim=-1)
122+
log_probs = torch.log(probs)
123+
return log_probs
124+
125+
def get_action(self, state, deterministic):
126+
state = torch.FloatTensor(state).unsqueeze(0).to(device)
127+
probs = self.forward(state)
128+
dist = Categorical(probs)
129+
130+
if deterministic:
131+
action = np.argmax(probs.detach().cpu().numpy())
132+
else:
133+
action = dist.sample().squeeze().detach().cpu().numpy()
134+
return action
135+
136+
137+
class SAC_Trainer():
138+
def __init__(self, replay_buffer, hidden_dim):
139+
self.replay_buffer = replay_buffer
140+
141+
self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
142+
self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
143+
self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
144+
self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
145+
self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
146+
self.log_alpha = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=device)
147+
print('Soft Q Network (1,2): ', self.soft_q_net1)
148+
print('Policy Network: ', self.policy_net)
149+
150+
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
151+
target_param.data.copy_(param.data)
152+
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
153+
target_param.data.copy_(param.data)
154+
155+
self.soft_q_criterion1 = nn.MSELoss()
156+
self.soft_q_criterion2 = nn.MSELoss()
157+
158+
soft_q_lr = 3e-4
159+
policy_lr = 3e-4
160+
alpha_lr = 3e-4
161+
162+
self.soft_q_optimizer1 = optim.Adam(self.soft_q_net1.parameters(), lr=soft_q_lr)
163+
self.soft_q_optimizer2 = optim.Adam(self.soft_q_net2.parameters(), lr=soft_q_lr)
164+
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
165+
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
166+
167+
168+
def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy=-2, gamma=0.99, soft_tau=1e-2):
169+
state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
170+
# print('sample:', state, action, reward, done)
171+
172+
state = torch.FloatTensor(state).to(device)
173+
next_state = torch.FloatTensor(next_state).to(device)
174+
action = torch.Tensor(action).to(torch.int64).to(device)
175+
reward = torch.FloatTensor(reward).unsqueeze(1).to(device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim;
176+
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)
177+
predicted_q_value1 = self.soft_q_net1(state)
178+
predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1))
179+
predicted_q_value2 = self.soft_q_net2(state)
180+
predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))
181+
log_prob = self.policy_net.evaluate(state)
182+
next_log_prob = self.policy_net.evaluate(next_state)
183+
# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
184+
185+
# Training Q Function
186+
# print((next_log_prob.exp()*self.target_soft_q_net2(next_state)).shape, next_log_prob.shape)
187+
target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state),self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)
188+
target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward
189+
q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) # detach: no gradients for the variable
190+
q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())
191+
192+
self.soft_q_optimizer1.zero_grad()
193+
q_value_loss1.backward()
194+
self.soft_q_optimizer1.step()
195+
self.soft_q_optimizer2.zero_grad()
196+
q_value_loss2.backward()
197+
self.soft_q_optimizer2.step()
198+
199+
# Training Policy Function
200+
predicted_new_q_value = torch.min(self.soft_q_net1(state),self.soft_q_net2(state))
201+
policy_loss = (log_prob.exp()*(self.alpha * log_prob - predicted_new_q_value)).sum(dim=-1).mean()
202+
203+
self.policy_optimizer.zero_grad()
204+
policy_loss.backward()
205+
self.policy_optimizer.step()
206+
207+
# print('q loss: ', q_value_loss1, q_value_loss2)
208+
# print('policy loss: ', policy_loss )
209+
210+
# Soft update the target value net
211+
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
212+
target_param.data.copy_( # copy data value into target parameters
213+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
214+
)
215+
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
216+
target_param.data.copy_( # copy data value into target parameters
217+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
218+
)
219+
220+
# Updating alpha wrt entropy
221+
# alpha = 0.0 # trade-off between exploration (max entropy) and exploitation (max Q)
222+
if auto_entropy is True:
223+
alpha_loss = -(self.log_alpha * (log_prob + target_entropy).detach()).mean()
224+
# print('alpha loss: ',alpha_loss)
225+
self.alpha_optimizer.zero_grad()
226+
alpha_loss.backward()
227+
self.alpha_optimizer.step()
228+
self.alpha = self.log_alpha.exp()
229+
else:
230+
self.alpha = 1.
231+
alpha_loss = 0
232+
233+
return predicted_new_q_value.mean()
234+
235+
236+
def save_model(self, path):
237+
torch.save(self.soft_q_net1.state_dict(), path+'_q1')
238+
torch.save(self.soft_q_net2.state_dict(), path+'_q2')
239+
torch.save(self.policy_net.state_dict(), path+'_policy')
240+
241+
def load_model(self, path):
242+
self.soft_q_net1.load_state_dict(torch.load(path+'_q1'))
243+
self.soft_q_net2.load_state_dict(torch.load(path+'_q2'))
244+
self.policy_net.load_state_dict(torch.load(path+'_policy'))
245+
246+
self.soft_q_net1.eval()
247+
self.soft_q_net2.eval()
248+
self.policy_net.eval()
249+
250+
251+
def plot(rewards):
252+
clear_output(True)
253+
plt.figure(figsize=(20,5))
254+
plt.plot(rewards)
255+
plt.savefig('sac_v2.png')
256+
# plt.show()
257+
258+
259+
replay_buffer_size = 1e6
260+
replay_buffer = ReplayBuffer(replay_buffer_size)
261+
262+
# choose env
263+
env = gym.make('CartPole-v1')
264+
state_dim = env.observation_space.shape[0]
265+
action_dim = env.action_space.n # discrete
266+
267+
# hyper-parameters for RL training
268+
max_episodes = 10000
269+
max_steps = 100
270+
frame_idx = 0
271+
batch_size = 256
272+
update_itr = 1
273+
AUTO_ENTROPY=False
274+
DETERMINISTIC=False
275+
hidden_dim = 64
276+
rewards = []
277+
model_path = './model/sac_discrete_v2'
278+
# target_entropy = -1.*action_dim
279+
target_entropy = 0.98 * -np.log(1 / action_dim)
280+
281+
sac_trainer=SAC_Trainer(replay_buffer, hidden_dim=hidden_dim)
282+
283+
if __name__ == '__main__':
284+
if args.train:
285+
# training loop
286+
for eps in range(max_episodes):
287+
state = env.reset()
288+
episode_reward = 0
289+
290+
291+
for step in range(max_steps):
292+
action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC)
293+
next_state, reward, done, _ = env.step(action)
294+
# env.render()
295+
296+
replay_buffer.push(state, action, reward, next_state, done)
297+
298+
state = next_state
299+
episode_reward += reward
300+
frame_idx += 1
301+
302+
303+
if len(replay_buffer) > batch_size:
304+
for i in range(update_itr):
305+
_=sac_trainer.update(batch_size, reward_scale=1., auto_entropy=AUTO_ENTROPY, target_entropy=target_entropy)
306+
307+
if done:
308+
break
309+
310+
if eps % 20 == 0 and eps>0: # plot and model saving interval
311+
plot(rewards)
312+
np.save('rewards', rewards)
313+
sac_trainer.save_model(model_path)
314+
print('Episode: ', eps, '| Episode Reward: ', episode_reward, '| Episode Length: ', step)
315+
rewards.append(episode_reward)
316+
sac_trainer.save_model(model_path)
317+
318+
if args.test:
319+
sac_trainer.load_model(model_path)
320+
for eps in range(10):
321+
state = env.reset()
322+
episode_reward = 0
323+
324+
for step in range(max_steps):
325+
action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC)
326+
next_state, reward, done, _ = env.step(action)
327+
env.render()
328+
329+
330+
episode_reward += reward
331+
state=next_state
332+
333+
print('Episode: ', eps, '| Episode Reward: ', episode_reward)

sac_v2.png

115 KB
Loading

0 commit comments

Comments
 (0)