Skip to content

Commit 6352aa5

Browse files
committed
add docs
1 parent a8e83d1 commit 6352aa5

File tree

7 files changed

+744
-19
lines changed

7 files changed

+744
-19
lines changed

robot_nav/models/PPO/PPO.py

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,23 @@
77
from numpy import inf
88

99

10-
################################## PPO Policy ##################################
1110
class RolloutBuffer:
11+
"""
12+
Buffer to store rollout data (transitions) for PPO training.
13+
14+
Attributes:
15+
actions (list): Actions taken by the agent.
16+
states (list): States observed by the agent.
17+
logprobs (list): Log probabilities of the actions.
18+
rewards (list): Rewards received from the environment.
19+
state_values (list): Value estimates for the states.
20+
is_terminals (list): Flags indicating episode termination.
21+
"""
22+
1223
def __init__(self):
24+
"""
25+
Initialize empty lists to store buffer elements.
26+
"""
1327
self.actions = []
1428
self.states = []
1529
self.logprobs = []
@@ -18,6 +32,9 @@ def __init__(self):
1832
self.is_terminals = []
1933

2034
def clear(self):
35+
"""
36+
Clear all stored data from the buffer.
37+
"""
2138
del self.actions[:]
2239
del self.states[:]
2340
del self.logprobs[:]
@@ -26,13 +43,44 @@ def clear(self):
2643
del self.is_terminals[:]
2744

2845
def add(self, state, action, reward, terminal, next_state):
46+
"""
47+
Add a transition to the buffer. (Partial implementation.)
48+
49+
Args:
50+
state: The current observed state.
51+
action: The action taken.
52+
reward: The reward received after taking the action.
53+
terminal (bool): Whether the episode terminated.
54+
next_state: The resulting state after taking the action.
55+
"""
2956
self.states.append(state)
3057
self.rewards.append(reward)
3158
self.is_terminals.append(terminal)
3259

3360

3461
class ActorCritic(nn.Module):
62+
"""
63+
Actor-Critic neural network model for PPO.
64+
65+
Attributes:
66+
actor (nn.Sequential): Policy network (actor) to output action mean.
67+
critic (nn.Sequential): Value network (critic) to predict state values.
68+
action_var (Tensor): Diagonal covariance matrix for action distribution.
69+
device (str): Device used for computation ('cpu' or 'cuda').
70+
max_action (float): Clipping range for action values.
71+
"""
72+
3573
def __init__(self, state_dim, action_dim, action_std_init, max_action, device):
74+
"""
75+
Initialize the Actor and Critic networks.
76+
77+
Args:
78+
state_dim (int): Dimension of the input state.
79+
action_dim (int): Dimension of the action space.
80+
action_std_init (float): Initial standard deviation of the action distribution.
81+
max_action (float): Maximum value allowed for an action (clipping range).
82+
device (str): Device to run the model on.
83+
"""
3684
super(ActorCritic, self).__init__()
3785

3886
self.device = device
@@ -61,15 +109,36 @@ def __init__(self, state_dim, action_dim, action_std_init, max_action, device):
61109
)
62110

63111
def set_action_std(self, new_action_std):
112+
"""
113+
Set a new standard deviation for the action distribution.
114+
115+
Args:
116+
new_action_std (float): New standard deviation.
117+
"""
64118
self.action_var = torch.full(
65119
(self.action_dim,), new_action_std * new_action_std
66120
).to(self.device)
67121

68122
def forward(self):
123+
"""
124+
Forward method is not implemented, as it's unused directly.
125+
126+
Raises:
127+
NotImplementedError: Always raised when called.
128+
"""
69129
raise NotImplementedError
70130

71131
def act(self, state, sample):
132+
"""
133+
Compute an action, its log probability, and the state value.
134+
135+
Args:
136+
state (Tensor): Input state tensor.
137+
sample (bool): Whether to sample from the action distribution or use mean.
72138
139+
Returns:
140+
Tuple[Tensor, Tensor, Tensor]: Sampled (or mean) action, log probability, and state value.
141+
"""
73142
action_mean = self.actor(state)
74143
cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
75144
dist = MultivariateNormal(action_mean, cov_mat)
@@ -86,7 +155,16 @@ def act(self, state, sample):
86155
return action.detach(), action_logprob.detach(), state_val.detach()
87156

88157
def evaluate(self, state, action):
158+
"""
159+
Evaluate action log probabilities, entropy, and state values for given states and actions.
160+
161+
Args:
162+
state (Tensor): Batch of states.
163+
action (Tensor): Batch of actions.
89164
165+
Returns:
166+
Tuple[Tensor, Tensor, Tensor]: Action log probabilities, state values, and distribution entropy.
167+
"""
90168
action_mean = self.actor(state)
91169

92170
action_var = self.action_var.expand_as(action_mean)
@@ -105,6 +183,30 @@ def evaluate(self, state, action):
105183

106184

107185
class PPO:
186+
"""
187+
Proximal Policy Optimization (PPO) implementation for continuous control tasks.
188+
189+
Attributes:
190+
max_action (float): Maximum action value.
191+
action_std (float): Standard deviation of the action distribution.
192+
action_std_decay_rate (float): Rate at which to decay action standard deviation.
193+
min_action_std (float): Minimum allowed action standard deviation.
194+
state_dim (int): Dimension of the state space.
195+
gamma (float): Discount factor for future rewards.
196+
eps_clip (float): Clipping range for policy updates.
197+
device (str): Device for model computation ('cpu' or 'cuda').
198+
save_every (int): Interval (in iterations) for saving model checkpoints.
199+
model_name (str): Name used when saving/loading model.
200+
save_directory (Path): Directory to save model checkpoints.
201+
iter_count (int): Number of training iterations completed.
202+
buffer (RolloutBuffer): Buffer to store trajectories.
203+
policy (ActorCritic): Current actor-critic network.
204+
optimizer (torch.optim.Optimizer): Optimizer for actor and critic.
205+
policy_old (ActorCritic): Old actor-critic network for computing PPO updates.
206+
MseLoss (nn.Module): Mean squared error loss function.
207+
writer (SummaryWriter): TensorBoard summary writer.
208+
"""
209+
108210
def __init__(
109211
self,
110212
state_dim,
@@ -160,11 +262,24 @@ def __init__(
160262
self.writer = SummaryWriter(comment=model_name)
161263

162264
def set_action_std(self, new_action_std):
265+
"""
266+
Set a new standard deviation for the action distribution.
267+
268+
Args:
269+
new_action_std (float): New standard deviation value.
270+
"""
163271
self.action_std = new_action_std
164272
self.policy.set_action_std(new_action_std)
165273
self.policy_old.set_action_std(new_action_std)
166274

167275
def decay_action_std(self, action_std_decay_rate, min_action_std):
276+
"""
277+
Decay the action standard deviation by a fixed rate, down to a minimum threshold.
278+
279+
Args:
280+
action_std_decay_rate (float): Amount to reduce standard deviation by.
281+
min_action_std (float): Minimum value for standard deviation.
282+
"""
168283
print(
169284
"--------------------------------------------------------------------------------------------"
170285
)
@@ -183,6 +298,16 @@ def decay_action_std(self, action_std_decay_rate, min_action_std):
183298
)
184299

185300
def get_action(self, state, add_noise):
301+
"""
302+
Sample an action using the current policy (optionally with noise), and store in buffer if noise is added.
303+
304+
Args:
305+
state (array_like): Input state for the policy.
306+
add_noise (bool): Whether to sample from the distribution (True) or use the deterministic mean (False).
307+
308+
Returns:
309+
np.ndarray: Sampled action.
310+
"""
186311

187312
with torch.no_grad():
188313
state = torch.FloatTensor(state).to(self.device)
@@ -197,6 +322,14 @@ def get_action(self, state, add_noise):
197322
return action.detach().cpu().numpy().flatten()
198323

199324
def train(self, replay_buffer, iterations, batch_size):
325+
"""
326+
Train the policy and value function using PPO loss based on the stored rollout buffer.
327+
328+
Args:
329+
replay_buffer: Placeholder for compatibility (not used).
330+
iterations (int): Number of epochs to optimize the policy per update.
331+
batch_size (int): Batch size (not used; training uses the whole buffer).
332+
"""
200333
# Monte Carlo estimate of returns
201334
rewards = []
202335
discounted_reward = 0
@@ -288,7 +421,21 @@ def train(self, replay_buffer, iterations, batch_size):
288421
self.save(filename=self.model_name, directory=self.save_directory)
289422

290423
def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action):
291-
# update the returned data from ROS into a form used for learning in the current model
424+
"""
425+
Convert raw sensor and navigation data into a normalized state vector for the policy.
426+
427+
Args:
428+
latest_scan (list[float]): LIDAR scan data.
429+
distance (float): Distance to the goal.
430+
cos (float): Cosine of angle to the goal.
431+
sin (float): Sine of angle to the goal.
432+
collision (bool): Whether the robot has collided.
433+
goal (bool): Whether the robot has reached the goal.
434+
action (tuple[float, float]): Last action taken (linear and angular velocities).
435+
436+
Returns:
437+
tuple[list[float], int]: Processed state vector and terminal flag (1 if terminal, else 0).
438+
"""
292439
latest_scan = np.array(latest_scan)
293440

294441
inf_mask = np.isinf(latest_scan)
@@ -319,12 +466,26 @@ def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action
319466
return state, terminal
320467

321468
def save(self, filename, directory):
469+
"""
470+
Save the current policy model to the specified directory.
471+
472+
Args:
473+
filename (str): Base name of the model file.
474+
directory (Path): Directory to save the model to.
475+
"""
322476
Path(directory).mkdir(parents=True, exist_ok=True)
323477
torch.save(
324478
self.policy_old.state_dict(), "%s/%s_policy.pth" % (directory, filename)
325479
)
326480

327481
def load(self, filename, directory):
482+
"""
483+
Load the policy model from a saved checkpoint.
484+
485+
Args:
486+
filename (str): Base name of the model file.
487+
directory (Path): Directory to load the model from.
488+
"""
328489
self.policy_old.load_state_dict(
329490
torch.load(
330491
"%s/%s_policy.pth" % (directory, filename),

0 commit comments

Comments
 (0)