Skip to content

Commit d5f1773

Browse files
authored
Add improved termination conditions with stability and safety checks
- Add object_reached_goal_with_stability() function with velocity check - Add object_dropped() function to detect failed grasps early - Add object_out_of_bounds() function to prevent unproductive exploration - Keeps original object_reached_goal() for backward compatibility Signed-off-by: Swamy Gadila <122666091+swamy18@users.noreply.github.com>
1 parent 76995d8 commit d5f1773

File tree

1 file changed

+86
-2
lines changed
  • source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/lift/mdp

1 file changed

+86
-2
lines changed

source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/lift/mdp/terminations.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def object_reached_goal(
3737
threshold: The threshold for the object to reach the goal position. Defaults to 0.02.
3838
robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
3939
object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
40-
4140
"""
4241
# extract the used quantities (to enable type-hinting)
4342
robot: RigidObject = env.scene[robot_cfg.name]
@@ -48,6 +47,91 @@ def object_reached_goal(
4847
des_pos_w, _ = combine_frame_transforms(robot.data.root_pos_w, robot.data.root_quat_w, des_pos_b)
4948
# distance of the end-effector to the object: (num_envs,)
5049
distance = torch.norm(des_pos_w - object.data.root_pos_w[:, :3], dim=1)
51-
5250
# rewarded if the object is lifted above the threshold
5351
return distance < threshold
52+
53+
54+
def object_reached_goal_with_stability(
55+
env: ManagerBasedRLEnv,
56+
command_name: str = "object_pose",
57+
position_threshold: float = 0.02,
58+
velocity_threshold: float = 0.1,
59+
robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
60+
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
61+
) -> torch.Tensor:
62+
"""NEW: Termination condition with velocity stability check.
63+
64+
This ensures the object has not only reached the goal but is also stable (low velocity).
65+
66+
Args:
67+
env: The environment.
68+
command_name: The name of the command that is used to control the object.
69+
position_threshold: The position threshold for goal reach. Defaults to 0.02.
70+
velocity_threshold: The velocity threshold for stability. Defaults to 0.1 m/s.
71+
robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
72+
object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
73+
"""
74+
# extract the used quantities
75+
robot: RigidObject = env.scene[robot_cfg.name]
76+
object: RigidObject = env.scene[object_cfg.name]
77+
command = env.command_manager.get_command(command_name)
78+
79+
# compute the desired position in the world frame
80+
des_pos_b = command[:, :3]
81+
des_pos_w, _ = combine_frame_transforms(robot.data.root_pos_w, robot.data.root_quat_w, des_pos_b)
82+
83+
# IMPROVEMENT 1: Check position distance
84+
distance = torch.norm(des_pos_w - object.data.root_pos_w[:, :3], dim=1)
85+
position_reached = distance < position_threshold
86+
87+
# IMPROVEMENT 2: Check velocity for stability
88+
velocity = torch.norm(object.data.root_lin_vel_w, dim=1)
89+
is_stable = velocity < velocity_threshold
90+
91+
# Both conditions must be met
92+
return position_reached & is_stable
93+
94+
95+
def object_dropped(
96+
env: ManagerBasedRLEnv,
97+
minimal_height: float = 0.05,
98+
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
99+
) -> torch.Tensor:
100+
"""NEW: Termination if object is dropped (falls below minimal height).
101+
102+
This helps identify failed grasps early.
103+
104+
Args:
105+
env: The environment.
106+
minimal_height: The minimal height threshold. Defaults to 0.05.
107+
object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
108+
"""
109+
object: RigidObject = env.scene[object_cfg.name]
110+
return object.data.root_pos_w[:, 2] < minimal_height
111+
112+
113+
def object_out_of_bounds(
114+
env: ManagerBasedRLEnv,
115+
max_distance: float = 2.0,
116+
robot_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
117+
object_cfg: SceneEntityCfg = SceneEntityCfg("object"),
118+
) -> torch.Tensor:
119+
"""NEW: Termination if object moves too far from robot.
120+
121+
Prevents unproductive exploration.
122+
123+
Args:
124+
env: The environment.
125+
max_distance: Maximum allowed distance from robot. Defaults to 2.0 meters.
126+
robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
127+
object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
128+
"""
129+
robot: RigidObject = env.scene[robot_cfg.name]
130+
object: RigidObject = env.scene[object_cfg.name]
131+
132+
# Calculate horizontal distance (ignore z-axis)
133+
robot_pos_xy = robot.data.root_pos_w[:, :2]
134+
object_pos_xy = object.data.root_pos_w[:, :2]
135+
distance = torch.norm(object_pos_xy - robot_pos_xy, dim=1)
136+
137+
return distance > max_distance

0 commit comments

Comments
 (0)