@@ -27,3 +27,73 @@ def object_position_in_robot_root_frame(
2727 object_pos_w = object .data .root_pos_w [:, :3 ]
2828 object_pos_b , _ = subtract_frame_transforms (robot .data .root_pos_w , robot .data .root_quat_w , object_pos_w )
2929 return object_pos_b
30+
31+
32+
33+
34+ class ObservationNormalizer :
35+ """NEW: Normalizes observations for stable training."""
36+
37+ def __init__ (self , obs_dim : int , num_envs : int , clip_range : float = 10.0 , device : str = "cuda" ):
38+ self .obs_mean = torch .zeros (obs_dim , device = device )
39+ self .obs_var = torch .ones (obs_dim , device = device )
40+ self .count = 0
41+ self .clip_range = clip_range
42+ self .device = device
43+
44+ def normalize (self , obs : torch .Tensor , update_stats : bool = True ) -> torch .Tensor :
45+ """Normalize observations using running mean and variance."""
46+ if update_stats and self .count < 10000 : # Update stats for first 10k steps
47+ batch_mean = obs .mean (dim = 0 )
48+ batch_var = obs .var (dim = 0 )
49+
50+ # Update running statistics
51+ self .count += obs .shape [0 ]
52+ delta = batch_mean - self .obs_mean
53+ self .obs_mean += delta * obs .shape [0 ] / self .count
54+ self .obs_var = (self .obs_var * (self .count - obs .shape [0 ]) +
55+ batch_var * obs .shape [0 ]) / self .count
56+
57+ # Normalize and clip
58+ normalized = (obs - self .obs_mean ) / (torch .sqrt (self .obs_var ) + 1e-8 )
59+ return torch .clamp (normalized , - self .clip_range , self .clip_range )
60+
61+
62+ class ObservationHistory :
63+ """NEW: Maintains history of observations for temporal context."""
64+
65+ def __init__ (self , obs_dim : int , num_envs : int , history_length : int = 3 , device : str = "cuda" ):
66+ self .history_length = history_length
67+ self .history = torch .zeros ((num_envs , history_length , obs_dim ), device = device )
68+ self .device = device
69+
70+ def add (self , obs : torch .Tensor ):
71+ """Add new observation and shift history."""
72+ self .history = torch .roll (self .history , shifts = 1 , dims = 1 )
73+ self .history [:, 0 ] = obs
74+
75+ def get_flat (self ) -> torch .Tensor :
76+ """Get flattened history [num_envs, history_length * obs_dim]."""
77+ return self .history .reshape (self .history .shape [0 ], - 1 )
78+
79+ def reset (self , env_ids : torch .Tensor = None ):
80+ """Reset history for specific environments."""
81+ if env_ids is None :
82+ self .history .zero_ ()
83+ else :
84+ self .history [env_ids ] = 0.0
85+
86+
87+ def add_noise_to_observations (
88+ env : ManagerBasedRLEnv ,
89+ obs : torch .Tensor ,
90+ noise_std : float = 0.01 ,
91+ ) -> torch .Tensor :
92+ """NEW: Add domain randomization noise to observations.
93+
94+ Helps with sim-to-real transfer.
95+ """
96+ if env .training :
97+ noise = torch .randn_like (obs ) * noise_std
98+ return obs + noise
99+ return obs
0 commit comments