Skip to content

Commit 10f460c

Browse files
committed
add td3 documentation
1 parent fdac46c commit 10f460c

File tree

1 file changed

+147
-3
lines changed

1 file changed

+147
-3
lines changed

robot_nav/models/TD3/TD3.py

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111

1212

1313
class Actor(nn.Module):
14+
"""
15+
Actor network for the TD3 algorithm.
16+
17+
This neural network maps states to actions using a feedforward architecture with
18+
LeakyReLU activations and a final Tanh output to bound the actions in [-1, 1].
19+
20+
Architecture:
21+
Input: state_dim
22+
Hidden Layer 1: 400 units, LeakyReLU
23+
Hidden Layer 2: 300 units, LeakyReLU
24+
Output Layer: action_dim, Tanh
25+
26+
Args:
27+
state_dim (int): Dimension of the input state.
28+
action_dim (int): Dimension of the action output.
29+
"""
30+
1431
def __init__(self, state_dim, action_dim):
1532
super(Actor, self).__init__()
1633

@@ -22,13 +39,39 @@ def __init__(self, state_dim, action_dim):
2239
self.tanh = nn.Tanh()
2340

2441
def forward(self, s):
42+
"""
43+
Perform a forward pass through the actor network.
44+
45+
Args:
46+
s (torch.Tensor): Input state tensor.
47+
48+
Returns:
49+
torch.Tensor: Action output tensor after Tanh activation.
50+
"""
2551
s = F.leaky_relu(self.layer_1(s))
2652
s = F.leaky_relu(self.layer_2(s))
2753
a = self.tanh(self.layer_3(s))
2854
return a
2955

3056

3157
class Critic(nn.Module):
58+
"""
59+
Critic network for the TD3 algorithm.
60+
61+
This class defines two Q-value estimators (Q1 and Q2) using separate subnetworks.
62+
Each Q-network takes both state and action as input and outputs a scalar Q-value.
63+
64+
Architecture for each Q-network:
65+
Input: state_dim and action_dim
66+
- State pathway: Linear + LeakyReLU → 400 → 300
67+
- Action pathway: Linear → 300
68+
- Combined pathway: LeakyReLU(Linear(state) + Linear(action) + bias) → 1
69+
70+
Args:
71+
state_dim (int): Dimension of the input state.
72+
action_dim (int): Dimension of the input action.
73+
"""
74+
3275
def __init__(self, state_dim, action_dim):
3376
super(Critic, self).__init__()
3477

@@ -51,6 +94,18 @@ def __init__(self, state_dim, action_dim):
5194
torch.nn.init.kaiming_uniform_(self.layer_6.weight, nonlinearity="leaky_relu")
5295

5396
def forward(self, s, a):
97+
"""
98+
Perform a forward pass through both Q-networks.
99+
100+
Args:
101+
s (torch.Tensor): Input state tensor.
102+
a (torch.Tensor): Input action tensor.
103+
104+
Returns:
105+
tuple:
106+
- q1 (torch.Tensor): Output Q-value from the first critic network.
107+
- q2 (torch.Tensor): Output Q-value from the second critic network.
108+
"""
54109
s1 = F.leaky_relu(self.layer_1(s))
55110
self.layer_2_s(s1)
56111
self.layer_2_a(a)
@@ -86,8 +141,28 @@ def __init__(
86141
use_max_bound=False,
87142
bound_weight=0.25,
88143
):
89-
# Initialize the Actor network
144+
"""
145+
Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
146+
147+
This class implements the TD3 reinforcement learning algorithm for continuous control.
148+
It uses an Actor-Critic architecture with target networks and delayed policy updates.
149+
150+
Args:
151+
state_dim (int): Dimension of the input state.
152+
action_dim (int): Dimension of the action space.
153+
max_action (float): Maximum allowed value for actions.
154+
device (torch.device): Device to run the model on (CPU or CUDA).
155+
lr (float, optional): Learning rate for both actor and critic. Default is 1e-4.
156+
save_every (int, optional): Save model every `save_every` iterations. Default is 0.
157+
load_model (bool, optional): Whether to load model from checkpoint. Default is False.
158+
save_directory (Path, optional): Directory to save model checkpoints.
159+
model_name (str, optional): Name to use when saving/loading models.
160+
load_directory (Path, optional): Directory to load model checkpoints from.
161+
use_max_bound (bool, optional): Whether to apply maximum Q-value bounding during training.
162+
bound_weight (float, optional): Weight for the max-bound loss penalty.
163+
"""
90164
self.device = device
165+
# Initialize the Actor network
91166
self.actor = Actor(state_dim, action_dim).to(self.device)
92167
self.actor_target = Actor(state_dim, action_dim).to(self.device)
93168
self.actor_target.load_state_dict(self.actor.state_dict())
@@ -113,6 +188,16 @@ def __init__(
113188
self.bound_weight = bound_weight
114189

115190
def get_action(self, obs, add_noise):
191+
"""
192+
Get an action from the current policy with optional exploration noise.
193+
194+
Args:
195+
obs (np.ndarray): The current state observation.
196+
add_noise (bool): Whether to add exploration noise.
197+
198+
Returns:
199+
np.ndarray: The chosen action clipped to [-max_action, max_action].
200+
"""
116201
if add_noise:
117202
return (
118203
self.act(obs) + np.random.normal(0, 0.2, size=self.action_dim)
@@ -121,7 +206,15 @@ def get_action(self, obs, add_noise):
121206
return self.act(obs)
122207

123208
def act(self, state):
124-
# Function to get the action from the actor
209+
"""
210+
Compute the action using the actor network without exploration noise.
211+
212+
Args:
213+
state (np.ndarray): The current environment state.
214+
215+
Returns:
216+
np.ndarray: The deterministic action predicted by the actor.
217+
"""
125218
state = torch.Tensor(state).to(self.device)
126219
return self.actor(state).cpu().data.numpy().flatten()
127220

@@ -142,6 +235,24 @@ def train(
142235
distance_norm=10,
143236
time_step=0.3,
144237
):
238+
"""
239+
Train the TD3 agent using batches sampled from the replay buffer.
240+
241+
Args:
242+
replay_buffer: The replay buffer to sample experiences from.
243+
iterations (int): Number of training iterations to perform.
244+
batch_size (int): Size of each mini-batch.
245+
discount (float): Discount factor gamma for future rewards.
246+
tau (float): Soft update rate for target networks.
247+
policy_noise (float): Stddev of Gaussian noise added to target actions.
248+
noise_clip (float): Maximum magnitude of noise added to target actions.
249+
policy_freq (int): Frequency of policy (actor) updates.
250+
max_lin_vel (float): Max linear velocity used for upper bound estimation.
251+
max_ang_vel (float): Max angular velocity used for upper bound estimation.
252+
goal_reward (float): Reward given for reaching the goal.
253+
distance_norm (float): Distance normalization factor.
254+
time_step (float): Time step used in upper bound calculations.
255+
"""
145256
av_Q = 0
146257
max_Q = -inf
147258
av_loss = 0
@@ -248,6 +359,13 @@ def train(
248359
self.save(filename=self.model_name, directory=self.save_directory)
249360

250361
def save(self, filename, directory):
362+
"""
363+
Save the actor and critic networks (and their targets) to disk.
364+
365+
Args:
366+
filename (str): Name to use when saving model files.
367+
directory (Path): Directory where models should be saved.
368+
"""
251369
Path(directory).mkdir(parents=True, exist_ok=True)
252370
torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
253371
torch.save(
@@ -261,6 +379,13 @@ def save(self, filename, directory):
261379
)
262380

263381
def load(self, filename, directory):
382+
"""
383+
Load the actor and critic networks (and their targets) from disk.
384+
385+
Args:
386+
filename (str): Name used when saving the models.
387+
directory (Path): Directory where models are saved.
388+
"""
264389
self.actor.load_state_dict(
265390
torch.load("%s/%s_actor.pth" % (directory, filename))
266391
)
@@ -276,7 +401,26 @@ def load(self, filename, directory):
276401
print(f"Loaded weights from: {directory}")
277402

278403
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
279-
# update the returned data from ROS into a form used for learning in the current model
404+
"""
405+
Prepare the input state vector for training or inference.
406+
407+
Combines processed laser scan data, goal vector, and past action
408+
into a normalized state input matching the input dimension.
409+
410+
Args:
411+
latest_scan (list or np.ndarray): Laser scan data.
412+
distance (float): Distance to goal.
413+
cos (float): Cosine of the heading angle to goal.
414+
sin (float): Sine of the heading angle to goal.
415+
collision (bool): Whether a collision occurred.
416+
goal (bool): Whether the goal has been reached.
417+
action (list or np.ndarray): Last executed action [linear_vel, angular_vel].
418+
419+
Returns:
420+
tuple:
421+
- state (list): Prepared and normalized state vector.
422+
- terminal (int): 1 if episode should terminate (goal or collision), else 0.
423+
"""
280424
latest_scan = np.array(latest_scan)
281425

282426
inf_mask = np.isinf(latest_scan)

0 commit comments

Comments
 (0)