Skip to content

Commit d16f31c

Browse files
Implement datetime downcasting for accurate column comparison in constraints (#2293)
1 parent 39f060e commit d16f31c

File tree

5 files changed

+834
-4
lines changed

5 files changed

+834
-4
lines changed

sdv/constraints/tabular.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
get_datetime_diff,
5151
get_mappable_combination,
5252
logit,
53+
match_datetime_precision,
5354
matches_datetime_format,
5455
revert_nans_columns,
5556
sigmoid,
@@ -484,6 +485,15 @@ def is_valid(self, table_data):
484485
low = cast_to_datetime64(low, self._low_datetime_format)
485486
high = cast_to_datetime64(high, self._high_datetime_format)
486487

488+
format_matches = bool(self._low_datetime_format == self._high_datetime_format)
489+
if not format_matches:
490+
low, high = match_datetime_precision(
491+
low=low,
492+
high=high,
493+
low_datetime_format=self._low_datetime_format,
494+
high_datetime_format=self._high_datetime_format,
495+
)
496+
487497
valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
488498
return valid
489499

sdv/constraints/utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,39 @@
11
"""Constraint utility functions."""
22

3+
import re
34
from datetime import datetime
45
from decimal import Decimal
56

67
import numpy as np
78
import pandas as pd
89

10+
PRECISION_LEVELS = {
11+
'%Y': 1, # Year
12+
'%y': 1, # Year without century (same precision as %Y)
13+
'%B': 2, # Full month name
14+
'%b': 2, # Abbreviated month name
15+
'%m': 2, # Month as a number
16+
'%d': 3, # Day of the month
17+
'%j': 3, # Day of the year
18+
'%U': 3, # Week number (Sunday-starting)
19+
'%W': 3, # Week number (Monday-starting)
20+
'%A': 3, # Full weekday name
21+
'%a': 3, # Abbreviated weekday name
22+
'%w': 3, # Weekday as a decimal
23+
'%H': 4, # Hour (24-hour clock)
24+
'%I': 4, # Hour (12-hour clock)
25+
'%M': 5, # Minute
26+
'%S': 6, # Second
27+
'%f': 7, # Microsecond
28+
# Formats that don't add precision
29+
'%p': 0, # AM/PM
30+
'%z': 0, # UTC offset
31+
'%Z': 0, # Time zone name
32+
'%c': 0, # Locale-based date/time
33+
'%x': 0, # Locale-based date
34+
'%X': 0, # Locale-based time
35+
}
36+
937

1038
def cast_to_datetime64(value, datetime_format=None):
1139
"""Cast a given value to a ``numpy.datetime64`` format.
@@ -199,6 +227,14 @@ def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=
199227
low = cast_to_datetime64(low, low_datetime_format)
200228
high = cast_to_datetime64(high, high_datetime_format)
201229

230+
if low_datetime_format != high_datetime_format:
231+
low, high = match_datetime_precision(
232+
low=low,
233+
high=high,
234+
low_datetime_format=low_datetime_format,
235+
high_datetime_format=high_datetime_format,
236+
)
237+
202238
diff_column = high - low
203239
nan_mask = pd.isna(diff_column)
204240
diff_column = diff_column.astype(np.float64)
@@ -221,3 +257,98 @@ def get_mappable_combination(combination):
221257
A mappable combination of values.
222258
"""
223259
return tuple(None if pd.isna(x) else x for x in combination)
260+
261+
262+
def match_datetime_precision(low, high, low_datetime_format, high_datetime_format):
263+
"""Match `low` or `high` datetime array to the lower precision format.
264+
265+
Args:
266+
low (np.ndarray):
267+
Array of datetime values for the low column.
268+
high (np.ndarray):
269+
Array of datetime values for the high column.
270+
low_datetime_format (str):
271+
The datetime format of the `low` column.
272+
high_datetime_format (str):
273+
The datetime format of the `high` column.
274+
275+
Returns:
276+
Tuple[np.ndarray, np.ndarray]:
277+
Adjusted `low` and `high` arrays where the higher precision format is
278+
downcasted to the lower precision format.
279+
"""
280+
lower_precision_format = get_lower_precision_format(low_datetime_format, high_datetime_format)
281+
if lower_precision_format == high_datetime_format:
282+
low = downcast_datetime_to_lower_precision(low, lower_precision_format)
283+
else:
284+
high = downcast_datetime_to_lower_precision(high, lower_precision_format)
285+
286+
return low, high
287+
288+
289+
def get_datetime_format_precision(format_str):
290+
"""Return the precision level of a datetime format string."""
291+
# Find all format codes in the format string
292+
found_formats = re.findall(r'%[A-Za-z]', format_str)
293+
found_levels = (
294+
PRECISION_LEVELS.get(found_format)
295+
for found_format in found_formats
296+
if found_format in PRECISION_LEVELS
297+
)
298+
299+
return max(found_levels, default=0)
300+
301+
302+
def get_lower_precision_format(primary_format, secondary_format):
303+
"""Compare two datetime format strings and return the one with lower precision.
304+
305+
Args:
306+
primary_format (str):
307+
The first datetime format string to compare.
308+
low_precision_format (str):
309+
The second datetime format string to compare.
310+
311+
Returns:
312+
str:
313+
The datetime format string with the lower precision level.
314+
"""
315+
primary_level = get_datetime_format_precision(primary_format)
316+
secondary_level = get_datetime_format_precision(secondary_format)
317+
if primary_level >= secondary_level:
318+
return secondary_format
319+
320+
return primary_format
321+
322+
323+
def downcast_datetime_to_lower_precision(data, target_format):
324+
"""Convert a datetime string from a higher-precision format to a lower-precision format.
325+
326+
Args:
327+
data (np.array):
328+
The data to cast to the `target_format`.
329+
target_format (str):
330+
The datetime string to downcast.
331+
332+
Returns:
333+
str: The datetime string in the lower precision format.
334+
"""
335+
downcasted_data = format_datetime_array(data, target_format)
336+
return cast_to_datetime64(downcasted_data, target_format)
337+
338+
339+
def format_datetime_array(datetime_array, target_format):
340+
"""Format each element in a numpy datetime64 array to a specified string format.
341+
342+
Args:
343+
datetime_array (np.ndarray):
344+
Array of datetime64[ns] elements.
345+
target_format (str):
346+
The datetime format to cast each element to.
347+
348+
Returns:
349+
np.ndarray: Array of formatted datetime strings.
350+
"""
351+
return np.array([
352+
pd.to_datetime(date).strftime(target_format) if not pd.isna(date) else pd.NaT
353+
for date in datetime_array
354+
])

0 commit comments

Comments
 (0)