11
11
12
12
13
13
class Actor (nn .Module ):
14
+ """
15
+ Actor network for the TD3 algorithm.
16
+
17
+ This neural network maps states to actions using a feedforward architecture with
18
+ LeakyReLU activations and a final Tanh output to bound the actions in [-1, 1].
19
+
20
+ Architecture:
21
+ Input: state_dim
22
+ Hidden Layer 1: 400 units, LeakyReLU
23
+ Hidden Layer 2: 300 units, LeakyReLU
24
+ Output Layer: action_dim, Tanh
25
+
26
+ Args:
27
+ state_dim (int): Dimension of the input state.
28
+ action_dim (int): Dimension of the action output.
29
+ """
30
+
14
31
def __init__ (self , state_dim , action_dim ):
15
32
super (Actor , self ).__init__ ()
16
33
@@ -22,13 +39,39 @@ def __init__(self, state_dim, action_dim):
22
39
self .tanh = nn .Tanh ()
23
40
24
41
def forward (self , s ):
42
+ """
43
+ Perform a forward pass through the actor network.
44
+
45
+ Args:
46
+ s (torch.Tensor): Input state tensor.
47
+
48
+ Returns:
49
+ torch.Tensor: Action output tensor after Tanh activation.
50
+ """
25
51
s = F .leaky_relu (self .layer_1 (s ))
26
52
s = F .leaky_relu (self .layer_2 (s ))
27
53
a = self .tanh (self .layer_3 (s ))
28
54
return a
29
55
30
56
31
57
class Critic (nn .Module ):
58
+ """
59
+ Critic network for the TD3 algorithm.
60
+
61
+ This class defines two Q-value estimators (Q1 and Q2) using separate subnetworks.
62
+ Each Q-network takes both state and action as input and outputs a scalar Q-value.
63
+
64
+ Architecture for each Q-network:
65
+ Input: state_dim and action_dim
66
+ - State pathway: Linear + LeakyReLU → 400 → 300
67
+ - Action pathway: Linear → 300
68
+ - Combined pathway: LeakyReLU(Linear(state) + Linear(action) + bias) → 1
69
+
70
+ Args:
71
+ state_dim (int): Dimension of the input state.
72
+ action_dim (int): Dimension of the input action.
73
+ """
74
+
32
75
def __init__ (self , state_dim , action_dim ):
33
76
super (Critic , self ).__init__ ()
34
77
@@ -51,6 +94,18 @@ def __init__(self, state_dim, action_dim):
51
94
torch .nn .init .kaiming_uniform_ (self .layer_6 .weight , nonlinearity = "leaky_relu" )
52
95
53
96
def forward (self , s , a ):
97
+ """
98
+ Perform a forward pass through both Q-networks.
99
+
100
+ Args:
101
+ s (torch.Tensor): Input state tensor.
102
+ a (torch.Tensor): Input action tensor.
103
+
104
+ Returns:
105
+ tuple:
106
+ - q1 (torch.Tensor): Output Q-value from the first critic network.
107
+ - q2 (torch.Tensor): Output Q-value from the second critic network.
108
+ """
54
109
s1 = F .leaky_relu (self .layer_1 (s ))
55
110
self .layer_2_s (s1 )
56
111
self .layer_2_a (a )
@@ -86,8 +141,28 @@ def __init__(
86
141
use_max_bound = False ,
87
142
bound_weight = 0.25 ,
88
143
):
89
- # Initialize the Actor network
144
+ """
145
+ Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
146
+
147
+ This class implements the TD3 reinforcement learning algorithm for continuous control.
148
+ It uses an Actor-Critic architecture with target networks and delayed policy updates.
149
+
150
+ Args:
151
+ state_dim (int): Dimension of the input state.
152
+ action_dim (int): Dimension of the action space.
153
+ max_action (float): Maximum allowed value for actions.
154
+ device (torch.device): Device to run the model on (CPU or CUDA).
155
+ lr (float, optional): Learning rate for both actor and critic. Default is 1e-4.
156
+ save_every (int, optional): Save model every `save_every` iterations. Default is 0.
157
+ load_model (bool, optional): Whether to load model from checkpoint. Default is False.
158
+ save_directory (Path, optional): Directory to save model checkpoints.
159
+ model_name (str, optional): Name to use when saving/loading models.
160
+ load_directory (Path, optional): Directory to load model checkpoints from.
161
+ use_max_bound (bool, optional): Whether to apply maximum Q-value bounding during training.
162
+ bound_weight (float, optional): Weight for the max-bound loss penalty.
163
+ """
90
164
self .device = device
165
+ # Initialize the Actor network
91
166
self .actor = Actor (state_dim , action_dim ).to (self .device )
92
167
self .actor_target = Actor (state_dim , action_dim ).to (self .device )
93
168
self .actor_target .load_state_dict (self .actor .state_dict ())
@@ -113,6 +188,16 @@ def __init__(
113
188
self .bound_weight = bound_weight
114
189
115
190
def get_action (self , obs , add_noise ):
191
+ """
192
+ Get an action from the current policy with optional exploration noise.
193
+
194
+ Args:
195
+ obs (np.ndarray): The current state observation.
196
+ add_noise (bool): Whether to add exploration noise.
197
+
198
+ Returns:
199
+ np.ndarray: The chosen action clipped to [-max_action, max_action].
200
+ """
116
201
if add_noise :
117
202
return (
118
203
self .act (obs ) + np .random .normal (0 , 0.2 , size = self .action_dim )
@@ -121,7 +206,15 @@ def get_action(self, obs, add_noise):
121
206
return self .act (obs )
122
207
123
208
def act (self , state ):
124
- # Function to get the action from the actor
209
+ """
210
+ Compute the action using the actor network without exploration noise.
211
+
212
+ Args:
213
+ state (np.ndarray): The current environment state.
214
+
215
+ Returns:
216
+ np.ndarray: The deterministic action predicted by the actor.
217
+ """
125
218
state = torch .Tensor (state ).to (self .device )
126
219
return self .actor (state ).cpu ().data .numpy ().flatten ()
127
220
@@ -142,6 +235,24 @@ def train(
142
235
distance_norm = 10 ,
143
236
time_step = 0.3 ,
144
237
):
238
+ """
239
+ Train the TD3 agent using batches sampled from the replay buffer.
240
+
241
+ Args:
242
+ replay_buffer: The replay buffer to sample experiences from.
243
+ iterations (int): Number of training iterations to perform.
244
+ batch_size (int): Size of each mini-batch.
245
+ discount (float): Discount factor gamma for future rewards.
246
+ tau (float): Soft update rate for target networks.
247
+ policy_noise (float): Stddev of Gaussian noise added to target actions.
248
+ noise_clip (float): Maximum magnitude of noise added to target actions.
249
+ policy_freq (int): Frequency of policy (actor) updates.
250
+ max_lin_vel (float): Max linear velocity used for upper bound estimation.
251
+ max_ang_vel (float): Max angular velocity used for upper bound estimation.
252
+ goal_reward (float): Reward given for reaching the goal.
253
+ distance_norm (float): Distance normalization factor.
254
+ time_step (float): Time step used in upper bound calculations.
255
+ """
145
256
av_Q = 0
146
257
max_Q = - inf
147
258
av_loss = 0
@@ -248,6 +359,13 @@ def train(
248
359
self .save (filename = self .model_name , directory = self .save_directory )
249
360
250
361
def save (self , filename , directory ):
362
+ """
363
+ Save the actor and critic networks (and their targets) to disk.
364
+
365
+ Args:
366
+ filename (str): Name to use when saving model files.
367
+ directory (Path): Directory where models should be saved.
368
+ """
251
369
Path (directory ).mkdir (parents = True , exist_ok = True )
252
370
torch .save (self .actor .state_dict (), "%s/%s_actor.pth" % (directory , filename ))
253
371
torch .save (
@@ -261,6 +379,13 @@ def save(self, filename, directory):
261
379
)
262
380
263
381
def load (self , filename , directory ):
382
+ """
383
+ Load the actor and critic networks (and their targets) from disk.
384
+
385
+ Args:
386
+ filename (str): Name used when saving the models.
387
+ directory (Path): Directory where models are saved.
388
+ """
264
389
self .actor .load_state_dict (
265
390
torch .load ("%s/%s_actor.pth" % (directory , filename ))
266
391
)
@@ -276,7 +401,26 @@ def load(self, filename, directory):
276
401
print (f"Loaded weights from: { directory } " )
277
402
278
403
def prepare_state (self , latest_scan , distance , cos , sin , collision , goal , action ):
279
- # update the returned data from ROS into a form used for learning in the current model
404
+ """
405
+ Prepare the input state vector for training or inference.
406
+
407
+ Combines processed laser scan data, goal vector, and past action
408
+ into a normalized state input matching the input dimension.
409
+
410
+ Args:
411
+ latest_scan (list or np.ndarray): Laser scan data.
412
+ distance (float): Distance to goal.
413
+ cos (float): Cosine of the heading angle to goal.
414
+ sin (float): Sine of the heading angle to goal.
415
+ collision (bool): Whether a collision occurred.
416
+ goal (bool): Whether the goal has been reached.
417
+ action (list or np.ndarray): Last executed action [linear_vel, angular_vel].
418
+
419
+ Returns:
420
+ tuple:
421
+ - state (list): Prepared and normalized state vector.
422
+ - terminal (int): 1 if episode should terminate (goal or collision), else 0.
423
+ """
280
424
latest_scan = np .array (latest_scan )
281
425
282
426
inf_mask = np .isinf (latest_scan )
0 commit comments