|
| 1 | +# Copyright (c) 2022-2025, The Isaac Lab Project Developers. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | + |
| 6 | +"""Action utility functions for safe and smooth robot control.""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | +import torch |
| 10 | +from typing import Optional |
| 11 | + |
| 12 | + |
| 13 | +class ActionSmoother: |
| 14 | + """Smooths actions using exponential moving average.""" |
| 15 | + |
| 16 | + def __init__(self, action_dim: int, num_envs: int, smoothing_factor: float = 0.7, device: str = "cuda"): |
| 17 | + self.smoothing_factor = smoothing_factor |
| 18 | + self.prev_actions = torch.zeros((num_envs, action_dim), device=device) |
| 19 | + |
| 20 | + def smooth(self, actions: torch.Tensor) -> torch.Tensor: |
| 21 | + smoothed = self.smoothing_factor * actions + (1 - self.smoothing_factor) * self.prev_actions |
| 22 | + self.prev_actions = smoothed.clone() |
| 23 | + return smoothed |
| 24 | + |
| 25 | + def reset(self, env_ids: Optional[torch.Tensor] = None): |
| 26 | + if env_ids is None: |
| 27 | + self.prev_actions.zero_() |
| 28 | + else: |
| 29 | + self.prev_actions[env_ids] = 0.0 |
| 30 | + |
| 31 | + |
| 32 | +class ActionClipper: |
| 33 | + """Clips actions to safe bounds and limits rate of change.""" |
| 34 | + |
| 35 | + def __init__(self, action_dim: int, num_envs: int, action_low: float = -1.0, |
| 36 | + action_high: float = 1.0, max_delta: Optional[float] = None, device: str = "cuda"): |
| 37 | + self.action_low = action_low |
| 38 | + self.action_high = action_high |
| 39 | + self.max_delta = max_delta |
| 40 | + self.prev_actions = torch.zeros((num_envs, action_dim), device=device) |
| 41 | + |
| 42 | + def clip(self, actions: torch.Tensor) -> torch.Tensor: |
| 43 | + clipped = torch.clamp(actions, self.action_low, self.action_high) |
| 44 | + |
| 45 | + if self.max_delta is not None: |
| 46 | + delta = clipped - self.prev_actions |
| 47 | + delta = torch.clamp(delta, -self.max_delta, self.max_delta) |
| 48 | + clipped = self.prev_actions + delta |
| 49 | + clipped = torch.clamp(clipped, self.action_low, self.action_high) |
| 50 | + |
| 51 | + self.prev_actions = clipped.clone() |
| 52 | + return clipped |
| 53 | + |
| 54 | + def reset(self, env_ids: Optional[torch.Tensor] = None): |
| 55 | + if env_ids is None: |
| 56 | + self.prev_actions.zero_() |
| 57 | + else: |
| 58 | + self.prev_actions[env_ids] = 0.0 |
| 59 | + |
| 60 | + |
| 61 | +def validate_actions(actions: torch.Tensor, action_low: float = -1.0, action_high: float = 1.0) -> bool: |
| 62 | + """Check if actions are valid (no NaN/Inf, within bounds).""" |
| 63 | + if not torch.isfinite(actions).all(): |
| 64 | + return False |
| 65 | + if (actions < action_low).any() or (actions > action_high).any(): |
| 66 | + return False |
| 67 | + return True |
0 commit comments