Skip to content

Commit 18cd2e5

Browse files
Raise UserWarnings for Unused Numerical Distributions when using GaussianCopulaSynthesizer (#2301)
1 parent 5741ee5 commit 18cd2e5

File tree

7 files changed

+62
-27
lines changed

7 files changed

+62
-27
lines changed

sdv/single_table/copulagan.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from sdv.single_table.copulas import GaussianCopulaSynthesizer
99
from sdv.single_table.ctgan import CTGANSynthesizer
1010
from sdv.single_table.utils import (
11-
log_numerical_distributions_error,
1211
validate_numerical_distributions,
12+
warn_missing_numerical_distributions,
1313
)
1414

1515
LOGGER = logging.getLogger(__name__)
@@ -204,10 +204,7 @@ def _fit(self, processed_data):
204204
processed_data (pandas.DataFrame):
205205
Data to be learned.
206206
"""
207-
log_numerical_distributions_error(
208-
self.numerical_distributions, processed_data.columns, LOGGER
209-
)
210-
207+
warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns)
211208
gaussian_normalizer_config = self._create_gaussian_normalizer_config(processed_data)
212209
self._gaussian_normalizer_hyper_transformer = rdt.HyperTransformer()
213210
self._gaussian_normalizer_hyper_transformer.set_config(gaussian_normalizer_config)

sdv/single_table/copulas.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from sdv.single_table.base import BaseSingleTableSynthesizer
1818
from sdv.single_table.utils import (
1919
flatten_dict,
20-
log_numerical_distributions_error,
2120
unflatten_dict,
2221
validate_numerical_distributions,
22+
warn_missing_numerical_distributions,
2323
)
2424

2525
LOGGER = logging.getLogger(__name__)
@@ -132,9 +132,7 @@ def _fit(self, processed_data):
132132
processed_data (pandas.DataFrame):
133133
Data to be learned.
134134
"""
135-
log_numerical_distributions_error(
136-
self.numerical_distributions, processed_data.columns, LOGGER
137-
)
135+
warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns)
138136
self._num_rows = self._learn_num_rows(processed_data)
139137
numerical_distributions = self._get_numerical_distributions(processed_data)
140138
self._model = self._initialize_model(numerical_distributions)

sdv/single_table/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,12 @@ def validate_numerical_distributions(numerical_distributions, metadata_columns):
330330
)
331331

332332

333-
def log_numerical_distributions_error(numerical_distributions, processed_data_columns, logger):
334-
"""Log error when numerical distributions columns don't exist anymore."""
333+
def warn_missing_numerical_distributions(numerical_distributions, processed_data_columns):
334+
"""Raise an `UserWarning` when numerical distribution columns don't exist anymore."""
335335
unseen_columns = numerical_distributions.keys() - set(processed_data_columns)
336336
for column in unseen_columns:
337-
logger.info(
338-
f"Requested distribution '{numerical_distributions[column]}' "
339-
f"cannot be applied to column '{column}' because it no longer "
340-
'exists after preprocessing.'
337+
warnings.warn(
338+
f"Cannot use distribution '{numerical_distributions[column]}' for column "
339+
f"'{column}' because the column is not statistically modeled.",
340+
UserWarning,
341341
)

tests/integration/single_table/test_copulas.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,3 +500,25 @@ def test_support_nullable_pandas_dtypes():
500500
assert (synthetic_data.dtypes == data.dtypes).all()
501501
assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True)
502502
assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True)
503+
504+
505+
def test_user_warning_for_unused_numerical_distribution():
506+
"""Ensure that a `UserWarning` is raised when a numerical distribution is not applied.
507+
508+
This test verifies that the synthesizer warns the user if a specified numerical
509+
distribution is not used because the corresponding column does not exist or is not
510+
modeled after preprocessing.
511+
"""
512+
# Setup
513+
data, metadata = download_demo('single_table', 'fake_hotel_guests')
514+
synthesizer = GaussianCopulaSynthesizer(
515+
metadata, numerical_distributions={'credit_card_number': 'beta'}
516+
)
517+
518+
# Run and Assert
519+
message = (
520+
"Cannot use distribution 'beta' for column 'credit_card_number' because the column is not "
521+
'statistically modeled.'
522+
)
523+
with pytest.warns(UserWarning, match=message):
524+
synthesizer.fit(data)

tests/unit/single_table/test_copulagan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,10 @@ def test__create_gaussian_normalizer_config(self, mock_rdt):
263263
assert config == expected_config
264264
assert mock_rdt.transformers.GaussianNormalizer.call_args_list == expected_calls
265265

266-
@patch('sdv.single_table.copulagan.LOGGER')
266+
@patch('sdv.single_table.utils.warnings')
267267
@patch('sdv.single_table.copulagan.CTGANSynthesizer._fit')
268268
@patch('sdv.single_table.copulagan.rdt')
269-
def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger):
269+
def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_warnings):
270270
"""Test a message is logged.
271271
272272
A message should be logged if the columns passed in ``numerical_distributions``
@@ -284,10 +284,11 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger):
284284
instance._fit(processed_data)
285285

286286
# Assert
287-
mock_logger.info.assert_called_once_with(
288-
"Requested distribution 'gamma' cannot be applied to column 'col' "
289-
'because it no longer exists after preprocessing.'
287+
warning_message = (
288+
"Cannot use distribution 'gamma' for column 'col' because the column is not "
289+
'statistically modeled.'
290290
)
291+
mock_warnings.warn.assert_called_once_with(warning_message, UserWarning)
291292

292293
@patch('sdv.single_table.copulagan.CTGANSynthesizer._fit')
293294
@patch('sdv.single_table.copulagan.rdt')

tests/unit/single_table/test_copulas.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def test_get_parameters(self):
159159
'default_distribution': 'beta',
160160
}
161161

162-
@patch('sdv.single_table.copulas.LOGGER')
163-
def test__fit_logging(self, mock_logger):
164-
"""Test a message is logged.
162+
@patch('sdv.single_table.utils.warnings')
163+
def test__fit_warning_numerical_distributions(self, mock_warnings):
164+
"""Test that a warning is shown when fitting numerical distributions on a dropped column.
165165
166-
A message should be logged if the columns passed in ``numerical_distributions``
166+
A warning message should be printed if the columns passed in ``numerical_distributions``
167167
were renamed/dropped during preprocessing.
168168
"""
169169
# Setup
@@ -180,10 +180,11 @@ def test__fit_logging(self, mock_logger):
180180
instance._fit(processed_data)
181181

182182
# Assert
183-
mock_logger.info.assert_called_once_with(
184-
"Requested distribution 'gamma' cannot be applied to column 'col' "
185-
'because it no longer exists after preprocessing.'
183+
warning_message = (
184+
"Cannot use distribution 'gamma' for column 'col' because the column is not "
185+
'statistically modeled.'
186186
)
187+
mock_warnings.warn.assert_called_once_with(warning_message, UserWarning)
187188

188189
@patch('sdv.single_table.copulas.warnings')
189190
@patch('sdv.single_table.copulas.multivariate')

tests/unit/single_table/test_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
handle_sampling_error,
1515
unflatten_dict,
1616
validate_file_path,
17+
warn_missing_numerical_distributions,
1718
)
1819

1920

@@ -328,3 +329,18 @@ def test_validate_file_path(mock_open):
328329
assert output_path in result
329330
assert none_result is None
330331
mock_open.assert_called_once_with(result, 'w+')
332+
333+
334+
def test_warn_missing_numerical_distributions():
335+
"""Test the warn_missing_numerical_distributions function."""
336+
# Setup
337+
numerical_distributions = {'age': 'beta', 'height': 'uniform'}
338+
processed_data_columns = ['height', 'weight']
339+
340+
# Run and Assert
341+
message = (
342+
"Cannot use distribution 'beta' for column 'age' because the column is not "
343+
'statistically modeled.'
344+
)
345+
with pytest.warns(UserWarning, match=message):
346+
warn_missing_numerical_distributions(numerical_distributions, processed_data_columns)

0 commit comments

Comments
 (0)