Skip to content

Commit f4596f8

Browse files
Fix replay script.
Signed-off-by: Franco Cipollone <franco.c@ekumenlabs.com>
1 parent 24862c6 commit f4596f8

File tree

8 files changed

+77
-37
lines changed

8 files changed

+77
-37
lines changed

dora/node_hub/dora_run_policy/dora_run_policy/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pyarrow as pa
1414
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
1515
from lerobot.policies.act.modeling_act import ACTPolicy
16+
from lerobot.policies.factory import make_pre_post_processors
1617
from lerobot.utils.control_utils import predict_action
1718
from lerobot.utils.utils import get_safe_torch_device
1819

@@ -114,6 +115,11 @@ def main() -> None:
114115
policy.reset()
115116
device_name = policy.config.device or "auto"
116117
device = get_safe_torch_device(device_name)
118+
# Build Policy Processors
119+
preprocessor, postprocessor = make_pre_post_processors(
120+
policy_cfg=policy.config,
121+
pretrained_path=model_name,
122+
)
117123
except Exception as e:
118124
raise RuntimeError(f"Failed to load policy '{model_name}': {e}") from None
119125

@@ -171,9 +177,16 @@ def main() -> None:
171177
observation_frame,
172178
policy,
173179
device,
180+
preprocessor,
181+
postprocessor,
174182
policy.config.use_amp,
175183
)
176184

185+
# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
186+
# Remove batch dimension if present
187+
if raw_action.dim() > 1:
188+
raw_action = raw_action.squeeze(0)
189+
177190
# Send action output
178191
metadata = event["metadata"].copy()
179192
metadata["primitive"] = "series"

dora/node_hub/dora_run_policy/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires-python = ">=3.10"
99

1010
dependencies = [
1111
"dora-rs >= 0.3.13",
12-
"lerobot==0.3.3"
12+
"lerobot==0.4.0"
1313
]
1414

1515
[build-system]

packages/lekiwi_lerobot/lekiwi_lerobot/evaluate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def main() -> None:
139139
robot=robot,
140140
events=events,
141141
fps=FPS,
142+
dataset=None, # Don't record during reset phase
142143
keyboard_handler=keyboard,
143144
arm_keyboard_handler=arm_keyboard_handler,
144145
control_time_s=RESET_TIME_SEC,

packages/lekiwi_lerobot/lekiwi_lerobot/record.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from lekiwi_teleoperate.teleoperate.arm import ArmTeleop
66
from lerobot.datasets.lerobot_dataset import LeRobotDataset
77
from lerobot.datasets.utils import hw_to_dataset_features
8-
from lerobot.processor import make_default_processors
98
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
109
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
1110
from lerobot.teleoperators.keyboard import (
1211
KeyboardTeleop,
1312
KeyboardTeleopConfig,
1413
)
14+
from lerobot.utils.constants import ACTION, OBS_STR
1515
from lerobot.utils.control_utils import (
1616
init_keyboard_listener,
1717
)
@@ -53,6 +53,12 @@ def main() -> None:
5353
default="Unnamed task",
5454
help="Task description to associate with each episode (default: 'Unnamed task').",
5555
)
56+
parser.add_argument(
57+
"--no-viz",
58+
action="store_false",
59+
dest="visualize",
60+
help="Disable Rerun visualization during recording.",
61+
)
5662

5763
args = parser.parse_args()
5864
if args.repo_id is None:
@@ -74,8 +80,8 @@ def main() -> None:
7480
keyboard = KeyboardTeleop(keyboard_config)
7581
arm_keyboard_handler = ArmTeleop()
7682
# Configure the dataset features
77-
action_features = hw_to_dataset_features(robot.action_features, "action")
78-
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
83+
action_features = hw_to_dataset_features(robot.action_features, ACTION)
84+
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
7985
logging.info(f"Recording the following observation features: {list(obs_features.keys())}")
8086
logging.info(f"Recording the following action features: {list(action_features.keys())}")
8187
dataset_features = {**action_features, **obs_features}
@@ -87,7 +93,7 @@ def main() -> None:
8793
features=dataset_features,
8894
robot_type=robot.name,
8995
use_videos=True,
90-
image_writer_threads=0,
96+
image_writer_threads=4,
9197
)
9298

9399
# To connect you already should have:
@@ -97,8 +103,11 @@ def main() -> None:
97103
robot.connect()
98104
keyboard.connect()
99105

100-
init_rerun(session_name="lekiwi_record")
101-
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
106+
if args.visualize:
107+
logging.info("Initializing Rerun for visualization.")
108+
init_rerun(session_name="lekiwi_record")
109+
else:
110+
logging.info("Rerun visualization is disabled.")
102111

103112
listener, events = init_keyboard_listener()
104113

@@ -119,21 +128,22 @@ def main() -> None:
119128
arm_keyboard_handler=arm_keyboard_handler,
120129
control_time_s=EPISODE_TIME_SEC,
121130
single_task=args.task,
122-
display_data=True,
131+
display_data=args.visualize,
123132
)
124133

125-
# Logic for reset env
134+
# Reset the environment if not stopping or re-recording
126135
if not events["stop_recording"] and ((recorded_episodes < args.episodes - 1) or events["rerecord_episode"]):
127136
logging.info("Reset the environment")
128137
record_loop(
129138
robot=robot,
130139
events=events,
131140
fps=FPS,
141+
dataset=None, # Don't record during reset phase
132142
keyboard_handler=keyboard,
133143
arm_keyboard_handler=arm_keyboard_handler,
134144
control_time_s=RESET_TIME_SEC,
135145
single_task=args.task,
136-
display_data=True,
146+
display_data=args.visualize,
137147
)
138148

139149
if events["rerecord_episode"]:
@@ -143,10 +153,12 @@ def main() -> None:
143153
dataset.clear_episode_buffer()
144154
continue
145155

156+
logging.info(f"Saving episode number {recorded_episodes} to dataset.")
146157
dataset.save_episode()
147158
recorded_episodes += 1
148159

149160
# Upload to hub and clean up
161+
dataset.finalize()
150162
dataset.push_to_hub()
151163

152164
robot.disconnect()

packages/lekiwi_lerobot/lekiwi_lerobot/replay.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lerobot.datasets.lerobot_dataset import LeRobotDataset
66
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
77
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
8+
from lerobot.utils.constants import ACTION
89
from lerobot.utils.robot_utils import busy_wait
910

1011

@@ -84,17 +85,22 @@ def main() -> None:
8485
logging.info(f"Downloading dataset from {args.repo_id} into {args.directory}")
8586
root = args.directory + "/" + args.repo_id.split("/")[-1]
8687
root = root.replace("//", "/")
87-
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[args.episode])
88+
episode_to_replay = args.episode
89+
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[episode_to_replay])
8890
logging.info(f"Dataset stored at {root}")
8991
actions = dataset.hf_dataset.select_columns("action")
92+
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
93+
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == episode_to_replay)
94+
actions = episode_frames.select_columns(ACTION)
9095

9196
robot.connect()
9297

9398
if not robot.is_connected:
9499
raise ValueError("Robot is not connected!")
95100

96-
logging.info(f"Replaying episode {args.episode} with {dataset.num_frames} frames.")
97-
for idx in range(dataset.num_frames):
101+
len_episodes_frames = len(episode_frames)
102+
logging.info(f"Replaying episode {args.episode} with {len_episodes_frames} frames.")
103+
for idx in range(len_episodes_frames):
98104
t0 = time.perf_counter()
99105

100106
action = {name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])}

packages/lekiwi_lerobot/lekiwi_lerobot/run_policy.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lerobot.policies.factory import make_pre_post_processors
88
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
99
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
10+
from lerobot.utils.constants import ACTION, OBS_STR
1011
from lerobot.utils.control_utils import init_keyboard_listener, predict_action
1112
from lerobot.utils.robot_utils import busy_wait
1213
from lerobot.utils.utils import get_safe_torch_device
@@ -83,14 +84,14 @@ def main() -> None:
8384
logging.info("Robot is connected.")
8485

8586
# Prepare for policy inference
86-
action_features = hw_to_dataset_features(robot.action_features, "action")
87-
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
87+
action_features = hw_to_dataset_features(robot.action_features, ACTION)
88+
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
8889
dataset_features = {**action_features, **obs_features}
8990
device = get_safe_torch_device(policy.config.device)
9091
# Build Policy Processors
9192
preprocessor, postprocessor = make_pre_post_processors(
9293
policy_cfg=policy.config,
93-
pretrained_path=None,
94+
pretrained_path=args.policy,
9495
)
9596
init_rerun(session_name="lekiwi_run_policy")
9697

@@ -102,7 +103,7 @@ def main() -> None:
102103

103104
observation = robot.get_observation()
104105

105-
observation_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
106+
observation_frame = build_dataset_frame(dataset_features, observation, prefix=OBS_STR)
106107

107108
action_values = predict_action(
108109
observation_frame,
@@ -114,6 +115,11 @@ def main() -> None:
114115
task=args.task,
115116
robot_type=robot.robot_type,
116117
)
118+
# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
119+
# Remove batch dimension if present
120+
if action_values.dim() > 1:
121+
action_values = action_values.squeeze(0)
122+
117123
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
118124

119125
logging.debug(f"Predicted action: {action}")

packages/lekiwi_lerobot/lekiwi_lerobot/utils.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from lerobot.datasets.utils import build_dataset_frame
99
from lerobot.policies.pretrained import PreTrainedPolicy
1010
from lerobot.processor import (
11-
RobotAction,
12-
RobotObservation,
13-
RobotProcessorPipeline,
11+
PolicyAction,
12+
PolicyProcessorPipeline,
1413
)
1514
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
1615
from lerobot.teleoperators.keyboard import (
@@ -31,17 +30,12 @@ def record_loop(
3130
robot: LeKiwiClient,
3231
events: dict[Any, Any],
3332
fps: int,
34-
teleop_action_processor: RobotProcessorPipeline[
35-
tuple[RobotAction, RobotObservation], RobotAction
36-
], # runs after teleop
37-
robot_action_processor: RobotProcessorPipeline[
38-
tuple[RobotAction, RobotObservation], RobotAction
39-
], # runs before robot
40-
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], # runs after robot
4133
dataset: LeRobotDataset | None = None,
4234
keyboard_handler: KeyboardTeleop | None = None,
4335
arm_keyboard_handler: ArmTeleop | None = None,
4436
policy: PreTrainedPolicy | None = None,
37+
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
38+
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
4539
control_time_s: int | None = None,
4640
single_task: str | None = None,
4741
display_data: bool = False,
@@ -64,6 +58,11 @@ def record_loop(
6458
if control_time_s is None:
6559
raise ValueError("A control time must be provided.")
6660

61+
if display_data:
62+
logging.info("Visualizing data with Rerun.")
63+
else:
64+
logging.info("Not visualizing data.")
65+
6766
# if policy is given it needs cleaning up
6867
if policy is not None:
6968
policy.reset()
@@ -84,17 +83,19 @@ def record_loop(
8483
raise ValueError("Dataset features must be defined if using a dataset or a policy.")
8584
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") # type: ignore[union-attr]
8685

87-
if policy is not None:
86+
if policy is not None and preprocessor is not None and postprocessor is not None:
8887
action_values = predict_action(
89-
observation_frame,
90-
policy,
91-
get_safe_torch_device(policy.config.device),
92-
policy.config.use_amp,
88+
observation=observation_frame,
89+
policy=policy,
90+
device=get_safe_torch_device(policy.config.device),
9391
preprocessor=preprocessor,
9492
postprocessor=postprocessor,
93+
use_amp=policy.config.use_amp,
9594
task=single_task,
9695
robot_type=robot.robot_type,
9796
)
97+
if action_values.dim() > 1:
98+
action_values = action_values.squeeze(0)
9899
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
99100
print(f"Predicted action: {action}")
100101
elif policy is None and keyboard_handler is not None and arm_keyboard_handler is not None:
@@ -103,6 +104,8 @@ def record_loop(
103104
arm_action = arm_keyboard_handler.from_keyboard_to_arm_action(pressed_keys)
104105

105106
action = {**base_action, **arm_action} # Merge base and arm actions
107+
# TODO(francocipollone): We would probably want to use the teleop_action_processor here.
108+
# action = teleop_action_processor((action, observation))
106109
logging.debug("Sending action: %s", action)
107110

108111
else:
@@ -113,14 +116,13 @@ def record_loop(
113116
)
114117
continue
115118

116-
# Action can eventually be clipped using `max_relative_target`,
117-
# so action actually sent is saved in the dataset.
119+
# TODO(francocipollone): We would probably want to use the robot_action_processor here before sending the action
118120
sent_action = robot.send_action(action)
119121

120122
if dataset is not None:
121123
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
122-
frame = {**observation_frame, **action_frame}
123-
dataset.add_frame(frame, task=single_task)
124+
frame = {**observation_frame, **action_frame, "task": single_task}
125+
dataset.add_frame(frame)
124126

125127
if display_data:
126128
log_rerun_data(observation, action)

packages/lekiwi_sim/lekiwi_sim/lekiwi_sim_host.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def main() -> None:
106106
watchdog_active = False
107107
except zmq.Again:
108108
if not watchdog_active:
109-
logging.warning("No command available")
109+
logging.debug("No command available")
110110
except Exception as e:
111111
logging.error("Message fetching failed: %s", e)
112112

0 commit comments

Comments
 (0)