Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ def validate_data(self, data, sdtype_warnings=None):
A warning is being raised if ``datetime_format`` is missing from a column represented
as ``object`` in the dataframe and its sdtype is ``datetime``.
"""
_datetime_format_warning_flag = sdtype_warnings is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it possible for other warnings to show up in this list? Should we filter through the sdtype_warnings and specifically find ones with datetime in them

sdtype_warnings = sdtype_warnings if sdtype_warnings is not None else defaultdict(list)
if not isinstance(data, pd.DataFrame):
raise ValueError(f'Data must be a DataFrame, not a {type(data)}.')
Expand All @@ -1315,7 +1316,7 @@ def validate_data(self, data, sdtype_warnings=None):
errors += self._validate_column_data(data[column], sdtype_warnings)

errors += self._validate_primary_key(data)
if sdtype_warnings is not None and len(sdtype_warnings):
if (not _datetime_format_warning_flag) and len(sdtype_warnings):
df = pd.DataFrame(sdtype_warnings)
message = (
"No 'datetime_format' is present in the metadata for the following columns:\n"
Expand Down
5 changes: 5 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def _initialize_models(self):
self._table_synthesizers[table_name] = self._synthesizer(
metadata=metadata, **synthesizer_parameters
)
# Mark synthesizer as embedded in a multi-table setting
# so it can suppres datetime_format warnings that are aggregated here
self._table_synthesizers[table_name]._suppress_datetime_format_warning = True
self._table_synthesizers[table_name]._data_processor.table_name = table_name

def _get_pbar_args(self, **kwargs):
Expand Down Expand Up @@ -404,6 +407,8 @@ def set_table_parameters(self, table_name, table_parameters):
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, **table_parameters
)
# Mark synthesizer as embedded in a multi-table setting to avoid duplicate datetime warnings
self._table_synthesizers[table_name]._suppress_datetime_format_warning = True
self._table_synthesizers[table_name]._data_processor.table_name = table_name
self._table_parameters[table_name].update(deepcopy(table_parameters))

Expand Down
13 changes: 12 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,18 @@ def validate(self, data):
data (pandas.DataFrame):
The data to validate.
"""
self._original_metadata.validate_data({self._table_name: data})
# Suppress duplicate datetime_format warning only when this single-table synthesizer
# is embedded inside a multi-table synthesizer
if getattr(self, '_suppress_datetime_format_warning', False):
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore',
message=r"No 'datetime_format' is present.*",
category=UserWarning,
)
self._original_metadata.validate_data({self._table_name: data})
else:
self._original_metadata.validate_data({self._table_name: data})
self._validate_transform_constraints(data, enforce_constraint_fitting=True)

# Retaining the logic of returning errors and raising them here to maintain consistency
Expand Down
10 changes: 9 additions & 1 deletion sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import sys
import warnings
from copy import deepcopy

import cloudpickle
Expand Down Expand Up @@ -55,7 +56,14 @@ def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=T
})
metadata.validate()
try:
metadata.validate_data(data)
# Suppress duplicate datetime_format warnings during referential integrity validation.
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore',
message=r"No 'datetime_format' is present.*",
category=UserWarning,
)
metadata.validate_data(data)
if drop_missing_values:
_validate_foreign_keys_not_null(metadata, data)

Expand Down
54 changes: 54 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2802,3 +2802,57 @@ def test_end_to_end_with_constraints():

# Assert
synthesizer.validate(synthetic_data)


def test_datetime_warning_doesnt_repeat():
"""Test that the datetime warning doesn't repeat GH#2739."""
# Setup
composite_data = {
'main': pd.DataFrame({
'pk': [1, 2, 3, 4, 5],
'denormalized_primary_key_1': [1, 1, 2, 2, 5],
'denormalized_primary_key_2': ['a', 'a', 'b', 'c', 'c'],
'denormalized_column': [
'2020-01-01',
'2020-01-01',
'2020-01-02',
'2020-01-02',
'2020-01-03',
],
'other_col': ['2020-01-01', '2020-01-02', '2020-01-03', '2020-01-04', '2020-01-05'],
})
}

composite_metadata = Metadata.load_from_dict({
'tables': {
'main': {
'columns': {
'pk': {'sdtype': 'id'},
'denormalized_primary_key_1': {'sdtype': 'id'},
'denormalized_primary_key_2': {'sdtype': 'categorical'},
'denormalized_column': {'sdtype': 'datetime'},
'other_col': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
},
'primary_key': 'pk',
},
},
})

comp_synth = HMASynthesizer(composite_metadata)

# Run
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
comp_synth.fit(composite_data)
comp_synth.sample(1)

# Assert
msg = (
"No 'datetime_format' is present in the metadata for the following columns:\n"
' Table Name Column Name sdtype datetime_format\n'
' main denormalized_column datetime None\n'
'Without this specification, SDV may not be able to accurately parse the data. '
"We recommend adding datetime formats using 'update_column'."
)
matching_warnings = [warning for warning in w if str(warning.message) == msg]
assert len(matching_warnings) == 1
Loading