Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions sdmetrics/column_pairs/statistical/contingency_similarity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Contingency Similarity Metric."""

import numpy as np
import pandas as pd
from scipy.stats.contingency import association

from sdmetrics.column_pairs.base import ColumnPairsMetric
from sdmetrics.goal import Goal
Expand Down Expand Up @@ -28,7 +30,12 @@ class ContingencySimilarity(ColumnPairsMetric):

@staticmethod
def _validate_inputs(
real_data, synthetic_data, continuous_column_names, num_discrete_bins, num_rows_subsample
real_data,
synthetic_data,
continuous_column_names,
num_discrete_bins,
num_rows_subsample,
real_association_threshold,
):
for data in [real_data, synthetic_data]:
if not isinstance(data, pd.DataFrame) or len(data.columns) != 2:
Expand All @@ -53,6 +60,13 @@ def _validate_inputs(
if not isinstance(num_rows_subsample, int) or num_rows_subsample <= 0:
raise ValueError('`num_rows_subsample` must be an integer greater than zero.')

if (
not isinstance(real_association_threshold, (int, float))
or real_association_threshold < 0
or real_association_threshold > 1
):
raise ValueError('real_association_threshold must be a number between 0 and 1.')

@classmethod
def compute_breakdown(
cls,
Expand All @@ -61,6 +75,7 @@ def compute_breakdown(
continuous_column_names=None,
num_discrete_bins=10,
num_rows_subsample=None,
real_association_threshold=0,
):
"""Compute the breakdown of this metric."""
cls._validate_inputs(
Expand All @@ -69,6 +84,7 @@ def compute_breakdown(
continuous_column_names,
num_discrete_bins,
num_rows_subsample,
real_association_threshold,
)
columns = real_data.columns[:2]

Expand All @@ -84,7 +100,14 @@ def compute_breakdown(
real[column], synthetic[column], num_discrete_bins=num_discrete_bins
)

contingency_real = real.groupby(list(columns), dropna=False).size() / len(real)
contingency_real_counts = real.groupby(list(columns), dropna=False).size()
if real_association_threshold > 0:
contingency_2d = contingency_real_counts.unstack(fill_value=0) # noqa: PD010
real_cramer = association(contingency_2d.values, method='cramer')
if real_cramer <= real_association_threshold:
return {'score': np.nan}

contingency_real = contingency_real_counts / len(real)
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
synthetic
)
Expand All @@ -103,6 +126,7 @@ def compute(
continuous_column_names=None,
num_discrete_bins=10,
num_rows_subsample=None,
real_association_threshold=0,
):
"""Compare the contingency similarity of two discrete columns.

Expand All @@ -120,17 +144,23 @@ def compute(
num_rows_subsample (int, optional):
The number of rows to subsample from the real and synthetic data before computing
the metric. Defaults to ``None``.
real_association_threshold (float, optional):
The minimum Cramer's V association score required in the real data for the
metric to be computed. If the real data's association is below this threshold,
the metric returns NaN. Defaults to 0 (no threshold).

Returns:
float:
The contingency similarity of the two columns.
The contingency similarity of the two columns, or NaN if the real data's
association is below the threshold.
"""
return cls.compute_breakdown(
real_data,
synthetic_data,
continuous_column_names,
num_discrete_bins,
num_rows_subsample,
real_association_threshold,
)['score']

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test__validate_inputs(self):
continuous_column_names=None,
num_discrete_bins=10,
num_rows_subsample=3,
real_association_threshold=0,
)
expected_bad_data = re.escape('The data must be a pandas DataFrame with two columns.')
with pytest.raises(ValueError, match=expected_bad_data):
Expand All @@ -36,6 +37,7 @@ def test__validate_inputs(self):
continuous_column_names=None,
num_discrete_bins=10,
num_rows_subsample=3,
real_association_threshold=0,
)

expected_mismatch_columns_error = re.escape(
Expand All @@ -48,6 +50,7 @@ def test__validate_inputs(self):
continuous_column_names=None,
num_discrete_bins=10,
num_rows_subsample=3,
real_association_threshold=0,
)

expected_bad_continous_column_error = re.escape(
Expand All @@ -60,6 +63,7 @@ def test__validate_inputs(self):
continuous_column_names=bad_continous_columns,
num_discrete_bins=10,
num_rows_subsample=3,
real_association_threshold=0,
)

expected_bad_num_discrete_bins_error = re.escape(
Expand All @@ -72,6 +76,7 @@ def test__validate_inputs(self):
continuous_column_names=['col1'],
num_discrete_bins=bad_num_discrete_bins,
num_rows_subsample=3,
real_association_threshold=0,
)
expected_bad_num_rows_subsample_error = re.escape(
'`num_rows_subsample` must be an integer greater than zero.'
Expand All @@ -83,6 +88,20 @@ def test__validate_inputs(self):
continuous_column_names=['col1'],
num_discrete_bins=10,
num_rows_subsample=bad_num_rows_subsample,
real_association_threshold=0,
)

expected_bad_threshold_error = re.escape(
'real_association_threshold must be a number between 0 and 1.'
)
with pytest.raises(ValueError, match=expected_bad_threshold_error):
ContingencySimilarity._validate_inputs(
real_data=real_data,
synthetic_data=synthetic_data,
continuous_column_names=['col1'],
num_discrete_bins=10,
num_rows_subsample=3,
real_association_threshold=-0.1,
)

@patch(
Expand All @@ -99,7 +118,7 @@ def test_compute_mock(self, compute_breakdown_mock):
score = ContingencySimilarity.compute(real_data, synthetic_data)

# Assert
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None, 10, None)
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None, 10, None, 0)
assert score == 0.25

@patch(
Expand Down Expand Up @@ -134,6 +153,7 @@ def test_compute_breakdown(self, validate_inputs_mock):
None,
10,
None,
0,
)
assert result == {'score': expected_score}

Expand Down Expand Up @@ -218,3 +238,47 @@ def test_no_runtime_warning_raised(self):
ContingencySimilarity.compute(
real_data=real_data[['A', 'B']], synthetic_data=synthetic_data[['A', 'B']]
)

def test_real_association_threshold_returns_nan(self):
"""Test that NaN is returned when real association is below threshold."""
# Setup
real_data = pd.DataFrame({
'col1': np.random.choice(['A', 'B', 'C'], size=100),
'col2': np.random.choice(['X', 'Y', 'Z'], size=100),
})
synthetic_data = pd.DataFrame({
'col1': np.random.choice(['A', 'B', 'C'], size=100),
'col2': np.random.choice(['X', 'Y', 'Z'], size=100),
})

# Run
result = ContingencySimilarity.compute(
real_data=real_data,
synthetic_data=synthetic_data,
real_association_threshold=0.3,
)

# Assert
assert np.isnan(result)

def test_real_association_threshold_computes_normally(self):
"""Test that metric computes normally when real association exceeds threshold."""
# Setup
real_data = pd.DataFrame({
'col1': ['A'] * 50 + ['B'] * 50,
'col2': ['X'] * 48 + ['Y'] * 2 + ['Y'] * 48 + ['X'] * 2,
})
synthetic_data = pd.DataFrame({
'col1': ['A'] * 50 + ['B'] * 50,
'col2': ['X'] * 45 + ['Y'] * 5 + ['Y'] * 45 + ['X'] * 5,
})

# Run
result = ContingencySimilarity.compute(
real_data=real_data,
synthetic_data=synthetic_data,
real_association_threshold=0.3,
)

# Assert
assert 0 <= result <= 1