Skip to content

Commit 9d12c3e

Browse files
committed
fix bcnnpg
1 parent d508138 commit 9d12c3e

File tree

2 files changed

+42
-40
lines changed

2 files changed

+42
-40
lines changed

robot_nav/models/BPG/BCNNPG.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,20 @@ def act(self, state):
154154

155155
# training cycle
156156
def train(
157-
self,
158-
replay_buffer,
159-
iterations,
160-
batch_size,
161-
discount=0.99,
162-
tau=0.005,
163-
policy_noise=0.2,
164-
noise_clip=0.5,
165-
policy_freq=2,
166-
max_lin_vel=0.5,
167-
max_ang_vel=1,
168-
goal_reward=100,
169-
distance_norm=10,
170-
time_step=0.3,
157+
self,
158+
replay_buffer,
159+
iterations,
160+
batch_size,
161+
discount=0.99,
162+
tau=0.005,
163+
policy_noise=0.2,
164+
noise_clip=0.5,
165+
policy_freq=2,
166+
max_lin_vel=0.5,
167+
max_ang_vel=1,
168+
goal_reward=100,
169+
distance_norm=10,
170+
time_step=0.3,
171171
):
172172
av_Q = 0
173173
max_b = 0
@@ -203,15 +203,13 @@ def train(
203203
next_action = (next_action + noise).clamp(-self.max_action, self.max_action)
204204

205205
# Calculate the Q values from the critic-target network for the next state-action pair
206-
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
206+
target_Q = self.critic_target(next_state, next_action)
207207

208-
# Select the minimal Q value from the 2 calculated values
209-
target_Q = torch.min(target_Q1, target_Q2)
210208
av_Q += torch.mean(target_Q)
211209
max_Q = max(max_Q, torch.max(target_Q))
212-
213210
# Calculate the final Q value from the target network parameters by using Bellman equation
214211
target_Q = reward + ((1 - done) * discount * target_Q).detach()
212+
215213
# Get the Q values of the basis networks with the current parameters
216214
current_Q = self.critic(state, action)
217215

@@ -224,15 +222,18 @@ def train(
224222
distance_norm,
225223
goal_reward,
226224
reward,
225+
done,
227226
)
228-
max_b += max(max_b, torch.max(max_bound))
227+
max_b = max(max_b, torch.max(max_bound))
229228
max_bound_loss_Q = current_Q - max_bound
230229
max_bound_loss_Q[max_bound_loss_Q < 0] = 0
231230
max_bound_loss_Q = torch.square(max_bound_loss_Q).mean()
231+
max_bound_loss = max_bound_loss_Q
232232

233233
# Calculate the loss between the current Q value and the target Q value
234234
loss_target_Q = F.mse_loss(current_Q, target_Q)
235-
max_bound_loss = self.bound_weight * max_bound_loss_Q
235+
236+
max_bound_loss = self.bound_weight * max_bound_loss
236237
loss = loss_target_Q + max_bound_loss
237238
# Perform the gradient descent
238239
self.critic_optimizer.zero_grad()
@@ -242,7 +243,7 @@ def train(
242243
if it % policy_freq == 0:
243244
# Maximize the actor output value by performing gradient descent on negative Q values
244245
# (essentially perform gradient ascent)
245-
actor_grad, _ = self.critic(state, self.actor(state))
246+
actor_grad = self.critic(state, self.actor(state))
246247
actor_grad = -actor_grad.mean()
247248
self.actor_optimizer.zero_grad()
248249
actor_grad.backward()
@@ -251,15 +252,15 @@ def train(
251252
# Use soft update to update the actor-target network parameters by
252253
# infusing small amount of current parameters
253254
for param, target_param in zip(
254-
self.actor.parameters(), self.actor_target.parameters()
255+
self.actor.parameters(), self.actor_target.parameters()
255256
):
256257
target_param.data.copy_(
257258
tau * param.data + (1 - tau) * target_param.data
258259
)
259260
# Use soft update to update the critic-target network parameters by infusing
260261
# small amount of current parameters
261262
for param, target_param in zip(
262-
self.critic.parameters(), self.critic_target.parameters()
263+
self.critic.parameters(), self.critic_target.parameters()
263264
):
264265
target_param.data.copy_(
265266
tau * param.data + (1 - tau) * target_param.data
@@ -284,15 +285,16 @@ def train(
284285
self.save(filename=self.model_name, directory=self.save_directory)
285286

286287
def get_max_bound(
287-
self,
288-
next_state,
289-
discount,
290-
max_ang_vel,
291-
max_lin_vel,
292-
time_step,
293-
distance_norm,
294-
goal_reward,
295-
reward,
288+
self,
289+
next_state,
290+
discount,
291+
max_ang_vel,
292+
max_lin_vel,
293+
time_step,
294+
distance_norm,
295+
goal_reward,
296+
reward,
297+
done,
296298
):
297299
cos = next_state[:, -4]
298300
sin = next_state[:, -3]
@@ -302,7 +304,7 @@ def get_max_bound(
302304
full_turn_steps = torch.floor(turn_steps.abs())
303305
turn_rew = [
304306
(
305-
-1 * discount**step * max_ang_vel
307+
-1 * discount ** step * max_ang_vel
306308
if step
307309
else torch.zeros(1, device=self.device)
308310
)
@@ -324,20 +326,20 @@ def get_max_bound(
324326
final_steps = torch.ceil(distances) + full_turn_steps
325327
inter_steps = torch.trunc(distances) + full_turn_steps
326328
final_discount = torch.tensor(
327-
[discount**pw for pw in final_steps], device=self.device
329+
[discount ** pw for pw in final_steps], device=self.device
328330
)
329331
final_rew = (
330-
torch.ones_like(distances, device=self.device)
331-
* goal_reward
332-
* final_discount
332+
torch.ones_like(distances, device=self.device)
333+
* goal_reward
334+
* final_discount
333335
)
334336

335337
max_inter_steps = inter_steps.max()
336338
exponents = torch.arange(
337339
1, max_inter_steps + 1, dtype=torch.float32, device=self.device
338340
)
339341
discount_exponents = torch.tensor(
340-
[discount**e for e in exponents], device=self.device
342+
[discount ** e for e in exponents], device=self.device
341343
)
342344
inter_rew = torch.tensor(
343345
[
@@ -352,7 +354,7 @@ def get_max_bound(
352354
device=self.device,
353355
)
354356
max_future_rew = full_turn_rew + final_rew + inter_rew
355-
max_bound = reward + max_future_rew.view(-1, 1)
357+
max_bound = reward + (1 - done) * max_future_rew.view(-1, 1)
356358
return max_bound
357359

358360
def save(self, filename, directory):

robot_nav/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def main(args=None):
1818
"""Main training function"""
1919
action_dim = 2 # number of actions produced by the model
2020
max_action = 1 # maximum absolute value of output actions
21-
state_dim = 25 # number of input values in the neural network (vector length of state input)
21+
state_dim = 185 # number of input values in the neural network (vector length of state input)
2222
device = torch.device(
2323
"cuda" if torch.cuda.is_available() else "cpu"
2424
) # using cuda if it is available, cpu otherwise

0 commit comments

Comments
 (0)