Skip to content

Commit 451e50b

Browse files
committed
fix the bounds loss
1 parent e0c5596 commit 451e50b

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

robot_nav/models/BPG/BCNNPG.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def train(
230230
max_b = max(max_b, torch.max(max_bound))
231231
av_bound += torch.mean(max_bound)
232232

233-
max_bound_Q = torch.min(current_Q, max_bound)
234-
max_bound_loss = F.mse_loss(current_Q, max_bound_Q)
233+
max_excess_Q = F.relu(current_Q - max_bound)
234+
max_bound_loss = (max_excess_Q ** 2).mean()
235235
# Calculate the loss between the current Q value and the target Q value
236236
loss_target_Q = F.mse_loss(current_Q, target_Q)
237237

robot_nav/models/BPG/BCNNTD3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,16 @@ def train(
247247
)
248248
max_b += max(max_b, torch.max(max_bound))
249249
av_bound += torch.mean(max_bound)
250-
max_bound_Q1 = torch.min(current_Q1, max_bound)
251-
max_bound_loss_Q1 = F.mse_loss(current_Q1, max_bound_Q1)
252-
max_bound_Q2 = torch.min(current_Q2, max_bound)
253-
max_bound_loss_Q2 = F.mse_loss(current_Q2, max_bound_Q2)
250+
max_excess_Q1 = F.relu(current_Q1 - max_bound)
251+
max_bound_loss_Q1 = (max_excess_Q1 ** 2).mean()
252+
max_excess_Q2 = F.relu(current_Q2 - max_bound)
253+
max_bound_loss_Q2 = (max_excess_Q2 ** 2).mean()
254+
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
254255

255256
# Calculate the loss between the current Q value and the target Q value
256257
loss_target_Q = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
257258
current_Q2, target_Q
258259
)
259-
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
260260
loss = loss_target_Q + max_bound_loss
261261
# Perform the gradient descent
262262
self.critic_optimizer.zero_grad()

robot_nav/models/BPG/BPG.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def train(
183183
max_b = max(max_b, torch.max(max_bound))
184184
av_bound += torch.mean(max_bound)
185185

186-
max_bound_Q = torch.min(current_Q, max_bound)
187-
max_bound_loss = F.mse_loss(current_Q, max_bound_Q)
186+
max_excess_Q = F.relu(current_Q - max_bound)
187+
max_bound_loss = (max_excess_Q ** 2).mean()
188188
# Calculate the loss between the current Q value and the target Q value
189189
loss_target_Q = F.mse_loss(current_Q, target_Q)
190190

robot_nav/models/BPG/BTD3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,16 @@ def train(
201201
)
202202
max_b += max(max_b, torch.max(max_bound))
203203
av_bound += torch.mean(max_bound)
204-
max_bound_Q1 = torch.min(current_Q1, max_bound)
205-
max_bound_loss_Q1 = F.mse_loss(current_Q1, max_bound_Q1)
206-
max_bound_Q2 = torch.min(current_Q2, max_bound)
207-
max_bound_loss_Q2 = F.mse_loss(current_Q2, max_bound_Q2)
204+
max_excess_Q1 = F.relu(current_Q1 - max_bound)
205+
max_bound_loss_Q1 = (max_excess_Q1 ** 2).mean()
206+
max_excess_Q2 = F.relu(current_Q2 - max_bound)
207+
max_bound_loss_Q2 = (max_excess_Q2 ** 2).mean()
208+
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
208209

209210
# Calculate the loss between the current Q value and the target Q value
210211
loss_target_Q = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
211212
current_Q2, target_Q
212213
)
213-
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
214214
loss = loss_target_Q + max_bound_loss
215215
# Perform the gradient descent
216216
self.critic_optimizer.zero_grad()

0 commit comments

Comments
 (0)