@@ -31,7 +31,7 @@ def object_ee_distance(
3131 object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
3232 ee_frame_cfg : SceneEntityCfg = SceneEntityCfg ("ee_frame" ),
3333) -> torch .Tensor :
34- """Reward the agent for reaching the object using tanh-kernel."""
34+ """Reward the agent for reaching the object using tanh-kernel with improvements ."""
3535 # extract the used quantities (to enable type-hinting)
3636 object : RigidObject = env .scene [object_cfg .name ]
3737 ee_frame : FrameTransformer = env .scene [ee_frame_cfg .name ]
@@ -41,8 +41,16 @@ def object_ee_distance(
4141 ee_w = ee_frame .data .target_pos_w [..., 0 , :]
4242 # Distance of the end-effector to the object: (num_envs,)
4343 object_ee_distance = torch .norm (cube_pos_w - ee_w , dim = 1 )
44-
45- return 1 - torch .tanh (object_ee_distance / std )
44+
45+ # IMPROVEMENT 1: Add adaptive scaling based on episode progress
46+ episode_progress = env .episode_length_buf .float () / env .max_episode_length
47+ std_adaptive = std * (1.0 + 0.1 * episode_progress )
48+
49+ # IMPROVEMENT 2: Calculate reward with clipping to prevent extreme values
50+ reward = 1 - torch .tanh (object_ee_distance / std_adaptive )
51+ reward = torch .clamp (reward , 0.0 , 1.0 )
52+
53+ return reward
4654
4755
4856def object_goal_distance (
@@ -53,7 +61,7 @@ def object_goal_distance(
5361 robot_cfg : SceneEntityCfg = SceneEntityCfg ("robot" ),
5462 object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
5563) -> torch .Tensor :
56- """Reward the agent for tracking the goal pose using tanh-kernel."""
64+ """Reward the agent for tracking the goal pose using tanh-kernel with improvements ."""
5765 # extract the used quantities (to enable type-hinting)
5866 robot : RigidObject = env .scene [robot_cfg .name ]
5967 object : RigidObject = env .scene [object_cfg .name ]
@@ -63,5 +71,57 @@ def object_goal_distance(
6371 des_pos_w , _ = combine_frame_transforms (robot .data .root_pos_w , robot .data .root_quat_w , des_pos_b )
6472 # distance of the end-effector to the object: (num_envs,)
6573 distance = torch .norm (des_pos_w - object .data .root_pos_w , dim = 1 )
66- # rewarded if the object is lifted above the threshold
67- return (object .data .root_pos_w [:, 2 ] > minimal_height ) * (1 - torch .tanh (distance / std ))
74+
75+ # IMPROVEMENT 1: Check if object is lifted
76+ is_lifted = object .data .root_pos_w [:, 2 ] > minimal_height
77+
78+ # IMPROVEMENT 2: Add velocity stability bonus
79+ velocity = torch .norm (object .data .root_lin_vel_w , dim = 1 )
80+ velocity_bonus = torch .exp (- 2.0 * velocity ) # Reward stability
81+
82+ # IMPROVEMENT 3: Combined reward with clipping
83+ distance_reward = 1 - torch .tanh (distance / std )
84+ combined_reward = is_lifted .float () * distance_reward * velocity_bonus
85+ combined_reward = torch .clamp (combined_reward , 0.0 , 1.0 )
86+
87+ return combined_reward
88+
89+
90+ def action_smoothness_penalty (
91+ env : ManagerBasedRLEnv ,
92+ penalty_scale : float = 0.01 ,
93+ ) -> torch .Tensor :
94+ """NEW: Penalize large action changes to encourage smooth movements."""
95+ if not hasattr (env , '_prev_actions' ):
96+ env ._prev_actions = torch .zeros_like (env .action_manager .action )
97+ return torch .zeros (env .num_envs , device = env .device )
98+
99+ action_diff = torch .norm (env .action_manager .action - env ._prev_actions , dim = 1 )
100+ env ._prev_actions = env .action_manager .action .clone ()
101+
102+ penalty = - penalty_scale * action_diff
103+ return penalty
104+
105+
106+ def grasp_success_bonus (
107+ env : ManagerBasedRLEnv ,
108+ bonus_value : float = 2.0 ,
109+ object_cfg : SceneEntityCfg = SceneEntityCfg ("object" ),
110+ ee_frame_cfg : SceneEntityCfg = SceneEntityCfg ("ee_frame" ),
111+ ) -> torch .Tensor :
112+ """NEW: Provide large bonus when object is successfully grasped and stable."""
113+ object : RigidObject = env .scene [object_cfg .name ]
114+ ee_frame : FrameTransformer = env .scene [ee_frame_cfg .name ]
115+
116+ # Check if object is close to gripper
117+ cube_pos_w = object .data .root_pos_w
118+ ee_w = ee_frame .data .target_pos_w [..., 0 , :]
119+ distance = torch .norm (cube_pos_w - ee_w , dim = 1 )
120+
121+ # Check if object velocity is low (stable grasp)
122+ velocity = torch .norm (object .data .root_lin_vel_w , dim = 1 )
123+
124+ # Grasp is successful if distance < 0.05m and velocity < 0.1 m/s
125+ successful_grasp = (distance < 0.05 ) & (velocity < 0.1 )
126+
127+ return torch .where (successful_grasp , torch .tensor (bonus_value , device = env .device ), torch .tensor (0.0 , device = env .device ))
0 commit comments