Skip to content

Commit 93c311c

Browse files
authored
Add DisclosureProtection metric (#678)
1 parent 682c85b commit 93c311c

File tree

5 files changed

+682
-0
lines changed

5 files changed

+682
-0
lines changed

sdmetrics/single_table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
CategoricalRF,
6868
CategoricalSVM,
6969
)
70+
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
7071
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
7172
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
7273
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -109,6 +110,7 @@
109110
'CategoricalCAP',
110111
'CategoricalZeroCAP',
111112
'CategoricalGeneralizedCAP',
113+
'DisclosureProtection',
112114
'NumericalMLP',
113115
'NumericalLR',
114116
'NumericalSVR',

sdmetrics/single_table/privacy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
CategoricalRF,
1313
CategoricalSVM,
1414
)
15+
from sdmetrics.single_table.privacy.disclosure_protection import DisclosureProtection
1516
from sdmetrics.single_table.privacy.ensemble import CategoricalEnsemble
1617
from sdmetrics.single_table.privacy.numerical_sklearn import NumericalLR, NumericalMLP, NumericalSVR
1718
from sdmetrics.single_table.privacy.radius_nearest_neighbor import NumericalRadiusNearestNeighbor
@@ -26,6 +27,7 @@
2627
'CategoricalRF',
2728
'CategoricalSVM',
2829
'CategoricalZeroCAP',
30+
'DisclosureProtection',
2931
'NumericalLR',
3032
'NumericalMLP',
3133
'NumericalPrivacyMetric',
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""Disclosure protection metrics."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from sdmetrics.goal import Goal
7+
from sdmetrics.single_table.base import SingleTableMetric
8+
from sdmetrics.single_table.privacy.cap import (
9+
CategoricalCAP,
10+
CategoricalGeneralizedCAP,
11+
CategoricalZeroCAP,
12+
)
13+
14+
CAP_METHODS = {
15+
'CAP': CategoricalCAP,
16+
'ZERO_CAP': CategoricalZeroCAP,
17+
'GENERALIZED_CAP': CategoricalGeneralizedCAP,
18+
}
19+
20+
21+
class DisclosureProtection(SingleTableMetric):
22+
"""The DisclosureProtection metric."""
23+
24+
goal = Goal.MAXIMIZE
25+
min_value = 0
26+
max_value = 1
27+
28+
@classmethod
29+
def _validate_inputs(
30+
cls,
31+
real_data,
32+
synthetic_data,
33+
known_column_names,
34+
sensitive_column_names,
35+
computation_method,
36+
continuous_column_names,
37+
num_discrete_bins,
38+
):
39+
if not isinstance(real_data, pd.DataFrame) or not isinstance(real_data, pd.DataFrame):
40+
raise ValueError('Real and synthetic data must be pandas DataFrames.')
41+
42+
if len(known_column_names) == 0:
43+
raise ValueError('Must provide at least 1 known column name.')
44+
elif not set(real_data.columns).issuperset(set(known_column_names)):
45+
missing = "', '".join(set(known_column_names) - set(real_data.columns))
46+
raise ValueError(f"Known column(s) '{missing}' are missing from the real data.")
47+
48+
if len(sensitive_column_names) == 0:
49+
raise ValueError('Must provide at least 1 sensitive column name.')
50+
elif not set(real_data.columns).issuperset(set(sensitive_column_names)):
51+
missing = "', '".join(set(sensitive_column_names) - set(real_data.columns))
52+
raise ValueError(f"Sensitive column(s) '{missing}' are missing from the real data.")
53+
54+
if computation_method.upper() not in CAP_METHODS.keys():
55+
raise ValueError(
56+
f"Unknown computation method '{computation_method}'. "
57+
f"Please use one of 'cap', 'zero_cap', or 'generalized_cap'."
58+
)
59+
60+
if continuous_column_names is not None and not set(real_data.columns).issuperset(
61+
set(continuous_column_names)
62+
):
63+
missing = "', '".join(set(continuous_column_names) - set(real_data.columns))
64+
raise ValueError(f"Continous column(s) '{missing}' are missing from the real data.")
65+
66+
if not isinstance(num_discrete_bins, int) or num_discrete_bins <= 0:
67+
raise ValueError('`num_discrete_bins` must be an integer greater than zero.')
68+
69+
super()._validate_inputs(real_data, synthetic_data)
70+
71+
@classmethod
72+
def _get_null_categories(cls, real_data, synthetic_data, columns):
73+
base_null_value = '__NULL_VALUE__'
74+
null_category_map = {}
75+
for col in columns:
76+
null_value = base_null_value
77+
categories = set(real_data[col].unique()).union(set(synthetic_data[col].unique()))
78+
while null_value in categories:
79+
null_value += '_'
80+
81+
null_category_map[col] = null_value
82+
83+
return null_category_map
84+
85+
@classmethod
86+
def _discretize_column(cls, real_column, synthetic_column, num_bins):
87+
bin_labels = [str(x) for x in range(num_bins)]
88+
real_binned, bins = pd.cut(
89+
pd.to_numeric(real_column.to_numpy()), num_bins, labels=bin_labels, retbins=True
90+
)
91+
bins[0], bins[-1] = -np.inf, np.inf
92+
synthetic_binned = pd.cut(
93+
pd.to_numeric(synthetic_column.to_numpy()), bins, labels=bin_labels
94+
)
95+
96+
return real_binned.to_numpy(), synthetic_binned.to_numpy()
97+
98+
@classmethod
99+
def _compute_baseline(cls, real_data, sensitive_column_names):
100+
unique_categories_prod = np.prod([
101+
real_data[col].nunique(dropna=False) for col in sensitive_column_names
102+
])
103+
return 1 - float(1 / unique_categories_prod)
104+
105+
@classmethod
106+
def compute_breakdown(
107+
cls,
108+
real_data,
109+
synthetic_data,
110+
known_column_names,
111+
sensitive_column_names,
112+
computation_method='cap',
113+
continuous_column_names=None,
114+
num_discrete_bins=10,
115+
):
116+
"""Compute this metric breakdown.
117+
118+
Args:
119+
real_data (pd.DataFrame):
120+
A pd.DataFrame with the real data.
121+
synthetic_data (pd.DataFrame):
122+
A pd.DataFrame with the synthetic data.
123+
known_column_names (list[str]):
124+
A list with the string names of the columns that an attacker may already know.
125+
sensitive_column_names (list[str]):
126+
A list with the string names of the columns that an attacker wants to guess
127+
(but does not already know).
128+
computation_method (str, optional):
129+
The type of computation we'll use to simulate the attack. Options are:
130+
- 'cap': Use the CAP method described in the original paper.
131+
- 'generalized_cap': Use the generalized CAP method.
132+
- 'zero_cap': Use the zero cap method.
133+
Defaults to 'cap'.
134+
continuous_column_names (list[str], optional):
135+
A list of column names that represent continuous values (as opposed to discrete
136+
values). These columns will be discretized. Defaults to None.
137+
num_discrete_bins (int, optional):
138+
Number of bins to discretize continous columns in to. Defaults to 10.
139+
140+
Returns:
141+
dict
142+
Mapping of the metric output with the keys:
143+
- 'score': The overall score for the metric.
144+
- 'cap_protection': The protection score from the selected computation method.
145+
- 'baseline_protection': The baseline protection for the columns.
146+
"""
147+
cls._validate_inputs(
148+
real_data,
149+
synthetic_data,
150+
known_column_names,
151+
sensitive_column_names,
152+
computation_method,
153+
continuous_column_names,
154+
num_discrete_bins,
155+
)
156+
computation_method = computation_method.upper()
157+
real_data = real_data.copy()
158+
synthetic_data = synthetic_data.copy()
159+
160+
# Discretize continous columns
161+
if continuous_column_names is not None:
162+
for col_name in continuous_column_names:
163+
real_data[col_name], synthetic_data[col_name] = cls._discretize_column(
164+
real_data[col_name], synthetic_data[col_name], num_discrete_bins
165+
)
166+
167+
# Convert null values to own category
168+
null_category_map = cls._get_null_categories(
169+
real_data, synthetic_data, known_column_names + sensitive_column_names
170+
)
171+
real_data = real_data.fillna(null_category_map)
172+
synthetic_data = synthetic_data.fillna(null_category_map)
173+
174+
# Compute baseline
175+
baseline_protection = cls._compute_baseline(real_data, sensitive_column_names)
176+
177+
# Compute CAP metric
178+
cap_metric = CAP_METHODS.get(computation_method)
179+
cap_protection = cap_metric.compute(
180+
real_data,
181+
synthetic_data,
182+
key_fields=known_column_names,
183+
sensitive_fields=sensitive_column_names,
184+
)
185+
186+
if baseline_protection == 0:
187+
score = 0 if cap_protection == 0 else 1
188+
else:
189+
score = min(cap_protection / baseline_protection, 1)
190+
191+
return {
192+
'score': score,
193+
'cap_protection': cap_protection,
194+
'baseline_protection': baseline_protection,
195+
}
196+
197+
@classmethod
198+
def compute(
199+
cls,
200+
real_data,
201+
synthetic_data,
202+
known_column_names,
203+
sensitive_column_names,
204+
computation_method='cap',
205+
continuous_column_names=None,
206+
num_discrete_bins=10,
207+
):
208+
"""Compute the DisclosureProtection metric.
209+
210+
Args:
211+
real_data (pd.DataFrame):
212+
A pd.DataFrame with the real data.
213+
synthetic_data (pd.DataFrame):
214+
A pd.DataFrame with the synthetic data.
215+
known_column_names (list[str]):
216+
A list with the string names of the columns that an attacker may already know.
217+
sensitive_column_names (list[str]):
218+
A list with the string names of the columns that an attacker wants to guess
219+
(but does not know).
220+
computation_method (str, optional):
221+
The type of computation we'll use to simulate the attack. Options are:
222+
- 'cap': Use the CAP method described in the original paper.
223+
- 'generalized_cap': Use the generalized CAP method.
224+
- 'zero_cap': Use the zero cap method.
225+
Defaults to 'cap'.
226+
continuous_column_names (list[str], optional):
227+
A list of column names that represent continuous values (as opposed to discrete
228+
values). These columns will be discretized. Defaults to None.
229+
num_discrete_bins (int, optional):
230+
Number of bins to discretize continous columns in to. Defaults to 10.
231+
232+
Returns:
233+
float:
234+
The score for the DisclosureProtection metric.
235+
"""
236+
score_breakdown = cls.compute_breakdown(
237+
real_data,
238+
synthetic_data,
239+
known_column_names,
240+
sensitive_column_names,
241+
computation_method,
242+
continuous_column_names,
243+
num_discrete_bins,
244+
)
245+
return score_breakdown['score']

0 commit comments

Comments
 (0)