From 7f309a36db8e9bd08e22ad2b01ab8a3e9519f51b Mon Sep 17 00:00:00 2001 From: Kelian Massa Date: Fri, 14 Mar 2025 14:21:48 +0200 Subject: [PATCH] Fix: fallback to full range if IQR is zero in RobustScaler --- gluon_utils/scalers/robust_scaler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/gluon_utils/scalers/robust_scaler.py b/gluon_utils/scalers/robust_scaler.py index 15cde41..a6617d8 100644 --- a/gluon_utils/scalers/robust_scaler.py +++ b/gluon_utils/scalers/robust_scaler.py @@ -22,7 +22,7 @@ class RobustScaler(Scaler): """ Computes a scaling factor by removing the median and scaling by the - interquartile range (IQR). + interquartile range (IQR) or, if IQR is 0, by the full range. Parameters ---------- @@ -61,9 +61,18 @@ def __call__( q3 = torch.nanquantile(observed_data, 0.75, dim=self.dim, keepdim=True) iqr = q3 - q1 + # Compute full range as a fallback if IQR is 0 + data_min = torch.nanquantile(observed_data, 0, dim=self.dim, keepdim=True) + data_max = torch.nanquantile(observed_data, 1, dim=self.dim, keepdim=True) + full_range = data_max - data_min + # if observed data is all zeros, nanmedian returns nan loc = torch.where(torch.isnan(med), torch.zeros_like(med), med) - scale = torch.where(torch.isnan(iqr), torch.ones_like(iqr), iqr) + scale = torch.where( + torch.isnan(iqr), + torch.ones_like(iqr), + torch.where(iqr == 0, full_range, iqr) + ) scale = torch.maximum(scale, torch.full_like(iqr, self.minimum_scale)) scaled_data = (data - loc) / scale