@@ -47,11 +47,11 @@ def add(self, state, action, reward, terminal, next_state):
47
47
Add a transition to the buffer. (Partial implementation.)
48
48
49
49
Args:
50
- state: The current observed state.
51
- action: The action taken.
52
- reward: The reward received after taking the action.
50
+ state (list or np.array) : The current observed state.
51
+ action (list or np.array) : The action taken.
52
+ reward (float) : The reward received after taking the action.
53
53
terminal (bool): Whether the episode terminated.
54
- next_state: The resulting state after taking the action.
54
+ next_state (list or np.array) : The resulting state after taking the action.
55
55
"""
56
56
self .states .append (state )
57
57
self .rewards .append (reward )
@@ -137,7 +137,7 @@ def act(self, state, sample):
137
137
sample (bool): Whether to sample from the action distribution or use mean.
138
138
139
139
Returns:
140
- Tuple[Tensor, Tensor, Tensor]: Sampled (or mean) action, log probability, and state value.
140
+ ( Tuple[Tensor, Tensor, Tensor]) : Sampled (or mean) action, log probability, and state value.
141
141
"""
142
142
action_mean = self .actor (state )
143
143
cov_mat = torch .diag (self .action_var ).unsqueeze (dim = 0 )
@@ -163,7 +163,7 @@ def evaluate(self, state, action):
163
163
action (Tensor): Batch of actions.
164
164
165
165
Returns:
166
- Tuple[Tensor, Tensor, Tensor]: Action log probabilities, state values, and distribution entropy.
166
+ ( Tuple[Tensor, Tensor, Tensor]) : Action log probabilities, state values, and distribution entropy.
167
167
"""
168
168
action_mean = self .actor (state )
169
169
@@ -306,7 +306,7 @@ def get_action(self, state, add_noise):
306
306
add_noise (bool): Whether to sample from the distribution (True) or use the deterministic mean (False).
307
307
308
308
Returns:
309
- np.ndarray: Sampled action.
309
+ ( np.ndarray) : Sampled action.
310
310
"""
311
311
312
312
with torch .no_grad ():
@@ -326,7 +326,7 @@ def train(self, replay_buffer, iterations, batch_size):
326
326
Train the policy and value function using PPO loss based on the stored rollout buffer.
327
327
328
328
Args:
329
- replay_buffer: Placeholder for compatibility (not used).
329
+ replay_buffer (object) : Placeholder for compatibility (not used).
330
330
iterations (int): Number of epochs to optimize the policy per update.
331
331
batch_size (int): Batch size (not used; training uses the whole buffer).
332
332
"""
@@ -434,7 +434,7 @@ def prepare_state(self, latest_scan, distance, cos, sin, collision, goal, action
434
434
action (tuple[float, float]): Last action taken (linear and angular velocities).
435
435
436
436
Returns:
437
- tuple[list[float], int]: Processed state vector and terminal flag (1 if terminal, else 0).
437
+ ( tuple[list[float], int]) : Processed state vector and terminal flag (1 if terminal, else 0).
438
438
"""
439
439
latest_scan = np .array (latest_scan )
440
440
0 commit comments