@@ -154,20 +154,20 @@ def act(self, state):
154
154
155
155
# training cycle
156
156
def train (
157
- self ,
158
- replay_buffer ,
159
- iterations ,
160
- batch_size ,
161
- discount = 0.99 ,
162
- tau = 0.005 ,
163
- policy_noise = 0.2 ,
164
- noise_clip = 0.5 ,
165
- policy_freq = 2 ,
166
- max_lin_vel = 0.5 ,
167
- max_ang_vel = 1 ,
168
- goal_reward = 100 ,
169
- distance_norm = 10 ,
170
- time_step = 0.3 ,
157
+ self ,
158
+ replay_buffer ,
159
+ iterations ,
160
+ batch_size ,
161
+ discount = 0.99 ,
162
+ tau = 0.005 ,
163
+ policy_noise = 0.2 ,
164
+ noise_clip = 0.5 ,
165
+ policy_freq = 2 ,
166
+ max_lin_vel = 0.5 ,
167
+ max_ang_vel = 1 ,
168
+ goal_reward = 100 ,
169
+ distance_norm = 10 ,
170
+ time_step = 0.3 ,
171
171
):
172
172
av_Q = 0
173
173
max_b = 0
@@ -203,15 +203,13 @@ def train(
203
203
next_action = (next_action + noise ).clamp (- self .max_action , self .max_action )
204
204
205
205
# Calculate the Q values from the critic-target network for the next state-action pair
206
- target_Q1 , target_Q2 = self .critic_target (next_state , next_action )
206
+ target_Q = self .critic_target (next_state , next_action )
207
207
208
- # Select the minimal Q value from the 2 calculated values
209
- target_Q = torch .min (target_Q1 , target_Q2 )
210
208
av_Q += torch .mean (target_Q )
211
209
max_Q = max (max_Q , torch .max (target_Q ))
212
-
213
210
# Calculate the final Q value from the target network parameters by using Bellman equation
214
211
target_Q = reward + ((1 - done ) * discount * target_Q ).detach ()
212
+
215
213
# Get the Q values of the basis networks with the current parameters
216
214
current_Q = self .critic (state , action )
217
215
@@ -224,15 +222,18 @@ def train(
224
222
distance_norm ,
225
223
goal_reward ,
226
224
reward ,
225
+ done ,
227
226
)
228
- max_b + = max (max_b , torch .max (max_bound ))
227
+ max_b = max (max_b , torch .max (max_bound ))
229
228
max_bound_loss_Q = current_Q - max_bound
230
229
max_bound_loss_Q [max_bound_loss_Q < 0 ] = 0
231
230
max_bound_loss_Q = torch .square (max_bound_loss_Q ).mean ()
231
+ max_bound_loss = max_bound_loss_Q
232
232
233
233
# Calculate the loss between the current Q value and the target Q value
234
234
loss_target_Q = F .mse_loss (current_Q , target_Q )
235
- max_bound_loss = self .bound_weight * max_bound_loss_Q
235
+
236
+ max_bound_loss = self .bound_weight * max_bound_loss
236
237
loss = loss_target_Q + max_bound_loss
237
238
# Perform the gradient descent
238
239
self .critic_optimizer .zero_grad ()
@@ -242,7 +243,7 @@ def train(
242
243
if it % policy_freq == 0 :
243
244
# Maximize the actor output value by performing gradient descent on negative Q values
244
245
# (essentially perform gradient ascent)
245
- actor_grad , _ = self .critic (state , self .actor (state ))
246
+ actor_grad = self .critic (state , self .actor (state ))
246
247
actor_grad = - actor_grad .mean ()
247
248
self .actor_optimizer .zero_grad ()
248
249
actor_grad .backward ()
@@ -251,15 +252,15 @@ def train(
251
252
# Use soft update to update the actor-target network parameters by
252
253
# infusing small amount of current parameters
253
254
for param , target_param in zip (
254
- self .actor .parameters (), self .actor_target .parameters ()
255
+ self .actor .parameters (), self .actor_target .parameters ()
255
256
):
256
257
target_param .data .copy_ (
257
258
tau * param .data + (1 - tau ) * target_param .data
258
259
)
259
260
# Use soft update to update the critic-target network parameters by infusing
260
261
# small amount of current parameters
261
262
for param , target_param in zip (
262
- self .critic .parameters (), self .critic_target .parameters ()
263
+ self .critic .parameters (), self .critic_target .parameters ()
263
264
):
264
265
target_param .data .copy_ (
265
266
tau * param .data + (1 - tau ) * target_param .data
@@ -284,15 +285,16 @@ def train(
284
285
self .save (filename = self .model_name , directory = self .save_directory )
285
286
286
287
def get_max_bound (
287
- self ,
288
- next_state ,
289
- discount ,
290
- max_ang_vel ,
291
- max_lin_vel ,
292
- time_step ,
293
- distance_norm ,
294
- goal_reward ,
295
- reward ,
288
+ self ,
289
+ next_state ,
290
+ discount ,
291
+ max_ang_vel ,
292
+ max_lin_vel ,
293
+ time_step ,
294
+ distance_norm ,
295
+ goal_reward ,
296
+ reward ,
297
+ done ,
296
298
):
297
299
cos = next_state [:, - 4 ]
298
300
sin = next_state [:, - 3 ]
@@ -302,7 +304,7 @@ def get_max_bound(
302
304
full_turn_steps = torch .floor (turn_steps .abs ())
303
305
turn_rew = [
304
306
(
305
- - 1 * discount ** step * max_ang_vel
307
+ - 1 * discount ** step * max_ang_vel
306
308
if step
307
309
else torch .zeros (1 , device = self .device )
308
310
)
@@ -324,20 +326,20 @@ def get_max_bound(
324
326
final_steps = torch .ceil (distances ) + full_turn_steps
325
327
inter_steps = torch .trunc (distances ) + full_turn_steps
326
328
final_discount = torch .tensor (
327
- [discount ** pw for pw in final_steps ], device = self .device
329
+ [discount ** pw for pw in final_steps ], device = self .device
328
330
)
329
331
final_rew = (
330
- torch .ones_like (distances , device = self .device )
331
- * goal_reward
332
- * final_discount
332
+ torch .ones_like (distances , device = self .device )
333
+ * goal_reward
334
+ * final_discount
333
335
)
334
336
335
337
max_inter_steps = inter_steps .max ()
336
338
exponents = torch .arange (
337
339
1 , max_inter_steps + 1 , dtype = torch .float32 , device = self .device
338
340
)
339
341
discount_exponents = torch .tensor (
340
- [discount ** e for e in exponents ], device = self .device
342
+ [discount ** e for e in exponents ], device = self .device
341
343
)
342
344
inter_rew = torch .tensor (
343
345
[
@@ -352,7 +354,7 @@ def get_max_bound(
352
354
device = self .device ,
353
355
)
354
356
max_future_rew = full_turn_rew + final_rew + inter_rew
355
- max_bound = reward + max_future_rew .view (- 1 , 1 )
357
+ max_bound = reward + ( 1 - done ) * max_future_rew .view (- 1 , 1 )
356
358
return max_bound
357
359
358
360
def save (self , filename , directory ):
0 commit comments