11from collections import defaultdict
2- from typing import Optional
2+ from typing import Optional , Sequence
33
44import matplotlib .pyplot as plt
55import numpy as np
1414class RT1Inference :
1515 def __init__ (
1616 self ,
17- saved_model_path = "rt_1_x_tf_trained_for_002272480_step" ,
18- lang_embed_model_path = "https://tfhub.dev/google/universal-sentence-encoder-large/5" ,
19- image_width = 320 ,
20- image_height = 256 ,
21- action_scale = 1.0 ,
22- policy_setup = "google_robot" ,
23- ):
17+ saved_model_path : str = "rt_1_x_tf_trained_for_002272480_step" ,
18+ lang_embed_model_path : str = "https://tfhub.dev/google/universal-sentence-encoder-large/5" ,
19+ image_width : int = 320 ,
20+ image_height : int = 256 ,
21+ action_scale : float = 1.0 ,
22+ policy_setup : str = "google_robot" ,
23+ ) -> None :
2424 self .lang_embed_model = hub .load (lang_embed_model_path )
2525 self .tfa_policy = py_tf_eager_policy .SavedModelPyTFEagerPolicy (
2626 model_path = saved_model_path ,
@@ -53,10 +53,10 @@ def __init__(
5353
5454 @staticmethod
5555 def _rescale_action_with_bound (
56- actions : np .ndarray ,
56+ actions : np .ndarray | tf . Tensor ,
5757 low : float ,
5858 high : float ,
59- safety_margin : float = 0 ,
59+ safety_margin : float = 0.0 ,
6060 post_scaling_max : float = 1.0 ,
6161 post_scaling_min : float = - 1.0 ,
6262 ) -> np .ndarray :
@@ -68,7 +68,7 @@ def _rescale_action_with_bound(
6868 post_scaling_max - safety_margin ,
6969 )
7070
71- def _unnormalize_action_widowx_bridge (self , action ) :
71+ def _unnormalize_action_widowx_bridge (self , action : dict [ str , np . ndarray | tf . Tensor ]) -> dict [ str , np . ndarray ] :
7272 action ["world_vector" ] = self ._rescale_action_with_bound (
7373 action ["world_vector" ],
7474 low = - 1.75 ,
@@ -85,7 +85,7 @@ def _unnormalize_action_widowx_bridge(self, action):
8585 )
8686 return action
8787
88- def _initialize_model (self ):
88+ def _initialize_model (self ) -> None :
8989 # Perform one step of inference using dummy input to trace the tensoflow graph
9090 # Obtain a dummy observation, where the features are all 0
9191 self .observation = tf_agents .specs .zero_spec_nest (
@@ -98,25 +98,25 @@ def _initialize_model(self):
9898 # Run inference using the policy
9999 _action = self .tfa_policy .action (self .tfa_time_step , self .policy_state )
100100
101- def _resize_image (self , image ) :
101+ def _resize_image (self , image : np . ndarray | tf . Tensor ) -> tf . Tensor :
102102 image = tf .image .resize_with_pad (image , target_width = self .image_width , target_height = self .image_height )
103103 image = tf .cast (image , tf .uint8 )
104104 return image
105105
106- def _initialize_task_description (self , task_description ) :
106+ def _initialize_task_description (self , task_description : Optional [ str ] = None ) -> None :
107107 if task_description is not None :
108108 self .task_description = task_description
109109 self .task_description_embedding = self .lang_embed_model ([task_description ])[0 ]
110110 else :
111111 self .task_description = ""
112112 self .task_description_embedding = tf .zeros ((512 ,), dtype = tf .float32 )
113113
114- def reset (self , task_description ) :
114+ def reset (self , task_description : str ) -> None :
115115 self ._initialize_model ()
116116 self ._initialize_task_description (task_description )
117117
118118 @staticmethod
119- def _small_action_filter_google_robot (raw_action , arm_movement = False , gripper = True ):
119+ def _small_action_filter_google_robot (raw_action : dict [ str , np . ndarray | tf . Tensor ], arm_movement : bool = False , gripper : bool = True ) -> dict [ str , np . ndarray | tf . Tensor ] :
120120 # small action filtering for google robot
121121 if arm_movement :
122122 raw_action ["world_vector" ] = tf .where (
@@ -147,7 +147,7 @@ def _small_action_filter_google_robot(raw_action, arm_movement=False, gripper=Tr
147147 )
148148 return raw_action
149149
150- def step (self , image , task_description : Optional [str ] = None ):
150+ def step (self , image : np . ndarray , task_description : Optional [str ] = None ) -> tuple [ dict [ str , np . ndarray ], dict [ str , np . ndarray ]] :
151151 """
152152 Input:
153153 image: np.ndarray of shape (H, W, 3), uint8
@@ -179,6 +179,8 @@ def step(self, image, task_description: Optional[str] = None):
179179 raw_action = self ._small_action_filter_google_robot (raw_action , arm_movement = False , gripper = True )
180180 if self .unnormalize_action :
181181 raw_action = self .unnormalize_action_fxn (raw_action )
182+ for k in raw_action .keys ():
183+ raw_action [k ] = np .asarray (raw_action [k ])
182184
183185 # process raw_action to obtain the action to be sent to the maniskill2 environment
184186 action = {}
@@ -227,7 +229,7 @@ def step(self, image, task_description: Optional[str] = None):
227229
228230 return raw_action , action
229231
230- def visualize_epoch (self , predicted_raw_actions , images , save_path ) :
232+ def visualize_epoch (self , predicted_raw_actions : Sequence [ np . ndarray ] , images : Sequence [ np . ndarray ] , save_path : str ) -> None :
231233 images = [self ._resize_image (image ) for image in images ]
232234 predicted_action_name_to_values_over_time = defaultdict (list )
233235 figure_layout = [
0 commit comments