Skip to content

Commit 07bcb41

Browse files
committed
pass max bound by init
1 parent e4763c8 commit 07bcb41

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

robot_nav/models/BPG/BCNNPG.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def __init__(
112112
save_directory=Path("robot_nav/models/BPG/checkpoint"),
113113
model_name="BCNNPG",
114114
load_directory=Path("robot_nav/models/BPG/checkpoint"),
115+
bound_weight=8
115116
):
116117
# Initialize the Actor network
118+
self.bound_weight = bound_weight
117119
self.device = device
118120
self.actor = Actor(action_dim).to(self.device)
119121
self.actor_target = Actor(action_dim).to(self.device)
@@ -223,7 +225,7 @@ def train(
223225

224226
# Calculate the loss between the current Q value and the target Q value
225227
loss_target_Q = F.mse_loss(current_Q, target_Q)
226-
max_bound_loss = 10 * max_bound_loss_Q
228+
max_bound_loss = self.bound_weight * max_bound_loss_Q
227229
loss = loss_target_Q + max_bound_loss
228230
# Perform the gradient descent
229231
self.critic_optimizer.zero_grad()

robot_nav/models/BPG/BCNNTD3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ def __init__(
128128
save_directory=Path("robot_nav/models/BPG/checkpoint"),
129129
model_name="BCNNTD3",
130130
load_directory=Path("robot_nav/models/BPG/checkpoint"),
131+
bound_weight=8
131132
):
132133
# Initialize the Actor network
134+
self.bound_weight = bound_weight
133135
self.device = device
134136
self.actor = Actor(action_dim).to(self.device)
135137
self.actor_target = Actor(action_dim).to(self.device)
@@ -244,7 +246,7 @@ def train(
244246
loss_target_Q = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
245247
current_Q2, target_Q
246248
)
247-
max_bound_loss = 10 * (max_bound_loss_Q1 + max_bound_loss_Q2)
249+
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
248250
loss = loss_target_Q + max_bound_loss
249251
# Perform the gradient descent
250252
self.critic_optimizer.zero_grad()

robot_nav/models/BPG/BPG.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ def __init__(
6565
save_directory=Path("robot_nav/models/BPG/checkpoint"),
6666
model_name="BPG",
6767
load_directory=Path("robot_nav/models/BPG/checkpoint"),
68+
bound_weight=8
6869
):
6970
# Initialize the Actor network
71+
self.bound_weight = bound_weight
7072
self.device = device
7173
self.actor = Actor(state_dim, action_dim).to(self.device)
7274
self.actor_target = Actor(state_dim, action_dim).to(self.device)
@@ -175,7 +177,7 @@ def train(
175177
# Calculate the loss between the current Q value and the target Q value
176178
loss_target_Q = F.mse_loss(current_Q, target_Q)
177179

178-
max_bound_loss = 10 * max_bound_loss
180+
max_bound_loss = self.bound_weight * max_bound_loss
179181
loss = loss_target_Q + max_bound_loss
180182
# Perform the gradient descent
181183
self.critic_optimizer.zero_grad()

robot_nav/models/BPG/BTD3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,10 @@ def __init__(
8282
save_directory=Path("robot_nav/models/BPG/checkpoint"),
8383
model_name="BTD3",
8484
load_directory=Path("robot_nav/models/BPG/checkpoint"),
85+
bound_weight = 8
8586
):
8687
# Initialize the Actor network
88+
self.bound_weight = bound_weight
8789
self.device = device
8890
self.actor = Actor(state_dim, action_dim).to(self.device)
8991
self.actor_target = Actor(state_dim, action_dim).to(self.device)
@@ -197,7 +199,7 @@ def train(
197199
loss_target_Q = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
198200
current_Q2, target_Q
199201
)
200-
max_bound_loss = 10 * (max_bound_loss_Q1 + max_bound_loss_Q2)
202+
max_bound_loss = self.bound_weight * (max_bound_loss_Q1 + max_bound_loss_Q2)
201203
loss = loss_target_Q + max_bound_loss
202204
# Perform the gradient descent
203205
self.critic_optimizer.zero_grad()

0 commit comments

Comments
 (0)