1
+
2
+ import torch
3
+ import torch .nn as nn
4
+ import torch .nn .functional as f
5
+ import torch .optim as optim
6
+ import time
7
+ import random , numpy , argparse , logging , os
8
+ from collections import namedtuple
9
+ import numpy as np
10
+ import datetime , math
11
+ import gym
12
+
13
+ # Hyper Parameters
14
+ MAX_EPI = 10000
15
+ MAX_STEP = 10000
16
+ SAVE_INTERVAL = 20
17
+ TARGET_UPDATE_INTERVAL = 20
18
+
19
+ BATCH_SIZE = 128
20
+ REPLAY_BUFFER_SIZE = 100000
21
+ REPLAY_START_SIZE = 2000
22
+
23
+ GAMMA = 0.95
24
+ EPSILON = 0.05 # if not using epsilon scheduler, use a constant
25
+ EPSILON_START = 1.
26
+ EPSILON_END = 0.05
27
+ EPSILON_DECAY = 10000
28
+ LR = 1e-4 # learning rate
29
+ N_MULTI_STEP = 3 # n-step return
30
+
31
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
32
+
33
+ class EpsilonScheduler ():
34
+ def __init__ (self , eps_start , eps_final , eps_decay ):
35
+ """A scheduler for epsilon-greedy strategy.
36
+
37
+ :param eps_start: starting value of epsilon, default 1. as purely random policy
38
+ :type eps_start: float
39
+ :param eps_final: final value of epsilon
40
+ :type eps_final: float
41
+ :param eps_decay: number of timesteps from eps_start to eps_final
42
+ :type eps_decay: int
43
+ """
44
+ self .eps_start = eps_start
45
+ self .eps_final = eps_final
46
+ self .eps_decay = eps_decay
47
+ self .epsilon = self .eps_start
48
+ self .ini_frame_idx = 0
49
+ self .current_frame_idx = 0
50
+
51
+ def reset (self , ):
52
+ """ Reset the scheduler """
53
+ self .ini_frame_idx = self .current_frame_idx
54
+
55
+ def step (self , frame_idx ):
56
+ self .current_frame_idx = frame_idx
57
+ delta_frame_idx = self .current_frame_idx - self .ini_frame_idx
58
+ self .epsilon = self .eps_final + (self .eps_start - self .eps_final ) * math .exp (- 1. * delta_frame_idx / self .eps_decay )
59
+
60
+ def get_epsilon (self ):
61
+ return self .epsilon
62
+
63
+
64
+ class QNetwork (nn .Module ):
65
+ def __init__ (self , act_shape , obs_shape , hidden_units = 64 ):
66
+ super (QNetwork , self ).__init__ ()
67
+ in_dim = obs_shape [0 ]
68
+ out_dim = act_shape
69
+
70
+ self .linear = nn .Sequential (
71
+ nn .Linear (in_dim , hidden_units ),
72
+ nn .ReLU (),
73
+ nn .Linear (hidden_units , hidden_units ),
74
+ nn .ReLU (),
75
+ nn .Linear (hidden_units , hidden_units ),
76
+ nn .ReLU (),
77
+ nn .Linear (hidden_units , out_dim )
78
+ )
79
+
80
+ def forward (self , x ):
81
+ o = self .linear (x )
82
+ return o
83
+
84
+ class QNetworkCNN (nn .Module ):
85
+ def __init__ (self , num_actions , in_shape , out_channels = 8 , kernel_size = 5 , stride = 1 , hidden_units = 256 ):
86
+ super (QNetworkCNN , self ).__init__ ()
87
+
88
+ self .in_shape = in_shape
89
+ in_channels = in_shape [0 ]
90
+
91
+ self .conv = nn .Sequential (
92
+ nn .Conv2d (in_channels , int (out_channels / 2 ), kernel_size , stride ),
93
+ nn .ReLU (),
94
+ nn .MaxPool2d (kernel_size , stride = 2 ),
95
+ nn .Conv2d (int (out_channels / 2 ), int (out_channels ), kernel_size , stride ),
96
+ nn .ReLU (),
97
+ nn .MaxPool2d (kernel_size , stride = 2 )
98
+ )
99
+ self .conv .apply (self .init_weights )
100
+
101
+ self .linear = nn .Sequential (
102
+ nn .Linear (self .size_after_conv (), hidden_units ),
103
+ nn .ReLU (),
104
+ nn .Linear (hidden_units , num_actions )
105
+ )
106
+
107
+ self .linear .apply (self .init_weights )
108
+
109
+ def init_weights (self , m ):
110
+ if type (m ) == nn .Conv2d or type (m ) == nn .Linear :
111
+ torch .nn .init .xavier_uniform (m .weight )
112
+ m .bias .data .fill_ (0.01 )
113
+
114
+ def size_after_conv (self ,):
115
+ x = torch .rand (1 , * self .in_shape )
116
+ o = self .conv (x )
117
+ size = 1
118
+ for i in o .shape [1 :]:
119
+ size *= i
120
+ return int (size )
121
+
122
+ def forward (self , x ):
123
+ x = self .conv (x )
124
+ o = self .linear (x .view (x .size (0 ), - 1 ))
125
+ return o
126
+
127
+ transition = namedtuple ('transition' , 'state, next_state, action, reward, is_terminal' )
128
+
129
+ class ReplayBuffer :
130
+ '''
131
+ Replay Buffer class to keep the agent memories memorized in a deque structure.
132
+ Ref: https://github.com/andri27-ts/Reinforcement-Learning/blob/c57064f747f51d1c495639c7413f5a2be01acd5f/Week3/buffers.py
133
+ '''
134
+ def __init__ (self , buffer_size , n_multi_step , gamma ):
135
+ self .buffer = []
136
+ self .buffer_size = buffer_size
137
+ self .n_multi_step = n_multi_step
138
+ self .gamma = gamma
139
+ self .location = 0
140
+
141
+ def __len__ (self ):
142
+ return len (self .buffer )
143
+
144
+ def add (self , samples ):
145
+ # Append when the buffer is not full but overwrite when the buffer is full
146
+ wrap_tensor = lambda x : torch .tensor ([x ])
147
+ if len (self .buffer ) < self .buffer_size :
148
+ self .buffer .append (transition (* map (wrap_tensor , samples )))
149
+ else :
150
+ self .buffer [self .location ] = transition (* map (wrap_tensor , samples ))
151
+
152
+ # Increment the buffer location
153
+ self .location = (self .location + 1 ) % self .buffer_size
154
+
155
+ def sample (self , batch_size ):
156
+ '''
157
+ Sample batch_size memories from the buffer.
158
+ NB: It deals the N-step DQN
159
+ '''
160
+ # randomly pick batch_size elements from the buffer
161
+ indices = np .random .choice (len (self .buffer ), batch_size , replace = False )
162
+ samples = []
163
+
164
+ # for each indices
165
+ for i in indices :
166
+ sum_reward = 0
167
+ states_look_ahead = self .buffer [i ].next_state
168
+ done_look_ahead = self .buffer [i ].is_terminal
169
+
170
+ # N-step look ahead loop to compute the reward and pick the new 'next_state' (of the n-th state)
171
+ for n in range (self .n_multi_step ):
172
+ if len (self .buffer ) > i + n :
173
+ # compute the n-th reward
174
+ sum_reward += (self .gamma ** n ) * self .buffer [i + n ].reward
175
+ if self .buffer [i + n ].is_terminal :
176
+ states_look_ahead = self .buffer [i + n ].next_state
177
+ done_look_ahead = self .buffer [i + n ].is_terminal
178
+ break
179
+ else :
180
+ states_look_ahead = self .buffer [i + n ].next_state
181
+ done_look_ahead = self .buffer [i + n ].is_terminal
182
+
183
+ sample = transition (self .buffer [i ].state , states_look_ahead , self .buffer [i ].action , sum_reward , done_look_ahead )
184
+ samples .append (sample )
185
+
186
+ return samples
187
+
188
+ class DQN (object ):
189
+ def __init__ (self , env ):
190
+ self .action_shape = env .action_space .n
191
+ self .obs_shape = env .observation_space .shape
192
+ self .eval_net , self .target_net = QNetwork (self .action_shape , self .obs_shape ).to (device ), QNetwork (self .action_shape , self .obs_shape ).to (device )
193
+ self .learn_step_counter = 0 # for target updating
194
+ self .optimizer = torch .optim .Adam (self .eval_net .parameters (), lr = LR )
195
+ self .loss_func = nn .MSELoss ()
196
+ self .epsilon_scheduler = EpsilonScheduler (EPSILON_START , EPSILON_END , EPSILON_DECAY )
197
+ self .updates = 0
198
+
199
+ def choose_action (self , x ):
200
+ # x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0)).to(device)
201
+ x = torch .unsqueeze (torch .FloatTensor (x ), 0 ).to (device )
202
+ # input only one sample
203
+ # if np.random.uniform() > EPSILON: # greedy
204
+ epsilon = self .epsilon_scheduler .get_epsilon ()
205
+ if np .random .uniform () > epsilon : # greedy
206
+ actions_value = self .eval_net .forward (x )
207
+ action = torch .max (actions_value , 1 )[1 ].data .cpu ().numpy ()[0 ] # return the argmax
208
+ else : # random
209
+ action = np .random .randint (0 , self .action_shape )
210
+ return action
211
+
212
+ def learn (self , sample ,):
213
+ # Batch is a list of namedtuple's, the following operation returns samples grouped by keys
214
+ batch_samples = transition (* zip (* sample ))
215
+
216
+ # states, next_states are of tensor (BATCH_SIZE, in_channel, 10, 10) - inline with pytorch NCHW format
217
+ # actions, rewards, is_terminal are of tensor (BATCH_SIZE, 1)
218
+ states = torch .cat (batch_samples .state ).float ().to (device )
219
+ next_states = torch .cat (batch_samples .next_state ).float ().to (device )
220
+ actions = torch .cat (batch_samples .action ).to (device )
221
+ rewards = torch .cat (batch_samples .reward ).float ().to (device )
222
+ is_terminal = torch .cat (batch_samples .is_terminal ).to (device )
223
+ # Obtain a batch of Q(S_t, A_t) and compute the forward pass.
224
+ # Note: policy_network output Q-values for all the actions of a state, but all we need is the A_t taken at time t
225
+ # in state S_t. Thus we gather along the columns and get the Q-values corresponds to S_t, A_t.
226
+ # Q_s_a is of size (BATCH_SIZE, 1).
227
+ Q = self .eval_net (states )
228
+ Q_s_a = Q .gather (1 , actions )
229
+
230
+ # Obtain max_{a} Q(S_{t+1}, a) of any non-terminal state S_{t+1}. If S_{t+1} is terminal, Q(S_{t+1}, A_{t+1}) = 0.
231
+ # Note: each row of the network's output corresponds to the actions of S_{t+1}. max(1)[0] gives the max action
232
+ # values in each row (since this a batch). The detach() detaches the target net's tensor from computation graph so
233
+ # to prevent the computation of its gradient automatically. Q_s_prime_a_prime is of size (BATCH_SIZE, 1).
234
+
235
+ # Get the indices of next_states that are not terminal
236
+ none_terminal_next_state_index = torch .tensor ([i for i , is_term in enumerate (is_terminal ) if is_term == 0 ], dtype = torch .int64 , device = device )
237
+ # Select the indices of each row
238
+ none_terminal_next_states = next_states .index_select (0 , none_terminal_next_state_index )
239
+
240
+ Q_s_prime_a_prime = torch .zeros (len (sample ), 1 , device = device )
241
+ if len (none_terminal_next_states ) != 0 :
242
+ Q_s_prime_a_prime [none_terminal_next_state_index ] = self .target_net (none_terminal_next_states ).detach ().max (1 )[0 ].unsqueeze (1 )
243
+
244
+ # Q_s_prime_a_prime = self.target_net(next_states).detach().max(1, keepdim=True)[0] # this one is simpler regardless of terminal state
245
+ Q_s_prime_a_prime = (Q_s_prime_a_prime - Q_s_prime_a_prime .mean ())/ (Q_s_prime_a_prime .std () + 1e-5 ) # normalization
246
+
247
+ # Compute the target
248
+ target = rewards + (GAMMA ** N_MULTI_STEP ) * Q_s_prime_a_prime
249
+
250
+ # Update with loss
251
+ # loss = self.loss_func(target.detach(), Q_s_a)
252
+ loss = f .smooth_l1_loss (target .detach (), Q_s_a )
253
+ # Zero gradients, backprop, update the weights of policy_net
254
+ self .optimizer .zero_grad ()
255
+ loss .backward ()
256
+ self .optimizer .step ()
257
+
258
+ self .updates += 1
259
+ if self .updates % TARGET_UPDATE_INTERVAL == 0 :
260
+ self .update_target ()
261
+
262
+ return loss .item ()
263
+
264
+ def save_model (self , model_path = None ):
265
+ torch .save (self .eval_net .state_dict (), 'model/dqn' )
266
+
267
+ def update_target (self , ):
268
+ """
269
+ Update the target model when necessary.
270
+ """
271
+ self .target_net .load_state_dict (self .eval_net .state_dict ())
272
+
273
+ def rollout (env , model ):
274
+ r_buffer = ReplayBuffer (REPLAY_BUFFER_SIZE , N_MULTI_STEP , GAMMA )
275
+ log = []
276
+ timestamp = datetime .datetime .now ().strftime ("%Y%m%d_%H%M" )
277
+ print ('\n Collecting experience...' )
278
+ total_step = 0
279
+ for epi in range (MAX_EPI ):
280
+ s = env .reset ()
281
+ epi_r = 0
282
+ epi_loss = 0
283
+ for step in range (MAX_STEP ):
284
+ # env.render()
285
+ total_step += 1
286
+ a = model .choose_action (s )
287
+ s_ , r , done , info = env .step (a )
288
+ # r_buffer.add(torch.tensor([s]), torch.tensor([s_]), torch.tensor([[a]]), torch.tensor([[r]], dtype=torch.float), torch.tensor([[done]]))
289
+ r_buffer .add ([s ,s_ ,[a ],[r ],[done ]])
290
+ model .epsilon_scheduler .step (total_step )
291
+ epi_r += r
292
+ if total_step > REPLAY_START_SIZE and len (r_buffer .buffer ) >= BATCH_SIZE :
293
+ sample = r_buffer .sample (BATCH_SIZE )
294
+ loss = model .learn (sample )
295
+ epi_loss += loss
296
+ if done :
297
+ break
298
+ s = s_
299
+ print ('Ep: ' , epi , '| Ep_r: ' , epi_r , '| Steps: ' , step , f'| Ep_Loss: { epi_loss :.4f} ' , )
300
+ log .append ([epi , epi_r , step ])
301
+ if epi % SAVE_INTERVAL == 0 :
302
+ model .save_model ()
303
+ np .save ('log/' + timestamp , log )
304
+
305
+ if __name__ == '__main__' :
306
+ env = gym .make ('CartPole-v1' )
307
+ print (env .observation_space , env .action_space )
308
+ model = DQN (env )
309
+ rollout (env , model )
0 commit comments