Skip to content

Commit 54f2b26

Browse files
committed
fix max bound error
1 parent c49514f commit 54f2b26

File tree

6 files changed

+174
-78
lines changed

6 files changed

+174
-78
lines changed

robot_nav/models/BPG/BCNNPG.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
save_directory=Path("robot_nav/models/BPG/checkpoint"),
113113
model_name="BCNNPG",
114114
load_directory=Path("robot_nav/models/BPG/checkpoint"),
115-
bound_weight=8
115+
bound_weight=8,
116116
):
117117
# Initialize the Actor network
118118
self.bound_weight = bound_weight
@@ -154,20 +154,20 @@ def act(self, state):
154154

155155
# training cycle
156156
def train(
157-
self,
158-
replay_buffer,
159-
iterations,
160-
batch_size,
161-
discount=0.99,
162-
tau=0.005,
163-
policy_noise=0.2,
164-
noise_clip=0.5,
165-
policy_freq=2,
166-
max_lin_vel=0.5,
167-
max_ang_vel=1,
168-
goal_reward=100,
169-
distance_norm=10,
170-
time_step=0.3,
157+
self,
158+
replay_buffer,
159+
iterations,
160+
batch_size,
161+
discount=0.99,
162+
tau=0.005,
163+
policy_noise=0.2,
164+
noise_clip=0.5,
165+
policy_freq=2,
166+
max_lin_vel=0.5,
167+
max_ang_vel=1,
168+
goal_reward=100,
169+
distance_norm=10,
170+
time_step=0.3,
171171
):
172172
av_Q = 0
173173
max_b = 0
@@ -215,9 +215,16 @@ def train(
215215
# Get the Q values of the basis networks with the current parameters
216216
current_Q = self.critic(state, action)
217217

218-
max_bound = self.get_max_bound(next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm,
219-
goal_reward,
220-
reward)
218+
max_bound = self.get_max_bound(
219+
next_state,
220+
discount,
221+
max_ang_vel,
222+
max_lin_vel,
223+
time_step,
224+
distance_norm,
225+
goal_reward,
226+
reward,
227+
)
221228
max_b += max(max_b, torch.max(max_bound))
222229
max_bound_loss_Q = current_Q - max_bound
223230
max_bound_loss_Q[max_bound_loss_Q < 0] = 0
@@ -244,15 +251,15 @@ def train(
244251
# Use soft update to update the actor-target network parameters by
245252
# infusing small amount of current parameters
246253
for param, target_param in zip(
247-
self.actor.parameters(), self.actor_target.parameters()
254+
self.actor.parameters(), self.actor_target.parameters()
248255
):
249256
target_param.data.copy_(
250257
tau * param.data + (1 - tau) * target_param.data
251258
)
252259
# Use soft update to update the critic-target network parameters by infusing
253260
# small amount of current parameters
254261
for param, target_param in zip(
255-
self.critic.parameters(), self.critic_target.parameters()
262+
self.critic.parameters(), self.critic_target.parameters()
256263
):
257264
target_param.data.copy_(
258265
tau * param.data + (1 - tau) * target_param.data
@@ -276,16 +283,29 @@ def train(
276283
if self.save_every > 0 and self.iter_count % self.save_every == 0:
277284
self.save(filename=self.model_name, directory=self.save_directory)
278285

279-
def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm, goal_reward,
280-
reward):
286+
def get_max_bound(
287+
self,
288+
next_state,
289+
discount,
290+
max_ang_vel,
291+
max_lin_vel,
292+
time_step,
293+
distance_norm,
294+
goal_reward,
295+
reward,
296+
):
281297
cos = next_state[:, -4]
282298
sin = next_state[:, -3]
283299
theta = torch.atan2(sin, cos)
284300

285301
turn_steps = theta / (max_ang_vel * time_step)
286302
full_turn_steps = torch.floor(turn_steps.abs())
287303
turn_rew = [
288-
-1 * discount ** step * max_ang_vel if step else torch.zeros(1, device=self.device)
304+
(
305+
-1 * discount**step * max_ang_vel
306+
if step
307+
else torch.zeros(1, device=self.device)
308+
)
289309
for step in full_turn_steps
290310
]
291311
final_turn = turn_steps.abs() - full_turn_steps
@@ -304,18 +324,20 @@ def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_ste
304324
final_steps = torch.ceil(distances) + full_turn_steps
305325
inter_steps = torch.trunc(distances) + full_turn_steps
306326
final_discount = torch.tensor(
307-
[discount ** pw for pw in final_steps], device=self.device
327+
[discount**pw for pw in final_steps], device=self.device
308328
)
309329
final_rew = (
310-
torch.ones_like(distances, device=self.device) * goal_reward * final_discount
330+
torch.ones_like(distances, device=self.device)
331+
* goal_reward
332+
* final_discount
311333
)
312334

313335
max_inter_steps = inter_steps.max()
314336
exponents = torch.arange(
315337
1, max_inter_steps + 1, dtype=torch.float32, device=self.device
316338
)
317339
discount_exponents = torch.tensor(
318-
[discount ** e for e in exponents], device=self.device
340+
[discount**e for e in exponents], device=self.device
319341
)
320342
inter_rew = torch.tensor(
321343
[

robot_nav/models/BPG/BCNNTD3.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
save_directory=Path("robot_nav/models/BPG/checkpoint"),
129129
model_name="BCNNTD3",
130130
load_directory=Path("robot_nav/models/BPG/checkpoint"),
131-
bound_weight=8
131+
bound_weight=8,
132132
):
133133
# Initialize the Actor network
134134
self.bound_weight = bound_weight
@@ -170,20 +170,20 @@ def act(self, state):
170170

171171
# training cycle
172172
def train(
173-
self,
174-
replay_buffer,
175-
iterations,
176-
batch_size,
177-
discount=0.99,
178-
tau=0.005,
179-
policy_noise=0.2,
180-
noise_clip=0.5,
181-
policy_freq=2,
182-
max_lin_vel=0.5,
183-
max_ang_vel=1,
184-
goal_reward=100,
185-
distance_norm=10,
186-
time_step=0.3,
173+
self,
174+
replay_buffer,
175+
iterations,
176+
batch_size,
177+
discount=0.99,
178+
tau=0.005,
179+
policy_noise=0.2,
180+
noise_clip=0.5,
181+
policy_freq=2,
182+
max_lin_vel=0.5,
183+
max_ang_vel=1,
184+
goal_reward=100,
185+
distance_norm=10,
186+
time_step=0.3,
187187
):
188188
av_Q = 0
189189
max_b = 0
@@ -231,9 +231,16 @@ def train(
231231
# Get the Q values of the basis networks with the current parameters
232232
current_Q1, current_Q2 = self.critic(state, action)
233233

234-
max_bound = self.get_max_bound(next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm,
235-
goal_reward,
236-
reward)
234+
max_bound = self.get_max_bound(
235+
next_state,
236+
discount,
237+
max_ang_vel,
238+
max_lin_vel,
239+
time_step,
240+
distance_norm,
241+
goal_reward,
242+
reward,
243+
)
237244
max_b += max(max_b, torch.max(max_bound))
238245
max_bound_loss_Q1 = current_Q1 - max_bound
239246
max_bound_loss_Q2 = current_Q2 - max_bound
@@ -265,15 +272,15 @@ def train(
265272
# Use soft update to update the actor-target network parameters by
266273
# infusing small amount of current parameters
267274
for param, target_param in zip(
268-
self.actor.parameters(), self.actor_target.parameters()
275+
self.actor.parameters(), self.actor_target.parameters()
269276
):
270277
target_param.data.copy_(
271278
tau * param.data + (1 - tau) * target_param.data
272279
)
273280
# Use soft update to update the critic-target network parameters by infusing
274281
# small amount of current parameters
275282
for param, target_param in zip(
276-
self.critic.parameters(), self.critic_target.parameters()
283+
self.critic.parameters(), self.critic_target.parameters()
277284
):
278285
target_param.data.copy_(
279286
tau * param.data + (1 - tau) * target_param.data
@@ -297,16 +304,29 @@ def train(
297304
if self.save_every > 0 and self.iter_count % self.save_every == 0:
298305
self.save(filename=self.model_name, directory=self.save_directory)
299306

300-
def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm, goal_reward,
301-
reward):
307+
def get_max_bound(
308+
self,
309+
next_state,
310+
discount,
311+
max_ang_vel,
312+
max_lin_vel,
313+
time_step,
314+
distance_norm,
315+
goal_reward,
316+
reward,
317+
):
302318
cos = next_state[:, -4]
303319
sin = next_state[:, -3]
304320
theta = torch.atan2(sin, cos)
305321

306322
turn_steps = theta / (max_ang_vel * time_step)
307323
full_turn_steps = torch.floor(turn_steps.abs())
308324
turn_rew = [
309-
-1 * discount ** step * max_ang_vel if step else torch.zeros(1, device=self.device)
325+
(
326+
-1 * discount**step * max_ang_vel
327+
if step
328+
else torch.zeros(1, device=self.device)
329+
)
310330
for step in full_turn_steps
311331
]
312332
final_turn = turn_steps.abs() - full_turn_steps
@@ -325,18 +345,20 @@ def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_ste
325345
final_steps = torch.ceil(distances) + full_turn_steps
326346
inter_steps = torch.trunc(distances) + full_turn_steps
327347
final_discount = torch.tensor(
328-
[discount ** pw for pw in final_steps], device=self.device
348+
[discount**pw for pw in final_steps], device=self.device
329349
)
330350
final_rew = (
331-
torch.ones_like(distances, device=self.device) * goal_reward * final_discount
351+
torch.ones_like(distances, device=self.device)
352+
* goal_reward
353+
* final_discount
332354
)
333355

334356
max_inter_steps = inter_steps.max()
335357
exponents = torch.arange(
336358
1, max_inter_steps + 1, dtype=torch.float32, device=self.device
337359
)
338360
discount_exponents = torch.tensor(
339-
[discount ** e for e in exponents], device=self.device
361+
[discount**e for e in exponents], device=self.device
340362
)
341363
inter_rew = torch.tensor(
342364
[

robot_nav/models/BPG/BPG.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
save_directory=Path("robot_nav/models/BPG/checkpoint"),
6666
model_name="BPG",
6767
load_directory=Path("robot_nav/models/BPG/checkpoint"),
68-
bound_weight=8
68+
bound_weight=8,
6969
):
7070
# Initialize the Actor network
7171
self.bound_weight = bound_weight
@@ -116,11 +116,11 @@ def train(
116116
policy_noise=0.2,
117117
noise_clip=0.5,
118118
policy_freq=2,
119-
max_lin_vel = 0.5,
120-
max_ang_vel = 1,
121-
goal_reward = 100,
122-
distance_norm = 10,
123-
time_step = 0.3,
119+
max_lin_vel=0.5,
120+
max_ang_vel=1,
121+
goal_reward=100,
122+
distance_norm=10,
123+
time_step=0.3,
124124
):
125125
av_Q = 0
126126
max_b = 0
@@ -166,9 +166,18 @@ def train(
166166
# Get the Q values of the basis networks with the current parameters
167167
current_Q = self.critic(state, action)
168168

169-
max_bound = self.get_max_bound(next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm, goal_reward,
170-
reward)
171-
max_b += max(max_b, torch.max(max_bound))
169+
max_bound = self.get_max_bound(
170+
next_state,
171+
discount,
172+
max_ang_vel,
173+
max_lin_vel,
174+
time_step,
175+
distance_norm,
176+
goal_reward,
177+
reward,
178+
done,
179+
)
180+
max_b = max(max_b, torch.max(max_bound))
172181
max_bound_loss_Q = current_Q - max_bound
173182
max_bound_loss_Q[max_bound_loss_Q < 0] = 0
174183
max_bound_loss_Q = torch.square(max_bound_loss_Q).mean()
@@ -228,15 +237,30 @@ def train(
228237
if self.save_every > 0 and self.iter_count % self.save_every == 0:
229238
self.save(filename=self.model_name, directory=self.save_directory)
230239

231-
def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_step, distance_norm, goal_reward, reward):
240+
def get_max_bound(
241+
self,
242+
next_state,
243+
discount,
244+
max_ang_vel,
245+
max_lin_vel,
246+
time_step,
247+
distance_norm,
248+
goal_reward,
249+
reward,
250+
done,
251+
):
232252
cos = next_state[:, -4]
233253
sin = next_state[:, -3]
234254
theta = torch.atan2(sin, cos)
235255

236256
turn_steps = theta / (max_ang_vel * time_step)
237257
full_turn_steps = torch.floor(turn_steps.abs())
238258
turn_rew = [
239-
-1 * discount ** step * max_ang_vel if step else torch.zeros(1, device=self.device)
259+
(
260+
-1 * discount**step * max_ang_vel
261+
if step
262+
else torch.zeros(1, device=self.device)
263+
)
240264
for step in full_turn_steps
241265
]
242266
final_turn = turn_steps.abs() - full_turn_steps
@@ -255,18 +279,20 @@ def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_ste
255279
final_steps = torch.ceil(distances) + full_turn_steps
256280
inter_steps = torch.trunc(distances) + full_turn_steps
257281
final_discount = torch.tensor(
258-
[discount ** pw for pw in final_steps], device=self.device
282+
[discount**pw for pw in final_steps], device=self.device
259283
)
260284
final_rew = (
261-
torch.ones_like(distances, device=self.device) * goal_reward * final_discount
285+
torch.ones_like(distances, device=self.device)
286+
* goal_reward
287+
* final_discount
262288
)
263289

264290
max_inter_steps = inter_steps.max()
265291
exponents = torch.arange(
266292
1, max_inter_steps + 1, dtype=torch.float32, device=self.device
267293
)
268294
discount_exponents = torch.tensor(
269-
[discount ** e for e in exponents], device=self.device
295+
[discount**e for e in exponents], device=self.device
270296
)
271297
inter_rew = torch.tensor(
272298
[
@@ -281,7 +307,7 @@ def get_max_bound(self, next_state, discount, max_ang_vel, max_lin_vel, time_ste
281307
device=self.device,
282308
)
283309
max_future_rew = full_turn_rew + final_rew + inter_rew
284-
max_bound = reward + max_future_rew.view(-1, 1)
310+
max_bound = reward + (1 - done) * max_future_rew.view(-1, 1)
285311
return max_bound
286312

287313
def save(self, filename, directory):

0 commit comments

Comments
 (0)