Skip to content

Commit 4e8adce

Browse files
Unnormalize actions with model#sample_actions in OctoInference (#5)
* Add support for user supplied OctoModel in OctoInference * Simplify OctoInference --------- Co-authored-by: Xuanlin (Simon) Li <[email protected]>
1 parent 7fd9e22 commit 4e8adce

File tree

1 file changed

+24
-83
lines changed

1 file changed

+24
-83
lines changed

simpler_env/policies/octo/octo_model.py

Lines changed: 24 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
class OctoInference:
1717
def __init__(
1818
self,
19+
model: Optional[OctoModel] = None,
20+
dataset_id: Optional[str] = None,
1921
model_type: str = "octo-base",
2022
policy_setup: str = "widowx_bridge",
2123
horizon: int = 2,
@@ -27,87 +29,34 @@ def __init__(
2729
) -> None:
2830
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2931
if policy_setup == "widowx_bridge":
30-
dataset_id = "bridge_dataset"
32+
dataset_id = "bridge_dataset" if dataset_id is None else dataset_id
3133
action_ensemble = True
3234
action_ensemble_temp = 0.0
3335
self.sticky_gripper_num_repeat = 1
3436
elif policy_setup == "google_robot":
35-
dataset_id = "fractal20220817_data"
37+
dataset_id = "fractal20220817_data" if dataset_id is None else dataset_id
3638
action_ensemble = True
3739
action_ensemble_temp = 0.0
3840
self.sticky_gripper_num_repeat = 15
3941
else:
4042
raise NotImplementedError(f"Policy setup {policy_setup} not supported for octo models.")
4143
self.policy_setup = policy_setup
44+
self.dataset_id = dataset_id
4245

43-
if model_type in ["octo-base", "octo-small"]:
46+
if model is not None:
47+
self.tokenizer, self.tokenizer_kwargs = None, None
48+
self.model = model
49+
self.action_mean = self.model.dataset_statistics[dataset_id]["action"]["mean"]
50+
self.action_std = self.model.dataset_statistics[dataset_id]["action"]["std"]
51+
elif model_type in ["octo-base", "octo-small"]:
4452
# released huggingface octo models
4553
self.model_type = f"hf://rail-berkeley/{model_type}"
4654
self.tokenizer, self.tokenizer_kwargs = None, None
4755
self.model = OctoModel.load_pretrained(self.model_type)
4856
self.action_mean = self.model.dataset_statistics[dataset_id]["action"]["mean"]
4957
self.action_std = self.model.dataset_statistics[dataset_id]["action"]["std"]
50-
self.automatic_task_creation = True
5158
else:
52-
# custom model path
53-
self.model_type = model_type
54-
self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
55-
self.tokenizer_kwargs = {
56-
"max_length": 16,
57-
"padding": "max_length",
58-
"truncation": True,
59-
"return_tensors": "np",
60-
}
61-
self.model = tf.saved_model.load(self.model_type)
62-
if dataset_id == "bridge_dataset":
63-
self.action_mean = np.array(
64-
[
65-
0.00021161,
66-
0.00012614,
67-
-0.00017022,
68-
-0.00015062,
69-
-0.00023831,
70-
0.00025646,
71-
0.0,
72-
]
73-
)
74-
self.action_std = np.array(
75-
[
76-
0.00963721,
77-
0.0135066,
78-
0.01251861,
79-
0.02806791,
80-
0.03016905,
81-
0.07632624,
82-
1.0,
83-
]
84-
)
85-
elif dataset_id == "fractal20220817_data":
86-
self.action_mean = np.array(
87-
[
88-
0.00696389,
89-
0.00627008,
90-
-0.01263256,
91-
0.04330839,
92-
-0.00570499,
93-
0.00089247,
94-
0.0,
95-
]
96-
)
97-
self.action_std = np.array(
98-
[
99-
0.06925472,
100-
0.06019009,
101-
0.07354742,
102-
0.15605888,
103-
0.1316399,
104-
0.14593437,
105-
1.0,
106-
]
107-
)
108-
else:
109-
raise NotImplementedError(f"{dataset_id} not supported yet for custom octo model checkpoints.")
110-
self.automatic_task_creation = False
59+
raise NotImplementedError()
11160

11261
self.image_size = image_size
11362
self.action_scale = action_scale
@@ -165,10 +114,7 @@ def _obtain_image_history_and_mask(self) -> tuple[np.ndarray, np.ndarray]:
165114
return images, pad_mask
166115

167116
def reset(self, task_description: str) -> None:
168-
if self.automatic_task_creation:
169-
self.task = self.model.create_tasks(texts=[task_description])
170-
else:
171-
self.task = self.tokenizer(task_description, **self.tokenizer_kwargs)
117+
self.task = self.model.create_tasks(texts=[task_description])
172118
self.task_description = task_description
173119
self.image_history.clear()
174120
if self.action_ensemble:
@@ -209,25 +155,20 @@ def step(self, image: np.ndarray, task_description: Optional[str] = None, *args,
209155
self.rng, key = jax.random.split(self.rng) # each shape [2,]
210156
# print("octo local rng", self.rng, key)
211157

212-
if self.automatic_task_creation:
213-
input_observation = {"image_primary": images, "pad_mask": pad_mask}
214-
norm_raw_actions = self.model.sample_actions(input_observation, self.task, rng=key)
215-
else:
216-
input_observation = {"image_primary": images, "timestep_pad_mask": pad_mask}
217-
input_observation = {
218-
"observations": input_observation,
219-
"tasks": {"language_instruction": self.task},
220-
"rng": np.concatenate([self.rng, key]),
221-
}
222-
norm_raw_actions = self.model.lc_ws2(input_observation)[:, :, :7]
223-
norm_raw_actions = norm_raw_actions[0] # remove batch, becoming (action_pred_horizon, action_dim)
224-
assert norm_raw_actions.shape == (self.pred_action_horizon, 7)
158+
input_observation = {"image_primary": images, "timestep_pad_mask": pad_mask}
159+
raw_actions = self.model.sample_actions(
160+
input_observation,
161+
self.task,
162+
rng=key,
163+
unnormalization_statistics=self.model.dataset_statistics[self.dataset_id]["action"]
164+
)
165+
raw_actions = raw_actions[0] # remove batch, becoming (action_pred_horizon, action_dim)
225166

167+
assert raw_actions.shape == (self.pred_action_horizon, 7)
226168
if self.action_ensemble:
227-
norm_raw_actions = self.action_ensembler.ensemble_action(norm_raw_actions)
228-
norm_raw_actions = norm_raw_actions[None] # [1, 7]
169+
raw_actions = self.action_ensembler.ensemble_action(raw_actions)
170+
raw_actions = raw_actions[None] # [1, 7]
229171

230-
raw_actions = norm_raw_actions * self.action_std[None] + self.action_mean[None]
231172
raw_action = {
232173
"world_vector": np.array(raw_actions[0, :3]),
233174
"rotation_delta": np.array(raw_actions[0, 3:6]),

0 commit comments

Comments
 (0)