Skip to content

Commit 62f9e8d

Browse files
committed
add policy typehints
1 parent 430b004 commit 62f9e8d

File tree

3 files changed

+52
-49
lines changed

3 files changed

+52
-49
lines changed

simpler_env/policies/octo/octo_model.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import deque
2-
from typing import Optional
2+
from typing import Optional, Sequence
33
import os
44

55
import jax
@@ -16,15 +16,15 @@
1616
class OctoInference:
1717
def __init__(
1818
self,
19-
model_type="octo-base",
20-
policy_setup="widowx_bridge",
21-
horizon=2,
22-
pred_action_horizon=4,
23-
exec_horizon=1,
24-
image_size=256,
25-
action_scale=1.0,
26-
init_rng=0,
27-
):
19+
model_type: str = "octo-base",
20+
policy_setup: str = "widowx_bridge",
21+
horizon: int = 2,
22+
pred_action_horizon: int = 4,
23+
exec_horizon: int = 1,
24+
image_size: int = 256,
25+
action_scale: float = 1.0,
26+
init_rng: int = 0,
27+
) -> None:
2828
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2929
if policy_setup == "widowx_bridge":
3030
dataset_id = "bridge_dataset"
@@ -118,7 +118,7 @@ def __init__(
118118
self.action_ensemble_temp = action_ensemble_temp
119119
self.rng = jax.random.PRNGKey(init_rng)
120120
for _ in range(5):
121-
# to match octo server's inference seeds
121+
# the purpose of this for loop is just to match octo server's inference seeds
122122
self.rng, _key = jax.random.split(self.rng) # each shape [2,]
123123

124124
self.sticky_action_is_on = False
@@ -136,7 +136,7 @@ def __init__(
136136
self.action_ensembler = None
137137
self.num_image_history = 0
138138

139-
def _resize_image(self, image):
139+
def _resize_image(self, image: np.ndarray) -> np.ndarray:
140140
image = tf.image.resize(
141141
image,
142142
size=(self.image_size, self.image_size),
@@ -146,24 +146,25 @@ def _resize_image(self, image):
146146
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
147147
return image
148148

149-
def _add_image_to_history(self, image):
149+
def _add_image_to_history(self, image: np.ndarray) -> None:
150150
self.image_history.append(image)
151+
# Alternative implementation below; but looks like for real eval, filling the entire buffer at the first step is not necessary
151152
# if self.num_image_history == 0:
152153
# self.image_history.extend([image] * self.horizon)
153154
# else:
154155
# self.image_history.append(image)
155156
self.num_image_history = min(self.num_image_history + 1, self.horizon)
156157

157-
def _obtain_image_history_and_mask(self):
158+
def _obtain_image_history_and_mask(self) -> tuple[np.ndarray, np.ndarray]:
158159
images = np.stack(self.image_history, axis=0)
159160
horizon = len(self.image_history)
160-
pad_mask = np.ones(horizon, dtype=np.float64) # note: this is not np.bool
161+
pad_mask = np.ones(horizon, dtype=np.float64) # note: this should be of float type, not a bool type
161162
pad_mask[: horizon - min(horizon, self.num_image_history)] = 0
162-
# pad_mask = np.ones(self.horizon, dtype=np.float64) # note: this is not np.bool
163+
# pad_mask = np.ones(self.horizon, dtype=np.float64) # note: this should be of float type, not a bool type
163164
# pad_mask[:self.horizon - self.num_image_history] = 0
164165
return images, pad_mask
165166

166-
def reset(self, task_description):
167+
def reset(self, task_description: str) -> None:
167168
if self.automatic_task_creation:
168169
self.task = self.model.create_tasks(texts=[task_description])
169170
else:
@@ -180,7 +181,7 @@ def reset(self, task_description):
180181
# self.gripper_is_closed = False
181182
self.previous_gripper_action = None
182183

183-
def step(self, image, task_description: Optional[str] = None, *args, **kwargs):
184+
def step(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
184185
"""
185186
Input:
186187
image: np.ndarray of shape (H, W, 3), uint8
@@ -299,7 +300,7 @@ def step(self, image, task_description: Optional[str] = None, *args, **kwargs):
299300

300301
return raw_action, action
301302

302-
def visualize_epoch(self, predicted_raw_actions, images, save_path):
303+
def visualize_epoch(self, predicted_raw_actions: Sequence[np.ndarray], images: Sequence[np.ndarray], save_path: str) -> None:
303304
images = [self._resize_image(image) for image in images]
304305
ACTION_DIM_LABELS = ["x", "y", "z", "roll", "pitch", "yaw", "grasp"]
305306

simpler_env/policies/octo/octo_server_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from base64 import b64decode, b64encode
2-
from typing import Optional
2+
from typing import Optional, Sequence, Any
33
import json
44
import time
55
import urllib
@@ -71,11 +71,11 @@ def patch():
7171
class OctoServerInference:
7272
def __init__(
7373
self,
74-
model_type="octo-base",
75-
policy_setup="widowx_bridge",
76-
image_size=256,
77-
action_scale=1.0,
78-
):
74+
model_type: str = "octo-base",
75+
policy_setup: str = "widowx_bridge",
76+
image_size: str = 256,
77+
action_scale: float = 1.0,
78+
) -> None:
7979
if policy_setup == "widowx_bridge":
8080
self.sticky_gripper_num_repeat = 1
8181
self.dataset_name = "bridge_dataset"
@@ -97,7 +97,7 @@ def __init__(
9797
self.action_scale = action_scale
9898
self.task = None
9999

100-
def _resize_image(self, image):
100+
def _resize_image(self, image: np.ndarray) -> np.ndarray:
101101
image = tf.image.resize(
102102
image,
103103
size=(self.image_size, self.image_size),
@@ -107,7 +107,7 @@ def _resize_image(self, image):
107107
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
108108
return image
109109

110-
def reset(self, task_description):
110+
def reset(self, task_description: str) -> None:
111111
self.task = task_description
112112
self.sticky_action_is_on = False
113113
self.gripper_action_repeat = 0
@@ -120,7 +120,7 @@ def reset(self, task_description):
120120
)
121121
time.sleep(1.0)
122122

123-
def _get_fake_pay_load(self, image_primary, text, modality="l"):
123+
def _get_fake_pay_load(self, image_primary: np.ndarray, text: str, modality: str = "l") -> dict:
124124
payload = {
125125
"dataset_name": self.dataset_name,
126126
"observation": {
@@ -133,7 +133,7 @@ def _get_fake_pay_load(self, image_primary, text, modality="l"):
133133
fake_pay_load = {"use_this": dumps(payload)}
134134
return fake_pay_load
135135

136-
def _query_for_action(self, image_primary, text, goal, modality="l"):
136+
def _query_for_action(self, image_primary: np.ndarray, text: str, goal: Optional[Any], modality="l") -> list:
137137
del goal
138138
# _ = requests.post(urllib.parse.urljoin("http://ari.bair.berkeley.edu:8000", "reset"),)
139139
fake_pay_load = self._get_fake_pay_load(image_primary, text, modality)
@@ -145,7 +145,7 @@ def _query_for_action(self, image_primary, text, goal, modality="l"):
145145
# print(reply)
146146
return loads(reply)
147147

148-
def step(self, image, task_description: Optional[str] = None, *args, **kwargs):
148+
def step(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
149149
"""
150150
Input:
151151
image: np.ndarray of shape (H, W, 3), uint8
@@ -239,7 +239,7 @@ def step(self, image, task_description: Optional[str] = None, *args, **kwargs):
239239

240240
return raw_action, action
241241

242-
def visualize_epoch(self, predicted_raw_actions, images, save_path):
242+
def visualize_epoch(self, predicted_raw_actions: Sequence[np.ndarray], images: Sequence[np.ndarray], save_path: str):
243243
images = [self._resize_image(image) for image in images]
244244
ACTION_DIM_LABELS = ["x", "y", "z", "yaw", "pitch", "roll", "grasp"]
245245

simpler_env/policies/rt1/rt1_model.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

44
import matplotlib.pyplot as plt
55
import numpy as np
@@ -14,13 +14,13 @@
1414
class RT1Inference:
1515
def __init__(
1616
self,
17-
saved_model_path="rt_1_x_tf_trained_for_002272480_step",
18-
lang_embed_model_path="https://tfhub.dev/google/universal-sentence-encoder-large/5",
19-
image_width=320,
20-
image_height=256,
21-
action_scale=1.0,
22-
policy_setup="google_robot",
23-
):
17+
saved_model_path: str = "rt_1_x_tf_trained_for_002272480_step",
18+
lang_embed_model_path: str = "https://tfhub.dev/google/universal-sentence-encoder-large/5",
19+
image_width: int = 320,
20+
image_height: int = 256,
21+
action_scale: float = 1.0,
22+
policy_setup: str = "google_robot",
23+
) -> None:
2424
self.lang_embed_model = hub.load(lang_embed_model_path)
2525
self.tfa_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
2626
model_path=saved_model_path,
@@ -53,10 +53,10 @@ def __init__(
5353

5454
@staticmethod
5555
def _rescale_action_with_bound(
56-
actions: np.ndarray,
56+
actions: np.ndarray | tf.Tensor,
5757
low: float,
5858
high: float,
59-
safety_margin: float = 0,
59+
safety_margin: float = 0.0,
6060
post_scaling_max: float = 1.0,
6161
post_scaling_min: float = -1.0,
6262
) -> np.ndarray:
@@ -68,7 +68,7 @@ def _rescale_action_with_bound(
6868
post_scaling_max - safety_margin,
6969
)
7070

71-
def _unnormalize_action_widowx_bridge(self, action):
71+
def _unnormalize_action_widowx_bridge(self, action: dict[str, np.ndarray | tf.Tensor]) -> dict[str, np.ndarray]:
7272
action["world_vector"] = self._rescale_action_with_bound(
7373
action["world_vector"],
7474
low=-1.75,
@@ -85,7 +85,7 @@ def _unnormalize_action_widowx_bridge(self, action):
8585
)
8686
return action
8787

88-
def _initialize_model(self):
88+
def _initialize_model(self) -> None:
8989
# Perform one step of inference using dummy input to trace the tensoflow graph
9090
# Obtain a dummy observation, where the features are all 0
9191
self.observation = tf_agents.specs.zero_spec_nest(
@@ -98,25 +98,25 @@ def _initialize_model(self):
9898
# Run inference using the policy
9999
_action = self.tfa_policy.action(self.tfa_time_step, self.policy_state)
100100

101-
def _resize_image(self, image):
101+
def _resize_image(self, image: np.ndarray | tf.Tensor) -> tf.Tensor:
102102
image = tf.image.resize_with_pad(image, target_width=self.image_width, target_height=self.image_height)
103103
image = tf.cast(image, tf.uint8)
104104
return image
105105

106-
def _initialize_task_description(self, task_description):
106+
def _initialize_task_description(self, task_description: Optional[str] = None) -> None:
107107
if task_description is not None:
108108
self.task_description = task_description
109109
self.task_description_embedding = self.lang_embed_model([task_description])[0]
110110
else:
111111
self.task_description = ""
112112
self.task_description_embedding = tf.zeros((512,), dtype=tf.float32)
113113

114-
def reset(self, task_description):
114+
def reset(self, task_description: str) -> None:
115115
self._initialize_model()
116116
self._initialize_task_description(task_description)
117117

118118
@staticmethod
119-
def _small_action_filter_google_robot(raw_action, arm_movement=False, gripper=True):
119+
def _small_action_filter_google_robot(raw_action: dict[str, np.ndarray | tf.Tensor], arm_movement: bool = False, gripper: bool = True) -> dict[str, np.ndarray | tf.Tensor]:
120120
# small action filtering for google robot
121121
if arm_movement:
122122
raw_action["world_vector"] = tf.where(
@@ -147,7 +147,7 @@ def _small_action_filter_google_robot(raw_action, arm_movement=False, gripper=Tr
147147
)
148148
return raw_action
149149

150-
def step(self, image, task_description: Optional[str] = None):
150+
def step(self, image: np.ndarray, task_description: Optional[str] = None) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
151151
"""
152152
Input:
153153
image: np.ndarray of shape (H, W, 3), uint8
@@ -179,6 +179,8 @@ def step(self, image, task_description: Optional[str] = None):
179179
raw_action = self._small_action_filter_google_robot(raw_action, arm_movement=False, gripper=True)
180180
if self.unnormalize_action:
181181
raw_action = self.unnormalize_action_fxn(raw_action)
182+
for k in raw_action.keys():
183+
raw_action[k] = np.asarray(raw_action[k])
182184

183185
# process raw_action to obtain the action to be sent to the maniskill2 environment
184186
action = {}
@@ -227,7 +229,7 @@ def step(self, image, task_description: Optional[str] = None):
227229

228230
return raw_action, action
229231

230-
def visualize_epoch(self, predicted_raw_actions, images, save_path):
232+
def visualize_epoch(self, predicted_raw_actions: Sequence[np.ndarray], images: Sequence[np.ndarray], save_path: str) -> None:
231233
images = [self._resize_image(image) for image in images]
232234
predicted_action_name_to_values_over_time = defaultdict(list)
233235
figure_layout = [

0 commit comments

Comments
 (0)