@@ -37,7 +37,6 @@ def object_reached_goal(
3737 threshold: The threshold for the object to reach the goal position. Defaults to 0.02.
3838 robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
3939 object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
40-
4140 """
4241 # extract the used quantities (to enable type-hinting)
4342 robot : RigidObject = env .scene [robot_cfg .name ]
@@ -48,6 +47,91 @@ def object_reached_goal(
4847 des_pos_w , _ = combine_frame_transforms (robot .data .root_pos_w , robot .data .root_quat_w , des_pos_b )
4948 # distance of the end-effector to the object: (num_envs,)
5049 distance = torch .norm (des_pos_w - object .data .root_pos_w [:, :3 ], dim = 1 )
51-
5250 # rewarded if the object is lifted above the threshold
5351 return distance < threshold
52+
53+
54+ def object_reached_goal_with_stability (
55+ env : ManagerBasedRLEnv ,
56+ command_name : str = "object_pose" ,
57+ position_threshold : float = 0.02 ,
58+ velocity_threshold : float = 0.1 ,
59+ robot_cfg : SceneEntityCfg = SceneEntityCfg ("robot" ),
60+ object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
61+ ) -> torch .Tensor :
62+ """NEW: Termination condition with velocity stability check.
63+
64+ This ensures the object has not only reached the goal but is also stable (low velocity).
65+
66+ Args:
67+ env: The environment.
68+ command_name: The name of the command that is used to control the object.
69+ position_threshold: The position threshold for goal reach. Defaults to 0.02.
70+ velocity_threshold: The velocity threshold for stability. Defaults to 0.1 m/s.
71+ robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
72+ object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
73+ """
74+ # extract the used quantities
75+ robot : RigidObject = env .scene [robot_cfg .name ]
76+ object : RigidObject = env .scene [object_cfg .name ]
77+ command = env .command_manager .get_command (command_name )
78+
79+ # compute the desired position in the world frame
80+ des_pos_b = command [:, :3 ]
81+ des_pos_w , _ = combine_frame_transforms (robot .data .root_pos_w , robot .data .root_quat_w , des_pos_b )
82+
83+ # IMPROVEMENT 1: Check position distance
84+ distance = torch .norm (des_pos_w - object .data .root_pos_w [:, :3 ], dim = 1 )
85+ position_reached = distance < position_threshold
86+
87+ # IMPROVEMENT 2: Check velocity for stability
88+ velocity = torch .norm (object .data .root_lin_vel_w , dim = 1 )
89+ is_stable = velocity < velocity_threshold
90+
91+ # Both conditions must be met
92+ return position_reached & is_stable
93+
94+
95+ def object_dropped (
96+ env : ManagerBasedRLEnv ,
97+ minimal_height : float = 0.05 ,
98+ object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
99+ ) -> torch .Tensor :
100+ """NEW: Termination if object is dropped (falls below minimal height).
101+
102+ This helps identify failed grasps early.
103+
104+ Args:
105+ env: The environment.
106+ minimal_height: The minimal height threshold. Defaults to 0.05.
107+ object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
108+ """
109+ object : RigidObject = env .scene [object_cfg .name ]
110+ return object .data .root_pos_w [:, 2 ] < minimal_height
111+
112+
113+ def object_out_of_bounds (
114+ env : ManagerBasedRLEnv ,
115+ max_distance : float = 2.0 ,
116+ robot_cfg : SceneEntityCfg = SceneEntityCfg ("robot" ),
117+ object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
118+ ) -> torch .Tensor :
119+ """NEW: Termination if object moves too far from robot.
120+
121+ Prevents unproductive exploration.
122+
123+ Args:
124+ env: The environment.
125+ max_distance: Maximum allowed distance from robot. Defaults to 2.0 meters.
126+ robot_cfg: The robot configuration. Defaults to SceneEntityCfg("robot").
127+ object_cfg: The object configuration. Defaults to SceneEntityCfg("object").
128+ """
129+ robot : RigidObject = env .scene [robot_cfg .name ]
130+ object : RigidObject = env .scene [object_cfg .name ]
131+
132+ # Calculate horizontal distance (ignore z-axis)
133+ robot_pos_xy = robot .data .root_pos_w [:, :2 ]
134+ object_pos_xy = object .data .root_pos_w [:, :2 ]
135+ distance = torch .norm (object_pos_xy - robot_pos_xy , dim = 1 )
136+
137+ return distance > max_distance
0 commit comments