Skip to content

Commit 2554ae7

Browse files
committed
add ddpg documentation
1 parent c40c0f3 commit 2554ae7

File tree

1 file changed

+143
-2
lines changed

1 file changed

+143
-2
lines changed

robot_nav/models/DDPG/DDPG.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111

1212

1313
class Actor(nn.Module):
14+
"""
15+
Actor network for the DDPG algorithm.
16+
17+
This network maps input states to actions using a fully connected feedforward architecture.
18+
It uses Leaky ReLU activations in the hidden layers and a tanh activation at the output
19+
to ensure the output actions are in the range [-1, 1].
20+
21+
Architecture:
22+
- Linear(state_dim → 400) + LeakyReLU
23+
- Linear(400 → 300) + LeakyReLU
24+
- Linear(300 → action_dim) + Tanh
25+
26+
Args:
27+
state_dim (int): Dimension of the input state.
28+
action_dim (int): Dimension of the output action space.
29+
"""
30+
1431
def __init__(self, state_dim, action_dim):
1532
super(Actor, self).__init__()
1633

@@ -22,13 +39,41 @@ def __init__(self, state_dim, action_dim):
2239
self.tanh = nn.Tanh()
2340

2441
def forward(self, s):
42+
"""
43+
Forward pass of the actor network.
44+
45+
Args:
46+
s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
47+
48+
Returns:
49+
torch.Tensor: Output action tensor of shape (batch_size, action_dim), scaled to [-1, 1].
50+
"""
2551
s = F.leaky_relu(self.layer_1(s))
2652
s = F.leaky_relu(self.layer_2(s))
2753
a = self.tanh(self.layer_3(s))
2854
return a
2955

3056

3157
class Critic(nn.Module):
58+
"""
59+
Critic network for the DDPG algorithm.
60+
61+
This network evaluates the Q-value of a given state-action pair. It separately processes
62+
state and action inputs through linear layers, combines them, and passes the result through
63+
another linear layer to predict a scalar Q-value.
64+
65+
Architecture:
66+
- Linear(state_dim → 400) + LeakyReLU
67+
- Linear(400 → 300) [state branch]
68+
- Linear(action_dim → 300) [action branch]
69+
- Combine both branches, apply LeakyReLU
70+
- Linear(300 → 1) for Q-value output
71+
72+
Args:
73+
state_dim (int): Dimension of the input state.
74+
action_dim (int): Dimension of the input action.
75+
"""
76+
3277
def __init__(self, state_dim, action_dim):
3378
super(Critic, self).__init__()
3479

@@ -42,6 +87,16 @@ def __init__(self, state_dim, action_dim):
4287
torch.nn.init.kaiming_uniform_(self.layer_3.weight, nonlinearity="leaky_relu")
4388

4489
def forward(self, s, a):
90+
"""
91+
Forward pass of the critic network.
92+
93+
Args:
94+
s (torch.Tensor): State tensor of shape (batch_size, state_dim).
95+
a (torch.Tensor): Action tensor of shape (batch_size, action_dim).
96+
97+
Returns:
98+
torch.Tensor: Q-value tensor of shape (batch_size, 1).
99+
"""
45100
s1 = F.leaky_relu(self.layer_1(s))
46101
self.layer_2_s(s1)
47102
self.layer_2_a(a)
@@ -55,6 +110,28 @@ def forward(self, s, a):
55110

56111
# DDPG network
57112
class DDPG(object):
113+
"""
114+
Deep Deterministic Policy Gradient (DDPG) agent implementation.
115+
116+
This class encapsulates the actor-critic learning framework using DDPG, which is suitable
117+
for continuous action spaces. It supports training, action selection, model saving/loading,
118+
and state preparation for a reinforcement learning agent, specifically designed for robot navigation.
119+
120+
Args:
121+
state_dim (int): Dimension of the input state.
122+
action_dim (int): Dimension of the action space.
123+
max_action (float): Maximum action value allowed.
124+
device (torch.device): Computation device (CPU or GPU).
125+
lr (float): Learning rate for the optimizers. Default is 1e-4.
126+
save_every (int): Frequency of saving the model in training iterations. 0 means no saving. Default is 0.
127+
load_model (bool): Flag indicating whether to load a model from disk. Default is False.
128+
save_directory (Path): Directory to save the model checkpoints. Default is "robot_nav/models/DDPG/checkpoint".
129+
model_name (str): Name used for saving and TensorBoard logging. Default is "DDPG".
130+
load_directory (Path): Directory to load model checkpoints from. Default is "robot_nav/models/DDPG/checkpoint".
131+
use_max_bound (bool): Whether to enforce a learned upper bound on the Q-value. Default is False.
132+
bound_weight (float): Weight of the upper bound loss penalty. Default is 0.25.
133+
"""
134+
58135
def __init__(
59136
self,
60137
state_dim,
@@ -97,6 +174,16 @@ def __init__(
97174
self.bound_weight = bound_weight
98175

99176
def get_action(self, obs, add_noise):
177+
"""
178+
Selects an action based on the observation.
179+
180+
Args:
181+
obs (np.array): The current state observation.
182+
add_noise (bool): Whether to add exploration noise to the action.
183+
184+
Returns:
185+
np.array: Action selected by the actor network.
186+
"""
100187
if add_noise:
101188
return (
102189
self.act(obs) + np.random.normal(0, 0.2, size=self.action_dim)
@@ -105,7 +192,15 @@ def get_action(self, obs, add_noise):
105192
return self.act(obs)
106193

107194
def act(self, state):
108-
# Function to get the action from the actor
195+
"""
196+
Computes the action for a given state using the actor network.
197+
198+
Args:
199+
state (np.array): Environment state.
200+
201+
Returns:
202+
np.array: Action values as output by the actor network.
203+
"""
109204
state = torch.Tensor(state).to(self.device)
110205
return self.actor(state).cpu().data.numpy().flatten()
111206

@@ -126,6 +221,24 @@ def train(
126221
distance_norm=10,
127222
time_step=0.3,
128223
):
224+
"""
225+
Trains the actor and critic networks using a replay buffer and soft target updates.
226+
227+
Args:
228+
replay_buffer (object): Replay buffer object with a sample_batch method.
229+
iterations (int): Number of training iterations.
230+
batch_size (int): Size of each training batch.
231+
discount (float): Discount factor for future rewards.
232+
tau (float): Soft update factor for target networks.
233+
policy_noise (float): Standard deviation of noise added to target policy.
234+
noise_clip (float): Maximum value to clip target policy noise.
235+
policy_freq (int): Frequency of actor and target updates.
236+
max_lin_vel (float): Maximum linear velocity, used in Q-bound calculation.
237+
max_ang_vel (float): Maximum angular velocity, used in Q-bound calculation.
238+
goal_reward (float): Reward given upon reaching goal.
239+
distance_norm (float): Distance normalization factor.
240+
time_step (float): Time step used in max bound calculation.
241+
"""
129242
av_Q = 0
130243
max_Q = -inf
131244
av_loss = 0
@@ -229,6 +342,13 @@ def train(
229342
self.save(filename=self.model_name, directory=self.save_directory)
230343

231344
def save(self, filename, directory):
345+
"""
346+
Saves the model parameters to disk.
347+
348+
Args:
349+
filename (str): Base filename for saving the model components.
350+
directory (str or Path): Directory where the model files will be saved.
351+
"""
232352
Path(directory).mkdir(parents=True, exist_ok=True)
233353
torch.save(self.actor.state_dict(), "%s/%s_actor.pth" % (directory, filename))
234354
torch.save(
@@ -242,6 +362,13 @@ def save(self, filename, directory):
242362
)
243363

244364
def load(self, filename, directory):
365+
"""
366+
Loads model parameters from disk.
367+
368+
Args:
369+
filename (str): Base filename used for loading model components.
370+
directory (str or Path): Directory to load the model files from.
371+
"""
245372
self.actor.load_state_dict(
246373
torch.load("%s/%s_actor.pth" % (directory, filename))
247374
)
@@ -257,7 +384,21 @@ def load(self, filename, directory):
257384
print(f"Loaded weights from: {directory}")
258385

259386
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
260-
# update the returned data from ROS into a form used for learning in the current model
387+
"""
388+
Processes raw sensor input and additional information into a normalized state representation.
389+
390+
Args:
391+
latest_scan (list or np.array): Raw LIDAR or laser scan data.
392+
distance (float): Distance to the goal.
393+
cos (float): Cosine of the angle to the goal.
394+
sin (float): Sine of the angle to the goal.
395+
collision (bool): Whether a collision has occurred.
396+
goal (bool): Whether the goal has been reached.
397+
action (list or np.array): The action taken in the previous step.
398+
399+
Returns:
400+
tuple: (state vector, terminal flag)
401+
"""
261402
latest_scan = np.array(latest_scan)
262403

263404
inf_mask = np.isinf(latest_scan)

0 commit comments

Comments
 (0)