Skip to content

Multi-GPU training crashes on Rough terrain task (NaN in PPO) #9

@diaskabdualiev

Description

@diaskabdualiev

Description

Training Mjlab-Velocity-Rough-Unitree-Go2 crashes with RuntimeError: normal expects all elements of std >= 0.0 within the first few iterations. The issue occurs both when training from scratch and when fine-tuning from a Flat model checkpoint. The Mjlab-Velocity-Flat-Unitree-Go2 task trains successfully with the same setup.

Environment

  • 10× NVIDIA RTX A4000 (multi-GPU via torchrunx)
  • 56 CPU cores
  • --env.scene.num-envs=8192
  • Ubuntu, Python 3.11, PyTorch with CUDA

Steps to Reproduce

python scripts/train.py Mjlab-Velocity-Rough-Unitree-Go2 \
  --gpu-ids 0 1 2 3 4 5 6 7 8 9 \
  --env.scene.num-envs=8192

Error

RuntimeError: normal expects all elements of std >= 0.0
  File "mjlab/rsl_rl/modules/actor_critic.py", line 146, in act
    return self.distribution.sample()
  File "mjlab/rsl_rl/algorithms/ppo.py", line 249, in update
    self.policy.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])

Crashes at iteration 1-2 when training from scratch. Metrics before crash:

Mean reward: -5.07
Mean episode length: 20.55
Episode_Termination/fell_over: 26.75
Episode_Termination/illegal_contact: 375.67
Curriculum/terrain_levels: 2.13 → 0.00

Root Cause Analysis

Four issues contribute to the crash:

1. num_waves type mismatch in mjlab/terrains/config.py

# Line 64 — causes tyro type error on launch
SubTerrainSceneCfg(num_waves=4, ...)  # int, but tyro expects float

Fix: num_waves=4num_waves=4.0

2. No clamp on action noise std (actor_critic.py)

When noise_std_type="scalar" (default), self.std is a raw nn.Parameter with no lower bound. The optimizer can push it negative or to NaN:

# Line 89: no constraint
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))

# Line 134: used directly
std = self.std.expand_as(mean)

# Line 140: Normal crashes if std < 0 or NaN
self.distribution = Normal(mean, std)

Same issue exists in actor_critic_recurrent.py.

3. Ratio overflow in PPO surrogate loss (ppo.py)

# Line 297: exp() can overflow to Inf → NaN propagates to loss
ratio = torch.exp(actions_log_prob_batch - old_actions_log_prob_batch)

On Rough terrain the robot falls immediately (episode_length ~20), producing extreme penalties. During PPO update with 5 epochs × 4 minibatches = 20 optimizer steps, the policy diverges enough that log_prob_new - log_prob_old overflows exp().

4. NaN gradient propagation in multi-GPU reduction (ppo.py)

# Line 453: all_reduce(SUM) — one GPU with NaN poisons all others
torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)

If any single GPU produces NaN gradients (from overflow), all_reduce propagates NaN to all GPUs, corrupting all parameters irreversibly.

Proposed Fix

PR: #8

  1. Clamp std + handle NaN in update_distribution() (actor_critic.py, actor_critic_recurrent.py):

    std = torch.clamp(std, min=1e-6)
    std = torch.nan_to_num(std, nan=1.0, posinf=1.0, neginf=1e-6)
    mean = torch.nan_to_num(mean, nan=0.0)
    self.distribution = Normal(mean, std)
  2. Clamp log-ratio before exp() (ppo.py):

    log_ratio = actions_log_prob_batch - old_actions_log_prob_batch
    log_ratio = torch.clamp(log_ratio, -20.0, 20.0)
    ratio = torch.exp(log_ratio)
  3. Skip optimizer step on NaN loss (ppo.py):

    if not torch.isfinite(loss):
        self.optimizer.zero_grad()
    else:
        # reduce_parameters, clip_grad_norm, optimizer.step
  4. Clean NaN gradients before all_reduce (ppo.py):

    all_grads = torch.nan_to_num(all_grads, nan=0.0, posinf=0.0, neginf=0.0)
    torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM)
  5. Fix num_waves type (mjlab/terrains/config.py):

    num_waves=4.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions