Skip to content

Commit 6a443fc

Browse files
fealhoR-Palazzo
andauthored
[DayZ Parameters] Make create_parameters use a fall back to default parameters if parameters cannot be detected (#2715)
Co-authored-by: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com>
1 parent 86e63bf commit 6a443fc

File tree

2 files changed

+170
-29
lines changed

2 files changed

+170
-29
lines changed

sdv/single_table/_dayz_utils.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,79 @@ def detect_table_parameters(data):
1818
return {'num_rows': len(data)}
1919

2020

21+
def _compute_missing_values_proportion(series):
22+
"""Compute missing value proportion with a safe fallback for empty series."""
23+
if len(series) == 0:
24+
return 0.0
25+
26+
value = float(series.isna().mean())
27+
return 0.0 if pd.isna(value) else value
28+
29+
30+
def _detect_numerical_column_parameters(series):
31+
"""Detect numerical-specific parameters with fallbacks when undetectable.
32+
33+
Returns only keys that can be reliably detected (no None values).
34+
"""
35+
params = {}
36+
non_null = series.dropna()
37+
if non_null.empty:
38+
return params
39+
40+
try:
41+
num_decimal_digits = learn_rounding_digits(series)
42+
if isinstance(num_decimal_digits, int) and num_decimal_digits >= 0:
43+
params['num_decimal_digits'] = num_decimal_digits
44+
except Exception:
45+
pass
46+
47+
min_value = non_null.min()
48+
max_value = non_null.max()
49+
if not pd.isna(min_value):
50+
params['min_value'] = min_value.item() if hasattr(min_value, 'item') else float(min_value)
51+
if not pd.isna(max_value):
52+
params['max_value'] = max_value.item() if hasattr(max_value, 'item') else float(max_value)
53+
54+
return params
55+
56+
57+
def _detect_datetime_column_parameters(series, column_metadata):
58+
"""Detect datetime-specific parameters with fallbacks when undetectable.
59+
60+
Returns only keys that can be reliably detected (no None values).
61+
"""
62+
params = {}
63+
datetime_format = column_metadata.get('datetime_format', None)
64+
if datetime_format:
65+
datetime_column = pd.to_datetime(series, format=datetime_format, errors='coerce')
66+
else:
67+
datetime_column = pd.to_datetime(series, errors='coerce')
68+
69+
non_na = datetime_column[~pd.isna(datetime_column)]
70+
if non_na.empty:
71+
return params
72+
73+
start_dt = non_na.min()
74+
end_dt = non_na.max()
75+
if datetime_format:
76+
params['start_timestamp'] = start_dt.strftime(datetime_format)
77+
params['end_timestamp'] = end_dt.strftime(datetime_format)
78+
else:
79+
params['start_timestamp'] = start_dt.strftime('%Y-%m-%d %H:%M:%S')
80+
params['end_timestamp'] = end_dt.strftime('%Y-%m-%d %H:%M:%S')
81+
82+
return params
83+
84+
85+
def _detect_categorical_column_parameters(series):
86+
"""Detect categorical/boolean parameters."""
87+
categorical_values = series.dropna().unique()
88+
if len(categorical_values) == 0:
89+
return {}
90+
91+
return {'category_values': categorical_values.tolist()}
92+
93+
2194
def detect_column_parameters(data, metadata, table_name):
2295
"""Detect all column-level Dayz parameters.
2396
@@ -37,40 +110,17 @@ def detect_column_parameters(data, metadata, table_name):
37110
table_metadata = metadata.tables[table_name]
38111
column_parameters = {}
39112
for column_name, column_metadata in table_metadata.columns.items():
40-
column_parameters[column_name] = {}
41113
sdtype = column_metadata['sdtype']
114+
params = {}
42115
if sdtype == 'numerical':
43-
column_parameters[column_name] = {
44-
'num_decimal_digits': learn_rounding_digits(data[column_name]),
45-
'min_value': data[column_name].min(),
46-
'max_value': data[column_name].max(),
47-
}
116+
params.update(_detect_numerical_column_parameters(data[column_name]))
48117
elif sdtype == 'datetime':
49-
datetime_format = column_metadata.get('datetime_format', None)
50-
if datetime_format:
51-
datetime_column = pd.to_datetime(
52-
data[column_name], format=datetime_format, errors='coerce'
53-
)
54-
start_timestamp = datetime_column.min().strftime(datetime_format)
55-
end_timestamp = datetime_column.max().strftime(datetime_format)
56-
57-
else:
58-
datetime_column = pd.to_datetime(data[column_name], errors='coerce')
59-
start_timestamp = str(datetime_column.min())
60-
end_timestamp = str(datetime_column.max())
61-
62-
column_parameters[column_name] = {
63-
'start_timestamp': start_timestamp,
64-
'end_timestamp': end_timestamp,
65-
}
118+
params.update(_detect_datetime_column_parameters(data[column_name], column_metadata))
66119
elif sdtype == 'categorical':
67-
column_parameters[column_name] = {
68-
'category_values': data[column_name].dropna().unique().tolist()
69-
}
120+
params.update(_detect_categorical_column_parameters(data[column_name]))
70121

71-
column_parameters[column_name]['missing_values_proportion'] = float(
72-
data[column_name].isna().mean()
73-
)
122+
params['missing_values_proportion'] = _compute_missing_values_proportion(data[column_name])
123+
column_parameters[column_name] = params
74124

75125
return {'columns': column_parameters}
76126

tests/unit/single_table/test_dayz.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from unittest.mock import call, patch
33

4+
import numpy as np
45
import pandas as pd
56
import pytest
67

@@ -432,3 +433,93 @@ def test__validate_parameters_errors_with_relationships(self):
432433
)
433434
with pytest.raises(SynthesizerProcessingError, match=expected_error_msg):
434435
DayZSynthesizer.validate_parameters(metadata, dayz_parameters)
436+
437+
def test_create_parameters_returns_valid_defaults(self):
438+
"""Test create_parameters returns valid defaults."""
439+
# Setup
440+
data = pd.DataFrame({'col': [np.nan]})
441+
metadata = Metadata.detect_from_dataframe(data)
442+
443+
# Run
444+
params = DayZSynthesizer.create_parameters(data, metadata)
445+
446+
# Assert
447+
assert params == {
448+
'tables': {
449+
'table': {
450+
'columns': {
451+
'col': {'missing_values_proportion': 1.0},
452+
},
453+
'num_rows': 1,
454+
},
455+
},
456+
'DAYZ_SPEC_VERSION': 'V1',
457+
}
458+
459+
def test_create_parameters_all_null_categorical_column(self):
460+
"""Categorical column with all nulls should not have the category_values key parameter."""
461+
# Setup
462+
data = pd.DataFrame({'col': [None, None, np.nan, pd.NA]})
463+
metadata = Metadata.detect_from_dataframe(data)
464+
465+
# Run
466+
params = DayZSynthesizer.create_parameters(data, metadata)
467+
468+
# Assert
469+
assert params == {
470+
'tables': {
471+
'table': {
472+
'columns': {
473+
'col': {'missing_values_proportion': 1.0},
474+
},
475+
'num_rows': 4,
476+
},
477+
},
478+
'DAYZ_SPEC_VERSION': 'V1',
479+
}
480+
481+
def test_create_parameters_all_null_numerical_column(self):
482+
"""Numerical column with all nulls should produce empty min/max values."""
483+
# Setup
484+
data = pd.DataFrame({'col': [np.nan]})
485+
metadata = Metadata()
486+
metadata.add_table('table')
487+
metadata.add_column('col', 'table', sdtype='numerical')
488+
489+
# Run
490+
params = DayZSynthesizer.create_parameters(data, metadata)
491+
492+
# Assert
493+
assert params == {
494+
'tables': {
495+
'table': {
496+
'columns': {
497+
'col': {'missing_values_proportion': 1.0},
498+
},
499+
'num_rows': 1,
500+
},
501+
},
502+
'DAYZ_SPEC_VERSION': 'V1',
503+
}
504+
505+
def test_create_parameters_all_null_datetime_column(self):
506+
"""Datetime column with all nulls should omit start/end timestamps."""
507+
# Setup
508+
data = pd.DataFrame({'col': pd.to_datetime([None, None])})
509+
metadata = Metadata.detect_from_dataframe(data)
510+
511+
# Run
512+
params = DayZSynthesizer.create_parameters(data, metadata)
513+
514+
# Assert
515+
assert params == {
516+
'tables': {
517+
'table': {
518+
'columns': {
519+
'col': {'missing_values_proportion': 1.0},
520+
},
521+
'num_rows': 2,
522+
},
523+
},
524+
'DAYZ_SPEC_VERSION': 'V1',
525+
}

0 commit comments

Comments
 (0)