Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions ignite/metrics/regression/spearman_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,58 @@
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _get_ranks(x: Tensor) -> Tensor:
"""Calculates ranks with average method for ties natively in PyTorch."""
n = x.size(0)
# Get sorted indices and the inverse mapping
sorter = torch.argsort(x)
inv_sorter = torch.empty(n, dtype=torch.long, device=x.device)
inv_sorter[sorter] = torch.arange(n, device=x.device)

x_sorted = x[sorter]
# Find ties
obs = torch.cat([torch.tensor([True], device=x.device), x_sorted[1:] != x_sorted[:-1]])
dense_ranks = torch.cumsum(obs, dim=0)

# Calculate average ranks for ties
count = torch.cat([torch.nonzero(obs).flatten(), torch.tensor([n], device=x.device)])
repetitions = count[1:] - count[:-1]

# Use cumsum of repetitions to find the range of ranks for each unique value
right = torch.cumsum(repetitions, dim=0)
left = right - repetitions + 1
avg_ranks = (left + right).double() / 2.0

# Map back to original order
return avg_ranks[dense_ranks - 1][inv_sorter]


def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
Comment thread
Prathamesh8989 marked this conversation as resolved.
from scipy.stats import spearmanr
preds_flat = predictions.flatten()
targets_flat = targets.flatten()

if torch.isnan(preds_flat).any() or torch.isnan(targets_flat).any():
return float("nan")

np_preds = predictions.flatten().cpu().numpy()
np_targets = targets.flatten().cpu().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r
# Native PyTorch Ranking
r_preds = _get_ranks(preds_flat)
r_targets = _get_ranks(targets_flat)

# Correlation of ranks (Pearson Correlation)
mu_x = torch.mean(r_preds)
mu_y = torch.mean(r_targets)

diff_x = r_preds - mu_x
diff_y = r_targets - mu_y

norm_x = torch.norm(diff_x, 2)
norm_y = torch.norm(diff_y, 2)

if norm_x == 0 or norm_y == 0:
return float("nan")

r = torch.sum(diff_x * diff_y) / (norm_x * norm_y)
return r.item()


class SpearmanRankCorrelation(EpochMetric):
Expand All @@ -30,8 +75,7 @@ class SpearmanRankCorrelation(EpochMetric):
where :math:`A` and :math:`P` are the ground truth and predicted value,
and :math:`R[X]` is the ranking value of :math:`X`.

The computation of this metric is implemented with
`scipy.stats.spearmanr <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html>`_.
The computation of this metric is implemented natively in PyTorch.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
Expand Down Expand Up @@ -76,6 +120,10 @@ class SpearmanRankCorrelation(EpochMetric):
0.7142857142857143

.. versionadded:: 0.5.2

.. versionchanged:: 0.5.5
Implementation updated to use a native PyTorch computation for rank calculation and
correlation, removing the dependency on SciPy.
"""

def __init__(
Expand All @@ -85,11 +133,6 @@ def __init__(
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from scipy.stats import spearmanr # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scipy to be installed.")

super().__init__(_spearman_r, output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
Expand All @@ -99,10 +142,10 @@ def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
if y.ndim == 1:
y = y.unsqueeze(1)

_check_output_shapes(output)
_check_output_types(output)
_check_output_shapes((y_pred, y))
_check_output_types((y_pred, y))

super().update(output)
super().update((y_pred, y))

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
Expand Down
39 changes: 36 additions & 3 deletions tests/ignite/metrics/regression/test_spearman_correlation.py
Comment thread
vfdev-5 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.regression import SpearmanRankCorrelation
from ignite.metrics.regression.spearman_correlation import _get_ranks


def test_zero_sample():
Expand Down Expand Up @@ -67,8 +68,10 @@ def test_spearman_correlation(available_device):
all_preds.append(x)
all_targets.append(ground_truth)

pred_cat = torch.cat(all_preds).numpy()
target_cat = torch.cat(all_targets).numpy()
pred_cat = torch.cat(all_preds).cpu().numpy()
target_cat = torch.cat(all_targets).cpu().numpy()

# Convert only for computing the expected value
expected = spearmanr(pred_cat, target_cat).statistic
assert m.compute() == pytest.approx(expected, rel=1e-4)

Expand Down Expand Up @@ -105,7 +108,7 @@ def update_fn(engine: Engine, batch):
corr = engine.run(data, max_epochs=1).metrics["spearman_corr"]

# Convert only for computing the expected value
expected = spearmanr(y_pred.numpy().ravel(), y.numpy().ravel()).statistic
expected = spearmanr(y_pred.cpu().numpy().ravel(), y.cpu().numpy().ravel()).statistic

assert pytest.approx(expected, rel=2e-4) == corr

Expand Down Expand Up @@ -182,3 +185,33 @@ def test_integration(self, n_epochs: int):
np_ans = spearmanr(np_y_pred, np_y).statistic

assert pytest.approx(np_ans, rel=tol) == res


def test_nan_inputs():
metric = SpearmanRankCorrelation()

y_pred = torch.tensor([1.0, float("nan"), 3.0])
y = torch.tensor([1.0, 2.0, 3.0])

metric.update((y_pred, y))
assert torch.isnan(torch.tensor(metric.compute()))


def test_constant_inputs():
metric = SpearmanRankCorrelation()

y_pred = torch.tensor([5.0, 5.0, 5.0, 5.0])
y = torch.tensor([1.0, 2.0, 3.0, 4.0])

metric.update((y_pred, y))
assert torch.isnan(torch.tensor(metric.compute()))


def test_average_rank_logic():
x = torch.tensor([10.0, 20.0, 20.0, 30.0])

ranks = _get_ranks(x)

expected = torch.tensor([1.0, 2.5, 2.5, 4.0], dtype=torch.double)

assert torch.allclose(ranks, expected)
Loading