11
11
12
12
13
13
class Actor (nn .Module ):
14
+ """
15
+ Actor network for the CNNTD3 agent.
16
+
17
+ This network takes as input a state composed of laser scan data, goal position encoding,
18
+ and previous action. It processes the scan through a 1D CNN stack and embeds the other
19
+ inputs before merging all features through fully connected layers to output a continuous
20
+ action vector.
21
+
22
+ Args:
23
+ action_dim (int): The dimension of the action space.
24
+
25
+ Architecture:
26
+ - 1D CNN layers process the laser scan data.
27
+ - Fully connected layers embed the goal vector (cos, sin, distance) and last action.
28
+ - Combined features are passed through two fully connected layers with LeakyReLU.
29
+ - Final action output is scaled with Tanh to bound the values.
30
+ """
31
+
14
32
def __init__ (self , action_dim ):
15
33
super (Actor , self ).__init__ ()
16
34
@@ -29,6 +47,17 @@ def __init__(self, action_dim):
29
47
self .tanh = nn .Tanh ()
30
48
31
49
def forward (self , s ):
50
+ """
51
+ Forward pass through the Actor network.
52
+
53
+ Args:
54
+ s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
55
+ The last 5 elements are [distance, cos, sin, lin_vel, ang_vel].
56
+
57
+ Returns:
58
+ torch.Tensor: Action tensor of shape (batch_size, action_dim),
59
+ with values in range [-1, 1] due to tanh activation.
60
+ """
32
61
if len (s .shape ) == 1 :
33
62
s = s .unsqueeze (0 )
34
63
laser = s [:, :- 5 ]
@@ -54,6 +83,24 @@ def forward(self, s):
54
83
55
84
56
85
class Critic (nn .Module ):
86
+ """
87
+ Critic network for the CNNTD3 agent.
88
+
89
+ The Critic estimates Q-values for state-action pairs using two separate sub-networks
90
+ (Q1 and Q2), as required by the TD3 algorithm. Each sub-network uses a combination of
91
+ CNN-extracted features, embedded goal and previous action features, and the current action.
92
+
93
+ Args:
94
+ action_dim (int): The dimension of the action space.
95
+
96
+ Architecture:
97
+ - Shared CNN layers process the laser scan input.
98
+ - Goal and previous action are embedded and concatenated.
99
+ - Each Q-network uses separate fully connected layers to produce scalar Q-values.
100
+ - Both Q-networks receive the full state and current action.
101
+ - Outputs two Q-value tensors (Q1, Q2) for TD3-style training and target smoothing.
102
+ """
103
+
57
104
def __init__ (self , action_dim ):
58
105
super (Critic , self ).__init__ ()
59
106
self .cnn1 = nn .Conv1d (1 , 4 , kernel_size = 8 , stride = 4 )
@@ -82,6 +129,19 @@ def __init__(self, action_dim):
82
129
torch .nn .init .kaiming_uniform_ (self .layer_6 .weight , nonlinearity = "leaky_relu" )
83
130
84
131
def forward (self , s , action ):
132
+ """
133
+ Forward pass through both Q-networks of the Critic.
134
+
135
+ Args:
136
+ s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
137
+ The last 5 elements are [distance, cos, sin, lin_vel, ang_vel].
138
+ action (torch.Tensor): Current action tensor of shape (batch_size, action_dim).
139
+
140
+ Returns:
141
+ tuple:
142
+ - q1 (torch.Tensor): First Q-value estimate (batch_size, 1).
143
+ - q2 (torch.Tensor): Second Q-value estimate (batch_size, 1).
144
+ """
85
145
laser = s [:, :- 5 ]
86
146
goal = s [:, - 5 :- 2 ]
87
147
act = s [:, - 2 :]
@@ -118,6 +178,30 @@ def forward(self, s, action):
118
178
119
179
# CNNTD3 network
120
180
class CNNTD3 (object ):
181
+ """
182
+ CNNTD3 (Twin Delayed Deep Deterministic Policy Gradient with CNN-based inputs) agent for
183
+ continuous control tasks.
184
+
185
+ This class encapsulates the full implementation of the TD3 algorithm using neural network
186
+ architectures for the actor and critic, with optional bounding for critic outputs to
187
+ regularize learning. The agent is designed to train in environments where sensor
188
+ observations (e.g., LiDAR) are used for navigation tasks.
189
+
190
+ Args:
191
+ state_dim (int): Dimension of the input state.
192
+ action_dim (int): Dimension of the output action.
193
+ max_action (float): Maximum magnitude of the action.
194
+ device (torch.device): Torch device to use (CPU or GPU).
195
+ lr (float): Learning rate for both actor and critic optimizers.
196
+ save_every (int): Save model every N training iterations (0 to disable).
197
+ load_model (bool): Whether to load a pre-trained model at initialization.
198
+ save_directory (Path): Path to the directory for saving model checkpoints.
199
+ model_name (str): Base name for the saved model files.
200
+ load_directory (Path): Path to load model checkpoints from (if `load_model=True`).
201
+ use_max_bound (bool): Whether to apply maximum Q-value bounding during training.
202
+ bound_weight (float): Weight for the bounding loss term in total loss.
203
+ """
204
+
121
205
def __init__ (
122
206
self ,
123
207
state_dim ,
@@ -160,6 +244,16 @@ def __init__(
160
244
self .bound_weight = bound_weight
161
245
162
246
def get_action (self , obs , add_noise ):
247
+ """
248
+ Selects an action for a given observation.
249
+
250
+ Args:
251
+ obs (np.ndarray): The current observation/state.
252
+ add_noise (bool): Whether to add exploration noise to the action.
253
+
254
+ Returns:
255
+ np.ndarray: The selected action.
256
+ """
163
257
if add_noise :
164
258
return (
165
259
self .act (obs ) + np .random .normal (0 , 0.2 , size = self .action_dim )
@@ -168,6 +262,15 @@ def get_action(self, obs, add_noise):
168
262
return self .act (obs )
169
263
170
264
def act (self , state ):
265
+ """
266
+ Computes the deterministic action from the actor network for a given state.
267
+
268
+ Args:
269
+ state (np.ndarray): Input state.
270
+
271
+ Returns:
272
+ np.ndarray: Action predicted by the actor network.
273
+ """
171
274
# Function to get the action from the actor
172
275
state = torch .Tensor (state ).to (self .device )
173
276
return self .actor (state ).cpu ().data .numpy ().flatten ()
@@ -189,6 +292,24 @@ def train(
189
292
distance_norm = 10 ,
190
293
time_step = 0.3 ,
191
294
):
295
+ """
296
+ Trains the CNNTD3 agent using sampled batches from the replay buffer.
297
+
298
+ Args:
299
+ replay_buffer (ReplayBuffer): Buffer storing environment transitions.
300
+ iterations (int): Number of training iterations.
301
+ batch_size (int): Size of each training batch.
302
+ discount (float): Discount factor for future rewards.
303
+ tau (float): Soft update rate for target networks.
304
+ policy_noise (float): Std. dev. of noise added to target policy.
305
+ noise_clip (float): Maximum value for target policy noise.
306
+ policy_freq (int): Frequency of actor and target network updates.
307
+ max_lin_vel (float): Maximum linear velocity for bounding calculations.
308
+ max_ang_vel (float): Maximum angular velocity for bounding calculations.
309
+ goal_reward (float): Reward value for reaching the goal.
310
+ distance_norm (float): Normalization factor for distance in bounding.
311
+ time_step (float): Time delta between steps.
312
+ """
192
313
av_Q = 0
193
314
max_Q = - inf
194
315
av_loss = 0
@@ -295,6 +416,13 @@ def train(
295
416
self .save (filename = self .model_name , directory = self .save_directory )
296
417
297
418
def save (self , filename , directory ):
419
+ """
420
+ Saves the current model parameters to the specified directory.
421
+
422
+ Args:
423
+ filename (str): Base filename for saved files.
424
+ directory (Path): Path to save the model files.
425
+ """
298
426
Path (directory ).mkdir (parents = True , exist_ok = True )
299
427
torch .save (self .actor .state_dict (), "%s/%s_actor.pth" % (directory , filename ))
300
428
torch .save (
@@ -308,6 +436,13 @@ def save(self, filename, directory):
308
436
)
309
437
310
438
def load (self , filename , directory ):
439
+ """
440
+ Loads model parameters from the specified directory.
441
+
442
+ Args:
443
+ filename (str): Base filename for saved files.
444
+ directory (Path): Path to load the model files from.
445
+ """
311
446
self .actor .load_state_dict (
312
447
torch .load ("%s/%s_actor.pth" % (directory , filename ))
313
448
)
@@ -323,7 +458,24 @@ def load(self, filename, directory):
323
458
print (f"Loaded weights from: { directory } " )
324
459
325
460
def prepare_state (self , latest_scan , distance , cos , sin , collision , goal , action ):
326
- # update the returned data from ROS into a form used for learning in the current model
461
+ """
462
+ Prepares the environment's raw sensor data and navigation variables into
463
+ a format suitable for learning.
464
+
465
+ Args:
466
+ latest_scan (list or np.ndarray): Raw scan data (e.g., LiDAR).
467
+ distance (float): Distance to goal.
468
+ cos (float): Cosine of heading angle to goal.
469
+ sin (float): Sine of heading angle to goal.
470
+ collision (bool): Collision status (True if collided).
471
+ goal (bool): Goal reached status.
472
+ action (list or np.ndarray): Last action taken [lin_vel, ang_vel].
473
+
474
+ Returns:
475
+ tuple:
476
+ - state (list): Normalized and concatenated state vector.
477
+ - terminal (int): Terminal flag (1 if collision or goal, else 0).
478
+ """
327
479
latest_scan = np .array (latest_scan )
328
480
329
481
inf_mask = np .isinf (latest_scan )
0 commit comments