-
Notifications
You must be signed in to change notification settings - Fork 43
Description
Description
Training Mjlab-Velocity-Rough-Unitree-Go2 crashes with RuntimeError: normal expects all elements of std >= 0.0 within the first few iterations. The issue occurs both when training from scratch and when fine-tuning from a Flat model checkpoint. The Mjlab-Velocity-Flat-Unitree-Go2 task trains successfully with the same setup.
Environment
- 10× NVIDIA RTX A4000 (multi-GPU via torchrunx)
- 56 CPU cores
--env.scene.num-envs=8192- Ubuntu, Python 3.11, PyTorch with CUDA
Steps to Reproduce
python scripts/train.py Mjlab-Velocity-Rough-Unitree-Go2 \
--gpu-ids 0 1 2 3 4 5 6 7 8 9 \
--env.scene.num-envs=8192Error
RuntimeError: normal expects all elements of std >= 0.0
File "mjlab/rsl_rl/modules/actor_critic.py", line 146, in act
return self.distribution.sample()
File "mjlab/rsl_rl/algorithms/ppo.py", line 249, in update
self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
Crashes at iteration 1-2 when training from scratch. Metrics before crash:
Mean reward: -5.07
Mean episode length: 20.55
Episode_Termination/fell_over: 26.75
Episode_Termination/illegal_contact: 375.67
Curriculum/terrain_levels: 2.13 → 0.00
Root Cause Analysis
Four issues contribute to the crash:
1. num_waves type mismatch in mjlab/terrains/config.py
# Line 64 — causes tyro type error on launch
SubTerrainSceneCfg(num_waves=4, ...) # int, but tyro expects floatFix: num_waves=4 → num_waves=4.0
2. No clamp on action noise std (actor_critic.py)
When noise_std_type="scalar" (default), self.std is a raw nn.Parameter with no lower bound. The optimizer can push it negative or to NaN:
# Line 89: no constraint
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
# Line 134: used directly
std = self.std.expand_as(mean)
# Line 140: Normal crashes if std < 0 or NaN
self.distribution = Normal(mean, std)Same issue exists in actor_critic_recurrent.py.
3. Ratio overflow in PPO surrogate loss (ppo.py)
# Line 297: exp() can overflow to Inf → NaN propagates to loss
ratio = torch.exp(actions_log_prob_batch - old_actions_log_prob_batch)On Rough terrain the robot falls immediately (episode_length ~20), producing extreme penalties. During PPO update with 5 epochs × 4 minibatches = 20 optimizer steps, the policy diverges enough that log_prob_new - log_prob_old overflows exp().
4. NaN gradient propagation in multi-GPU reduction (ppo.py)
# Line 453: all_reduce(SUM) — one GPU with NaN poisons all others
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)If any single GPU produces NaN gradients (from overflow), all_reduce propagates NaN to all GPUs, corrupting all parameters irreversibly.
Proposed Fix
PR: #8
-
Clamp std + handle NaN in
update_distribution()(actor_critic.py,actor_critic_recurrent.py):std = torch.clamp(std, min=1e-6) std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6) mean = torch.nan_to_num(mean, nan=0.0) self.distribution = Normal(mean, std)
-
Clamp log-ratio before
exp()(ppo.py):log_ratio = actions_log_prob_batch - old_actions_log_prob_batch log_ratio = torch.clamp(log_ratio, -20.0, 20.0) ratio = torch.exp(log_ratio)
-
Skip optimizer step on NaN loss (
ppo.py):if not torch.isfinite(loss): self.optimizer.zero_grad() else: # reduce_parameters, clip_grad_norm, optimizer.step
-
Clean NaN gradients before all_reduce (
ppo.py):all_grads = torch.nan_to_num(all_grads, nan=0.0, posinf=0.0, neginf=0.0) torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
-
Fix
num_wavestype (mjlab/terrains/config.py):num_waves=4.0