7
7
from numpy import inf
8
8
9
9
10
- ################################## PPO Policy ##################################
11
10
class RolloutBuffer :
11
+ """
12
+ Buffer to store rollout data (transitions) for PPO training.
13
+
14
+ Attributes:
15
+ actions (list): Actions taken by the agent.
16
+ states (list): States observed by the agent.
17
+ logprobs (list): Log probabilities of the actions.
18
+ rewards (list): Rewards received from the environment.
19
+ state_values (list): Value estimates for the states.
20
+ is_terminals (list): Flags indicating episode termination.
21
+ """
22
+
12
23
def __init__ (self ):
24
+ """
25
+ Initialize empty lists to store buffer elements.
26
+ """
13
27
self .actions = []
14
28
self .states = []
15
29
self .logprobs = []
@@ -18,6 +32,9 @@ def __init__(self):
18
32
self .is_terminals = []
19
33
20
34
def clear (self ):
35
+ """
36
+ Clear all stored data from the buffer.
37
+ """
21
38
del self .actions [:]
22
39
del self .states [:]
23
40
del self .logprobs [:]
@@ -26,13 +43,44 @@ def clear(self):
26
43
del self .is_terminals [:]
27
44
28
45
def add (self , state , action , reward , terminal , next_state ):
46
+ """
47
+ Add a transition to the buffer. (Partial implementation.)
48
+
49
+ Args:
50
+ state: The current observed state.
51
+ action: The action taken.
52
+ reward: The reward received after taking the action.
53
+ terminal (bool): Whether the episode terminated.
54
+ next_state: The resulting state after taking the action.
55
+ """
29
56
self .states .append (state )
30
57
self .rewards .append (reward )
31
58
self .is_terminals .append (terminal )
32
59
33
60
34
61
class ActorCritic (nn .Module ):
62
+ """
63
+ Actor-Critic neural network model for PPO.
64
+
65
+ Attributes:
66
+ actor (nn.Sequential): Policy network (actor) to output action mean.
67
+ critic (nn.Sequential): Value network (critic) to predict state values.
68
+ action_var (Tensor): Diagonal covariance matrix for action distribution.
69
+ device (str): Device used for computation ('cpu' or 'cuda').
70
+ max_action (float): Clipping range for action values.
71
+ """
72
+
35
73
def __init__ (self , state_dim , action_dim , action_std_init , max_action , device ):
74
+ """
75
+ Initialize the Actor and Critic networks.
76
+
77
+ Args:
78
+ state_dim (int): Dimension of the input state.
79
+ action_dim (int): Dimension of the action space.
80
+ action_std_init (float): Initial standard deviation of the action distribution.
81
+ max_action (float): Maximum value allowed for an action (clipping range).
82
+ device (str): Device to run the model on.
83
+ """
36
84
super (ActorCritic , self ).__init__ ()
37
85
38
86
self .device = device
@@ -61,15 +109,36 @@ def __init__(self, state_dim, action_dim, action_std_init, max_action, device):
61
109
)
62
110
63
111
def set_action_std (self , new_action_std ):
112
+ """
113
+ Set a new standard deviation for the action distribution.
114
+
115
+ Args:
116
+ new_action_std (float): New standard deviation.
117
+ """
64
118
self .action_var = torch .full (
65
119
(self .action_dim ,), new_action_std * new_action_std
66
120
).to (self .device )
67
121
68
122
def forward (self ):
123
+ """
124
+ Forward method is not implemented, as it's unused directly.
125
+
126
+ Raises:
127
+ NotImplementedError: Always raised when called.
128
+ """
69
129
raise NotImplementedError
70
130
71
131
def act (self , state , sample ):
132
+ """
133
+ Compute an action, its log probability, and the state value.
134
+
135
+ Args:
136
+ state (Tensor): Input state tensor.
137
+ sample (bool): Whether to sample from the action distribution or use mean.
72
138
139
+ Returns:
140
+ Tuple[Tensor, Tensor, Tensor]: Sampled (or mean) action, log probability, and state value.
141
+ """
73
142
action_mean = self .actor (state )
74
143
cov_mat = torch .diag (self .action_var ).unsqueeze (dim = 0 )
75
144
dist = MultivariateNormal (action_mean , cov_mat )
@@ -86,7 +155,16 @@ def act(self, state, sample):
86
155
return action .detach (), action_logprob .detach (), state_val .detach ()
87
156
88
157
def evaluate (self , state , action ):
158
+ """
159
+ Evaluate action log probabilities, entropy, and state values for given states and actions.
160
+
161
+ Args:
162
+ state (Tensor): Batch of states.
163
+ action (Tensor): Batch of actions.
89
164
165
+ Returns:
166
+ Tuple[Tensor, Tensor, Tensor]: Action log probabilities, state values, and distribution entropy.
167
+ """
90
168
action_mean = self .actor (state )
91
169
92
170
action_var = self .action_var .expand_as (action_mean )
@@ -105,6 +183,30 @@ def evaluate(self, state, action):
105
183
106
184
107
185
class PPO :
186
+ """
187
+ Proximal Policy Optimization (PPO) implementation for continuous control tasks.
188
+
189
+ Attributes:
190
+ max_action (float): Maximum action value.
191
+ action_std (float): Standard deviation of the action distribution.
192
+ action_std_decay_rate (float): Rate at which to decay action standard deviation.
193
+ min_action_std (float): Minimum allowed action standard deviation.
194
+ state_dim (int): Dimension of the state space.
195
+ gamma (float): Discount factor for future rewards.
196
+ eps_clip (float): Clipping range for policy updates.
197
+ device (str): Device for model computation ('cpu' or 'cuda').
198
+ save_every (int): Interval (in iterations) for saving model checkpoints.
199
+ model_name (str): Name used when saving/loading model.
200
+ save_directory (Path): Directory to save model checkpoints.
201
+ iter_count (int): Number of training iterations completed.
202
+ buffer (RolloutBuffer): Buffer to store trajectories.
203
+ policy (ActorCritic): Current actor-critic network.
204
+ optimizer (torch.optim.Optimizer): Optimizer for actor and critic.
205
+ policy_old (ActorCritic): Old actor-critic network for computing PPO updates.
206
+ MseLoss (nn.Module): Mean squared error loss function.
207
+ writer (SummaryWriter): TensorBoard summary writer.
208
+ """
209
+
108
210
def __init__ (
109
211
self ,
110
212
state_dim ,
@@ -160,11 +262,24 @@ def __init__(
160
262
self .writer = SummaryWriter (comment = model_name )
161
263
162
264
def set_action_std (self , new_action_std ):
265
+ """
266
+ Set a new standard deviation for the action distribution.
267
+
268
+ Args:
269
+ new_action_std (float): New standard deviation value.
270
+ """
163
271
self .action_std = new_action_std
164
272
self .policy .set_action_std (new_action_std )
165
273
self .policy_old .set_action_std (new_action_std )
166
274
167
275
def decay_action_std (self , action_std_decay_rate , min_action_std ):
276
+ """
277
+ Decay the action standard deviation by a fixed rate, down to a minimum threshold.
278
+
279
+ Args:
280
+ action_std_decay_rate (float): Amount to reduce standard deviation by.
281
+ min_action_std (float): Minimum value for standard deviation.
282
+ """
168
283
print (
169
284
"--------------------------------------------------------------------------------------------"
170
285
)
@@ -183,6 +298,16 @@ def decay_action_std(self, action_std_decay_rate, min_action_std):
183
298
)
184
299
185
300
def get_action (self , state , add_noise ):
301
+ """
302
+ Sample an action using the current policy (optionally with noise), and store in buffer if noise is added.
303
+
304
+ Args:
305
+ state (array_like): Input state for the policy.
306
+ add_noise (bool): Whether to sample from the distribution (True) or use the deterministic mean (False).
307
+
308
+ Returns:
309
+ np.ndarray: Sampled action.
310
+ """
186
311
187
312
with torch .no_grad ():
188
313
state = torch .FloatTensor (state ).to (self .device )
@@ -197,6 +322,14 @@ def get_action(self, state, add_noise):
197
322
return action .detach ().cpu ().numpy ().flatten ()
198
323
199
324
def train (self , replay_buffer , iterations , batch_size ):
325
+ """
326
+ Train the policy and value function using PPO loss based on the stored rollout buffer.
327
+
328
+ Args:
329
+ replay_buffer: Placeholder for compatibility (not used).
330
+ iterations (int): Number of epochs to optimize the policy per update.
331
+ batch_size (int): Batch size (not used; training uses the whole buffer).
332
+ """
200
333
# Monte Carlo estimate of returns
201
334
rewards = []
202
335
discounted_reward = 0
@@ -288,7 +421,21 @@ def train(self, replay_buffer, iterations, batch_size):
288
421
self .save (filename = self .model_name , directory = self .save_directory )
289
422
290
423
def prepare_state (self , latest_scan , distance , cos , sin , collision , goal , action ):
291
- # update the returned data from ROS into a form used for learning in the current model
424
+ """
425
+ Convert raw sensor and navigation data into a normalized state vector for the policy.
426
+
427
+ Args:
428
+ latest_scan (list[float]): LIDAR scan data.
429
+ distance (float): Distance to the goal.
430
+ cos (float): Cosine of angle to the goal.
431
+ sin (float): Sine of angle to the goal.
432
+ collision (bool): Whether the robot has collided.
433
+ goal (bool): Whether the robot has reached the goal.
434
+ action (tuple[float, float]): Last action taken (linear and angular velocities).
435
+
436
+ Returns:
437
+ tuple[list[float], int]: Processed state vector and terminal flag (1 if terminal, else 0).
438
+ """
292
439
latest_scan = np .array (latest_scan )
293
440
294
441
inf_mask = np .isinf (latest_scan )
@@ -319,12 +466,26 @@ def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action
319
466
return state , terminal
320
467
321
468
def save (self , filename , directory ):
469
+ """
470
+ Save the current policy model to the specified directory.
471
+
472
+ Args:
473
+ filename (str): Base name of the model file.
474
+ directory (Path): Directory to save the model to.
475
+ """
322
476
Path (directory ).mkdir (parents = True , exist_ok = True )
323
477
torch .save (
324
478
self .policy_old .state_dict (), "%s/%s_policy.pth" % (directory , filename )
325
479
)
326
480
327
481
def load (self , filename , directory ):
482
+ """
483
+ Load the policy model from a saved checkpoint.
484
+
485
+ Args:
486
+ filename (str): Base name of the model file.
487
+ directory (Path): Directory to load the model from.
488
+ """
328
489
self .policy_old .load_state_dict (
329
490
torch .load (
330
491
"%s/%s_policy.pth" % (directory , filename ),
0 commit comments