File tree Expand file tree Collapse file tree 4 files changed +14
-14
lines changed Expand file tree Collapse file tree 4 files changed +14
-14
lines changed Original file line number Diff line number Diff line change @@ -230,8 +230,8 @@ def train(
230
230
max_b = max (max_b , torch .max (max_bound ))
231
231
av_bound += torch .mean (max_bound )
232
232
233
- max_bound_Q = torch . min (current_Q , max_bound )
234
- max_bound_loss = F . mse_loss ( current_Q , max_bound_Q )
233
+ max_excess_Q = F . relu (current_Q - max_bound )
234
+ max_bound_loss = ( max_excess_Q ** 2 ). mean ( )
235
235
# Calculate the loss between the current Q value and the target Q value
236
236
loss_target_Q = F .mse_loss (current_Q , target_Q )
237
237
Original file line number Diff line number Diff line change @@ -247,16 +247,16 @@ def train(
247
247
)
248
248
max_b += max (max_b , torch .max (max_bound ))
249
249
av_bound += torch .mean (max_bound )
250
- max_bound_Q1 = torch .min (current_Q1 , max_bound )
251
- max_bound_loss_Q1 = F .mse_loss (current_Q1 , max_bound_Q1 )
252
- max_bound_Q2 = torch .min (current_Q2 , max_bound )
253
- max_bound_loss_Q2 = F .mse_loss (current_Q2 , max_bound_Q2 )
250
+ max_excess_Q1 = F .relu (current_Q1 - max_bound )
251
+ max_bound_loss_Q1 = (max_excess_Q1 ** 2 ).mean ()
252
+ max_excess_Q2 = F .relu (current_Q2 - max_bound )
253
+ max_bound_loss_Q2 = (max_excess_Q2 ** 2 ).mean ()
254
+ max_bound_loss = self .bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2 )
254
255
255
256
# Calculate the loss between the current Q value and the target Q value
256
257
loss_target_Q = F .mse_loss (current_Q1 , target_Q ) + F .mse_loss (
257
258
current_Q2 , target_Q
258
259
)
259
- max_bound_loss = self .bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2 )
260
260
loss = loss_target_Q + max_bound_loss
261
261
# Perform the gradient descent
262
262
self .critic_optimizer .zero_grad ()
Original file line number Diff line number Diff line change @@ -183,8 +183,8 @@ def train(
183
183
max_b = max (max_b , torch .max (max_bound ))
184
184
av_bound += torch .mean (max_bound )
185
185
186
- max_bound_Q = torch . min (current_Q , max_bound )
187
- max_bound_loss = F . mse_loss ( current_Q , max_bound_Q )
186
+ max_excess_Q = F . relu (current_Q - max_bound )
187
+ max_bound_loss = ( max_excess_Q ** 2 ). mean ( )
188
188
# Calculate the loss between the current Q value and the target Q value
189
189
loss_target_Q = F .mse_loss (current_Q , target_Q )
190
190
Original file line number Diff line number Diff line change @@ -201,16 +201,16 @@ def train(
201
201
)
202
202
max_b += max (max_b , torch .max (max_bound ))
203
203
av_bound += torch .mean (max_bound )
204
- max_bound_Q1 = torch .min (current_Q1 , max_bound )
205
- max_bound_loss_Q1 = F .mse_loss (current_Q1 , max_bound_Q1 )
206
- max_bound_Q2 = torch .min (current_Q2 , max_bound )
207
- max_bound_loss_Q2 = F .mse_loss (current_Q2 , max_bound_Q2 )
204
+ max_excess_Q1 = F .relu (current_Q1 - max_bound )
205
+ max_bound_loss_Q1 = (max_excess_Q1 ** 2 ).mean ()
206
+ max_excess_Q2 = F .relu (current_Q2 - max_bound )
207
+ max_bound_loss_Q2 = (max_excess_Q2 ** 2 ).mean ()
208
+ max_bound_loss = self .bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2 )
208
209
209
210
# Calculate the loss between the current Q value and the target Q value
210
211
loss_target_Q = F .mse_loss (current_Q1 , target_Q ) + F .mse_loss (
211
212
current_Q2 , target_Q
212
213
)
213
- max_bound_loss = self .bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2 )
214
214
loss = loss_target_Q + max_bound_loss
215
215
# Perform the gradient descent
216
216
self .critic_optimizer .zero_grad ()
You can’t perform that action at this time.
0 commit comments