77from lerobot .datasets .lerobot_dataset import LeRobotDataset
88from lerobot .datasets .utils import build_dataset_frame
99from lerobot .policies .pretrained import PreTrainedPolicy
10+ from lerobot .processor import (
11+ PolicyAction ,
12+ PolicyProcessorPipeline ,
13+ )
1014from lerobot .robots .lekiwi .lekiwi_client import LeKiwiClient
1115from lerobot .teleoperators .keyboard import (
1216 KeyboardTeleop ,
@@ -30,6 +34,8 @@ def record_loop(
3034 keyboard_handler : KeyboardTeleop | None = None ,
3135 arm_keyboard_handler : ArmTeleop | None = None ,
3236 policy : PreTrainedPolicy | None = None ,
37+ preprocessor : PolicyProcessorPipeline [dict [str , Any ], dict [str , Any ]] | None = None ,
38+ postprocessor : PolicyProcessorPipeline [PolicyAction , PolicyAction ] | None = None ,
3339 control_time_s : int | None = None ,
3440 single_task : str | None = None ,
3541 display_data : bool = False ,
@@ -52,6 +58,11 @@ def record_loop(
5258 if control_time_s is None :
5359 raise ValueError ("A control time must be provided." )
5460
61+ if display_data :
62+ logging .info ("Visualizing data with Rerun." )
63+ else :
64+ logging .info ("Not visualizing data." )
65+
5566 # if policy is given it needs cleaning up
5667 if policy is not None :
5768 policy .reset ()
@@ -72,15 +83,19 @@ def record_loop(
7283 raise ValueError ("Dataset features must be defined if using a dataset or a policy." )
7384 observation_frame = build_dataset_frame (dataset .features , observation , prefix = "observation" ) # type: ignore[union-attr]
7485
75- if policy is not None :
86+ if policy is not None and preprocessor is not None and postprocessor is not None :
7687 action_values = predict_action (
77- observation_frame ,
78- policy ,
79- get_safe_torch_device (policy .config .device ),
80- policy .config .use_amp ,
88+ observation = observation_frame ,
89+ policy = policy ,
90+ device = get_safe_torch_device (policy .config .device ),
91+ preprocessor = preprocessor ,
92+ postprocessor = postprocessor ,
93+ use_amp = policy .config .use_amp ,
8194 task = single_task ,
8295 robot_type = robot .robot_type ,
8396 )
97+ if action_values .dim () > 1 :
98+ action_values = action_values .squeeze (0 )
8499 action = {key : action_values [i ].item () for i , key in enumerate (robot .action_features )}
85100 print (f"Predicted action: { action } " )
86101 elif policy is None and keyboard_handler is not None and arm_keyboard_handler is not None :
@@ -89,6 +104,8 @@ def record_loop(
89104 arm_action = arm_keyboard_handler .from_keyboard_to_arm_action (pressed_keys )
90105
91106 action = {** base_action , ** arm_action } # Merge base and arm actions
107+ # TODO(francocipollone): We would probably want to use the teleop_action_processor here.
108+ # action = teleop_action_processor((action, observation))
92109 logging .debug ("Sending action: %s" , action )
93110
94111 else :
@@ -99,14 +116,13 @@ def record_loop(
99116 )
100117 continue
101118
102- # Action can eventually be clipped using `max_relative_target`,
103- # so action actually sent is saved in the dataset.
119+ # TODO(francocipollone): We would probably want to use the robot_action_processor here before sending the action
104120 sent_action = robot .send_action (action )
105121
106122 if dataset is not None :
107123 action_frame = build_dataset_frame (dataset .features , sent_action , prefix = "action" )
108- frame = {** observation_frame , ** action_frame }
109- dataset .add_frame (frame , task = single_task )
124+ frame = {** observation_frame , ** action_frame , "task" : single_task }
125+ dataset .add_frame (frame )
110126
111127 if display_data :
112128 log_rerun_data (observation , action )
0 commit comments