9
9
10
10
11
11
class Pretraining :
12
+ """
13
+ Handles loading of offline experience data and pretraining of a reinforcement learning model.
14
+
15
+ Attributes:
16
+ file_names (List[str]): List of YAML files containing pre-recorded environment samples.
17
+ model (object): The model with `prepare_state` and `train` methods.
18
+ replay_buffer (object): The buffer used to store experiences for training.
19
+ reward_function (callable): Function to compute the reward from the environment state.
20
+ """
21
+
12
22
def __init__ (
13
23
self ,
14
24
file_names : List [str ],
@@ -22,6 +32,12 @@ def __init__(
22
32
self .reward_function = reward_function
23
33
24
34
def load_buffer (self ):
35
+ """
36
+ Load samples from the specified files and populate the replay buffer.
37
+
38
+ Returns:
39
+ object: The populated replay buffer.
40
+ """
25
41
for file_name in self .file_names :
26
42
print ("Loading file: " , file_name )
27
43
with open (file_name , "r" ) as file :
@@ -76,6 +92,15 @@ def train(
76
92
iterations ,
77
93
batch_size ,
78
94
):
95
+ """
96
+ Run pretraining on the model using the replay buffer.
97
+
98
+ Args:
99
+ pretraining_iterations (int): Number of outer loop iterations for pretraining.
100
+ replay_buffer (object): Buffer to sample training batches from.
101
+ iterations (int): Number of training steps per pretraining iteration.
102
+ batch_size (int): Batch size used during training.
103
+ """
79
104
print ("Running Pretraining" )
80
105
for _ in tqdm (range (pretraining_iterations )):
81
106
self .model .train (
@@ -99,6 +124,25 @@ def get_buffer(
99
124
file_names = ["robot_nav/assets/data.yml" ],
100
125
history_len = 10 ,
101
126
):
127
+ """
128
+ Get or construct the replay buffer depending on model type and training configuration.
129
+
130
+ Args:
131
+ model (object): The RL model, can be PPO, RCPG, or other.
132
+ sim (object): Simulation environment with a `get_reward` function.
133
+ load_saved_buffer (bool): Whether to load experiences from file.
134
+ pretrain (bool): Whether to run pretraining using the buffer.
135
+ pretraining_iterations (int): Number of outer loop iterations for pretraining.
136
+ training_iterations (int): Number of iterations in each training loop.
137
+ batch_size (int): Size of the training batch.
138
+ buffer_size (int, optional): Maximum size of the buffer. Defaults to 50000.
139
+ random_seed (int, optional): Seed for reproducibility. Defaults to 666.
140
+ file_names (List[str], optional): List of YAML data file paths. Defaults to ["robot_nav/assets/data.yml"].
141
+ history_len (int, optional): Used for RCPG buffer configuration. Defaults to 10.
142
+
143
+ Returns:
144
+ object: The initialized and optionally pre-populated replay buffer.
145
+ """
102
146
if isinstance (model , PPO ):
103
147
return model .buffer
104
148
@@ -147,6 +191,27 @@ def get_max_bound(
147
191
done ,
148
192
device ,
149
193
):
194
+ """
195
+ Estimate the maximum possible return (upper bound) from the next state onward.
196
+
197
+ This is used in constrained RL or safe policy optimization where a conservative
198
+ estimate of return is useful for policy updates.
199
+
200
+ Args:
201
+ next_state (torch.Tensor): Tensor of next state observations.
202
+ discount (float): Discount factor for future rewards.
203
+ max_ang_vel (float): Maximum angular velocity of the agent.
204
+ max_lin_vel (float): Maximum linear velocity of the agent.
205
+ time_step (float): Duration of one time step.
206
+ distance_norm (float): Normalization factor for distance.
207
+ goal_reward (float): Reward received upon reaching the goal.
208
+ reward (torch.Tensor): Immediate reward from the environment.
209
+ done (torch.Tensor): Binary tensor indicating episode termination.
210
+ device (torch.device): PyTorch device for computation.
211
+
212
+ Returns:
213
+ torch.Tensor: Maximum return bound for each sample in the batch.
214
+ """
150
215
next_state = next_state .clone () # Prevents in-place modifications
151
216
reward = reward .clone () # Ensures original reward is unchanged
152
217
done = done .clone ()
0 commit comments