Skip to content

Commit f3af6f2

Browse files
Bump lerobot to 0.4.0. (#36)
* Bump lerobot to 0.4.0. Signed-off-by: Franco Cipollone <franco.c@ekumenlabs.com> * Fix replay script. Signed-off-by: Franco Cipollone <franco.c@ekumenlabs.com> * Fixes rerun Signed-off-by: Franco Cipollone <franco.c@ekumenlabs.com> --------- Signed-off-by: Franco Cipollone <franco.c@ekumenlabs.com>
1 parent 1268722 commit f3af6f2

File tree

15 files changed

+175
-133
lines changed

15 files changed

+175
-133
lines changed

dora/lekiwi_sim/graphs/mujoco_sim.yml

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,21 @@ nodes:
4141
############################################################
4242
# Example of a rerun visualization node
4343
############################################################
44-
# TODO: Enable once lerobot is updated to 0.4.0 due to rerun version conflict.
45-
# - id: rerun-viz
46-
# # TODO: Use tag once there is a new release
47-
# build: cargo install --git https://github.com/dora-rs/dora-hub dora-rerun
48-
# path: dora-rerun
49-
# inputs:
50-
# front_camera:
51-
# source: dora_lekiwi_client/image_front
52-
# wrist_camera:
53-
# source: dora_lekiwi_client/image_wrist
54-
# observation_state:
55-
# # ["arm_shoulder_pan.pos", "arm_shoulder_lift.pos", "arm_elbow_flex.pos", "arm_wrist_flex.pos", "arm_wrist_roll.pos", "arm_gripper.pos", "x.vel", "y.vel", "theta.vel"]
56-
# source: dora_lekiwi_client/observation_state
57-
# actions:
58-
# source: dora_run_policy/actions
59-
# env:
60-
# OPERATING_MODE: SPAWN
44+
- id: rerun-viz
45+
build: cargo install --git https://github.com/dora-rs/dora-hub --tag 0.3.13 dora-rerun
46+
path: dora-rerun
47+
inputs:
48+
front_camera:
49+
source: dora_lekiwi_client/image_front
50+
wrist_camera:
51+
source: dora_lekiwi_client/image_wrist
52+
observation_state:
53+
# ["arm_shoulder_pan.pos", "arm_shoulder_lift.pos", "arm_elbow_flex.pos", "arm_wrist_flex.pos", "arm_wrist_roll.pos", "arm_gripper.pos", "x.vel", "y.vel", "theta.vel"]
54+
source: dora_lekiwi_client/observation_state
55+
actions:
56+
source: dora_run_policy/actions
57+
env:
58+
OPERATING_MODE: SPAWN
6159

6260
###########################################################
6361
# Example of a recorder node. The outputs are parquet files.

dora/node_hub/dora_run_policy/dora_run_policy/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pyarrow as pa
1414
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
1515
from lerobot.policies.act.modeling_act import ACTPolicy
16+
from lerobot.policies.factory import make_pre_post_processors
1617
from lerobot.utils.control_utils import predict_action
1718
from lerobot.utils.utils import get_safe_torch_device
1819

@@ -114,6 +115,11 @@ def main() -> None:
114115
policy.reset()
115116
device_name = policy.config.device or "auto"
116117
device = get_safe_torch_device(device_name)
118+
# Build Policy Processors
119+
preprocessor, postprocessor = make_pre_post_processors(
120+
policy_cfg=policy.config,
121+
pretrained_path=model_name,
122+
)
117123
except Exception as e:
118124
raise RuntimeError(f"Failed to load policy '{model_name}': {e}") from None
119125

@@ -171,9 +177,16 @@ def main() -> None:
171177
observation_frame,
172178
policy,
173179
device,
180+
preprocessor,
181+
postprocessor,
174182
policy.config.use_amp,
175183
)
176184

185+
# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
186+
# Remove batch dimension if present
187+
if raw_action.dim() > 1:
188+
raw_action = raw_action.squeeze(0)
189+
177190
# Send action output
178191
metadata = event["metadata"].copy()
179192
metadata["primitive"] = "series"

dora/node_hub/dora_run_policy/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires-python = ">=3.10"
99

1010
dependencies = [
1111
"dora-rs >= 0.3.13",
12-
"lerobot==0.3.3"
12+
"lerobot==0.4.0"
1313
]
1414

1515
[build-system]

packages/lekiwi_lerobot/lekiwi_lerobot/evaluate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from lerobot.utils.control_utils import (
1616
init_keyboard_listener,
1717
)
18-
from lerobot.utils.visualization_utils import _init_rerun
18+
from lerobot.utils.visualization_utils import init_rerun
1919

2020
FPS = 30
2121
EPISODE_TIME_SEC = 180
@@ -108,7 +108,7 @@ def main() -> None:
108108
robot.connect()
109109
keyboard.connect()
110110

111-
_init_rerun(session_name="lekiwi_evaluate")
111+
init_rerun(session_name="lekiwi_evaluate")
112112

113113
listener, events = init_keyboard_listener()
114114

@@ -139,6 +139,7 @@ def main() -> None:
139139
robot=robot,
140140
events=events,
141141
fps=FPS,
142+
dataset=None, # Don't record during reset phase
142143
keyboard_handler=keyboard,
143144
arm_keyboard_handler=arm_keyboard_handler,
144145
control_time_s=RESET_TIME_SEC,

packages/lekiwi_lerobot/lekiwi_lerobot/record.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
KeyboardTeleop,
1212
KeyboardTeleopConfig,
1313
)
14+
from lerobot.utils.constants import ACTION, OBS_STR
1415
from lerobot.utils.control_utils import (
1516
init_keyboard_listener,
1617
)
17-
from lerobot.utils.visualization_utils import _init_rerun
18+
from lerobot.utils.visualization_utils import init_rerun
1819

1920
FPS = 30
2021
EPISODE_TIME_SEC = 120
@@ -52,6 +53,12 @@ def main() -> None:
5253
default="Unnamed task",
5354
help="Task description to associate with each episode (default: 'Unnamed task').",
5455
)
56+
parser.add_argument(
57+
"--no-viz",
58+
action="store_false",
59+
dest="visualize",
60+
help="Disable Rerun visualization during recording.",
61+
)
5562

5663
args = parser.parse_args()
5764
if args.repo_id is None:
@@ -73,8 +80,8 @@ def main() -> None:
7380
keyboard = KeyboardTeleop(keyboard_config)
7481
arm_keyboard_handler = ArmTeleop()
7582
# Configure the dataset features
76-
action_features = hw_to_dataset_features(robot.action_features, "action")
77-
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
83+
action_features = hw_to_dataset_features(robot.action_features, ACTION)
84+
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
7885
logging.info(f"Recording the following observation features: {list(obs_features.keys())}")
7986
logging.info(f"Recording the following action features: {list(action_features.keys())}")
8087
dataset_features = {**action_features, **obs_features}
@@ -86,7 +93,7 @@ def main() -> None:
8693
features=dataset_features,
8794
robot_type=robot.name,
8895
use_videos=True,
89-
image_writer_threads=0,
96+
image_writer_threads=4,
9097
)
9198

9299
# To connect you already should have:
@@ -96,7 +103,11 @@ def main() -> None:
96103
robot.connect()
97104
keyboard.connect()
98105

99-
_init_rerun(session_name="lekiwi_record")
106+
if args.visualize:
107+
logging.info("Initializing Rerun for visualization.")
108+
init_rerun(session_name="lekiwi_record")
109+
else:
110+
logging.info("Rerun visualization is disabled.")
100111

101112
listener, events = init_keyboard_listener()
102113

@@ -117,21 +128,22 @@ def main() -> None:
117128
arm_keyboard_handler=arm_keyboard_handler,
118129
control_time_s=EPISODE_TIME_SEC,
119130
single_task=args.task,
120-
display_data=True,
131+
display_data=args.visualize,
121132
)
122133

123-
# Logic for reset env
134+
# Reset the environment if not stopping or re-recording
124135
if not events["stop_recording"] and ((recorded_episodes < args.episodes - 1) or events["rerecord_episode"]):
125136
logging.info("Reset the environment")
126137
record_loop(
127138
robot=robot,
128139
events=events,
129140
fps=FPS,
141+
dataset=None, # Don't record during reset phase
130142
keyboard_handler=keyboard,
131143
arm_keyboard_handler=arm_keyboard_handler,
132144
control_time_s=RESET_TIME_SEC,
133145
single_task=args.task,
134-
display_data=True,
146+
display_data=args.visualize,
135147
)
136148

137149
if events["rerecord_episode"]:
@@ -141,10 +153,12 @@ def main() -> None:
141153
dataset.clear_episode_buffer()
142154
continue
143155

156+
logging.info(f"Saving episode number {recorded_episodes} to dataset.")
144157
dataset.save_episode()
145158
recorded_episodes += 1
146159

147160
# Upload to hub and clean up
161+
dataset.finalize()
148162
dataset.push_to_hub()
149163

150164
robot.disconnect()

packages/lekiwi_lerobot/lekiwi_lerobot/replay.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lerobot.datasets.lerobot_dataset import LeRobotDataset
66
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
77
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
8+
from lerobot.utils.constants import ACTION
89
from lerobot.utils.robot_utils import busy_wait
910

1011

@@ -84,17 +85,22 @@ def main() -> None:
8485
logging.info(f"Downloading dataset from {args.repo_id} into {args.directory}")
8586
root = args.directory + "/" + args.repo_id.split("/")[-1]
8687
root = root.replace("//", "/")
87-
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[args.episode])
88+
episode_to_replay = args.episode
89+
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[episode_to_replay])
8890
logging.info(f"Dataset stored at {root}")
8991
actions = dataset.hf_dataset.select_columns("action")
92+
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
93+
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == episode_to_replay)
94+
actions = episode_frames.select_columns(ACTION)
9095

9196
robot.connect()
9297

9398
if not robot.is_connected:
9499
raise ValueError("Robot is not connected!")
95100

96-
logging.info(f"Replaying episode {args.episode} with {dataset.num_frames} frames.")
97-
for idx in range(dataset.num_frames):
101+
len_episodes_frames = len(episode_frames)
102+
logging.info(f"Replaying episode {args.episode} with {len_episodes_frames} frames.")
103+
for idx in range(len_episodes_frames):
98104
t0 = time.perf_counter()
99105

100106
action = {name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])}

packages/lekiwi_lerobot/lekiwi_lerobot/run_policy.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
66
from lerobot.policies.act.modeling_act import ACTPolicy
7+
from lerobot.policies.factory import make_pre_post_processors
78
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
89
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
10+
from lerobot.utils.constants import ACTION, OBS_STR
911
from lerobot.utils.control_utils import init_keyboard_listener, predict_action
1012
from lerobot.utils.robot_utils import busy_wait
1113
from lerobot.utils.utils import get_safe_torch_device
12-
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
14+
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
1315

1416
FPS = 30
1517

@@ -82,12 +84,16 @@ def main() -> None:
8284
logging.info("Robot is connected.")
8385

8486
# Prepare for policy inference
85-
action_features = hw_to_dataset_features(robot.action_features, "action")
86-
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
87+
action_features = hw_to_dataset_features(robot.action_features, ACTION)
88+
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
8789
dataset_features = {**action_features, **obs_features}
8890
device = get_safe_torch_device(policy.config.device)
89-
90-
_init_rerun(session_name="lekiwi_run_policy")
91+
# Build Policy Processors
92+
preprocessor, postprocessor = make_pre_post_processors(
93+
policy_cfg=policy.config,
94+
pretrained_path=args.policy,
95+
)
96+
init_rerun(session_name="lekiwi_run_policy")
9197

9298
listener, events = init_keyboard_listener()
9399

@@ -97,16 +103,23 @@ def main() -> None:
97103

98104
observation = robot.get_observation()
99105

100-
observation_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
106+
observation_frame = build_dataset_frame(dataset_features, observation, prefix=OBS_STR)
101107

102108
action_values = predict_action(
103109
observation_frame,
104110
policy,
105111
device,
112+
preprocessor,
113+
postprocessor,
106114
policy.config.use_amp,
107115
task=args.task,
108116
robot_type=robot.robot_type,
109117
)
118+
# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
119+
# Remove batch dimension if present
120+
if action_values.dim() > 1:
121+
action_values = action_values.squeeze(0)
122+
110123
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
111124

112125
logging.debug(f"Predicted action: {action}")

packages/lekiwi_lerobot/lekiwi_lerobot/utils.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from lerobot.datasets.lerobot_dataset import LeRobotDataset
88
from lerobot.datasets.utils import build_dataset_frame
99
from lerobot.policies.pretrained import PreTrainedPolicy
10+
from lerobot.processor import (
11+
PolicyAction,
12+
PolicyProcessorPipeline,
13+
)
1014
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
1115
from lerobot.teleoperators.keyboard import (
1216
KeyboardTeleop,
@@ -30,6 +34,8 @@ def record_loop(
3034
keyboard_handler: KeyboardTeleop | None = None,
3135
arm_keyboard_handler: ArmTeleop | None = None,
3236
policy: PreTrainedPolicy | None = None,
37+
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
38+
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
3339
control_time_s: int | None = None,
3440
single_task: str | None = None,
3541
display_data: bool = False,
@@ -52,6 +58,11 @@ def record_loop(
5258
if control_time_s is None:
5359
raise ValueError("A control time must be provided.")
5460

61+
if display_data:
62+
logging.info("Visualizing data with Rerun.")
63+
else:
64+
logging.info("Not visualizing data.")
65+
5566
# if policy is given it needs cleaning up
5667
if policy is not None:
5768
policy.reset()
@@ -72,15 +83,19 @@ def record_loop(
7283
raise ValueError("Dataset features must be defined if using a dataset or a policy.")
7384
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") # type: ignore[union-attr]
7485

75-
if policy is not None:
86+
if policy is not None and preprocessor is not None and postprocessor is not None:
7687
action_values = predict_action(
77-
observation_frame,
78-
policy,
79-
get_safe_torch_device(policy.config.device),
80-
policy.config.use_amp,
88+
observation=observation_frame,
89+
policy=policy,
90+
device=get_safe_torch_device(policy.config.device),
91+
preprocessor=preprocessor,
92+
postprocessor=postprocessor,
93+
use_amp=policy.config.use_amp,
8194
task=single_task,
8295
robot_type=robot.robot_type,
8396
)
97+
if action_values.dim() > 1:
98+
action_values = action_values.squeeze(0)
8499
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
85100
print(f"Predicted action: {action}")
86101
elif policy is None and keyboard_handler is not None and arm_keyboard_handler is not None:
@@ -89,6 +104,8 @@ def record_loop(
89104
arm_action = arm_keyboard_handler.from_keyboard_to_arm_action(pressed_keys)
90105

91106
action = {**base_action, **arm_action} # Merge base and arm actions
107+
# TODO(francocipollone): We would probably want to use the teleop_action_processor here.
108+
# action = teleop_action_processor((action, observation))
92109
logging.debug("Sending action: %s", action)
93110

94111
else:
@@ -99,14 +116,13 @@ def record_loop(
99116
)
100117
continue
101118

102-
# Action can eventually be clipped using `max_relative_target`,
103-
# so action actually sent is saved in the dataset.
119+
# TODO(francocipollone): We would probably want to use the robot_action_processor here before sending the action
104120
sent_action = robot.send_action(action)
105121

106122
if dataset is not None:
107123
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
108-
frame = {**observation_frame, **action_frame}
109-
dataset.add_frame(frame, task=single_task)
124+
frame = {**observation_frame, **action_frame, "task": single_task}
125+
dataset.add_frame(frame)
110126

111127
if display_data:
112128
log_rerun_data(observation, action)

packages/lekiwi_lerobot/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77
requires-python = ">=3.11"
88
dependencies = [
99
"lekiwi_teleoperate>=0.1.0",
10-
"lerobot==0.3.3",
10+
"lerobot==0.4.0",
1111
"numpy>=1.24.0",
1212
]
1313

0 commit comments

Comments
 (0)