Skip to content

Commit f7da32f

Browse files
committed
add per buffer
1 parent c7e3300 commit f7da32f

File tree

4 files changed

+350
-5
lines changed

4 files changed

+350
-5
lines changed
1.77 KB
Binary file not shown.

common/buffers.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,56 @@ def __len__(
3636
def get_length(self):
3737
return len(self.buffer)
3838

39+
class ReplayBufferPER:
40+
"""
41+
Replay buffer with Prioritized Experience Replay (PER),
42+
TD error as sampling weights. This is a simple version without sumtree.
43+
44+
Reference:
45+
https://github.com/Felhof/DiscreteSAC/blob/main/utilities/ReplayBuffer.py
46+
"""
47+
def __init__(self, capacity):
48+
self.capacity = capacity
49+
self.buffer = []
50+
self.position = 0
51+
self.weights = np.zeros(int(capacity))
52+
self.max_weight = 10**-2
53+
self.delta = 10**-4
54+
55+
def push(self, state, action, reward, next_state, done):
56+
if len(self.buffer) < self.capacity:
57+
self.buffer.append(None)
58+
self.buffer[self.position] = (state, action, reward, next_state, done)
59+
self.weights[self.position] = self.max_weight # new sample has max weights
60+
61+
self.position = int((self.position + 1) % self.capacity) # as a ring buffer
62+
63+
def sample(self, batch_size):
64+
set_weights = self.weights[:self.position] + self.delta
65+
probabilities = set_weights / sum(set_weights)
66+
self.indices = np.random.choice(range(self.position), batch_size, p=probabilities, replace=False)
67+
batch = np.array(self.buffer)[list(self.indices)]
68+
state, action, reward, next_state, done = map(np.stack,
69+
zip(*batch)) # stack for each element
70+
'''
71+
the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ;
72+
zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ;
73+
the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ;
74+
np.stack((1,2)) => array([1, 2])
75+
'''
76+
return state, action, reward, next_state, done
77+
78+
def update_weights(self, prediction_errors):
79+
max_error = max(prediction_errors)
80+
self.max_weight = max(self.max_weight, max_error)
81+
self.weights[self.indices] = prediction_errors
82+
83+
def __len__(
84+
self): # this is a stupid func! cannot work in multiprocessing case, len(replay_buffer) is not available in proxy of manager!
85+
return len(self.buffer)
86+
87+
def get_length(self):
88+
return len(self.buffer)
3989

4090
class ReplayBufferLSTM:
4191
"""
@@ -73,9 +123,6 @@ def __len__(
73123
def get_length(self):
74124
return len(self.buffer)
75125

76-
77-
78-
79126
class ReplayBufferLSTM2:
80127
"""
81128
Replay buffer for agent with LSTM network additionally storing previous action,
@@ -128,8 +175,6 @@ def get_length(self):
128175
return len(self.buffer)
129176

130177

131-
132-
133178
class ReplayBufferGRU:
134179
"""
135180
Replay buffer for agent with GRU network additionally storing previous action,

sac_discrete_per.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
'''
2+
Soft Actor-Critic version 2
3+
using target Q instead of V net: 2 Q net, 2 target Q net, 1 policy net
4+
add alpha loss compared with version 1
5+
paper: https://arxiv.org/pdf/1812.05905.pdf
6+
7+
Discrete version reference:
8+
https://towardsdatascience.com/adapting-soft-actor-critic-for-discrete-action-spaces-a20614d4a50a
9+
'''
10+
11+
12+
import random
13+
import gym
14+
import numpy as np
15+
import torch
16+
import torch.nn as nn
17+
import torch.optim as optim
18+
import torch.nn.functional as F
19+
from torch.distributions import Categorical
20+
from IPython.display import clear_output
21+
import matplotlib.pyplot as plt
22+
import argparse
23+
from common.buffers import ReplayBufferPER
24+
25+
GPU = True
26+
device_idx = 0
27+
if GPU:
28+
device = torch.device("cuda:" + str(device_idx) if torch.cuda.is_available() else "cpu")
29+
else:
30+
device = torch.device("cpu")
31+
print(device)
32+
33+
34+
parser = argparse.ArgumentParser(description='Train or test neural net motor controller.')
35+
parser.add_argument('--train', dest='train', action='store_true', default=False)
36+
parser.add_argument('--test', dest='test', action='store_true', default=False)
37+
38+
args = parser.parse_args()
39+
40+
class SoftQNetwork(nn.Module):
41+
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3):
42+
super(SoftQNetwork, self).__init__()
43+
44+
self.linear1 = nn.Linear(num_inputs, hidden_size)
45+
self.linear2 = nn.Linear(hidden_size, hidden_size)
46+
# self.linear3 = nn.Linear(hidden_size, hidden_size)
47+
self.linear4 = nn.Linear(hidden_size, num_actions)
48+
49+
self.linear4.weight.data.uniform_(-init_w, init_w)
50+
self.linear4.bias.data.uniform_(-init_w, init_w)
51+
52+
def forward(self, state):
53+
x = F.tanh(self.linear1(state))
54+
x = F.tanh(self.linear2(x))
55+
# x = F.tanh(self.linear3(x))
56+
x = self.linear4(x)
57+
return x
58+
59+
60+
class PolicyNetwork(nn.Module):
61+
def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
62+
super(PolicyNetwork, self).__init__()
63+
64+
self.linear1 = nn.Linear(num_inputs, hidden_size)
65+
self.linear2 = nn.Linear(hidden_size, hidden_size)
66+
# self.linear3 = nn.Linear(hidden_size, hidden_size)
67+
# self.linear4 = nn.Linear(hidden_size, hidden_size)
68+
69+
self.output = nn.Linear(hidden_size, num_actions)
70+
71+
self.num_actions = num_actions
72+
73+
def forward(self, state, softmax_dim=-1):
74+
x = F.tanh(self.linear1(state))
75+
x = F.tanh(self.linear2(x))
76+
# x = F.tanh(self.linear3(x))
77+
# x = F.tanh(self.linear4(x))
78+
79+
probs = F.softmax(self.output(x), dim=softmax_dim)
80+
81+
return probs
82+
83+
def evaluate(self, state, epsilon=1e-6):
84+
'''
85+
generate sampled action with state as input wrt the policy network;
86+
'''
87+
probs = self.forward(state, softmax_dim=-1)
88+
log_probs = torch.log(probs)
89+
return log_probs
90+
91+
def get_action(self, state, deterministic):
92+
state = torch.FloatTensor(state).unsqueeze(0).to(device)
93+
probs = self.forward(state)
94+
dist = Categorical(probs)
95+
96+
if deterministic:
97+
action = np.argmax(probs.detach().cpu().numpy())
98+
else:
99+
action = dist.sample().squeeze().detach().cpu().numpy()
100+
return action
101+
102+
103+
class SAC_Trainer():
104+
def __init__(self, replay_buffer, hidden_dim):
105+
self.replay_buffer = replay_buffer
106+
107+
self.soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
108+
self.soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
109+
self.target_soft_q_net1 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
110+
self.target_soft_q_net2 = SoftQNetwork(state_dim, action_dim, hidden_dim).to(device)
111+
self.policy_net = PolicyNetwork(state_dim, action_dim, hidden_dim).to(device)
112+
self.log_alpha = torch.zeros(1, dtype=torch.float32, requires_grad=True, device=device)
113+
print('Soft Q Network (1,2): ', self.soft_q_net1)
114+
print('Policy Network: ', self.policy_net)
115+
116+
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
117+
target_param.data.copy_(param.data)
118+
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
119+
target_param.data.copy_(param.data)
120+
121+
self.soft_q_criterion1 = nn.MSELoss(reduction="none")
122+
self.soft_q_criterion2 = nn.MSELoss(reduction="none")
123+
124+
soft_q_lr = 3e-4
125+
policy_lr = 3e-4
126+
alpha_lr = 3e-4
127+
128+
self.soft_q_optimizer1 = optim.Adam(self.soft_q_net1.parameters(), lr=soft_q_lr)
129+
self.soft_q_optimizer2 = optim.Adam(self.soft_q_net2.parameters(), lr=soft_q_lr)
130+
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
131+
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
132+
133+
134+
def update(self, batch_size, reward_scale=10., auto_entropy=True, target_entropy=-2, gamma=0.99, soft_tau=1e-2):
135+
state, action, reward, next_state, done = self.replay_buffer.sample(batch_size)
136+
# print('sample:', state, action, reward, done)
137+
138+
state = torch.FloatTensor(state).to(device)
139+
next_state = torch.FloatTensor(next_state).to(device)
140+
action = torch.Tensor(action).to(torch.int64).to(device)
141+
reward = torch.FloatTensor(reward).unsqueeze(1).to(device) # reward is single value, unsqueeze() to add one dim to be [reward] at the sample dim;
142+
done = torch.FloatTensor(np.float32(done)).unsqueeze(1).to(device)
143+
predicted_q_value1 = self.soft_q_net1(state)
144+
predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1))
145+
predicted_q_value2 = self.soft_q_net2(state)
146+
predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))
147+
log_prob = self.policy_net.evaluate(state)
148+
next_log_prob = self.policy_net.evaluate(next_state)
149+
# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem
150+
151+
# Training Q Function
152+
self.alpha = self.log_alpha.exp()
153+
target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state),self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)
154+
target_q_value = reward + (1 - done) * gamma * target_q_min # if done==1, only reward
155+
q_value_loss1 = self.soft_q_criterion1(predicted_q_value1, target_q_value.detach()) # detach: no gradients for the variable
156+
q_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())
157+
weight_update = [min(l1.item(), l2.item()) for l1, l2 in zip(q_value_loss1, q_value_loss2)]
158+
self.replay_buffer.update_weights(weight_update) # update sample weights with td error
159+
160+
self.soft_q_optimizer1.zero_grad()
161+
q_value_loss1.mean().backward()
162+
self.soft_q_optimizer1.step()
163+
self.soft_q_optimizer2.zero_grad()
164+
q_value_loss2.mean().backward()
165+
self.soft_q_optimizer2.step()
166+
167+
# Training Policy Function
168+
predicted_new_q_value = torch.min(self.soft_q_net1(state),self.soft_q_net2(state))
169+
policy_loss = (log_prob.exp()*(self.alpha * log_prob - predicted_new_q_value)).sum(dim=-1).mean()
170+
171+
self.policy_optimizer.zero_grad()
172+
policy_loss.backward()
173+
self.policy_optimizer.step()
174+
175+
# Updating alpha wrt entropy
176+
# alpha = 0.0 # trade-off between exploration (max entropy) and exploitation (max Q)
177+
if auto_entropy is True:
178+
alpha_loss = -(self.log_alpha * (log_prob + target_entropy).detach()).mean()
179+
# print('alpha loss: ',alpha_loss)
180+
self.alpha_optimizer.zero_grad()
181+
alpha_loss.backward()
182+
self.alpha_optimizer.step()
183+
else:
184+
self.alpha = 1.
185+
alpha_loss = 0
186+
187+
# print('q loss: ', q_value_loss1.item(), q_value_loss2.item())
188+
# print('policy loss: ', policy_loss.item() )
189+
190+
# Soft update the target value net
191+
for target_param, param in zip(self.target_soft_q_net1.parameters(), self.soft_q_net1.parameters()):
192+
target_param.data.copy_( # copy data value into target parameters
193+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
194+
)
195+
for target_param, param in zip(self.target_soft_q_net2.parameters(), self.soft_q_net2.parameters()):
196+
target_param.data.copy_( # copy data value into target parameters
197+
target_param.data * (1.0 - soft_tau) + param.data * soft_tau
198+
)
199+
200+
return predicted_new_q_value.mean()
201+
202+
203+
def save_model(self, path):
204+
torch.save(self.soft_q_net1.state_dict(), path+'_q1')
205+
torch.save(self.soft_q_net2.state_dict(), path+'_q2')
206+
torch.save(self.policy_net.state_dict(), path+'_policy')
207+
208+
def load_model(self, path):
209+
self.soft_q_net1.load_state_dict(torch.load(path+'_q1'))
210+
self.soft_q_net2.load_state_dict(torch.load(path+'_q2'))
211+
self.policy_net.load_state_dict(torch.load(path+'_policy'))
212+
213+
self.soft_q_net1.eval()
214+
self.soft_q_net2.eval()
215+
self.policy_net.eval()
216+
217+
218+
def plot(rewards):
219+
clear_output(True)
220+
plt.figure(figsize=(20,5))
221+
plt.plot(rewards)
222+
plt.savefig('sac_v2.png')
223+
# plt.show()
224+
225+
226+
replay_buffer_size = 1e6
227+
replay_buffer = ReplayBufferPER(replay_buffer_size)
228+
229+
# choose env
230+
env = gym.make('CartPole-v1')
231+
232+
state_dim = env.observation_space.shape[0]
233+
action_dim = env.action_space.n # discrete
234+
235+
# hyper-parameters for RL training
236+
max_episodes = 10000
237+
max_steps = 200
238+
frame_idx = 0
239+
batch_size = 256
240+
update_itr = 1
241+
AUTO_ENTROPY=True
242+
DETERMINISTIC=False
243+
hidden_dim = 64
244+
rewards = []
245+
model_path = './model/sac_discrete_v2'
246+
target_entropy = -1.*action_dim
247+
# target_entropy = 0.98 * -np.log(1 / action_dim)
248+
249+
sac_trainer=SAC_Trainer(replay_buffer, hidden_dim=hidden_dim)
250+
251+
if __name__ == '__main__':
252+
if args.train:
253+
# training loop
254+
for eps in range(max_episodes):
255+
state = env.reset()
256+
episode_reward = 0
257+
258+
for step in range(max_steps):
259+
action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC)
260+
next_state, reward, done, _ = env.step(action)
261+
# env.render()
262+
263+
replay_buffer.push(state, action, reward, next_state, done)
264+
265+
state = next_state
266+
episode_reward += reward
267+
frame_idx += 1
268+
269+
270+
if len(replay_buffer) > batch_size:
271+
for i in range(update_itr):
272+
_=sac_trainer.update(batch_size, reward_scale=1., auto_entropy=AUTO_ENTROPY, target_entropy=target_entropy)
273+
274+
if done:
275+
break
276+
277+
if eps % 20 == 0 and eps>0: # plot and model saving interval
278+
plot(rewards)
279+
np.save('rewards', rewards)
280+
sac_trainer.save_model(model_path)
281+
print('Episode: ', eps, '| Episode Reward: ', episode_reward, '| Episode Length: ', step)
282+
rewards.append(episode_reward)
283+
sac_trainer.save_model(model_path)
284+
285+
if args.test:
286+
sac_trainer.load_model(model_path)
287+
for eps in range(10):
288+
state = env.reset()
289+
episode_reward = 0
290+
291+
for step in range(max_steps):
292+
action = sac_trainer.policy_net.get_action(state, deterministic = DETERMINISTIC)
293+
next_state, reward, done, _ = env.step(action)
294+
env.render()
295+
296+
297+
episode_reward += reward
298+
state=next_state
299+
300+
print('Episode: ', eps, '| Episode Reward: ', episode_reward)

sac_v2.png

-6.42 KB
Loading

0 commit comments

Comments
 (0)