Skip to content

Commit 37527fe

Browse files
authored
Ensure datetime_format warning is only raised once (#2737)
1 parent 526cd12 commit 37527fe

File tree

5 files changed

+82
-3
lines changed

5 files changed

+82
-3
lines changed

sdv/metadata/single_table.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,7 @@ def validate_data(self, data, sdtype_warnings=None):
12961296
A warning is being raised if ``datetime_format`` is missing from a column represented
12971297
as ``object`` in the dataframe and its sdtype is ``datetime``.
12981298
"""
1299+
_datetime_format_warning_flag = sdtype_warnings is not None
12991300
sdtype_warnings = sdtype_warnings if sdtype_warnings is not None else defaultdict(list)
13001301
if not isinstance(data, pd.DataFrame):
13011302
raise ValueError(f'Data must be a DataFrame, not a {type(data)}.')
@@ -1315,7 +1316,7 @@ def validate_data(self, data, sdtype_warnings=None):
13151316
errors += self._validate_column_data(data[column], sdtype_warnings)
13161317

13171318
errors += self._validate_primary_key(data)
1318-
if sdtype_warnings is not None and len(sdtype_warnings):
1319+
if (not _datetime_format_warning_flag) and len(sdtype_warnings):
13191320
df = pd.DataFrame(sdtype_warnings)
13201321
message = (
13211322
"No 'datetime_format' is present in the metadata for the following columns:\n"

sdv/multi_table/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def _initialize_models(self):
9090
self._table_synthesizers[table_name] = self._synthesizer(
9191
metadata=metadata, **synthesizer_parameters
9292
)
93+
# Mark synthesizer as embedded in a multi-table setting
94+
# so it can suppres datetime_format warnings that are aggregated here
95+
self._table_synthesizers[table_name]._suppress_datetime_format_warning = True
9396
self._table_synthesizers[table_name]._data_processor.table_name = table_name
9497

9598
def _get_pbar_args(self, **kwargs):
@@ -404,6 +407,8 @@ def set_table_parameters(self, table_name, table_parameters):
404407
self._table_synthesizers[table_name] = self._synthesizer(
405408
metadata=table_metadata, **table_parameters
406409
)
410+
# Mark synthesizer as embedded in a multi-table setting to avoid duplicate datetime warnings
411+
self._table_synthesizers[table_name]._suppress_datetime_format_warning = True
407412
self._table_synthesizers[table_name]._data_processor.table_name = table_name
408413
self._table_parameters[table_name].update(deepcopy(table_parameters))
409414

sdv/single_table/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,18 @@ def validate(self, data):
561561
data (pandas.DataFrame):
562562
The data to validate.
563563
"""
564-
self._original_metadata.validate_data({self._table_name: data})
564+
# Suppress duplicate datetime_format warning only when this single-table synthesizer
565+
# is embedded inside a multi-table synthesizer
566+
if getattr(self, '_suppress_datetime_format_warning', False):
567+
with warnings.catch_warnings():
568+
warnings.filterwarnings(
569+
'ignore',
570+
message=r"No 'datetime_format' is present.*",
571+
category=UserWarning,
572+
)
573+
self._original_metadata.validate_data({self._table_name: data})
574+
else:
575+
self._original_metadata.validate_data({self._table_name: data})
565576
self._validate_transform_constraints(data, enforce_constraint_fitting=True)
566577

567578
# Retaining the logic of returning errors and raising them here to maintain consistency

sdv/utils/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import datetime
44
import sys
5+
import warnings
56
from copy import deepcopy
67

78
import cloudpickle
@@ -55,7 +56,14 @@ def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=T
5556
})
5657
metadata.validate()
5758
try:
58-
metadata.validate_data(data)
59+
# Suppress duplicate datetime_format warnings during referential integrity validation.
60+
with warnings.catch_warnings():
61+
warnings.filterwarnings(
62+
'ignore',
63+
message=r"No 'datetime_format' is present.*",
64+
category=UserWarning,
65+
)
66+
metadata.validate_data(data)
5967
if drop_missing_values:
6068
_validate_foreign_keys_not_null(metadata, data)
6169

tests/integration/multi_table/test_hma.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,3 +2802,57 @@ def test_end_to_end_with_constraints():
28022802

28032803
# Assert
28042804
synthesizer.validate(synthetic_data)
2805+
2806+
2807+
def test_datetime_warning_doesnt_repeat():
2808+
"""Test that the datetime warning doesn't repeat GH#2739."""
2809+
# Setup
2810+
composite_data = {
2811+
'main': pd.DataFrame({
2812+
'pk': [1, 2, 3, 4, 5],
2813+
'denormalized_primary_key_1': [1, 1, 2, 2, 5],
2814+
'denormalized_primary_key_2': ['a', 'a', 'b', 'c', 'c'],
2815+
'denormalized_column': [
2816+
'2020-01-01',
2817+
'2020-01-01',
2818+
'2020-01-02',
2819+
'2020-01-02',
2820+
'2020-01-03',
2821+
],
2822+
'other_col': ['2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'],
2823+
})
2824+
}
2825+
2826+
composite_metadata = Metadata.load_from_dict({
2827+
'tables': {
2828+
'main': {
2829+
'columns': {
2830+
'pk': {'sdtype': 'id'},
2831+
'denormalized_primary_key_1': {'sdtype': 'id'},
2832+
'denormalized_primary_key_2': {'sdtype': 'categorical'},
2833+
'denormalized_column': {'sdtype': 'datetime'},
2834+
'other_col': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
2835+
},
2836+
'primary_key': 'pk',
2837+
},
2838+
},
2839+
})
2840+
2841+
comp_synth = HMASynthesizer(composite_metadata)
2842+
2843+
# Run
2844+
with warnings.catch_warnings(record=True) as w:
2845+
warnings.simplefilter('always')
2846+
comp_synth.fit(composite_data)
2847+
comp_synth.sample(1)
2848+
2849+
# Assert
2850+
msg = (
2851+
"No 'datetime_format' is present in the metadata for the following columns:\n"
2852+
' Table Name Column Name sdtype datetime_format\n'
2853+
' main denormalized_column datetime None\n'
2854+
'Without this specification, SDV may not be able to accurately parse the data. '
2855+
"We recommend adding datetime formats using 'update_column'."
2856+
)
2857+
matching_warnings = [warning for warning in w if str(warning.message) == msg]
2858+
assert len(matching_warnings) == 1

0 commit comments

Comments
 (0)