Skip to content

Commit 5685d43

Browse files
authored
Add observation processing utilities for robust training
- Add ObservationNormalizer class with running mean/variance - Add ObservationHistory class for temporal context (3-step history) - Add add_noise_to_observations() for domain randomization - Improves training stability and sim-to-real transfer Signed-off-by: Swamy Gadila <122666091+swamy18@users.noreply.github.com>
1 parent 7894315 commit 5685d43

File tree

1 file changed

+70
-0
lines changed
  • source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/lift/mdp

1 file changed

+70
-0
lines changed

source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/lift/mdp/observations.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,73 @@ def object_position_in_robot_root_frame(
2727
object_pos_w = object.data.root_pos_w[:, :3]
2828
object_pos_b, _ = subtract_frame_transforms(robot.data.root_pos_w, robot.data.root_quat_w, object_pos_w)
2929
return object_pos_b
30+
31+
32+
33+
34+
class ObservationNormalizer:
35+
"""NEW: Normalizes observations for stable training."""
36+
37+
def __init__(self, obs_dim: int, num_envs: int, clip_range: float = 10.0, device: str = "cuda"):
38+
self.obs_mean = torch.zeros(obs_dim, device=device)
39+
self.obs_var = torch.ones(obs_dim, device=device)
40+
self.count = 0
41+
self.clip_range = clip_range
42+
self.device = device
43+
44+
def normalize(self, obs: torch.Tensor, update_stats: bool = True) -> torch.Tensor:
45+
"""Normalize observations using running mean and variance."""
46+
if update_stats and self.count < 10000: # Update stats for first 10k steps
47+
batch_mean = obs.mean(dim=0)
48+
batch_var = obs.var(dim=0)
49+
50+
# Update running statistics
51+
self.count += obs.shape[0]
52+
delta = batch_mean - self.obs_mean
53+
self.obs_mean += delta * obs.shape[0] / self.count
54+
self.obs_var = (self.obs_var * (self.count - obs.shape[0]) +
55+
batch_var * obs.shape[0]) / self.count
56+
57+
# Normalize and clip
58+
normalized = (obs - self.obs_mean) / (torch.sqrt(self.obs_var) + 1e-8)
59+
return torch.clamp(normalized, -self.clip_range, self.clip_range)
60+
61+
62+
class ObservationHistory:
63+
"""NEW: Maintains history of observations for temporal context."""
64+
65+
def __init__(self, obs_dim: int, num_envs: int, history_length: int = 3, device: str = "cuda"):
66+
self.history_length = history_length
67+
self.history = torch.zeros((num_envs, history_length, obs_dim), device=device)
68+
self.device = device
69+
70+
def add(self, obs: torch.Tensor):
71+
"""Add new observation and shift history."""
72+
self.history = torch.roll(self.history, shifts=1, dims=1)
73+
self.history[:, 0] = obs
74+
75+
def get_flat(self) -> torch.Tensor:
76+
"""Get flattened history [num_envs, history_length * obs_dim]."""
77+
return self.history.reshape(self.history.shape[0], -1)
78+
79+
def reset(self, env_ids: torch.Tensor = None):
80+
"""Reset history for specific environments."""
81+
if env_ids is None:
82+
self.history.zero_()
83+
else:
84+
self.history[env_ids] = 0.0
85+
86+
87+
def add_noise_to_observations(
88+
env: ManagerBasedRLEnv,
89+
obs: torch.Tensor,
90+
noise_std: float = 0.01,
91+
) -> torch.Tensor:
92+
"""NEW: Add domain randomization noise to observations.
93+
94+
Helps with sim-to-real transfer.
95+
"""
96+
if env.training:
97+
noise = torch.randn_like(obs) * noise_std
98+
return obs + noise
99+
return obs

0 commit comments

Comments
 (0)