1616class OctoInference :
1717 def __init__ (
1818 self ,
19+ model : Optional [OctoModel ] = None ,
20+ dataset_id : Optional [str ] = None ,
1921 model_type : str = "octo-base" ,
2022 policy_setup : str = "widowx_bridge" ,
2123 horizon : int = 2 ,
@@ -27,87 +29,34 @@ def __init__(
2729 ) -> None :
2830 os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
2931 if policy_setup == "widowx_bridge" :
30- dataset_id = "bridge_dataset"
32+ dataset_id = "bridge_dataset" if dataset_id is None else dataset_id
3133 action_ensemble = True
3234 action_ensemble_temp = 0.0
3335 self .sticky_gripper_num_repeat = 1
3436 elif policy_setup == "google_robot" :
35- dataset_id = "fractal20220817_data"
37+ dataset_id = "fractal20220817_data" if dataset_id is None else dataset_id
3638 action_ensemble = True
3739 action_ensemble_temp = 0.0
3840 self .sticky_gripper_num_repeat = 15
3941 else :
4042 raise NotImplementedError (f"Policy setup { policy_setup } not supported for octo models." )
4143 self .policy_setup = policy_setup
44+ self .dataset_id = dataset_id
4245
43- if model_type in ["octo-base" , "octo-small" ]:
46+ if model is not None :
47+ self .tokenizer , self .tokenizer_kwargs = None , None
48+ self .model = model
49+ self .action_mean = self .model .dataset_statistics [dataset_id ]["action" ]["mean" ]
50+ self .action_std = self .model .dataset_statistics [dataset_id ]["action" ]["std" ]
51+ elif model_type in ["octo-base" , "octo-small" ]:
4452 # released huggingface octo models
4553 self .model_type = f"hf://rail-berkeley/{ model_type } "
4654 self .tokenizer , self .tokenizer_kwargs = None , None
4755 self .model = OctoModel .load_pretrained (self .model_type )
4856 self .action_mean = self .model .dataset_statistics [dataset_id ]["action" ]["mean" ]
4957 self .action_std = self .model .dataset_statistics [dataset_id ]["action" ]["std" ]
50- self .automatic_task_creation = True
5158 else :
52- # custom model path
53- self .model_type = model_type
54- self .tokenizer = AutoTokenizer .from_pretrained ("t5-base" )
55- self .tokenizer_kwargs = {
56- "max_length" : 16 ,
57- "padding" : "max_length" ,
58- "truncation" : True ,
59- "return_tensors" : "np" ,
60- }
61- self .model = tf .saved_model .load (self .model_type )
62- if dataset_id == "bridge_dataset" :
63- self .action_mean = np .array (
64- [
65- 0.00021161 ,
66- 0.00012614 ,
67- - 0.00017022 ,
68- - 0.00015062 ,
69- - 0.00023831 ,
70- 0.00025646 ,
71- 0.0 ,
72- ]
73- )
74- self .action_std = np .array (
75- [
76- 0.00963721 ,
77- 0.0135066 ,
78- 0.01251861 ,
79- 0.02806791 ,
80- 0.03016905 ,
81- 0.07632624 ,
82- 1.0 ,
83- ]
84- )
85- elif dataset_id == "fractal20220817_data" :
86- self .action_mean = np .array (
87- [
88- 0.00696389 ,
89- 0.00627008 ,
90- - 0.01263256 ,
91- 0.04330839 ,
92- - 0.00570499 ,
93- 0.00089247 ,
94- 0.0 ,
95- ]
96- )
97- self .action_std = np .array (
98- [
99- 0.06925472 ,
100- 0.06019009 ,
101- 0.07354742 ,
102- 0.15605888 ,
103- 0.1316399 ,
104- 0.14593437 ,
105- 1.0 ,
106- ]
107- )
108- else :
109- raise NotImplementedError (f"{ dataset_id } not supported yet for custom octo model checkpoints." )
110- self .automatic_task_creation = False
59+ raise NotImplementedError ()
11160
11261 self .image_size = image_size
11362 self .action_scale = action_scale
@@ -165,10 +114,7 @@ def _obtain_image_history_and_mask(self) -> tuple[np.ndarray, np.ndarray]:
165114 return images , pad_mask
166115
167116 def reset (self , task_description : str ) -> None :
168- if self .automatic_task_creation :
169- self .task = self .model .create_tasks (texts = [task_description ])
170- else :
171- self .task = self .tokenizer (task_description , ** self .tokenizer_kwargs )
117+ self .task = self .model .create_tasks (texts = [task_description ])
172118 self .task_description = task_description
173119 self .image_history .clear ()
174120 if self .action_ensemble :
@@ -209,25 +155,20 @@ def step(self, image: np.ndarray, task_description: Optional[str] = None, *args,
209155 self .rng , key = jax .random .split (self .rng ) # each shape [2,]
210156 # print("octo local rng", self.rng, key)
211157
212- if self .automatic_task_creation :
213- input_observation = {"image_primary" : images , "pad_mask" : pad_mask }
214- norm_raw_actions = self .model .sample_actions (input_observation , self .task , rng = key )
215- else :
216- input_observation = {"image_primary" : images , "timestep_pad_mask" : pad_mask }
217- input_observation = {
218- "observations" : input_observation ,
219- "tasks" : {"language_instruction" : self .task },
220- "rng" : np .concatenate ([self .rng , key ]),
221- }
222- norm_raw_actions = self .model .lc_ws2 (input_observation )[:, :, :7 ]
223- norm_raw_actions = norm_raw_actions [0 ] # remove batch, becoming (action_pred_horizon, action_dim)
224- assert norm_raw_actions .shape == (self .pred_action_horizon , 7 )
158+ input_observation = {"image_primary" : images , "timestep_pad_mask" : pad_mask }
159+ raw_actions = self .model .sample_actions (
160+ input_observation ,
161+ self .task ,
162+ rng = key ,
163+ unnormalization_statistics = self .model .dataset_statistics [self .dataset_id ]["action" ]
164+ )
165+ raw_actions = raw_actions [0 ] # remove batch, becoming (action_pred_horizon, action_dim)
225166
167+ assert raw_actions .shape == (self .pred_action_horizon , 7 )
226168 if self .action_ensemble :
227- norm_raw_actions = self .action_ensembler .ensemble_action (norm_raw_actions )
228- norm_raw_actions = norm_raw_actions [None ] # [1, 7]
169+ raw_actions = self .action_ensembler .ensemble_action (raw_actions )
170+ raw_actions = raw_actions [None ] # [1, 7]
229171
230- raw_actions = norm_raw_actions * self .action_std [None ] + self .action_mean [None ]
231172 raw_action = {
232173 "world_vector" : np .array (raw_actions [0 , :3 ]),
233174 "rotation_delta" : np .array (raw_actions [0 , 3 :6 ]),
0 commit comments