Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions dora/lekiwi_sim/graphs/mujoco_sim.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,21 @@ nodes:
############################################################
# Example of a rerun visualization node
############################################################
# TODO: Enable once lerobot is updated to 0.4.0 due to rerun version conflict.
# - id: rerun-viz
# # TODO: Use tag once there is a new release
# build: cargo install --git https://github.com/dora-rs/dora-hub dora-rerun
# path: dora-rerun
# inputs:
# front_camera:
# source: dora_lekiwi_client/image_front
# wrist_camera:
# source: dora_lekiwi_client/image_wrist
# observation_state:
# # ["arm_shoulder_pan.pos", "arm_shoulder_lift.pos", "arm_elbow_flex.pos", "arm_wrist_flex.pos", "arm_wrist_roll.pos", "arm_gripper.pos", "x.vel", "y.vel", "theta.vel"]
# source: dora_lekiwi_client/observation_state
# actions:
# source: dora_run_policy/actions
# env:
# OPERATING_MODE: SPAWN
- id: rerun-viz
build: cargo install --git https://github.com/dora-rs/dora-hub --tag 0.3.13 dora-rerun
path: dora-rerun
inputs:
front_camera:
source: dora_lekiwi_client/image_front
wrist_camera:
source: dora_lekiwi_client/image_wrist
observation_state:
# ["arm_shoulder_pan.pos", "arm_shoulder_lift.pos", "arm_elbow_flex.pos", "arm_wrist_flex.pos", "arm_wrist_roll.pos", "arm_gripper.pos", "x.vel", "y.vel", "theta.vel"]
source: dora_lekiwi_client/observation_state
actions:
source: dora_run_policy/actions
env:
OPERATING_MODE: SPAWN

###########################################################
# Example of a recorder node. The outputs are parquet files.
Expand Down
13 changes: 13 additions & 0 deletions dora/node_hub/dora_run_policy/dora_run_policy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyarrow as pa
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.control_utils import predict_action
from lerobot.utils.utils import get_safe_torch_device

Expand Down Expand Up @@ -114,6 +115,11 @@ def main() -> None:
policy.reset()
device_name = policy.config.device or "auto"
device = get_safe_torch_device(device_name)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=model_name,
)
except Exception as e:
raise RuntimeError(f"Failed to load policy '{model_name}': {e}") from None

Expand Down Expand Up @@ -171,9 +177,16 @@ def main() -> None:
observation_frame,
policy,
device,
preprocessor,
postprocessor,
policy.config.use_amp,
)

# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
# Remove batch dimension if present
if raw_action.dim() > 1:
raw_action = raw_action.squeeze(0)

# Send action output
metadata = event["metadata"].copy()
metadata["primitive"] = "series"
Expand Down
2 changes: 1 addition & 1 deletion dora/node_hub/dora_run_policy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ requires-python = ">=3.10"

dependencies = [
"dora-rs >= 0.3.13",
"lerobot==0.3.3"
"lerobot==0.4.0"
]

[build-system]
Expand Down
5 changes: 3 additions & 2 deletions packages/lekiwi_lerobot/lekiwi_lerobot/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lerobot.utils.control_utils import (
init_keyboard_listener,
)
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun

FPS = 30
EPISODE_TIME_SEC = 180
Expand Down Expand Up @@ -108,7 +108,7 @@ def main() -> None:
robot.connect()
keyboard.connect()

_init_rerun(session_name="lekiwi_evaluate")
init_rerun(session_name="lekiwi_evaluate")

listener, events = init_keyboard_listener()

Expand Down Expand Up @@ -139,6 +139,7 @@ def main() -> None:
robot=robot,
events=events,
fps=FPS,
dataset=None, # Don't record during reset phase
keyboard_handler=keyboard,
arm_keyboard_handler=arm_keyboard_handler,
control_time_s=RESET_TIME_SEC,
Expand Down
30 changes: 22 additions & 8 deletions packages/lekiwi_lerobot/lekiwi_lerobot/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
KeyboardTeleop,
KeyboardTeleopConfig,
)
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
)
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun

FPS = 30
EPISODE_TIME_SEC = 120
Expand Down Expand Up @@ -52,6 +53,12 @@ def main() -> None:
default="Unnamed task",
help="Task description to associate with each episode (default: 'Unnamed task').",
)
parser.add_argument(
"--no-viz",
action="store_false",
dest="visualize",
help="Disable Rerun visualization during recording.",
)

args = parser.parse_args()
if args.repo_id is None:
Expand All @@ -73,8 +80,8 @@ def main() -> None:
keyboard = KeyboardTeleop(keyboard_config)
arm_keyboard_handler = ArmTeleop()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
logging.info(f"Recording the following observation features: {list(obs_features.keys())}")
logging.info(f"Recording the following action features: {list(action_features.keys())}")
dataset_features = {**action_features, **obs_features}
Expand All @@ -86,7 +93,7 @@ def main() -> None:
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=0,
image_writer_threads=4,
)

# To connect you already should have:
Expand All @@ -96,7 +103,11 @@ def main() -> None:
robot.connect()
keyboard.connect()

_init_rerun(session_name="lekiwi_record")
if args.visualize:
logging.info("Initializing Rerun for visualization.")
init_rerun(session_name="lekiwi_record")
else:
logging.info("Rerun visualization is disabled.")

listener, events = init_keyboard_listener()

Expand All @@ -117,21 +128,22 @@ def main() -> None:
arm_keyboard_handler=arm_keyboard_handler,
control_time_s=EPISODE_TIME_SEC,
single_task=args.task,
display_data=True,
display_data=args.visualize,
)

# Logic for reset env
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((recorded_episodes < args.episodes - 1) or events["rerecord_episode"]):
logging.info("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=None, # Don't record during reset phase
keyboard_handler=keyboard,
arm_keyboard_handler=arm_keyboard_handler,
control_time_s=RESET_TIME_SEC,
single_task=args.task,
display_data=True,
display_data=args.visualize,
)

if events["rerecord_episode"]:
Expand All @@ -141,10 +153,12 @@ def main() -> None:
dataset.clear_episode_buffer()
continue

logging.info(f"Saving episode number {recorded_episodes} to dataset.")
dataset.save_episode()
recorded_episodes += 1

# Upload to hub and clean up
dataset.finalize()
dataset.push_to_hub()

robot.disconnect()
Expand Down
12 changes: 9 additions & 3 deletions packages/lekiwi_lerobot/lekiwi_lerobot/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait


Expand Down Expand Up @@ -84,17 +85,22 @@ def main() -> None:
logging.info(f"Downloading dataset from {args.repo_id} into {args.directory}")
root = args.directory + "/" + args.repo_id.split("/")[-1]
root = root.replace("//", "/")
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[args.episode])
episode_to_replay = args.episode
dataset = LeRobotDataset(args.repo_id, root=root, episodes=[episode_to_replay])
logging.info(f"Dataset stored at {root}")
actions = dataset.hf_dataset.select_columns("action")
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == episode_to_replay)
actions = episode_frames.select_columns(ACTION)

robot.connect()

if not robot.is_connected:
raise ValueError("Robot is not connected!")

logging.info(f"Replaying episode {args.episode} with {dataset.num_frames} frames.")
for idx in range(dataset.num_frames):
len_episodes_frames = len(episode_frames)
logging.info(f"Replaying episode {args.episode} with {len_episodes_frames} frames.")
for idx in range(len_episodes_frames):
t0 = time.perf_counter()

action = {name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])}
Expand Down
25 changes: 19 additions & 6 deletions packages/lekiwi_lerobot/lekiwi_lerobot/run_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener, predict_action
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import get_safe_torch_device
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data

FPS = 30

Expand Down Expand Up @@ -82,12 +84,16 @@ def main() -> None:
logging.info("Robot is connected.")

# Prepare for policy inference
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
device = get_safe_torch_device(policy.config.device)

_init_rerun(session_name="lekiwi_run_policy")
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=args.policy,
)
init_rerun(session_name="lekiwi_run_policy")

listener, events = init_keyboard_listener()

Expand All @@ -97,16 +103,23 @@ def main() -> None:

observation = robot.get_observation()

observation_frame = build_dataset_frame(dataset_features, observation, prefix="observation")
observation_frame = build_dataset_frame(dataset_features, observation, prefix=OBS_STR)

action_values = predict_action(
observation_frame,
policy,
device,
preprocessor,
postprocessor,
policy.config.use_amp,
task=args.task,
robot_type=robot.robot_type,
)
# As of LeRobot 0.4.0, the postprocessor returns a tensor that might have batch dimension
# Remove batch dimension if present
if action_values.dim() > 1:
action_values = action_values.squeeze(0)

action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}

logging.debug(f"Predicted action: {action}")
Expand Down
34 changes: 25 additions & 9 deletions packages/lekiwi_lerobot/lekiwi_lerobot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
)
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.teleoperators.keyboard import (
KeyboardTeleop,
Expand All @@ -30,6 +34,8 @@ def record_loop(
keyboard_handler: KeyboardTeleop | None = None,
arm_keyboard_handler: ArmTeleop | None = None,
policy: PreTrainedPolicy | None = None,
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
Expand All @@ -52,6 +58,11 @@ def record_loop(
if control_time_s is None:
raise ValueError("A control time must be provided.")

if display_data:
logging.info("Visualizing data with Rerun.")
else:
logging.info("Not visualizing data.")

# if policy is given it needs cleaning up
if policy is not None:
policy.reset()
Expand All @@ -72,15 +83,19 @@ def record_loop(
raise ValueError("Dataset features must be defined if using a dataset or a policy.")
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") # type: ignore[union-attr]

if policy is not None:
if policy is not None and preprocessor is not None and postprocessor is not None:
action_values = predict_action(
observation_frame,
policy,
get_safe_torch_device(policy.config.device),
policy.config.use_amp,
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
)
if action_values.dim() > 1:
action_values = action_values.squeeze(0)
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
print(f"Predicted action: {action}")
elif policy is None and keyboard_handler is not None and arm_keyboard_handler is not None:
Expand All @@ -89,6 +104,8 @@ def record_loop(
arm_action = arm_keyboard_handler.from_keyboard_to_arm_action(pressed_keys)

action = {**base_action, **arm_action} # Merge base and arm actions
# TODO(francocipollone): We would probably want to use the teleop_action_processor here.
# action = teleop_action_processor((action, observation))
logging.debug("Sending action: %s", action)

else:
Expand All @@ -99,14 +116,13 @@ def record_loop(
)
continue

# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
# TODO(francocipollone): We would probably want to use the robot_action_processor here before sending the action
sent_action = robot.send_action(action)

if dataset is not None:
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)

if display_data:
log_rerun_data(observation, action)
Expand Down
2 changes: 1 addition & 1 deletion packages/lekiwi_lerobot/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"lekiwi_teleoperate>=0.1.0",
"lerobot==0.3.3",
"lerobot==0.4.0",
"numpy>=1.24.0",
]

Expand Down
Loading