Skip to content

Commit 4d56fe7

Browse files
authored
Add warning when unable to turn off rounding scheme for a column (#2279)
1 parent 0fe0123 commit 4d56fe7

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed

sdv/single_table/base.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ def _validate_transformers(self, column_name_to_transformer):
244244
f"Transformer for column '{column}' has already been fit on data."
245245
)
246246

247-
def _warn_for_update_transformers(self, column_name_to_transformer):
248-
"""Raise warnings for update_transformers.
247+
def _warn_quality_and_performance(self, column_name_to_transformer):
248+
"""Raise warning if the quality/performance may be impacted.
249249
250250
Args:
251251
column_name_to_transformer (dict):
@@ -259,6 +259,24 @@ def _warn_for_update_transformers(self, column_name_to_transformer):
259259
'might impact the quality of your synthetic data.'
260260
)
261261

262+
def _warn_unable_to_enforce_rounding(self, column_name_to_transformer):
263+
if self.enforce_rounding:
264+
invalid_columns = []
265+
for column, transformer in column_name_to_transformer.items():
266+
if (
267+
hasattr(transformer, 'learn_rounding_scheme')
268+
and not transformer.learn_rounding_scheme
269+
):
270+
invalid_columns.append(column)
271+
272+
if invalid_columns:
273+
warnings.warn(
274+
f'Unable to turn off rounding scheme for column(s) {invalid_columns}, '
275+
'because the overall synthesizer is enforcing rounding. We '
276+
"recommend setting the synthesizer's 'enforce_rounding' "
277+
'parameter to False.'
278+
)
279+
262280
def update_transformers(self, column_name_to_transformer):
263281
"""Update any of the transformers assigned to each of the column names.
264282
@@ -267,7 +285,8 @@ def update_transformers(self, column_name_to_transformer):
267285
Dict mapping column names to transformers to be used for that column.
268286
"""
269287
self._validate_transformers(column_name_to_transformer)
270-
self._warn_for_update_transformers(column_name_to_transformer)
288+
self._warn_quality_and_performance(column_name_to_transformer)
289+
self._warn_unable_to_enforce_rounding(column_name_to_transformer)
271290
self._data_processor.update_transformers(column_name_to_transformer)
272291
if self._fitted:
273292
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'

sdv/single_table/copulas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ def _fit_model(self, processed_data):
161161
warnings.filterwarnings('ignore', module='scipy')
162162
self._model.fit(processed_data)
163163

164-
def _warn_for_update_transformers(self, column_name_to_transformer):
165-
"""Raise warnings for update_transformers.
164+
def _warn_quality_and_performance(self, column_name_to_transformer):
165+
"""Raise warning if the quality/performance may be impacted.
166166
167167
Args:
168168
column_name_to_transformer (dict):

tests/integration/single_table/test_base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,22 @@ def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
844844
)
845845
with pytest.raises(SynthesizerInputError, match=message):
846846
instance.fit(data)
847+
848+
849+
@patch('sdv.single_table.base.warnings')
850+
def test_update_transformers(warning_mock):
851+
"""Test the proper warning is raised."""
852+
# Setup
853+
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')
854+
855+
# Run
856+
synthesizer = GaussianCopulaSynthesizer(metadata)
857+
synthesizer.auto_assign_transformers(data)
858+
synthesizer.update_transformers({'amenities_fee': FloatFormatter(learn_rounding_scheme=False)})
859+
860+
# Assert
861+
warning_mock.warn.assert_called_once_with(
862+
"Unable to turn off rounding scheme for column(s) ['amenities_fee'], because the overall "
863+
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
864+
"'enforce_rounding' parameter to False."
865+
)

tests/unit/single_table/test_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,31 @@ def test_update_transformers(self):
834834
assert isinstance(field_transformers['col1'], GaussianNormalizer)
835835
assert isinstance(field_transformers['col2'], GaussianNormalizer)
836836

837+
def test_update_transformers_warns_rounding(self):
838+
"""Test warning is raised if model cannot round."""
839+
# Setup
840+
column_name_to_transformer = {
841+
'col1': GaussianNormalizer(learn_rounding_scheme=False),
842+
'col2': GaussianNormalizer(learn_rounding_scheme=True),
843+
'col3': GaussianNormalizer(learn_rounding_scheme=False),
844+
}
845+
metadata = Metadata()
846+
instance = BaseSingleTableSynthesizer(metadata)
847+
instance._validate_transformers = MagicMock()
848+
instance._warn_quality_and_performance = MagicMock()
849+
instance._data_processor = MagicMock()
850+
instance.enforce_rounding = True
851+
instance._fitted = False
852+
853+
# Run and Assert
854+
warn_msg = re.escape(
855+
"Unable to turn off rounding scheme for column(s) ['col1', 'col3'], "
856+
'because the overall synthesizer is enforcing rounding. We recommend '
857+
"setting the synthesizer's 'enforce_rounding' parameter to False."
858+
)
859+
with pytest.warns(UserWarning, match=warn_msg):
860+
instance.update_transformers(column_name_to_transformer)
861+
837862
@patch('sdv.single_table.base.DataProcessor')
838863
def test__set_random_state(self, mock_data_processor):
839864
"""Test that ``_model.set_random_state`` is being called with the input value.

0 commit comments

Comments
 (0)