Skip to content

Commit 0a4b46a

Browse files
authored
Add a threshold to the CorrelationSimilarity metric (#822)
1 parent 88c74d4 commit 0a4b46a

File tree

2 files changed

+117
-11
lines changed

2 files changed

+117
-11
lines changed

sdmetrics/column_pairs/statistical/correlation_similarity.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,33 @@ def _validate_data_not_constant(cls, real_data, synthetic_data):
5555
cls._raise_constant_data_error(synthetic_columns, 'synthetic data')
5656

5757
@classmethod
58-
def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
58+
def compute_breakdown(
59+
cls, real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0
60+
):
5961
"""Compare the breakdown of correlation similarity of two continuous columns.
6062
6163
Args:
6264
real_data (Union[numpy.ndarray, pandas.Series]):
6365
The values from the real dataset.
6466
synthetic_data (Union[numpy.ndarray, pandas.Series]):
6567
The values from the synthetic dataset.
68+
coefficient (str):
69+
The correlation coefficient to use. Either 'Pearson' or 'Spearman'.
70+
Default is 'Pearson'.
71+
real_correlation_threshold (float):
72+
The minimum absolute correlation value for the real data to be considered
73+
correlated. Default is 0.
6674
6775
Returns:
6876
dict:
6977
A dict containing the score, and the real and synthetic metric values.
7078
"""
79+
if (
80+
not isinstance(real_correlation_threshold, (int, float))
81+
or not 0 <= real_correlation_threshold <= 1
82+
):
83+
raise ValueError('real_correlation_threshold must be a number between 0 and 1.')
84+
7185
real_data = real_data.copy()
7286
synthetic_data = synthetic_data.copy()
7387

@@ -101,10 +115,13 @@ def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
101115
)
102116

103117
correlation_real, _ = correlation_fn(real_data[column1], real_data[column2])
118+
if np.abs(correlation_real) <= real_correlation_threshold:
119+
return {'score': np.nan, 'real': correlation_real, 'synthetic': np.nan}
120+
104121
correlation_synthetic, _ = correlation_fn(synthetic_data[column1], synthetic_data[column2])
105122

106123
if np.isnan(correlation_real) or np.isnan(correlation_synthetic):
107-
return {'score': np.nan}
124+
return {'score': np.nan, 'real': correlation_real, 'synthetic': correlation_synthetic}
108125

109126
return {
110127
'score': 1 - abs(correlation_real - correlation_synthetic) / 2,
@@ -113,20 +130,30 @@ def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'):
113130
}
114131

115132
@classmethod
116-
def compute(cls, real_data, synthetic_data, coefficient='Pearson'):
133+
def compute(
134+
cls, real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0
135+
):
117136
"""Compare the correlation similarity of two continuous columns.
118137
119138
Args:
120139
real_data (Union[numpy.ndarray, pandas.Series]):
121140
The values from the real dataset.
122141
synthetic_data (Union[numpy.ndarray, pandas.Series]):
123142
The values from the synthetic dataset.
143+
coefficient (str):
144+
The correlation coefficient to use. Either 'Pearson' or 'Spearman'.
145+
Default is 'Pearson'.
146+
real_correlation_threshold (float):
147+
The minimum absolute correlation value for the real data to be considered
148+
correlated. Default is 0.
124149
125150
Returns:
126151
float:
127152
The correlation similarity of the two columns.
128153
"""
129-
return cls.compute_breakdown(real_data, synthetic_data, coefficient)['score']
154+
return cls.compute_breakdown(
155+
real_data, synthetic_data, coefficient, real_correlation_threshold
156+
)['score']
130157

131158
@classmethod
132159
def normalize(cls, raw_score):

tests/unit/column_pairs/statistical/test_correlation_similarity.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import re
12
from datetime import datetime
23
from unittest.mock import Mock, call, patch
34

5+
import numpy as np
46
import pandas as pd
57
import pytest
68

@@ -115,7 +117,78 @@ def test_compute_breakdown_constant_input(self):
115117
with pytest.raises(ConstantInputError, match=error_msg):
116118
metric.compute_breakdown(real_data, synthetic_data, coefficient='Pearson')
117119

118-
def test_compute(self):
120+
@pytest.mark.parametrize(
121+
'real_correlation_threshold, score',
122+
[
123+
(0, 0.9008941765855183),
124+
(0.35, 0.9008941765855183),
125+
(0.498212, np.nan),
126+
(0.75, np.nan),
127+
],
128+
)
129+
def test_compute_breakdown_with_real_correlation_threshold(
130+
self, real_correlation_threshold, score
131+
):
132+
"""Test the ``compute_breakdown`` method with `real_correlation_threshold`.
133+
134+
In this test, real data has a correlation of 0.498212 and synthetic data
135+
has a correlation of 0.3.
136+
"""
137+
# Setup
138+
real_data = pd.DataFrame({
139+
'col1': [1.0, 2.0, 3.0, 4.0],
140+
'col2': [0.2, -1.0895, -0.6425, 1.5365],
141+
})
142+
synthetic_data = pd.DataFrame({
143+
'col1': [1.0, 2.0, 3.0, 4.0],
144+
'col2': [0.616536, -1.216536, -0.916536, 1.516536],
145+
})
146+
147+
# Run
148+
metric = CorrelationSimilarity()
149+
result = metric.compute_breakdown(
150+
real_data,
151+
synthetic_data,
152+
coefficient='Pearson',
153+
real_correlation_threshold=real_correlation_threshold,
154+
)
155+
156+
# Assert
157+
assert (
158+
np.isclose(result['score'], score, atol=1e-6)
159+
if not np.isnan(score)
160+
else np.isnan(result['score'])
161+
)
162+
163+
def test_compute_breakdown_invalid_real_correlation_threshold(self):
164+
"""Test an error is thrown when an invalid `real_correlation_threshold` is passed."""
165+
# Setup
166+
real_data = pd.DataFrame({'col1': [1.0, 2.0, 3.0], 'col2': [2.0, 3.0, 4.0]})
167+
synthetic_data = pd.DataFrame({'col1': [0.9, 1.8, 3.1], 'col2': [2, 3, 4]})
168+
expected_error = re.escape('real_correlation_threshold must be a number between 0 and 1.')
169+
metric = CorrelationSimilarity()
170+
171+
# Run and Assert
172+
with pytest.raises(ValueError, match=expected_error):
173+
metric.compute_breakdown(
174+
real_data,
175+
synthetic_data,
176+
coefficient='Pearson',
177+
real_correlation_threshold=-0.1,
178+
)
179+
180+
with pytest.raises(ValueError, match=expected_error):
181+
metric.compute_breakdown(
182+
real_data,
183+
synthetic_data,
184+
coefficient='Pearson',
185+
real_correlation_threshold=None,
186+
)
187+
188+
@patch(
189+
'sdmetrics.column_pairs.statistical.correlation_similarity.CorrelationSimilarity.compute_breakdown'
190+
)
191+
def test_compute(self, compute_breakdown_mock):
119192
"""Test the ``compute`` method.
120193
121194
Expect that the selected coefficient is used to compare the real and synthetic data.
@@ -134,17 +207,23 @@ def test_compute(self):
134207
test_score = 0.2
135208
score_breakdown = {'score': test_score}
136209
metric = CorrelationSimilarity()
210+
compute_breakdown_mock.return_value = score_breakdown
211+
real_data = Mock()
212+
synthetic_data = Mock()
137213

138214
# Run
139-
with patch.object(
140-
CorrelationSimilarity,
141-
'compute_breakdown',
142-
return_value=score_breakdown,
143-
):
144-
result = metric.compute(Mock(), Mock(), coefficient='Pearson')
215+
result = metric.compute(
216+
real_data, synthetic_data, coefficient='Pearson', real_correlation_threshold=0.6
217+
)
145218

146219
# Assert
147220
assert result == test_score
221+
compute_breakdown_mock.assert_called_once_with(
222+
real_data,
223+
synthetic_data,
224+
'Pearson',
225+
0.6,
226+
)
148227

149228
@patch('sdmetrics.column_pairs.statistical.correlation_similarity.ColumnPairsMetric.normalize')
150229
def test_normalize(self, normalize_mock):

0 commit comments

Comments
 (0)