Skip to content

Commit e55a95f

Browse files
committed
add dqn multi-step reward
1 parent a6136fd commit e55a95f

File tree

1 file changed

+309
-0
lines changed

1 file changed

+309
-0
lines changed

dqn_multistep.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as f
5+
import torch.optim as optim
6+
import time
7+
import random, numpy, argparse, logging, os
8+
from collections import namedtuple
9+
import numpy as np
10+
import datetime, math
11+
import gym
12+
13+
# Hyper Parameters
14+
MAX_EPI=10000
15+
MAX_STEP = 10000
16+
SAVE_INTERVAL = 20
17+
TARGET_UPDATE_INTERVAL = 20
18+
19+
BATCH_SIZE = 128
20+
REPLAY_BUFFER_SIZE = 100000
21+
REPLAY_START_SIZE = 2000
22+
23+
GAMMA = 0.95
24+
EPSILON = 0.05 # if not using epsilon scheduler, use a constant
25+
EPSILON_START = 1.
26+
EPSILON_END = 0.05
27+
EPSILON_DECAY = 10000
28+
LR = 1e-4 # learning rate
29+
N_MULTI_STEP = 3 # n-step return
30+
31+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32+
33+
class EpsilonScheduler():
34+
def __init__(self, eps_start, eps_final, eps_decay):
35+
"""A scheduler for epsilon-greedy strategy.
36+
37+
:param eps_start: starting value of epsilon, default 1. as purely random policy
38+
:type eps_start: float
39+
:param eps_final: final value of epsilon
40+
:type eps_final: float
41+
:param eps_decay: number of timesteps from eps_start to eps_final
42+
:type eps_decay: int
43+
"""
44+
self.eps_start = eps_start
45+
self.eps_final = eps_final
46+
self.eps_decay = eps_decay
47+
self.epsilon = self.eps_start
48+
self.ini_frame_idx = 0
49+
self.current_frame_idx = 0
50+
51+
def reset(self, ):
52+
""" Reset the scheduler """
53+
self.ini_frame_idx = self.current_frame_idx
54+
55+
def step(self, frame_idx):
56+
self.current_frame_idx = frame_idx
57+
delta_frame_idx = self.current_frame_idx - self.ini_frame_idx
58+
self.epsilon = self.eps_final + (self.eps_start - self.eps_final) * math.exp(-1. * delta_frame_idx / self.eps_decay)
59+
60+
def get_epsilon(self):
61+
return self.epsilon
62+
63+
64+
class QNetwork(nn.Module):
65+
def __init__(self, act_shape, obs_shape, hidden_units=64):
66+
super(QNetwork, self).__init__()
67+
in_dim = obs_shape[0]
68+
out_dim = act_shape
69+
70+
self.linear = nn.Sequential(
71+
nn.Linear(in_dim, hidden_units),
72+
nn.ReLU(),
73+
nn.Linear(hidden_units, hidden_units),
74+
nn.ReLU(),
75+
nn.Linear(hidden_units, hidden_units),
76+
nn.ReLU(),
77+
nn.Linear(hidden_units, out_dim)
78+
)
79+
80+
def forward(self, x):
81+
o = self.linear(x)
82+
return o
83+
84+
class QNetworkCNN(nn.Module):
85+
def __init__(self, num_actions, in_shape, out_channels=8, kernel_size=5, stride=1, hidden_units=256):
86+
super(QNetworkCNN, self).__init__()
87+
88+
self.in_shape = in_shape
89+
in_channels = in_shape[0]
90+
91+
self.conv = nn.Sequential(
92+
nn.Conv2d(in_channels, int(out_channels/2), kernel_size, stride),
93+
nn.ReLU(),
94+
nn.MaxPool2d(kernel_size, stride=2),
95+
nn.Conv2d(int(out_channels/2), int(out_channels), kernel_size, stride),
96+
nn.ReLU(),
97+
nn.MaxPool2d(kernel_size, stride=2)
98+
)
99+
self.conv.apply(self.init_weights)
100+
101+
self.linear = nn.Sequential(
102+
nn.Linear(self.size_after_conv(), hidden_units),
103+
nn.ReLU(),
104+
nn.Linear(hidden_units, num_actions)
105+
)
106+
107+
self.linear.apply(self.init_weights)
108+
109+
def init_weights(self, m):
110+
if type(m) == nn.Conv2d or type(m) == nn.Linear:
111+
torch.nn.init.xavier_uniform(m.weight)
112+
m.bias.data.fill_(0.01)
113+
114+
def size_after_conv(self,):
115+
x = torch.rand(1, *self.in_shape)
116+
o = self.conv(x)
117+
size=1
118+
for i in o.shape[1:]:
119+
size*=i
120+
return int(size)
121+
122+
def forward(self, x):
123+
x = self.conv(x)
124+
o = self.linear(x.view(x.size(0), -1))
125+
return o
126+
127+
transition = namedtuple('transition', 'state, next_state, action, reward, is_terminal')
128+
129+
class ReplayBuffer:
130+
'''
131+
Replay Buffer class to keep the agent memories memorized in a deque structure.
132+
Ref: https://github.com/andri27-ts/Reinforcement-Learning/blob/c57064f747f51d1c495639c7413f5a2be01acd5f/Week3/buffers.py
133+
'''
134+
def __init__(self, buffer_size, n_multi_step, gamma):
135+
self.buffer = []
136+
self.buffer_size = buffer_size
137+
self.n_multi_step = n_multi_step
138+
self.gamma = gamma
139+
self.location = 0
140+
141+
def __len__(self):
142+
return len(self.buffer)
143+
144+
def add(self, samples):
145+
# Append when the buffer is not full but overwrite when the buffer is full
146+
wrap_tensor = lambda x: torch.tensor([x])
147+
if len(self.buffer) < self.buffer_size:
148+
self.buffer.append(transition(*map(wrap_tensor, samples)))
149+
else:
150+
self.buffer[self.location] = transition(*map(wrap_tensor, samples))
151+
152+
# Increment the buffer location
153+
self.location = (self.location + 1) % self.buffer_size
154+
155+
def sample(self, batch_size):
156+
'''
157+
Sample batch_size memories from the buffer.
158+
NB: It deals the N-step DQN
159+
'''
160+
# randomly pick batch_size elements from the buffer
161+
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
162+
samples = []
163+
164+
# for each indices
165+
for i in indices:
166+
sum_reward = 0
167+
states_look_ahead = self.buffer[i].next_state
168+
done_look_ahead = self.buffer[i].is_terminal
169+
170+
# N-step look ahead loop to compute the reward and pick the new 'next_state' (of the n-th state)
171+
for n in range(self.n_multi_step):
172+
if len(self.buffer) > i+n:
173+
# compute the n-th reward
174+
sum_reward += (self.gamma**n) * self.buffer[i+n].reward
175+
if self.buffer[i+n].is_terminal:
176+
states_look_ahead = self.buffer[i+n].next_state
177+
done_look_ahead = self.buffer[i+n].is_terminal
178+
break
179+
else:
180+
states_look_ahead = self.buffer[i+n].next_state
181+
done_look_ahead = self.buffer[i+n].is_terminal
182+
183+
sample = transition(self.buffer[i].state, states_look_ahead, self.buffer[i].action, sum_reward, done_look_ahead)
184+
samples.append(sample)
185+
186+
return samples
187+
188+
class DQN(object):
189+
def __init__(self, env):
190+
self.action_shape = env.action_space.n
191+
self.obs_shape = env.observation_space.shape
192+
self.eval_net, self.target_net = QNetwork(self.action_shape, self.obs_shape).to(device), QNetwork(self.action_shape, self.obs_shape).to(device)
193+
self.learn_step_counter = 0 # for target updating
194+
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
195+
self.loss_func = nn.MSELoss()
196+
self.epsilon_scheduler = EpsilonScheduler(EPSILON_START, EPSILON_END, EPSILON_DECAY)
197+
self.updates = 0
198+
199+
def choose_action(self, x):
200+
# x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0)).to(device)
201+
x = torch.unsqueeze(torch.FloatTensor(x), 0).to(device)
202+
# input only one sample
203+
# if np.random.uniform() > EPSILON: # greedy
204+
epsilon = self.epsilon_scheduler.get_epsilon()
205+
if np.random.uniform() > epsilon: # greedy
206+
actions_value = self.eval_net.forward(x)
207+
action = torch.max(actions_value, 1)[1].data.cpu().numpy()[0] # return the argmax
208+
else: # random
209+
action = np.random.randint(0, self.action_shape)
210+
return action
211+
212+
def learn(self, sample,):
213+
# Batch is a list of namedtuple's, the following operation returns samples grouped by keys
214+
batch_samples = transition(*zip(*sample))
215+
216+
# states, next_states are of tensor (BATCH_SIZE, in_channel, 10, 10) - inline with pytorch NCHW format
217+
# actions, rewards, is_terminal are of tensor (BATCH_SIZE, 1)
218+
states = torch.cat(batch_samples.state).float().to(device)
219+
next_states = torch.cat(batch_samples.next_state).float().to(device)
220+
actions = torch.cat(batch_samples.action).to(device)
221+
rewards = torch.cat(batch_samples.reward).float().to(device)
222+
is_terminal = torch.cat(batch_samples.is_terminal).to(device)
223+
# Obtain a batch of Q(S_t, A_t) and compute the forward pass.
224+
# Note: policy_network output Q-values for all the actions of a state, but all we need is the A_t taken at time t
225+
# in state S_t. Thus we gather along the columns and get the Q-values corresponds to S_t, A_t.
226+
# Q_s_a is of size (BATCH_SIZE, 1).
227+
Q = self.eval_net(states)
228+
Q_s_a=Q.gather(1, actions)
229+
230+
# Obtain max_{a} Q(S_{t+1}, a) of any non-terminal state S_{t+1}. If S_{t+1} is terminal, Q(S_{t+1}, A_{t+1}) = 0.
231+
# Note: each row of the network's output corresponds to the actions of S_{t+1}. max(1)[0] gives the max action
232+
# values in each row (since this a batch). The detach() detaches the target net's tensor from computation graph so
233+
# to prevent the computation of its gradient automatically. Q_s_prime_a_prime is of size (BATCH_SIZE, 1).
234+
235+
# Get the indices of next_states that are not terminal
236+
none_terminal_next_state_index = torch.tensor([i for i, is_term in enumerate(is_terminal) if is_term == 0], dtype=torch.int64, device=device)
237+
# Select the indices of each row
238+
none_terminal_next_states = next_states.index_select(0, none_terminal_next_state_index)
239+
240+
Q_s_prime_a_prime = torch.zeros(len(sample), 1, device=device)
241+
if len(none_terminal_next_states) != 0:
242+
Q_s_prime_a_prime[none_terminal_next_state_index] = self.target_net(none_terminal_next_states).detach().max(1)[0].unsqueeze(1)
243+
244+
# Q_s_prime_a_prime = self.target_net(next_states).detach().max(1, keepdim=True)[0] # this one is simpler regardless of terminal state
245+
Q_s_prime_a_prime = (Q_s_prime_a_prime-Q_s_prime_a_prime.mean())/ (Q_s_prime_a_prime.std() + 1e-5) # normalization
246+
247+
# Compute the target
248+
target = rewards + (GAMMA ** N_MULTI_STEP) * Q_s_prime_a_prime
249+
250+
# Update with loss
251+
# loss = self.loss_func(target.detach(), Q_s_a)
252+
loss = f.smooth_l1_loss(target.detach(), Q_s_a)
253+
# Zero gradients, backprop, update the weights of policy_net
254+
self.optimizer.zero_grad()
255+
loss.backward()
256+
self.optimizer.step()
257+
258+
self.updates += 1
259+
if self.updates % TARGET_UPDATE_INTERVAL == 0:
260+
self.update_target()
261+
262+
return loss.item()
263+
264+
def save_model(self, model_path=None):
265+
torch.save(self.eval_net.state_dict(), 'model/dqn')
266+
267+
def update_target(self, ):
268+
"""
269+
Update the target model when necessary.
270+
"""
271+
self.target_net.load_state_dict(self.eval_net.state_dict())
272+
273+
def rollout(env, model):
274+
r_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE, N_MULTI_STEP, GAMMA)
275+
log = []
276+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
277+
print('\nCollecting experience...')
278+
total_step = 0
279+
for epi in range(MAX_EPI):
280+
s=env.reset()
281+
epi_r = 0
282+
epi_loss = 0
283+
for step in range(MAX_STEP):
284+
# env.render()
285+
total_step += 1
286+
a = model.choose_action(s)
287+
s_, r, done, info = env.step(a)
288+
# r_buffer.add(torch.tensor([s]), torch.tensor([s_]), torch.tensor([[a]]), torch.tensor([[r]], dtype=torch.float), torch.tensor([[done]]))
289+
r_buffer.add([s,s_,[a],[r],[done]])
290+
model.epsilon_scheduler.step(total_step)
291+
epi_r += r
292+
if total_step > REPLAY_START_SIZE and len(r_buffer.buffer) >= BATCH_SIZE:
293+
sample = r_buffer.sample(BATCH_SIZE)
294+
loss = model.learn(sample)
295+
epi_loss += loss
296+
if done:
297+
break
298+
s = s_
299+
print('Ep: ', epi, '| Ep_r: ', epi_r, '| Steps: ', step, f'| Ep_Loss: {epi_loss:.4f}', )
300+
log.append([epi, epi_r, step])
301+
if epi % SAVE_INTERVAL == 0:
302+
model.save_model()
303+
np.save('log/'+timestamp, log)
304+
305+
if __name__ == '__main__':
306+
env = gym.make('CartPole-v1')
307+
print(env.observation_space, env.action_space)
308+
model = DQN(env)
309+
rollout(env, model)

0 commit comments

Comments
 (0)