7
7
from numpy import inf
8
8
from torch .utils .tensorboard import SummaryWriter
9
9
10
+ from robot_nav .utils import get_max_bound
11
+
10
12
11
13
class Actor (nn .Module ):
12
14
def __init__ (self , state_dim , action_dim ):
@@ -81,6 +83,8 @@ def __init__(
81
83
save_directory = Path ("robot_nav/models/TD3/checkpoint" ),
82
84
model_name = "TD3" ,
83
85
load_directory = Path ("robot_nav/models/TD3/checkpoint" ),
86
+ use_max_bound = False ,
87
+ bound_weight = 0.25 ,
84
88
):
85
89
# Initialize the Actor network
86
90
self .device = device
@@ -98,13 +102,15 @@ def __init__(
98
102
self .action_dim = action_dim
99
103
self .max_action = max_action
100
104
self .state_dim = state_dim
101
- self .writer = SummaryWriter ()
105
+ self .writer = SummaryWriter (comment = model_name )
102
106
self .iter_count = 0
103
107
if load_model :
104
108
self .load (filename = model_name , directory = load_directory )
105
109
self .save_every = save_every
106
110
self .model_name = model_name
107
111
self .save_directory = save_directory
112
+ self .use_max_bound = use_max_bound
113
+ self .bound_weight = bound_weight
108
114
109
115
def get_action (self , obs , add_noise ):
110
116
if add_noise :
@@ -130,6 +136,11 @@ def train(
130
136
policy_noise = 0.2 ,
131
137
noise_clip = 0.5 ,
132
138
policy_freq = 2 ,
139
+ max_lin_vel = 0.5 ,
140
+ max_ang_vel = 1 ,
141
+ goal_reward = 100 ,
142
+ distance_norm = 10 ,
143
+ time_step = 0.3 ,
133
144
):
134
145
av_Q = 0
135
146
max_Q = - inf
@@ -177,6 +188,25 @@ def train(
177
188
# Calculate the loss between the current Q value and the target Q value
178
189
loss = F .mse_loss (current_Q1 , target_Q ) + F .mse_loss (current_Q2 , target_Q )
179
190
191
+ if self .use_max_bound :
192
+ max_bound = get_max_bound (
193
+ next_state ,
194
+ discount ,
195
+ max_ang_vel ,
196
+ max_lin_vel ,
197
+ time_step ,
198
+ distance_norm ,
199
+ goal_reward ,
200
+ reward ,
201
+ done ,
202
+ self .device ,
203
+ )
204
+ max_excess_Q1 = F .relu (current_Q1 - max_bound )
205
+ max_excess_Q2 = F .relu (current_Q2 - max_bound )
206
+ max_bound_loss = (max_excess_Q1 ** 2 ).mean () + (max_excess_Q2 ** 2 ).mean ()
207
+ # Add loss for Q values exceeding maximum possible upper bound
208
+ loss += self .bound_weight * max_bound_loss
209
+
180
210
# Perform the gradient descent
181
211
self .critic_optimizer .zero_grad ()
182
212
loss .backward ()
0 commit comments