Skip to content

Commit c472caa

Browse files
authored
Add learning_strategy parameter to OneHotEncoding constraint (#2658)
1 parent a400b4e commit c472caa

File tree

3 files changed

+339
-19
lines changed

3 files changed

+339
-19
lines changed

sdv/cag/one_hot_encoding.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
import numpy as np
66

7+
from sdv._utils import _create_unique_name
78
from sdv.cag._errors import ConstraintNotMetError
89
from sdv.cag._utils import (
910
_get_is_valid_dict,
1011
_is_list_of_type,
12+
_remove_columns_from_metadata,
1113
_validate_table_and_column_names,
1214
_validate_table_name_if_defined,
1315
)
@@ -30,20 +32,30 @@ class OneHotEncoding(BaseConstraint):
3032
table_name (str, optional):
3133
The name of the table that contains the columns. Optional if the
3234
data is only a single table. Defaults to None.
35+
learning_strategy (str, optional):
36+
Strategy for how the model should learn the one-hot fields. Supported values:
37+
- 'one_hot' (default): Learn each one-hot column separately.
38+
- 'categorical': Internally collapse the one-hot columns into a single categorical
39+
column for the model to learn, then expand back to one-hot at sampling time.
3340
"""
3441

3542
@staticmethod
36-
def _validate_init_inputs(column_names, table_name):
43+
def _validate_init_inputs(column_names, table_name, learning_strategy):
3744
if not _is_list_of_type(column_names):
3845
raise ValueError('`column_names` must be a list of strings.')
3946

4047
_validate_table_name_if_defined(table_name)
4148

42-
def __init__(self, column_names, table_name=None):
49+
if learning_strategy not in ['one_hot', 'categorical']:
50+
raise ValueError("`learning_strategy` must be either 'one_hot' or 'categorical'.")
51+
52+
def __init__(self, column_names, table_name=None, learning_strategy='one_hot'):
4353
super().__init__()
44-
self._validate_init_inputs(column_names, table_name)
54+
self._validate_init_inputs(column_names, table_name, learning_strategy)
4555
self._column_names = column_names
4656
self.table_name = table_name
57+
self.learning_strategy = learning_strategy
58+
self._categorical_column = '#'.join(self._column_names)
4759

4860
def _validate_constraint_with_metadata(self, metadata):
4961
"""Validate the constraint is compatible with the provided metadata.
@@ -94,12 +106,25 @@ def _fit(self, data, metadata):
94106

95107
def _get_updated_metadata(self, metadata):
96108
table_name = self._get_single_table_name(metadata)
97-
metadata = deepcopy(metadata)
98-
for column in self._column_names:
99-
if metadata.tables[table_name].columns[column]['sdtype'] in ['categorical', 'boolean']:
100-
metadata.tables[table_name].columns[column]['sdtype'] = 'numerical'
101-
102-
return metadata
109+
if self.learning_strategy == 'categorical':
110+
self._categorical_column = _create_unique_name(
111+
self._categorical_column, metadata.tables[table_name].columns
112+
)
113+
md = metadata.to_dict()
114+
md['tables'][table_name]['columns'][self._categorical_column] = {
115+
'sdtype': 'categorical'
116+
}
117+
return _remove_columns_from_metadata(md, table_name, columns_to_drop=self._column_names)
118+
119+
else:
120+
metadata = deepcopy(metadata)
121+
for column in self._column_names:
122+
if metadata.tables[table_name].columns[column]['sdtype'] in [
123+
'categorical',
124+
'boolean',
125+
]:
126+
metadata.tables[table_name].columns[column]['sdtype'] = 'numerical'
127+
return metadata
103128

104129
def _transform(self, data):
105130
"""Transform the data.
@@ -113,9 +138,15 @@ def _transform(self, data):
113138
Transformed data.
114139
"""
115140
table_name = self._get_single_table_name(self.metadata)
116-
one_hot_data = data[table_name][self._column_names]
117-
one_hot_data = np.where(one_hot_data == 0, EPSILON, 1 - EPSILON)
118-
data[table_name][self._column_names] = one_hot_data
141+
if self.learning_strategy == 'categorical':
142+
table_data = data[table_name]
143+
categories = table_data[self._column_names].idxmax(axis=1)
144+
table_data[self._categorical_column] = categories
145+
data[table_name] = table_data.drop(self._column_names, axis=1)
146+
else:
147+
one_hot_data = data[table_name][self._column_names]
148+
one_hot_data = np.where(one_hot_data == 0, EPSILON, 1 - EPSILON)
149+
data[table_name][self._column_names] = one_hot_data
119150

120151
return data
121152

@@ -134,13 +165,28 @@ def _reverse_transform(self, data):
134165
"""
135166
table_name = self._get_single_table_name(self.metadata)
136167
table_data = data[table_name]
137-
one_hot_data = table_data[self._column_names]
138-
transformed_data = np.zeros_like(one_hot_data.to_numpy())
139-
max_category_indices = np.argmax(one_hot_data.to_numpy(), axis=1)
140-
transformed_data[np.arange(len(one_hot_data)), max_category_indices] = 1
141-
table_data[self._column_names] = transformed_data
142-
data[table_name] = table_data
143168

169+
if self.learning_strategy == 'categorical':
170+
categories = table_data.pop(self._categorical_column)
171+
num_rows = len(table_data)
172+
num_cols = len(self._column_names)
173+
transformed = np.zeros((num_rows, num_cols), dtype=float)
174+
175+
column_to_index = {name: idx for idx, name in enumerate(self._column_names)}
176+
indices = categories.map(lambda x: column_to_index[x]).to_numpy()
177+
transformed[np.arange(num_rows), indices] = 1
178+
179+
for idx, col in enumerate(self._column_names):
180+
table_data[col] = transformed[:, idx]
181+
182+
else:
183+
one_hot_data = table_data[self._column_names]
184+
transformed_data = np.zeros_like(one_hot_data.to_numpy())
185+
max_category_indices = np.argmax(one_hot_data.to_numpy(), axis=1)
186+
transformed_data[np.arange(len(one_hot_data)), max_category_indices] = 1
187+
table_data[self._column_names] = transformed_data
188+
189+
data[table_name] = table_data
144190
return data
145191

146192
def _is_valid(self, data, metadata):

tests/integration/cag/test_one_hot_encoding.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sdv.cag._errors import ConstraintNotMetError
99
from sdv.metadata import Metadata
1010
from sdv.single_table import GaussianCopulaSynthesizer
11-
from tests.utils import run_copula, run_hma
11+
from tests.utils import run_constraint, run_copula, run_hma
1212

1313

1414
@pytest.fixture()
@@ -201,3 +201,144 @@ def test_end_to_end_boolean():
201201
assert (samples.sum(axis=1) == 1).all()
202202
for col in columns:
203203
assert sorted(samples[col].unique().tolist()) == [0, 1]
204+
205+
206+
def test_end_to_end_categorical_single(data, metadata):
207+
"""End-to-end with learning_strategy='categorical' for single-table data."""
208+
# Setup
209+
constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical')
210+
211+
# Run
212+
synthesizer = run_copula(data, metadata, [constraint])
213+
synthetic_data = synthesizer.sample(200)
214+
synthesizer.validate_constraints(synthetic_data=synthetic_data)
215+
216+
# Assert
217+
assert set(synthetic_data.columns) == {'a', 'b', 'c'}
218+
for col in ['a', 'b', 'c']:
219+
assert set(synthetic_data[col]) == {0, 1}
220+
assert (synthetic_data[['a', 'b', 'c']].sum(axis=1) == 1).all()
221+
222+
223+
def test_end_to_end_categorical_single_raises(data, metadata):
224+
"""Invalid synthetic data should raise with learning_strategy='categorical'."""
225+
# Setup
226+
invalid_data = pd.DataFrame({
227+
'a': [1, 2, 0],
228+
'b': [0, 1, np.nan],
229+
'c': [0, 0, 3],
230+
})
231+
constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical')
232+
233+
# Run and Assert
234+
msg = re.escape(
235+
"Data is not valid for the 'OneHotEncoding' constraint in table 'table':\n"
236+
' a b c\n'
237+
'1 2 1.0 0\n'
238+
'2 0 NaN 3'
239+
)
240+
with pytest.raises(ConstraintNotMetError, match=msg):
241+
run_copula(invalid_data, metadata, [constraint])
242+
243+
# Run and Assert
244+
msg = re.escape('The one hot encoding requirement is not met for row indices: 1, 2')
245+
with pytest.raises(ConstraintNotMetError, match=msg):
246+
synthesizer = run_copula(data, metadata, [constraint])
247+
synthesizer.validate_constraints(synthetic_data=invalid_data)
248+
249+
250+
def test_end_to_end_categorical_multi(data_multi, metadata_multi):
251+
"""End-to-end with learning_strategy='categorical' for multi-table data."""
252+
# Setup
253+
constraint = OneHotEncoding(
254+
column_names=['a', 'b', 'c'], table_name='table1', learning_strategy='categorical'
255+
)
256+
257+
# Run
258+
synthesizer = run_hma(data_multi, metadata_multi, [constraint])
259+
synthetic = synthesizer.sample(200)
260+
synthesizer.validate_constraints(synthetic_data=synthetic)
261+
262+
# Assert
263+
assert set(synthetic['table1'].columns) == {'a', 'b', 'c'}
264+
for col in ['a', 'b', 'c']:
265+
assert set(synthetic['table1'][col]) == {0, 1}
266+
assert (synthetic['table1'][['a', 'b', 'c']].sum(axis=1) == 1).all()
267+
268+
269+
def test_end_to_end_categorical_multi_raises(data_multi, metadata_multi):
270+
"""Invalid multi-table synthetic data should raise with learning_strategy='categorical'."""
271+
# Setup
272+
constraint = OneHotEncoding(
273+
column_names=['a', 'b', 'c'], table_name='table1', learning_strategy='categorical'
274+
)
275+
invalid = {
276+
'table1': pd.DataFrame({
277+
'a': [1, 2, 0],
278+
'b': [0, 1, np.nan],
279+
'c': [0, 0, 3],
280+
}),
281+
'table2': pd.DataFrame({'id': range(5)}),
282+
}
283+
284+
# Run and Assert
285+
msg = re.escape(
286+
"Data is not valid for the 'OneHotEncoding' constraint in table 'table1':\n "
287+
'a b c\n1 2 1.0 0\n2 0 NaN 3'
288+
)
289+
with pytest.raises(ConstraintNotMetError, match=msg):
290+
run_hma(invalid, metadata_multi, [constraint])
291+
292+
# Run and Assert
293+
msg = "Table 'table1': The one hot encoding requirement is not met for row indices: 1, 2."
294+
with pytest.raises(ConstraintNotMetError, match=msg):
295+
synthesizer = run_hma(data_multi, metadata_multi, [constraint])
296+
synthesizer.validate_constraints(synthetic_data=invalid)
297+
298+
299+
def test_constraint_pipeline_categorical_single(data, metadata):
300+
"""Constraint pipeline behavior for categorical strategy (single table)."""
301+
# Setup
302+
constraint = OneHotEncoding(column_names=['a', 'b', 'c'], learning_strategy='categorical')
303+
304+
# Run
305+
updated_metadata, transformed, reverse_transformed = run_constraint(constraint, data, metadata)
306+
307+
# Assert metadata
308+
assert updated_metadata.get_column_names() == ['a#b#c']
309+
310+
# Assert transform
311+
assert transformed.shape[1] == 1
312+
assert not any(col in transformed.columns for col in ['a', 'b', 'c'])
313+
assert set(transformed.columns) == {'a#b#c'}
314+
315+
# Assert reverse_transform
316+
assert set(reverse_transformed.columns) == {'a', 'b', 'c'}
317+
assert (reverse_transformed[['a', 'b', 'c']].sum(axis=1) == 1).all()
318+
assert set(reverse_transformed.columns) == {'a', 'b', 'c'}
319+
320+
321+
def test_constraint_pipeline_categorical_multi(data_multi, metadata_multi):
322+
"""Constraint pipeline behavior for categorical strategy (multi table)."""
323+
# Setup
324+
orig_cols = ['a', 'b', 'c']
325+
constraint = OneHotEncoding(
326+
column_names=orig_cols, table_name='table1', learning_strategy='categorical'
327+
)
328+
329+
# Run
330+
updated_metadata, transformed, reverse_transformed = run_constraint(
331+
constraint, data_multi, metadata_multi
332+
)
333+
334+
# Assert metadata
335+
assert updated_metadata.tables['table1'].get_column_names() == ['a#b#c']
336+
337+
# Assert transform
338+
assert list(transformed['table1'].columns) != orig_cols
339+
assert transformed['table1'].shape[1] == 1
340+
assert list(transformed['table2'].columns) == list(data_multi['table2'].columns)
341+
342+
# Assert reverse_transform
343+
assert set(reverse_transformed['table1'].columns) == set(orig_cols)
344+
assert (reverse_transformed['table1'][orig_cols].sum(axis=1) == 1).all()

0 commit comments

Comments
 (0)