88from lerobot .datasets .utils import build_dataset_frame
99from lerobot .policies .pretrained import PreTrainedPolicy
1010from lerobot .processor import (
11- RobotAction ,
12- RobotObservation ,
13- RobotProcessorPipeline ,
11+ PolicyAction ,
12+ PolicyProcessorPipeline ,
1413)
1514from lerobot .robots .lekiwi .lekiwi_client import LeKiwiClient
1615from lerobot .teleoperators .keyboard import (
@@ -31,17 +30,12 @@ def record_loop(
3130 robot : LeKiwiClient ,
3231 events : dict [Any , Any ],
3332 fps : int ,
34- teleop_action_processor : RobotProcessorPipeline [
35- tuple [RobotAction , RobotObservation ], RobotAction
36- ], # runs after teleop
37- robot_action_processor : RobotProcessorPipeline [
38- tuple [RobotAction , RobotObservation ], RobotAction
39- ], # runs before robot
40- robot_observation_processor : RobotProcessorPipeline [RobotObservation , RobotObservation ], # runs after robot
4133 dataset : LeRobotDataset | None = None ,
4234 keyboard_handler : KeyboardTeleop | None = None ,
4335 arm_keyboard_handler : ArmTeleop | None = None ,
4436 policy : PreTrainedPolicy | None = None ,
37+ preprocessor : PolicyProcessorPipeline [dict [str , Any ], dict [str , Any ]] | None = None ,
38+ postprocessor : PolicyProcessorPipeline [PolicyAction , PolicyAction ] | None = None ,
4539 control_time_s : int | None = None ,
4640 single_task : str | None = None ,
4741 display_data : bool = False ,
@@ -64,6 +58,11 @@ def record_loop(
6458 if control_time_s is None :
6559 raise ValueError ("A control time must be provided." )
6660
61+ if display_data :
62+ logging .info ("Visualizing data with Rerun." )
63+ else :
64+ logging .info ("Not visualizing data." )
65+
6766 # if policy is given it needs cleaning up
6867 if policy is not None :
6968 policy .reset ()
@@ -84,17 +83,19 @@ def record_loop(
8483 raise ValueError ("Dataset features must be defined if using a dataset or a policy." )
8584 observation_frame = build_dataset_frame (dataset .features , observation , prefix = "observation" ) # type: ignore[union-attr]
8685
87- if policy is not None :
86+ if policy is not None and preprocessor is not None and postprocessor is not None :
8887 action_values = predict_action (
89- observation_frame ,
90- policy ,
91- get_safe_torch_device (policy .config .device ),
92- policy .config .use_amp ,
88+ observation = observation_frame ,
89+ policy = policy ,
90+ device = get_safe_torch_device (policy .config .device ),
9391 preprocessor = preprocessor ,
9492 postprocessor = postprocessor ,
93+ use_amp = policy .config .use_amp ,
9594 task = single_task ,
9695 robot_type = robot .robot_type ,
9796 )
97+ if action_values .dim () > 1 :
98+ action_values = action_values .squeeze (0 )
9899 action = {key : action_values [i ].item () for i , key in enumerate (robot .action_features )}
99100 print (f"Predicted action: { action } " )
100101 elif policy is None and keyboard_handler is not None and arm_keyboard_handler is not None :
@@ -103,6 +104,8 @@ def record_loop(
103104 arm_action = arm_keyboard_handler .from_keyboard_to_arm_action (pressed_keys )
104105
105106 action = {** base_action , ** arm_action } # Merge base and arm actions
107+ # TODO(francocipollone): We would probably want to use the teleop_action_processor here.
108+ # action = teleop_action_processor((action, observation))
106109 logging .debug ("Sending action: %s" , action )
107110
108111 else :
@@ -113,14 +116,13 @@ def record_loop(
113116 )
114117 continue
115118
116- # Action can eventually be clipped using `max_relative_target`,
117- # so action actually sent is saved in the dataset.
119+ # TODO(francocipollone): We would probably want to use the robot_action_processor here before sending the action
118120 sent_action = robot .send_action (action )
119121
120122 if dataset is not None :
121123 action_frame = build_dataset_frame (dataset .features , sent_action , prefix = "action" )
122- frame = {** observation_frame , ** action_frame }
123- dataset .add_frame (frame , task = single_task )
124+ frame = {** observation_frame , ** action_frame , "task" : single_task }
125+ dataset .add_frame (frame )
124126
125127 if display_data :
126128 log_rerun_data (observation , action )
0 commit comments