Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions sdmetrics/column_pairs/statistical/correlation_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,33 @@ def _validate_data_not_constant(cls, real_data, synthetic_data):
cls._raise_constant_data_error(synthetic_columns, 'synthetic data')

@classmethod
def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
def compute_breakdown(
cls, real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0
):
"""Compare the breakdown of correlation similarity of two continuous columns.

Args:
real_data (Union[numpy.ndarray, pandas.Series]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.Series]):
The values from the synthetic dataset.
coefficient (str):
The correlation coefficient to use. Either 'Pearson' or 'Spearman'.
Default is 'Pearson'.
real_correlation_threshold (float):
The minimum absolute correlation value for the real data to be considered
correlated. Default is 0.

Returns:
dict:
A dict containing the score, and the real and synthetic metric values.
"""
if (
not isinstance(real_correlation_threshold, (int, float))
or not 0 <= real_correlation_threshold <= 1
):
raise ValueError('real_correlation_threshold must be a number between 0 and 1.')

real_data = real_data.copy()
synthetic_data = synthetic_data.copy()

Expand Down Expand Up @@ -101,10 +115,13 @@ def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
)

correlation_real, _ = correlation_fn(real_data[column1], real_data[column2])
if np.abs(correlation_real) <= real_correlation_threshold:
return {'score': np.nan, 'real': correlation_real, 'synthetic': np.nan}

correlation_synthetic, _ = correlation_fn(synthetic_data[column1], synthetic_data[column2])

if np.isnan(correlation_real) or np.isnan(correlation_synthetic):
return {'score': np.nan}
return {'score': np.nan, 'real': correlation_real, 'synthetic': correlation_synthetic}

return {
'score': 1 - abs(correlation_real - correlation_synthetic) / 2,
Expand All @@ -113,20 +130,30 @@ def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
}

@classmethod
def compute(cls, real_data, synthetic_data, coefficient='Pearson'):
def compute(
cls, real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0
):
"""Compare the correlation similarity of two continuous columns.

Args:
real_data (Union[numpy.ndarray, pandas.Series]):
The values from the real dataset.
synthetic_data (Union[numpy.ndarray, pandas.Series]):
The values from the synthetic dataset.
coefficient (str):
The correlation coefficient to use. Either 'Pearson' or 'Spearman'.
Default is 'Pearson'.
real_correlation_threshold (float):
The minimum absolute correlation value for the real data to be considered
correlated. Default is 0.

Returns:
float:
The correlation similarity of the two columns.
"""
return cls.compute_breakdown(real_data, synthetic_data, coefficient)['score']
return cls.compute_breakdown(
real_data, synthetic_data, coefficient, real_correlation_threshold
)['score']

@classmethod
def normalize(cls, raw_score):
Expand Down
93 changes: 86 additions & 7 deletions tests/unit/column_pairs/statistical/test_correlation_similarity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
from datetime import datetime
from unittest.mock import Mock, call, patch

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -115,7 +117,78 @@ def test_compute_breakdown_constant_input(self):
with pytest.raises(ConstantInputError, match=error_msg):
metric.compute_breakdown(real_data, synthetic_data, coefficient='Pearson')

def test_compute(self):
@pytest.mark.parametrize(
'real_correlation_threshold, score',
[
(0, 0.9008941765855183),
(0.35, 0.9008941765855183),
(0.498212, np.nan),
(0.75, np.nan),
],
)
def test_compute_breakdown_with_real_correlation_threshold(
self, real_correlation_threshold, score
):
"""Test the ``compute_breakdown`` method with `real_correlation_threshold`.

In this test, real data has a correlation of 0.498212 and synthetic data
has a correlation of 0.3.
"""
# Setup
real_data = pd.DataFrame({
'col1': [1.0, 2.0, 3.0, 4.0],
'col2': [0.2, -1.0895, -0.6425, 1.5365],
})
synthetic_data = pd.DataFrame({
'col1': [1.0, 2.0, 3.0, 4.0],
'col2': [0.616536, -1.216536, -0.916536, 1.516536],
})

# Run
metric = CorrelationSimilarity()
result = metric.compute_breakdown(
real_data,
synthetic_data,
coefficient='Pearson',
real_correlation_threshold=real_correlation_threshold,
)

# Assert
assert (
np.isclose(result['score'], score, atol=1e-6)
if not np.isnan(score)
else np.isnan(result['score'])
)

def test_compute_breakdown_invalid_real_correlation_threshold(self):
"""Test an error is thrown when an invalid `real_correlation_threshold` is passed."""
# Setup
real_data = pd.DataFrame({'col1': [1.0, 2.0, 3.0], 'col2': [2.0, 3.0, 4.0]})
synthetic_data = pd.DataFrame({'col1': [0.9, 1.8, 3.1], 'col2': [2, 3, 4]})
expected_error = re.escape('real_correlation_threshold must be a number between 0 and 1.')
metric = CorrelationSimilarity()

# Run and Assert
with pytest.raises(ValueError, match=expected_error):
metric.compute_breakdown(
real_data,
synthetic_data,
coefficient='Pearson',
real_correlation_threshold=-0.1,
)

with pytest.raises(ValueError, match=expected_error):
metric.compute_breakdown(
real_data,
synthetic_data,
coefficient='Pearson',
real_correlation_threshold=None,
)

@patch(
'sdmetrics.column_pairs.statistical.correlation_similarity.CorrelationSimilarity.compute_breakdown'
)
def test_compute(self, compute_breakdown_mock):
"""Test the ``compute`` method.

Expect that the selected coefficient is used to compare the real and synthetic data.
Expand All @@ -134,17 +207,23 @@ def test_compute(self):
test_score = 0.2
score_breakdown = {'score': test_score}
metric = CorrelationSimilarity()
compute_breakdown_mock.return_value = score_breakdown
real_data = Mock()
synthetic_data = Mock()

# Run
with patch.object(
CorrelationSimilarity,
'compute_breakdown',
return_value=score_breakdown,
):
result = metric.compute(Mock(), Mock(), coefficient='Pearson')
result = metric.compute(
real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0.6
)

# Assert
assert result == test_score
compute_breakdown_mock.assert_called_once_with(
real_data,
synthetic_data,
'Pearson',
0.6,
)

@patch('sdmetrics.column_pairs.statistical.correlation_similarity.ColumnPairsMetric.normalize')
def test_normalize(self, normalize_mock):
Expand Down