Skip to content

Commit 219576d

Browse files
authored
Allow users to validate the DayZ parameters (#2671)
1 parent 0047a8d commit 219576d

File tree

12 files changed

+1117
-28
lines changed

12 files changed

+1117
-28
lines changed

sdv/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _datetime_string_matches_format(value, datetime_format):
152152
if pd.isna(value):
153153
return True
154154
try:
155-
parsed = datetime.strptime(str(value), datetime_format)
155+
parsed = pd.to_datetime(str(value), format=datetime_format, errors='coerce')
156156
return value == parsed.strftime(datetime_format)
157157
except ValueError:
158158
return False

sdv/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,7 @@ class RefitWarning(UserWarning):
9191
Warning to be raised if a change to a synthesizer requires the synthesizer
9292
to be refit for the change to be applied.
9393
"""
94+
95+
96+
class SynthesizerProcessingError(Exception):
97+
"""Error to raise when synthesizer parameters are invalid."""

sdv/multi_table/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,7 @@
33
from sdv.multi_table.hma import HMASynthesizer
44
from sdv.multi_table.dayz import DayZSynthesizer
55

6-
__all__ = ('HMASynthesizer', 'DayZSynthesizer')
6+
__all__ = (
7+
'DayZSynthesizer',
8+
'HMASynthesizer',
9+
)

sdv/multi_table/_dayz_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def detect_relationship_parameters(data, metadata):
1616
metadata (Metadata): The metadata object.
1717
1818
Returns:
19-
dict: A dictionary containing the detected parameters.
19+
dict: A list containing the detected parameters.
2020
"""
21-
relationship_parameters = {}
21+
relationship_parameters = []
2222
for relationship in metadata.relationships:
2323
rel_tuple = (
2424
relationship['parent_table_name'],
@@ -29,10 +29,14 @@ def detect_relationship_parameters(data, metadata):
2929
cardinality_table = pd.DataFrame(index=data[rel_tuple[0]][rel_tuple[2]].copy())
3030
cardinality_table['cardinality'] = data[rel_tuple[1]][rel_tuple[3]].value_counts()
3131
cardinality_table = cardinality_table.fillna(0)
32-
relationship_parameters[json.dumps(rel_tuple)] = {
32+
relationship_parameters.append({
33+
'parent_table_name': rel_tuple[0],
34+
'child_table_name': rel_tuple[1],
35+
'parent_primary_key': rel_tuple[2],
36+
'child_foreign_key': rel_tuple[3],
3337
'min_cardinality': cardinality_table['cardinality'].min(),
3438
'max_cardinality': cardinality_table['cardinality'].max(),
35-
}
39+
})
3640

3741
return relationship_parameters
3842

sdv/multi_table/dayz.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,162 @@
11
"""Multi-Table DayZ parameter detection and creation."""
22

3-
from sdv.errors import SynthesizerInputError
3+
from sdv.errors import SynthesizerInputError, SynthesizerProcessingError
44
from sdv.multi_table._dayz_utils import create_parameters_multi_table
5+
from sdv.single_table.dayz import _validate_parameter_structure, _validate_tables_parameter
6+
7+
REQUIRED_RELATIONSHIP_KEYS = [
8+
'parent_table_name',
9+
'child_table_name',
10+
'parent_primary_key',
11+
'child_foreign_key',
12+
]
13+
RELATIONSHIP_PARAMETER_KEYS = REQUIRED_RELATIONSHIP_KEYS + [
14+
'min_cardinality',
15+
'max_cardinality',
16+
]
17+
18+
DEFAULT_NUM_ROWS = 1000
19+
20+
21+
def _validate_min_cardinality(relationship):
22+
min_cardinality = relationship['min_cardinality']
23+
if not isinstance(min_cardinality, int) or min_cardinality < 0:
24+
msg = (
25+
f"Invalid 'min_cardinality' parameter ({min_cardinality}). The "
26+
"'min_cardinality' parameter must be an integer greater than or equal to zero."
27+
)
28+
raise SynthesizerProcessingError(msg)
29+
30+
31+
def _validate_max_cardinality(relationship):
32+
max_cardinality = relationship['max_cardinality']
33+
if not isinstance(max_cardinality, int) or max_cardinality <= 0:
34+
msg = (
35+
f"Invalid 'max_cardinality' parameter ({max_cardinality}). The "
36+
"'max_cardinality' parameter must be an integer greater than zero."
37+
)
38+
raise SynthesizerProcessingError(msg)
39+
40+
41+
def _validate_cardinality_bounds(relationship):
42+
if relationship['min_cardinality'] > relationship['max_cardinality']:
43+
msg = (
44+
"Invalid cardinality parameters, the 'min_cardinality' must be less than or "
45+
"equal to the 'max_cardinality'."
46+
)
47+
raise SynthesizerProcessingError(msg)
48+
49+
50+
def _validate_relationship_structure(dayz_parameters):
51+
if not isinstance(dayz_parameters.get('relationships', []), list):
52+
raise SynthesizerProcessingError("The 'relationships' parameter value must be a list.")
53+
54+
for relationship in dayz_parameters.get('relationships', []):
55+
unknown_relationship_parameters = relationship.keys() - set(RELATIONSHIP_PARAMETER_KEYS)
56+
if unknown_relationship_parameters:
57+
unknown_relationship_parameters = "', '".join(unknown_relationship_parameters)
58+
msg = (
59+
'Relationship parameter contains unexpected key(s) '
60+
f"'{unknown_relationship_parameters}'."
61+
)
62+
raise SynthesizerProcessingError(msg)
63+
missing_relationship_parameters = set(REQUIRED_RELATIONSHIP_KEYS) - relationship.keys()
64+
if missing_relationship_parameters:
65+
missing_relationship_parameters = "', '".join(missing_relationship_parameters)
66+
msg = (
67+
'Relationship parameter missing required key(s) '
68+
f"'{missing_relationship_parameters}'."
69+
)
70+
raise SynthesizerProcessingError(msg)
71+
72+
if 'min_cardinality' in relationship:
73+
_validate_min_cardinality(relationship)
74+
if 'max_cardinality' in relationship:
75+
_validate_max_cardinality(relationship)
76+
if 'min_cardinality' in relationship and 'max_cardinality' in relationship:
77+
_validate_cardinality_bounds(relationship)
78+
79+
80+
def _validate_cardinality(relationship_parameters, parent_num_rows, child_num_rows):
81+
"""Validate that the relationship cardinality works with the set number of rows."""
82+
parent_num_rows = parent_num_rows or DEFAULT_NUM_ROWS
83+
child_num_rows = child_num_rows or DEFAULT_NUM_ROWS
84+
min_cardinality = relationship_parameters.get('min_cardinality', 0)
85+
max_cardinality = relationship_parameters.get('max_cardinality')
86+
87+
min_child_size = min_cardinality * parent_num_rows
88+
max_child_size = max_cardinality * parent_num_rows if max_cardinality else None
89+
90+
if child_num_rows < min_child_size:
91+
msg = (
92+
f'Invalid cardinality parameters for relationship {relationship_parameters}. '
93+
f'Minimum cardinality requires child table to be at least {min_child_size} rows.'
94+
)
95+
raise SynthesizerProcessingError(msg)
96+
97+
if max_child_size and child_num_rows > max_child_size:
98+
msg = (
99+
f'Invalid cardinality parameters for relationship {relationship_parameters}. '
100+
f'Maximum cardinality requires child table to be less than {max_child_size} rows.'
101+
)
102+
raise SynthesizerProcessingError(msg)
103+
104+
105+
def _validate_relationship_parameters(metadata, dayz_parameters):
106+
"""Validate that every relationship exists in the metadata and the cardinality is valid."""
107+
seen_relationships = []
108+
for relationship_parameters in dayz_parameters.get('relationships', []):
109+
relationship = {
110+
key: value
111+
for key, value in relationship_parameters.items()
112+
if key in REQUIRED_RELATIONSHIP_KEYS
113+
}
114+
if relationship not in metadata.relationships:
115+
msg = (
116+
'Invalid relationship parameter: '
117+
f'relationship {relationship} does not exist in the metadata.'
118+
)
119+
raise SynthesizerProcessingError(msg)
120+
elif relationship in seen_relationships:
121+
msg = (
122+
'Invalid relationship parameter: '
123+
f'multiple entries for relationship {relationship} in parameters.'
124+
)
125+
raise SynthesizerProcessingError(msg)
126+
127+
seen_relationships.append(relationship)
128+
129+
parent_table = relationship['parent_table_name']
130+
child_table = relationship['child_table_name']
131+
parent_num_rows = dayz_parameters.get('tables', {}).get(parent_table, {}).get('num_rows')
132+
child_num_rows = dayz_parameters.get('tables', {}).get(child_table, {}).get('num_rows')
133+
_validate_cardinality(relationship_parameters, parent_num_rows, child_num_rows)
134+
135+
136+
def _validate_parameters(metadata, parameters):
137+
"""Validate a DayZSynthesizer parameters dictionary.
138+
139+
Args:
140+
metadata (sdv.Metadata):
141+
Metadata for the data.
142+
parameters (dict):
143+
The DayZ parameter dictionary.
144+
"""
145+
metadata.validate()
146+
_validate_parameter_structure(parameters)
147+
_validate_relationship_structure(parameters)
148+
_validate_tables_parameter(metadata, parameters)
149+
_validate_relationship_parameters(metadata, parameters)
5150

6151

7152
class DayZSynthesizer:
8153
"""Multi-Table DayZSynthesizer for public SDV."""
9154

10155
def __init__(self, metadata, locales=['en_US']):
11156
raise SynthesizerInputError(
12-
"Only the 'DayZSynthesizer.create_parameters' is a SDV public feature. "
13-
'To define and use and use a DayZSynthesizer object you must have SDV-Enterprise.'
157+
"Only the 'DayZSynthesizer.create_parameters' and the "
158+
'DayZSynthesizer.validate_parameters methods are an SDV public feature. To '
159+
'define and use a DayZSynthesizer object you must have SDV-Enterprise.'
14160
)
15161

16162
@classmethod
@@ -26,3 +172,15 @@ def create_parameters(cls, data, metadata, output_filename=None):
26172
dict: The created parameters.
27173
"""
28174
return create_parameters_multi_table(data, metadata, output_filename)
175+
176+
@staticmethod
177+
def validate_parameters(metadata, parameters):
178+
"""Validate a DayZSynthesizer parameters dictionary.
179+
180+
Args:
181+
metadata (sdv.Metadata):
182+
Metadata for the data.
183+
parameters (dict):
184+
The DayZ parameter dictionary.
185+
"""
186+
_validate_parameters(metadata, parameters)

sdv/single_table/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sdv.single_table.dayz import DayZSynthesizer
77

88
__all__ = (
9+
'DayZSynthesizer',
910
'GaussianCopulaSynthesizer',
1011
'CTGANSynthesizer',
1112
'TVAESynthesizer',

0 commit comments

Comments
 (0)