11
11
12
12
13
13
class Actor (nn .Module ):
14
+ """
15
+ Actor network for the DDPG algorithm.
16
+
17
+ This network maps input states to actions using a fully connected feedforward architecture.
18
+ It uses Leaky ReLU activations in the hidden layers and a tanh activation at the output
19
+ to ensure the output actions are in the range [-1, 1].
20
+
21
+ Architecture:
22
+ - Linear(state_dim → 400) + LeakyReLU
23
+ - Linear(400 → 300) + LeakyReLU
24
+ - Linear(300 → action_dim) + Tanh
25
+
26
+ Args:
27
+ state_dim (int): Dimension of the input state.
28
+ action_dim (int): Dimension of the output action space.
29
+ """
30
+
14
31
def __init__ (self , state_dim , action_dim ):
15
32
super (Actor , self ).__init__ ()
16
33
@@ -22,13 +39,41 @@ def __init__(self, state_dim, action_dim):
22
39
self .tanh = nn .Tanh ()
23
40
24
41
def forward (self , s ):
42
+ """
43
+ Forward pass of the actor network.
44
+
45
+ Args:
46
+ s (torch.Tensor): Input state tensor of shape (batch_size, state_dim).
47
+
48
+ Returns:
49
+ torch.Tensor: Output action tensor of shape (batch_size, action_dim), scaled to [-1, 1].
50
+ """
25
51
s = F .leaky_relu (self .layer_1 (s ))
26
52
s = F .leaky_relu (self .layer_2 (s ))
27
53
a = self .tanh (self .layer_3 (s ))
28
54
return a
29
55
30
56
31
57
class Critic (nn .Module ):
58
+ """
59
+ Critic network for the DDPG algorithm.
60
+
61
+ This network evaluates the Q-value of a given state-action pair. It separately processes
62
+ state and action inputs through linear layers, combines them, and passes the result through
63
+ another linear layer to predict a scalar Q-value.
64
+
65
+ Architecture:
66
+ - Linear(state_dim → 400) + LeakyReLU
67
+ - Linear(400 → 300) [state branch]
68
+ - Linear(action_dim → 300) [action branch]
69
+ - Combine both branches, apply LeakyReLU
70
+ - Linear(300 → 1) for Q-value output
71
+
72
+ Args:
73
+ state_dim (int): Dimension of the input state.
74
+ action_dim (int): Dimension of the input action.
75
+ """
76
+
32
77
def __init__ (self , state_dim , action_dim ):
33
78
super (Critic , self ).__init__ ()
34
79
@@ -42,6 +87,16 @@ def __init__(self, state_dim, action_dim):
42
87
torch .nn .init .kaiming_uniform_ (self .layer_3 .weight , nonlinearity = "leaky_relu" )
43
88
44
89
def forward (self , s , a ):
90
+ """
91
+ Forward pass of the critic network.
92
+
93
+ Args:
94
+ s (torch.Tensor): State tensor of shape (batch_size, state_dim).
95
+ a (torch.Tensor): Action tensor of shape (batch_size, action_dim).
96
+
97
+ Returns:
98
+ torch.Tensor: Q-value tensor of shape (batch_size, 1).
99
+ """
45
100
s1 = F .leaky_relu (self .layer_1 (s ))
46
101
self .layer_2_s (s1 )
47
102
self .layer_2_a (a )
@@ -55,6 +110,28 @@ def forward(self, s, a):
55
110
56
111
# DDPG network
57
112
class DDPG (object ):
113
+ """
114
+ Deep Deterministic Policy Gradient (DDPG) agent implementation.
115
+
116
+ This class encapsulates the actor-critic learning framework using DDPG, which is suitable
117
+ for continuous action spaces. It supports training, action selection, model saving/loading,
118
+ and state preparation for a reinforcement learning agent, specifically designed for robot navigation.
119
+
120
+ Args:
121
+ state_dim (int): Dimension of the input state.
122
+ action_dim (int): Dimension of the action space.
123
+ max_action (float): Maximum action value allowed.
124
+ device (torch.device): Computation device (CPU or GPU).
125
+ lr (float): Learning rate for the optimizers. Default is 1e-4.
126
+ save_every (int): Frequency of saving the model in training iterations. 0 means no saving. Default is 0.
127
+ load_model (bool): Flag indicating whether to load a model from disk. Default is False.
128
+ save_directory (Path): Directory to save the model checkpoints. Default is "robot_nav/models/DDPG/checkpoint".
129
+ model_name (str): Name used for saving and TensorBoard logging. Default is "DDPG".
130
+ load_directory (Path): Directory to load model checkpoints from. Default is "robot_nav/models/DDPG/checkpoint".
131
+ use_max_bound (bool): Whether to enforce a learned upper bound on the Q-value. Default is False.
132
+ bound_weight (float): Weight of the upper bound loss penalty. Default is 0.25.
133
+ """
134
+
58
135
def __init__ (
59
136
self ,
60
137
state_dim ,
@@ -97,6 +174,16 @@ def __init__(
97
174
self .bound_weight = bound_weight
98
175
99
176
def get_action (self , obs , add_noise ):
177
+ """
178
+ Selects an action based on the observation.
179
+
180
+ Args:
181
+ obs (np.array): The current state observation.
182
+ add_noise (bool): Whether to add exploration noise to the action.
183
+
184
+ Returns:
185
+ np.array: Action selected by the actor network.
186
+ """
100
187
if add_noise :
101
188
return (
102
189
self .act (obs ) + np .random .normal (0 , 0.2 , size = self .action_dim )
@@ -105,7 +192,15 @@ def get_action(self, obs, add_noise):
105
192
return self .act (obs )
106
193
107
194
def act (self , state ):
108
- # Function to get the action from the actor
195
+ """
196
+ Computes the action for a given state using the actor network.
197
+
198
+ Args:
199
+ state (np.array): Environment state.
200
+
201
+ Returns:
202
+ np.array: Action values as output by the actor network.
203
+ """
109
204
state = torch .Tensor (state ).to (self .device )
110
205
return self .actor (state ).cpu ().data .numpy ().flatten ()
111
206
@@ -126,6 +221,24 @@ def train(
126
221
distance_norm = 10 ,
127
222
time_step = 0.3 ,
128
223
):
224
+ """
225
+ Trains the actor and critic networks using a replay buffer and soft target updates.
226
+
227
+ Args:
228
+ replay_buffer (object): Replay buffer object with a sample_batch method.
229
+ iterations (int): Number of training iterations.
230
+ batch_size (int): Size of each training batch.
231
+ discount (float): Discount factor for future rewards.
232
+ tau (float): Soft update factor for target networks.
233
+ policy_noise (float): Standard deviation of noise added to target policy.
234
+ noise_clip (float): Maximum value to clip target policy noise.
235
+ policy_freq (int): Frequency of actor and target updates.
236
+ max_lin_vel (float): Maximum linear velocity, used in Q-bound calculation.
237
+ max_ang_vel (float): Maximum angular velocity, used in Q-bound calculation.
238
+ goal_reward (float): Reward given upon reaching goal.
239
+ distance_norm (float): Distance normalization factor.
240
+ time_step (float): Time step used in max bound calculation.
241
+ """
129
242
av_Q = 0
130
243
max_Q = - inf
131
244
av_loss = 0
@@ -229,6 +342,13 @@ def train(
229
342
self .save (filename = self .model_name , directory = self .save_directory )
230
343
231
344
def save (self , filename , directory ):
345
+ """
346
+ Saves the model parameters to disk.
347
+
348
+ Args:
349
+ filename (str): Base filename for saving the model components.
350
+ directory (str or Path): Directory where the model files will be saved.
351
+ """
232
352
Path (directory ).mkdir (parents = True , exist_ok = True )
233
353
torch .save (self .actor .state_dict (), "%s/%s_actor.pth" % (directory , filename ))
234
354
torch .save (
@@ -242,6 +362,13 @@ def save(self, filename, directory):
242
362
)
243
363
244
364
def load (self , filename , directory ):
365
+ """
366
+ Loads model parameters from disk.
367
+
368
+ Args:
369
+ filename (str): Base filename used for loading model components.
370
+ directory (str or Path): Directory to load the model files from.
371
+ """
245
372
self .actor .load_state_dict (
246
373
torch .load ("%s/%s_actor.pth" % (directory , filename ))
247
374
)
@@ -257,7 +384,21 @@ def load(self, filename, directory):
257
384
print (f"Loaded weights from: { directory } " )
258
385
259
386
def prepare_state (self , latest_scan , distance , cos , sin , collision , goal , action ):
260
- # update the returned data from ROS into a form used for learning in the current model
387
+ """
388
+ Processes raw sensor input and additional information into a normalized state representation.
389
+
390
+ Args:
391
+ latest_scan (list or np.array): Raw LIDAR or laser scan data.
392
+ distance (float): Distance to the goal.
393
+ cos (float): Cosine of the angle to the goal.
394
+ sin (float): Sine of the angle to the goal.
395
+ collision (bool): Whether a collision has occurred.
396
+ goal (bool): Whether the goal has been reached.
397
+ action (list or np.array): The action taken in the previous step.
398
+
399
+ Returns:
400
+ tuple: (state vector, terminal flag)
401
+ """
261
402
latest_scan = np .array (latest_scan )
262
403
263
404
inf_mask = np .isinf (latest_scan )
0 commit comments