diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index fc396def5..6f0e1aaa9 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1296,6 +1296,7 @@ def validate_data(self, data, sdtype_warnings=None): A warning is being raised if ``datetime_format`` is missing from a column represented as ``object`` in the dataframe and its sdtype is ``datetime``. """ + _datetime_format_warning_flag = sdtype_warnings is not None sdtype_warnings = sdtype_warnings if sdtype_warnings is not None else defaultdict(list) if not isinstance(data, pd.DataFrame): raise ValueError(f'Data must be a DataFrame, not a {type(data)}.') @@ -1315,7 +1316,7 @@ def validate_data(self, data, sdtype_warnings=None): errors += self._validate_column_data(data[column], sdtype_warnings) errors += self._validate_primary_key(data) - if sdtype_warnings is not None and len(sdtype_warnings): + if (not _datetime_format_warning_flag) and len(sdtype_warnings): df = pd.DataFrame(sdtype_warnings) message = ( "No 'datetime_format' is present in the metadata for the following columns:\n" diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 60d2b11f8..5150f5f5b 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -90,6 +90,9 @@ def _initialize_models(self): self._table_synthesizers[table_name] = self._synthesizer( metadata=metadata, **synthesizer_parameters ) + # Mark synthesizer as embedded in a multi-table setting + # so it can suppres datetime_format warnings that are aggregated here + self._table_synthesizers[table_name]._suppress_datetime_format_warning = True self._table_synthesizers[table_name]._data_processor.table_name = table_name def _get_pbar_args(self, **kwargs): @@ -404,6 +407,8 @@ def set_table_parameters(self, table_name, table_parameters): self._table_synthesizers[table_name] = self._synthesizer( metadata=table_metadata, **table_parameters ) + # Mark synthesizer as embedded in a multi-table setting to avoid duplicate datetime warnings + self._table_synthesizers[table_name]._suppress_datetime_format_warning = True self._table_synthesizers[table_name]._data_processor.table_name = table_name self._table_parameters[table_name].update(deepcopy(table_parameters)) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 8629b6343..eac21e425 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -561,7 +561,18 @@ def validate(self, data): data (pandas.DataFrame): The data to validate. """ - self._original_metadata.validate_data({self._table_name: data}) + # Suppress duplicate datetime_format warning only when this single-table synthesizer + # is embedded inside a multi-table synthesizer + if getattr(self, '_suppress_datetime_format_warning', False): + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + message=r"No 'datetime_format' is present.*", + category=UserWarning, + ) + self._original_metadata.validate_data({self._table_name: data}) + else: + self._original_metadata.validate_data({self._table_name: data}) self._validate_transform_constraints(data, enforce_constraint_fitting=True) # Retaining the logic of returning errors and raising them here to maintain consistency diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index fc7dcb2fa..56c3ba6b2 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -2,6 +2,7 @@ import datetime import sys +import warnings from copy import deepcopy import cloudpickle @@ -55,7 +56,14 @@ def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=T }) metadata.validate() try: - metadata.validate_data(data) + # Suppress duplicate datetime_format warnings during referential integrity validation. + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + message=r"No 'datetime_format' is present.*", + category=UserWarning, + ) + metadata.validate_data(data) if drop_missing_values: _validate_foreign_keys_not_null(metadata, data) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 50b370b4e..95e9f6167 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2802,3 +2802,57 @@ def test_end_to_end_with_constraints(): # Assert synthesizer.validate(synthetic_data) + + +def test_datetime_warning_doesnt_repeat(): + """Test that the datetime warning doesn't repeat GH#2739.""" + # Setup + composite_data = { + 'main': pd.DataFrame({ + 'pk': [1, 2, 3, 4, 5], + 'denormalized_primary_key_1': [1, 1, 2, 2, 5], + 'denormalized_primary_key_2': ['a', 'a', 'b', 'c', 'c'], + 'denormalized_column': [ + '2020-01-01', + '2020-01-01', + '2020-01-02', + '2020-01-02', + '2020-01-03', + ], + 'other_col': ['2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'], + }) + } + + composite_metadata = Metadata.load_from_dict({ + 'tables': { + 'main': { + 'columns': { + 'pk': {'sdtype': 'id'}, + 'denormalized_primary_key_1': {'sdtype': 'id'}, + 'denormalized_primary_key_2': {'sdtype': 'categorical'}, + 'denormalized_column': {'sdtype': 'datetime'}, + 'other_col': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, + }, + 'primary_key': 'pk', + }, + }, + }) + + comp_synth = HMASynthesizer(composite_metadata) + + # Run + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + comp_synth.fit(composite_data) + comp_synth.sample(1) + + # Assert + msg = ( + "No 'datetime_format' is present in the metadata for the following columns:\n" + ' Table Name Column Name sdtype datetime_format\n' + ' main denormalized_column datetime None\n' + 'Without this specification, SDV may not be able to accurately parse the data. ' + "We recommend adding datetime formats using 'update_column'." + ) + matching_warnings = [warning for warning in w if str(warning.message) == msg] + assert len(matching_warnings) == 1