Skip to content

Commit fdac46c

Browse files
committed
add cnntd3 documentation
1 parent 2554ae7 commit fdac46c

File tree

1 file changed

+153
-1
lines changed

1 file changed

+153
-1
lines changed

robot_nav/models/CNNTD3/CNNTD3.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,24 @@
1111

1212

1313
class Actor(nn.Module):
14+
"""
15+
Actor network for the CNNTD3 agent.
16+
17+
This network takes as input a state composed of laser scan data, goal position encoding,
18+
and previous action. It processes the scan through a 1D CNN stack and embeds the other
19+
inputs before merging all features through fully connected layers to output a continuous
20+
action vector.
21+
22+
Args:
23+
action_dim (int): The dimension of the action space.
24+
25+
Architecture:
26+
- 1D CNN layers process the laser scan data.
27+
- Fully connected layers embed the goal vector (cos, sin, distance) and last action.
28+
- Combined features are passed through two fully connected layers with LeakyReLU.
29+
- Final action output is scaled with Tanh to bound the values.
30+
"""
31+
1432
def __init__(self, action_dim):
1533
super(Actor, self).__init__()
1634

@@ -29,6 +47,17 @@ def __init__(self, action_dim):
2947
self.tanh = nn.Tanh()
3048

3149
def forward(self, s):
50+
"""
51+
Forward pass through the Actor network.
52+
53+
Args:
54+
s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
55+
The last 5 elements are [distance, cos, sin, lin_vel, ang_vel].
56+
57+
Returns:
58+
torch.Tensor: Action tensor of shape (batch_size, action_dim),
59+
with values in range [-1, 1] due to tanh activation.
60+
"""
3261
if len(s.shape) == 1:
3362
s = s.unsqueeze(0)
3463
laser = s[:, :-5]
@@ -54,6 +83,24 @@ def forward(self, s):
5483

5584

5685
class Critic(nn.Module):
86+
"""
87+
Critic network for the CNNTD3 agent.
88+
89+
The Critic estimates Q-values for state-action pairs using two separate sub-networks
90+
(Q1 and Q2), as required by the TD3 algorithm. Each sub-network uses a combination of
91+
CNN-extracted features, embedded goal and previous action features, and the current action.
92+
93+
Args:
94+
action_dim (int): The dimension of the action space.
95+
96+
Architecture:
97+
- Shared CNN layers process the laser scan input.
98+
- Goal and previous action are embedded and concatenated.
99+
- Each Q-network uses separate fully connected layers to produce scalar Q-values.
100+
- Both Q-networks receive the full state and current action.
101+
- Outputs two Q-value tensors (Q1, Q2) for TD3-style training and target smoothing.
102+
"""
103+
57104
def __init__(self, action_dim):
58105
super(Critic, self).__init__()
59106
self.cnn1 = nn.Conv1d(1, 4, kernel_size=8, stride=4)
@@ -82,6 +129,19 @@ def __init__(self, action_dim):
82129
torch.nn.init.kaiming_uniform_(self.layer_6.weight, nonlinearity="leaky_relu")
83130

84131
def forward(self, s, action):
132+
"""
133+
Forward pass through both Q-networks of the Critic.
134+
135+
Args:
136+
s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
137+
The last 5 elements are [distance, cos, sin, lin_vel, ang_vel].
138+
action (torch.Tensor): Current action tensor of shape (batch_size, action_dim).
139+
140+
Returns:
141+
tuple:
142+
- q1 (torch.Tensor): First Q-value estimate (batch_size, 1).
143+
- q2 (torch.Tensor): Second Q-value estimate (batch_size, 1).
144+
"""
85145
laser = s[:, :-5]
86146
goal = s[:, -5:-2]
87147
act = s[:, -2:]
@@ -118,6 +178,30 @@ def forward(self, s, action):
118178

119179
# CNNTD3 network
120180
class CNNTD3(object):
181+
"""
182+
CNNTD3 (Twin Delayed Deep Deterministic Policy Gradient with CNN-based inputs) agent for
183+
continuous control tasks.
184+
185+
This class encapsulates the full implementation of the TD3 algorithm using neural network
186+
architectures for the actor and critic, with optional bounding for critic outputs to
187+
regularize learning. The agent is designed to train in environments where sensor
188+
observations (e.g., LiDAR) are used for navigation tasks.
189+
190+
Args:
191+
state_dim (int): Dimension of the input state.
192+
action_dim (int): Dimension of the output action.
193+
max_action (float): Maximum magnitude of the action.
194+
device (torch.device): Torch device to use (CPU or GPU).
195+
lr (float): Learning rate for both actor and critic optimizers.
196+
save_every (int): Save model every N training iterations (0 to disable).
197+
load_model (bool): Whether to load a pre-trained model at initialization.
198+
save_directory (Path): Path to the directory for saving model checkpoints.
199+
model_name (str): Base name for the saved model files.
200+
load_directory (Path): Path to load model checkpoints from (if `load_model=True`).
201+
use_max_bound (bool): Whether to apply maximum Q-value bounding during training.
202+
bound_weight (float): Weight for the bounding loss term in total loss.
203+
"""
204+
121205
def __init__(
122206
self,
123207
state_dim,
@@ -160,6 +244,16 @@ def __init__(
160244
self.bound_weight = bound_weight
161245

162246
def get_action(self, obs, add_noise):
247+
"""
248+
Selects an action for a given observation.
249+
250+
Args:
251+
obs (np.ndarray): The current observation/state.
252+
add_noise (bool): Whether to add exploration noise to the action.
253+
254+
Returns:
255+
np.ndarray: The selected action.
256+
"""
163257
if add_noise:
164258
return (
165259
self.act(obs) + np.random.normal(0, 0.2, size=self.action_dim)
@@ -168,6 +262,15 @@ def get_action(self, obs, add_noise):
168262
return self.act(obs)
169263

170264
def act(self, state):
265+
"""
266+
Computes the deterministic action from the actor network for a given state.
267+
268+
Args:
269+
state (np.ndarray): Input state.
270+
271+
Returns:
272+
np.ndarray: Action predicted by the actor network.
273+
"""
171274
# Function to get the action from the actor
172275
state = torch.Tensor(state).to(self.device)
173276
return self.actor(state).cpu().data.numpy().flatten()
@@ -189,6 +292,24 @@ def train(
189292
distance_norm=10,
190293
time_step=0.3,
191294
):
295+
"""
296+
Trains the CNNTD3 agent using sampled batches from the replay buffer.
297+
298+
Args:
299+
replay_buffer (ReplayBuffer): Buffer storing environment transitions.
300+
iterations (int): Number of training iterations.
301+
batch_size (int): Size of each training batch.
302+
discount (float): Discount factor for future rewards.
303+
tau (float): Soft update rate for target networks.
304+
policy_noise (float): Std. dev. of noise added to target policy.
305+
noise_clip (float): Maximum value for target policy noise.
306+
policy_freq (int): Frequency of actor and target network updates.
307+
max_lin_vel (float): Maximum linear velocity for bounding calculations.
308+
max_ang_vel (float): Maximum angular velocity for bounding calculations.
309+
goal_reward (float): Reward value for reaching the goal.
310+
distance_norm (float): Normalization factor for distance in bounding.
311+
time_step (float): Time delta between steps.
312+
"""
192313
av_Q = 0
193314
max_Q = -inf
194315
av_loss = 0
@@ -295,6 +416,13 @@ def train(
295416
self.save(filename=self.model_name, directory=self.save_directory)
296417

297418
def save(self, filename, directory):
419+
"""
420+
Saves the current model parameters to the specified directory.
421+
422+
Args:
423+
filename (str): Base filename for saved files.
424+
directory (Path): Path to save the model files.
425+
"""
298426
Path(directory).mkdir(parents=True, exist_ok=True)
299427
torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
300428
torch.save(
@@ -308,6 +436,13 @@ def save(self, filename, directory):
308436
)
309437

310438
def load(self, filename, directory):
439+
"""
440+
Loads model parameters from the specified directory.
441+
442+
Args:
443+
filename (str): Base filename for saved files.
444+
directory (Path): Path to load the model files from.
445+
"""
311446
self.actor.load_state_dict(
312447
torch.load("%s/%s_actor.pth" % (directory, filename))
313448
)
@@ -323,7 +458,24 @@ def load(self, filename, directory):
323458
print(f"Loaded weights from: {directory}")
324459

325460
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
326-
# update the returned data from ROS into a form used for learning in the current model
461+
"""
462+
Prepares the environment's raw sensor data and navigation variables into
463+
a format suitable for learning.
464+
465+
Args:
466+
latest_scan (list or np.ndarray): Raw scan data (e.g., LiDAR).
467+
distance (float): Distance to goal.
468+
cos (float): Cosine of heading angle to goal.
469+
sin (float): Sine of heading angle to goal.
470+
collision (bool): Collision status (True if collided).
471+
goal (bool): Goal reached status.
472+
action (list or np.ndarray): Last action taken [lin_vel, ang_vel].
473+
474+
Returns:
475+
tuple:
476+
- state (list): Normalized and concatenated state vector.
477+
- terminal (int): Terminal flag (1 if collision or goal, else 0).
478+
"""
327479
latest_scan = np.array(latest_scan)
328480

329481
inf_mask = np.isinf(latest_scan)

0 commit comments

Comments
 (0)