Skip to content

Commit 87acd14

Browse files
committed
Fix minor octo issues from pull request #5
1 parent 4e8adce commit 87acd14

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

simpler_env/policies/octo/octo_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ def step(self, image: np.ndarray, task_description: Optional[str] = None, *args,
155155
self.rng, key = jax.random.split(self.rng) # each shape [2,]
156156
# print("octo local rng", self.rng, key)
157157

158-
input_observation = {"image_primary": images, "timestep_pad_mask": pad_mask}
159-
raw_actions = self.model.sample_actions(
158+
input_observation = {"image_primary": images, "pad_mask": pad_mask}
159+
norm_raw_actions = self.model.sample_actions(
160160
input_observation,
161161
self.task,
162162
rng=key,
163-
unnormalization_statistics=self.model.dataset_statistics[self.dataset_id]["action"]
164163
)
164+
raw_actions = norm_raw_actions * self.action_std[None] + self.action_mean[None]
165165
raw_actions = raw_actions[0] # remove batch, becoming (action_pred_horizon, action_dim)
166166

167167
assert raw_actions.shape == (self.pred_action_horizon, 7)

0 commit comments

Comments
 (0)