@@ -128,7 +128,7 @@ def __init__(
128
128
save_directory = Path ("robot_nav/models/BPG/checkpoint" ),
129
129
model_name = "BCNNTD3" ,
130
130
load_directory = Path ("robot_nav/models/BPG/checkpoint" ),
131
- bound_weight = 8
131
+ bound_weight = 8 ,
132
132
):
133
133
# Initialize the Actor network
134
134
self .bound_weight = bound_weight
@@ -170,20 +170,20 @@ def act(self, state):
170
170
171
171
# training cycle
172
172
def train (
173
- self ,
174
- replay_buffer ,
175
- iterations ,
176
- batch_size ,
177
- discount = 0.99 ,
178
- tau = 0.005 ,
179
- policy_noise = 0.2 ,
180
- noise_clip = 0.5 ,
181
- policy_freq = 2 ,
182
- max_lin_vel = 0.5 ,
183
- max_ang_vel = 1 ,
184
- goal_reward = 100 ,
185
- distance_norm = 10 ,
186
- time_step = 0.3 ,
173
+ self ,
174
+ replay_buffer ,
175
+ iterations ,
176
+ batch_size ,
177
+ discount = 0.99 ,
178
+ tau = 0.005 ,
179
+ policy_noise = 0.2 ,
180
+ noise_clip = 0.5 ,
181
+ policy_freq = 2 ,
182
+ max_lin_vel = 0.5 ,
183
+ max_ang_vel = 1 ,
184
+ goal_reward = 100 ,
185
+ distance_norm = 10 ,
186
+ time_step = 0.3 ,
187
187
):
188
188
av_Q = 0
189
189
max_b = 0
@@ -231,9 +231,16 @@ def train(
231
231
# Get the Q values of the basis networks with the current parameters
232
232
current_Q1 , current_Q2 = self .critic (state , action )
233
233
234
- max_bound = self .get_max_bound (next_state , discount , max_ang_vel , max_lin_vel , time_step , distance_norm ,
235
- goal_reward ,
236
- reward )
234
+ max_bound = self .get_max_bound (
235
+ next_state ,
236
+ discount ,
237
+ max_ang_vel ,
238
+ max_lin_vel ,
239
+ time_step ,
240
+ distance_norm ,
241
+ goal_reward ,
242
+ reward ,
243
+ )
237
244
max_b += max (max_b , torch .max (max_bound ))
238
245
max_bound_loss_Q1 = current_Q1 - max_bound
239
246
max_bound_loss_Q2 = current_Q2 - max_bound
@@ -265,15 +272,15 @@ def train(
265
272
# Use soft update to update the actor-target network parameters by
266
273
# infusing small amount of current parameters
267
274
for param , target_param in zip (
268
- self .actor .parameters (), self .actor_target .parameters ()
275
+ self .actor .parameters (), self .actor_target .parameters ()
269
276
):
270
277
target_param .data .copy_ (
271
278
tau * param .data + (1 - tau ) * target_param .data
272
279
)
273
280
# Use soft update to update the critic-target network parameters by infusing
274
281
# small amount of current parameters
275
282
for param , target_param in zip (
276
- self .critic .parameters (), self .critic_target .parameters ()
283
+ self .critic .parameters (), self .critic_target .parameters ()
277
284
):
278
285
target_param .data .copy_ (
279
286
tau * param .data + (1 - tau ) * target_param .data
@@ -297,16 +304,29 @@ def train(
297
304
if self .save_every > 0 and self .iter_count % self .save_every == 0 :
298
305
self .save (filename = self .model_name , directory = self .save_directory )
299
306
300
- def get_max_bound (self , next_state , discount , max_ang_vel , max_lin_vel , time_step , distance_norm , goal_reward ,
301
- reward ):
307
+ def get_max_bound (
308
+ self ,
309
+ next_state ,
310
+ discount ,
311
+ max_ang_vel ,
312
+ max_lin_vel ,
313
+ time_step ,
314
+ distance_norm ,
315
+ goal_reward ,
316
+ reward ,
317
+ ):
302
318
cos = next_state [:, - 4 ]
303
319
sin = next_state [:, - 3 ]
304
320
theta = torch .atan2 (sin , cos )
305
321
306
322
turn_steps = theta / (max_ang_vel * time_step )
307
323
full_turn_steps = torch .floor (turn_steps .abs ())
308
324
turn_rew = [
309
- - 1 * discount ** step * max_ang_vel if step else torch .zeros (1 , device = self .device )
325
+ (
326
+ - 1 * discount ** step * max_ang_vel
327
+ if step
328
+ else torch .zeros (1 , device = self .device )
329
+ )
310
330
for step in full_turn_steps
311
331
]
312
332
final_turn = turn_steps .abs () - full_turn_steps
@@ -325,18 +345,20 @@ def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_ste
325
345
final_steps = torch .ceil (distances ) + full_turn_steps
326
346
inter_steps = torch .trunc (distances ) + full_turn_steps
327
347
final_discount = torch .tensor (
328
- [discount ** pw for pw in final_steps ], device = self .device
348
+ [discount ** pw for pw in final_steps ], device = self .device
329
349
)
330
350
final_rew = (
331
- torch .ones_like (distances , device = self .device ) * goal_reward * final_discount
351
+ torch .ones_like (distances , device = self .device )
352
+ * goal_reward
353
+ * final_discount
332
354
)
333
355
334
356
max_inter_steps = inter_steps .max ()
335
357
exponents = torch .arange (
336
358
1 , max_inter_steps + 1 , dtype = torch .float32 , device = self .device
337
359
)
338
360
discount_exponents = torch .tensor (
339
- [discount ** e for e in exponents ], device = self .device
361
+ [discount ** e for e in exponents ], device = self .device
340
362
)
341
363
inter_rew = torch .tensor (
342
364
[
0 commit comments