Skip to content

Commit 79508b7

Browse files
committed
Add threshold
1 parent 88c74d4 commit 79508b7

File tree

3 files changed

+244
-4
lines changed

3 files changed

+244
-4
lines changed

sdmetrics/column_pairs/statistical/contingency_similarity.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Contingency Similarity Metric."""
22

3+
import numpy as np
34
import pandas as pd
5+
from scipy.stats.contingency import association
46

57
from sdmetrics.column_pairs.base import ColumnPairsMetric
68
from sdmetrics.goal import Goal
@@ -28,7 +30,12 @@ class ContingencySimilarity(ColumnPairsMetric):
2830

2931
@staticmethod
3032
def _validate_inputs(
31-
real_data, synthetic_data, continuous_column_names, num_discrete_bins, num_rows_subsample
33+
real_data,
34+
synthetic_data,
35+
continuous_column_names,
36+
num_discrete_bins,
37+
num_rows_subsample,
38+
real_association_threshold,
3239
):
3340
for data in [real_data, synthetic_data]:
3441
if not isinstance(data, pd.DataFrame) or len(data.columns) != 2:
@@ -53,6 +60,14 @@ def _validate_inputs(
5360
if not isinstance(num_rows_subsample, int) or num_rows_subsample <= 0:
5461
raise ValueError('`num_rows_subsample` must be an integer greater than zero.')
5562

63+
if (
64+
not isinstance(real_association_threshold, (int, float))
65+
or real_association_threshold < 0
66+
):
67+
raise ValueError(
68+
'`real_association_threshold` must be a number greater than or equal to zero.'
69+
)
70+
5671
@classmethod
5772
def compute_breakdown(
5873
cls,
@@ -61,6 +76,7 @@ def compute_breakdown(
6176
continuous_column_names=None,
6277
num_discrete_bins=10,
6378
num_rows_subsample=None,
79+
real_association_threshold=0,
6480
):
6581
"""Compute the breakdown of this metric."""
6682
cls._validate_inputs(
@@ -69,6 +85,7 @@ def compute_breakdown(
6985
continuous_column_names,
7086
num_discrete_bins,
7187
num_rows_subsample,
88+
real_association_threshold,
7289
)
7390
columns = real_data.columns[:2]
7491

@@ -84,7 +101,14 @@ def compute_breakdown(
84101
real[column], synthetic[column], num_discrete_bins=num_discrete_bins
85102
)
86103

87-
contingency_real = real.groupby(list(columns), dropna=False).size() / len(real)
104+
contingency_real_counts = real.groupby(list(columns), dropna=False).size()
105+
if real_association_threshold > 0:
106+
contingency_2d = contingency_real_counts.unstack(fill_value=0) # noqa: PD010
107+
real_cramer = association(contingency_2d.values, method='cramer')
108+
if real_cramer < real_association_threshold:
109+
return {'score': np.nan}
110+
111+
contingency_real = contingency_real_counts / len(real)
88112
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
89113
synthetic
90114
)
@@ -103,6 +127,7 @@ def compute(
103127
continuous_column_names=None,
104128
num_discrete_bins=10,
105129
num_rows_subsample=None,
130+
real_association_threshold=0,
106131
):
107132
"""Compare the contingency similarity of two discrete columns.
108133
@@ -120,17 +145,23 @@ def compute(
120145
num_rows_subsample (int, optional):
121146
The number of rows to subsample from the real and synthetic data before computing
122147
the metric. Defaults to ``None``.
148+
real_association_threshold (float, optional):
149+
The minimum Cramer's V association score required in the real data for the
150+
metric to be computed. If the real data's association is below this threshold,
151+
the metric returns NaN. Defaults to 0 (no threshold).
123152
124153
Returns:
125154
float:
126-
The contingency similarity of the two columns.
155+
The contingency similarity of the two columns, or NaN if the real data's
156+
association is below the threshold.
127157
"""
128158
return cls.compute_breakdown(
129159
real_data,
130160
synthetic_data,
131161
continuous_column_names,
132162
num_discrete_bins,
133163
num_rows_subsample,
164+
real_association_threshold,
134165
)['score']
135166

136167
@classmethod

tests/readme_test/README.md

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
<div align="center">
2+
<br/>
3+
<p align="center">
4+
<i>This repository is part of <a href="https://sdv.dev">The Synthetic Data Vault Project</a>, a project from <a href="https://datacebo.com">DataCebo</a>.</i>
5+
</p>
6+
7+
[![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
8+
[![PyPI Shield](https://img.shields.io/pypi/v/sdmetrics.svg)](https://pypi.python.org/pypi/sdmetrics)
9+
[![Downloads](https://pepy.tech/badge/sdmetrics)](https://pepy.tech/project/sdmetrics)
10+
[![Tests](https://github.com/sdv-dev/SDMetrics/workflows/Run%20Tests/badge.svg)](https://github.com/sdv-dev/SDMetrics/actions?query=workflow%3A%22Run+Tests%22+branch%3Amain)
11+
[![Coverage Status](https://codecov.io/gh/sdv-dev/SDMetrics/branch/main/graph/badge.svg)](https://codecov.io/gh/sdv-dev/SDMetrics)
12+
[![Slack](https://img.shields.io/badge/Community-Slack-blue?style=plastic&logo=slack)](https://bit.ly/sdv-slack-invite)
13+
[![Tutorial](https://img.shields.io/badge/Demo-Get%20started-orange?style=plastic&logo=googlecolab)](https://bit.ly/sdmetrics-demo)
14+
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.14279167.svg)](https://doi.org/10.5281/zenodo.14279167)
15+
16+
<div align="left">
17+
<br/>
18+
<p align="center">
19+
<a href="https://github.com/sdv-dev/SDV">
20+
<img align="center" width=40% src="https://github.com/sdv-dev/SDV/blob/stable/docs/images/SDMetrics-DataCebo.png"></img>
21+
</a>
22+
</p>
23+
</div>
24+
25+
</div>
26+
27+
# Overview
28+
29+
The SDMetrics library evaluates synthetic data by comparing it to the real data that you're trying to mimic. It includes a variety of metrics to capture different aspects of the data, for example **quality and privacy**. It also includes reports that you can run to generate insights, visualize data and share with your team.
30+
31+
The SDMetrics library is **model-agnostic**, meaning you can use any synthetic data. The library does not need to know how you created the data.
32+
33+
<img align="center" src="docs/images/column_comparison.png"></img>
34+
35+
# Install
36+
37+
Install SDMetrics using pip or conda. We recommend using a virtual environment to avoid conflicts with other software on your device.
38+
39+
```bash
40+
pip install sdmetrics
41+
```
42+
43+
```bash
44+
conda install -c conda-forge sdmetrics
45+
```
46+
47+
For more information about using SDMetrics, visit the [SDMetrics Documentation](https://docs.sdv.dev/sdmetrics).
48+
49+
# Usage
50+
51+
Get started with **SDMetrics Reports** using some demo data,
52+
53+
```python
54+
from sdmetrics import load_demo
55+
from sdmetrics.reports.single_table import QualityReport
56+
57+
real_data, synthetic_data, metadata = load_demo(modality='single_table')
58+
59+
my_report = QualityReport()
60+
my_report.generate(real_data, synthetic_data, metadata)
61+
```
62+
```
63+
Creating report: 100%|██████████| 4/4 [00:00<00:00, 5.22it/s]
64+
65+
Overall Quality Score: 82.84%
66+
67+
Properties:
68+
Column Shapes: 82.78%
69+
Column Pair Trends: 82.9%
70+
```
71+
72+
Once you generate the report, you can drill down on the details and visualize the results.
73+
74+
```python
75+
my_report.get_visualization(property_name='Column Pair Trends')
76+
```
77+
<img align="center" src="docs/images/column_pairs.png"></img>
78+
79+
Save the report and share it with your team.
80+
```python
81+
my_report.save(filepath='demo_data_quality_report.pkl')
82+
83+
# load it at any point in the future
84+
my_report = QualityReport.load(filepath='demo_data_quality_report.pkl')
85+
```
86+
87+
**Want more metrics?** You can also manually apply any of the metrics in this library to your data.
88+
89+
```python
90+
# calculate whether the synthetic data respects the min/max bounds
91+
# set by the real data
92+
from sdmetrics.single_column import BoundaryAdherence
93+
94+
BoundaryAdherence.compute(
95+
real_data['start_date'],
96+
synthetic_data['start_date']
97+
)
98+
```
99+
```
100+
0.8503937007874016
101+
```
102+
103+
```python
104+
# calculate whether the synthetic data is new or whether it's an exact copy of the real data
105+
from sdmetrics.single_table import NewRowSynthesis
106+
107+
NewRowSynthesis.compute(
108+
real_data,
109+
synthetic_data,
110+
metadata
111+
)
112+
```
113+
```
114+
1.0
115+
```
116+
117+
# What's next?
118+
119+
To learn more about the reports and metrics, visit the [SDMetrics Documentation](https://docs.sdv.dev/sdmetrics).
120+
121+
---
122+
123+
124+
<div align="center">
125+
<a href="https://datacebo.com"><img align="center" width=40% src="https://github.com/sdv-dev/SDV/blob/stable/docs/images/DataCebo.png"></img></a>
126+
</div>
127+
<br/>
128+
<br/>
129+
130+
[The Synthetic Data Vault Project](https://sdv.dev) was first created at MIT's [Data to AI Lab](
131+
https://dai.lids.mit.edu/) in 2016. After 4 years of research and traction with enterprise, we
132+
created [DataCebo](https://datacebo.com) in 2020 with the goal of growing the project.
133+
Today, DataCebo is the proud developer of SDV, the largest ecosystem for
134+
synthetic data generation & evaluation. It is home to multiple libraries that support synthetic
135+
data, including:
136+
137+
* 🔄 Data discovery & transformation. Reverse the transforms to reproduce realistic data.
138+
* 🧠 Multiple machine learning models -- ranging from Copulas to Deep Learning -- to create tabular,
139+
multi table and time series data.
140+
* 📊 Measuring quality and privacy of synthetic data, and comparing different synthetic data
141+
generation models.
142+
143+
[Get started using the SDV package](https://sdv.dev/SDV/getting_started/install.html) -- a fully
144+
integrated solution and your one-stop shop for synthetic data. Or, use the standalone libraries
145+
for specific needs.

tests/unit/column_pairs/statistical/test_contingency_similarity.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test__validate_inputs(self):
2727
continuous_column_names=None,
2828
num_discrete_bins=10,
2929
num_rows_subsample=3,
30+
real_association_threshold=0,
3031
)
3132
expected_bad_data = re.escape('The data must be a pandas DataFrame with two columns.')
3233
with pytest.raises(ValueError, match=expected_bad_data):
@@ -36,6 +37,7 @@ def test__validate_inputs(self):
3637
continuous_column_names=None,
3738
num_discrete_bins=10,
3839
num_rows_subsample=3,
40+
real_association_threshold=0,
3941
)
4042

4143
expected_mismatch_columns_error = re.escape(
@@ -48,6 +50,7 @@ def test__validate_inputs(self):
4850
continuous_column_names=None,
4951
num_discrete_bins=10,
5052
num_rows_subsample=3,
53+
real_association_threshold=0,
5154
)
5255

5356
expected_bad_continous_column_error = re.escape(
@@ -60,6 +63,7 @@ def test__validate_inputs(self):
6063
continuous_column_names=bad_continous_columns,
6164
num_discrete_bins=10,
6265
num_rows_subsample=3,
66+
real_association_threshold=0,
6367
)
6468

6569
expected_bad_num_discrete_bins_error = re.escape(
@@ -72,6 +76,7 @@ def test__validate_inputs(self):
7276
continuous_column_names=['col1'],
7377
num_discrete_bins=bad_num_discrete_bins,
7478
num_rows_subsample=3,
79+
real_association_threshold=0,
7580
)
7681
expected_bad_num_rows_subsample_error = re.escape(
7782
'`num_rows_subsample` must be an integer greater than zero.'
@@ -83,6 +88,20 @@ def test__validate_inputs(self):
8388
continuous_column_names=['col1'],
8489
num_discrete_bins=10,
8590
num_rows_subsample=bad_num_rows_subsample,
91+
real_association_threshold=0,
92+
)
93+
94+
expected_bad_threshold_error = re.escape(
95+
'`real_association_threshold` must be a number greater than or equal to zero.'
96+
)
97+
with pytest.raises(ValueError, match=expected_bad_threshold_error):
98+
ContingencySimilarity._validate_inputs(
99+
real_data=real_data,
100+
synthetic_data=synthetic_data,
101+
continuous_column_names=['col1'],
102+
num_discrete_bins=10,
103+
num_rows_subsample=3,
104+
real_association_threshold=-0.1,
86105
)
87106

88107
@patch(
@@ -99,7 +118,7 @@ def test_compute_mock(self, compute_breakdown_mock):
99118
score = ContingencySimilarity.compute(real_data, synthetic_data)
100119

101120
# Assert
102-
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None, 10, None)
121+
compute_breakdown_mock.assert_called_once_with(real_data, synthetic_data, None, 10, None, 0)
103122
assert score == 0.25
104123

105124
@patch(
@@ -134,6 +153,7 @@ def test_compute_breakdown(self, validate_inputs_mock):
134153
None,
135154
10,
136155
None,
156+
0,
137157
)
138158
assert result == {'score': expected_score}
139159

@@ -218,3 +238,47 @@ def test_no_runtime_warning_raised(self):
218238
ContingencySimilarity.compute(
219239
real_data=real_data[['A', 'B']], synthetic_data=synthetic_data[['A', 'B']]
220240
)
241+
242+
def test_real_association_threshold_returns_nan(self):
243+
"""Test that NaN is returned when real association is below threshold."""
244+
# Setup
245+
real_data = pd.DataFrame({
246+
'col1': np.random.choice(['A', 'B', 'C'], size=100),
247+
'col2': np.random.choice(['X', 'Y', 'Z'], size=100),
248+
})
249+
synthetic_data = pd.DataFrame({
250+
'col1': np.random.choice(['A', 'B', 'C'], size=100),
251+
'col2': np.random.choice(['X', 'Y', 'Z'], size=100),
252+
})
253+
254+
# Run
255+
result = ContingencySimilarity.compute(
256+
real_data=real_data,
257+
synthetic_data=synthetic_data,
258+
real_association_threshold=0.3,
259+
)
260+
261+
# Assert
262+
assert np.isnan(result)
263+
264+
def test_real_association_threshold_computes_normally(self):
265+
"""Test that metric computes normally when real association exceeds threshold."""
266+
# Setup
267+
real_data = pd.DataFrame({
268+
'col1': ['A'] * 50 + ['B'] * 50,
269+
'col2': ['X'] * 48 + ['Y'] * 2 + ['Y'] * 48 + ['X'] * 2,
270+
})
271+
synthetic_data = pd.DataFrame({
272+
'col1': ['A'] * 50 + ['B'] * 50,
273+
'col2': ['X'] * 45 + ['Y'] * 5 + ['Y'] * 45 + ['X'] * 5,
274+
})
275+
276+
# Run
277+
result = ContingencySimilarity.compute(
278+
real_data=real_data,
279+
synthetic_data=synthetic_data,
280+
real_association_threshold=0.3,
281+
)
282+
283+
# Assert
284+
assert 0 <= result <= 1

0 commit comments

Comments
 (0)