@@ -154,22 +154,23 @@ def act(self, state):
154
154
155
155
# training cycle
156
156
def train (
157
- self ,
158
- replay_buffer ,
159
- iterations ,
160
- batch_size ,
161
- discount = 0.99 ,
162
- tau = 0.005 ,
163
- policy_noise = 0.2 ,
164
- noise_clip = 0.5 ,
165
- policy_freq = 2 ,
166
- max_lin_vel = 0.5 ,
167
- max_ang_vel = 1 ,
168
- goal_reward = 100 ,
169
- distance_norm = 10 ,
170
- time_step = 0.3 ,
157
+ self ,
158
+ replay_buffer ,
159
+ iterations ,
160
+ batch_size ,
161
+ discount = 0.99 ,
162
+ tau = 0.005 ,
163
+ policy_noise = 0.2 ,
164
+ noise_clip = 0.5 ,
165
+ policy_freq = 2 ,
166
+ max_lin_vel = 0.5 ,
167
+ max_ang_vel = 1 ,
168
+ goal_reward = 100 ,
169
+ distance_norm = 10 ,
170
+ time_step = 0.3 ,
171
171
):
172
172
av_Q = 0
173
+ av_bound = 0
173
174
max_b = 0
174
175
max_Q = - inf
175
176
av_loss = 0
@@ -225,11 +226,10 @@ def train(
225
226
done ,
226
227
)
227
228
max_b = max (max_b , torch .max (max_bound ))
228
- max_bound_loss_Q = current_Q - max_bound
229
- max_bound_loss_Q [max_bound_loss_Q < 0 ] = 0
230
- max_bound_loss_Q = torch .square (max_bound_loss_Q ).mean ()
231
- max_bound_loss = max_bound_loss_Q
229
+ av_bound += torch .mean (max_bound )
232
230
231
+ max_bound_Q = torch .min (current_Q , max_bound )
232
+ max_bound_loss = F .mse_loss (current_Q , max_bound_Q )
233
233
# Calculate the loss between the current Q value and the target Q value
234
234
loss_target_Q = F .mse_loss (current_Q , target_Q )
235
235
@@ -244,6 +244,7 @@ def train(
244
244
# Maximize the actor output value by performing gradient descent on negative Q values
245
245
# (essentially perform gradient ascent)
246
246
actor_grad = self .critic (state , self .actor (state ))
247
+ actor_grad = torch .min (actor_grad , max_bound )
247
248
actor_grad = - actor_grad .mean ()
248
249
self .actor_optimizer .zero_grad ()
249
250
actor_grad .backward ()
@@ -252,15 +253,15 @@ def train(
252
253
# Use soft update to update the actor-target network parameters by
253
254
# infusing small amount of current parameters
254
255
for param , target_param in zip (
255
- self .actor .parameters (), self .actor_target .parameters ()
256
+ self .actor .parameters (), self .actor_target .parameters ()
256
257
):
257
258
target_param .data .copy_ (
258
259
tau * param .data + (1 - tau ) * target_param .data
259
260
)
260
261
# Use soft update to update the critic-target network parameters by infusing
261
262
# small amount of current parameters
262
263
for param , target_param in zip (
263
- self .critic .parameters (), self .critic_target .parameters ()
264
+ self .critic .parameters (), self .critic_target .parameters ()
264
265
):
265
266
target_param .data .copy_ (
266
267
tau * param .data + (1 - tau ) * target_param .data
@@ -279,22 +280,25 @@ def train(
279
280
"train/av_max_bound_loss" , av_max_bound_loss / iterations , self .iter_count
280
281
)
281
282
self .writer .add_scalar ("train/avg_Q" , av_Q / iterations , self .iter_count )
283
+ self .writer .add_scalar (
284
+ "train/avg_bound" , av_bound / iterations , self .iter_count
285
+ )
282
286
self .writer .add_scalar ("train/max_b" , max_b , self .iter_count )
283
287
self .writer .add_scalar ("train/max_Q" , max_Q , self .iter_count )
284
288
if self .save_every > 0 and self .iter_count % self .save_every == 0 :
285
289
self .save (filename = self .model_name , directory = self .save_directory )
286
290
287
291
def get_max_bound (
288
- self ,
289
- next_state ,
290
- discount ,
291
- max_ang_vel ,
292
- max_lin_vel ,
293
- time_step ,
294
- distance_norm ,
295
- goal_reward ,
296
- reward ,
297
- done ,
292
+ self ,
293
+ next_state ,
294
+ discount ,
295
+ max_ang_vel ,
296
+ max_lin_vel ,
297
+ time_step ,
298
+ distance_norm ,
299
+ goal_reward ,
300
+ reward ,
301
+ done ,
298
302
):
299
303
cos = next_state [:, - 4 ]
300
304
sin = next_state [:, - 3 ]
@@ -304,7 +308,7 @@ def get_max_bound(
304
308
full_turn_steps = torch .floor (turn_steps .abs ())
305
309
turn_rew = [
306
310
(
307
- - 1 * discount ** step * max_ang_vel
311
+ - 1 * discount ** step * max_ang_vel
308
312
if step
309
313
else torch .zeros (1 , device = self .device )
310
314
)
@@ -326,20 +330,20 @@ def get_max_bound(
326
330
final_steps = torch .ceil (distances ) + full_turn_steps
327
331
inter_steps = torch .trunc (distances ) + full_turn_steps
328
332
final_discount = torch .tensor (
329
- [discount ** pw for pw in final_steps ], device = self .device
333
+ [discount ** pw for pw in final_steps ], device = self .device
330
334
)
331
335
final_rew = (
332
- torch .ones_like (distances , device = self .device )
333
- * goal_reward
334
- * final_discount
336
+ torch .ones_like (distances , device = self .device )
337
+ * goal_reward
338
+ * final_discount
335
339
)
336
340
337
341
max_inter_steps = inter_steps .max ()
338
342
exponents = torch .arange (
339
343
1 , max_inter_steps + 1 , dtype = torch .float32 , device = self .device
340
344
)
341
345
discount_exponents = torch .tensor (
342
- [discount ** e for e in exponents ], device = self .device
346
+ [discount ** e for e in exponents ], device = self .device
343
347
)
344
348
inter_rew = torch .tensor (
345
349
[
0 commit comments