Skip to content

Commit 0ba5711

Browse files
committed
add utils documentation
1 parent 10f460c commit 0ba5711

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

robot_nav/utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99

1010

1111
class Pretraining:
12+
"""
13+
Handles loading of offline experience data and pretraining of a reinforcement learning model.
14+
15+
Attributes:
16+
file_names (List[str]): List of YAML files containing pre-recorded environment samples.
17+
model (object): The model with `prepare_state` and `train` methods.
18+
replay_buffer (object): The buffer used to store experiences for training.
19+
reward_function (callable): Function to compute the reward from the environment state.
20+
"""
21+
1222
def __init__(
1323
self,
1424
file_names: List[str],
@@ -22,6 +32,12 @@ def __init__(
2232
self.reward_function = reward_function
2333

2434
def load_buffer(self):
35+
"""
36+
Load samples from the specified files and populate the replay buffer.
37+
38+
Returns:
39+
object: The populated replay buffer.
40+
"""
2541
for file_name in self.file_names:
2642
print("Loading file: ", file_name)
2743
with open(file_name, "r") as file:
@@ -76,6 +92,15 @@ def train(
7692
iterations,
7793
batch_size,
7894
):
95+
"""
96+
Run pretraining on the model using the replay buffer.
97+
98+
Args:
99+
pretraining_iterations (int): Number of outer loop iterations for pretraining.
100+
replay_buffer (object): Buffer to sample training batches from.
101+
iterations (int): Number of training steps per pretraining iteration.
102+
batch_size (int): Batch size used during training.
103+
"""
79104
print("Running Pretraining")
80105
for _ in tqdm(range(pretraining_iterations)):
81106
self.model.train(
@@ -99,6 +124,25 @@ def get_buffer(
99124
file_names=["robot_nav/assets/data.yml"],
100125
history_len=10,
101126
):
127+
"""
128+
Get or construct the replay buffer depending on model type and training configuration.
129+
130+
Args:
131+
model (object): The RL model, can be PPO, RCPG, or other.
132+
sim (object): Simulation environment with a `get_reward` function.
133+
load_saved_buffer (bool): Whether to load experiences from file.
134+
pretrain (bool): Whether to run pretraining using the buffer.
135+
pretraining_iterations (int): Number of outer loop iterations for pretraining.
136+
training_iterations (int): Number of iterations in each training loop.
137+
batch_size (int): Size of the training batch.
138+
buffer_size (int, optional): Maximum size of the buffer. Defaults to 50000.
139+
random_seed (int, optional): Seed for reproducibility. Defaults to 666.
140+
file_names (List[str], optional): List of YAML data file paths. Defaults to ["robot_nav/assets/data.yml"].
141+
history_len (int, optional): Used for RCPG buffer configuration. Defaults to 10.
142+
143+
Returns:
144+
object: The initialized and optionally pre-populated replay buffer.
145+
"""
102146
if isinstance(model, PPO):
103147
return model.buffer
104148

@@ -147,6 +191,27 @@ def get_max_bound(
147191
done,
148192
device,
149193
):
194+
"""
195+
Estimate the maximum possible return (upper bound) from the next state onward.
196+
197+
This is used in constrained RL or safe policy optimization where a conservative
198+
estimate of return is useful for policy updates.
199+
200+
Args:
201+
next_state (torch.Tensor): Tensor of next state observations.
202+
discount (float): Discount factor for future rewards.
203+
max_ang_vel (float): Maximum angular velocity of the agent.
204+
max_lin_vel (float): Maximum linear velocity of the agent.
205+
time_step (float): Duration of one time step.
206+
distance_norm (float): Normalization factor for distance.
207+
goal_reward (float): Reward received upon reaching the goal.
208+
reward (torch.Tensor): Immediate reward from the environment.
209+
done (torch.Tensor): Binary tensor indicating episode termination.
210+
device (torch.device): PyTorch device for computation.
211+
212+
Returns:
213+
torch.Tensor: Maximum return bound for each sample in the batch.
214+
"""
150215
next_state = next_state.clone() # Prevents in-place modifications
151216
reward = reward.clone() # Ensures original reward is unchanged
152217
done = done.clone()

0 commit comments

Comments
 (0)