|
1 | 1 | """Multi-Table DayZ parameter detection and creation.""" |
2 | 2 |
|
| 3 | +import json |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | +from sdv.cag._utils import _is_list_of_type |
3 | 8 | from sdv.errors import SynthesizerInputError, SynthesizerProcessingError |
4 | | -from sdv.multi_table._dayz_utils import create_parameters_multi_table |
5 | | -from sdv.single_table.dayz import _validate_parameter_structure, _validate_tables_parameter |
| 9 | +from sdv.single_table.dayz import ( |
| 10 | + _validate_parameter_structure, |
| 11 | + _validate_tables_parameter, |
| 12 | + create_parameters, |
| 13 | +) |
6 | 14 |
|
7 | 15 | REQUIRED_RELATIONSHIP_KEYS = [ |
8 | 16 | 'parent_table_name', |
|
18 | 26 | DEFAULT_NUM_ROWS = 1000 |
19 | 27 |
|
20 | 28 |
|
| 29 | +def _detect_relationship_parameters(data, metadata): |
| 30 | + """Detect all relationship-level for the DayZ parameters. |
| 31 | +
|
| 32 | + The relationship-level parameters are: |
| 33 | + - The min and max cardinality |
| 34 | +
|
| 35 | + Args: |
| 36 | + data (dict[str, pd.DataFrame]): The input data. |
| 37 | + metadata (Metadata): The metadata object. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + dict: A list containing the detected parameters. |
| 41 | + """ |
| 42 | + relationship_parameters = [] |
| 43 | + for relationship in metadata.relationships: |
| 44 | + rel_tuple = ( |
| 45 | + relationship['parent_table_name'], |
| 46 | + relationship['child_table_name'], |
| 47 | + relationship['parent_primary_key'], |
| 48 | + relationship['child_foreign_key'], |
| 49 | + ) |
| 50 | + cardinality_table = pd.DataFrame(index=data[rel_tuple[0]][rel_tuple[2]].copy()) |
| 51 | + cardinality_table['cardinality'] = data[rel_tuple[1]][rel_tuple[3]].value_counts() |
| 52 | + cardinality_table = cardinality_table.fillna(0) |
| 53 | + relationship_parameters.append({ |
| 54 | + 'parent_table_name': rel_tuple[0], |
| 55 | + 'child_table_name': rel_tuple[1], |
| 56 | + 'parent_primary_key': rel_tuple[2], |
| 57 | + 'child_foreign_key': rel_tuple[3], |
| 58 | + 'min_cardinality': cardinality_table['cardinality'].min(), |
| 59 | + 'max_cardinality': cardinality_table['cardinality'].max(), |
| 60 | + }) |
| 61 | + |
| 62 | + return relationship_parameters |
| 63 | + |
| 64 | + |
| 65 | +def create_parameters_multi_table(data, metadata, output_filename): |
| 66 | + """Create parameters for the DayZSynthesizer.""" |
| 67 | + parameters = create_parameters(data, metadata, None) |
| 68 | + parameters['relationships'] = _detect_relationship_parameters(data, metadata) |
| 69 | + if output_filename: |
| 70 | + with open(output_filename, 'w') as f: |
| 71 | + json.dump(parameters, f, indent=4) |
| 72 | + |
| 73 | + return parameters |
| 74 | + |
| 75 | + |
21 | 76 | def _validate_min_cardinality(relationship): |
22 | 77 | min_cardinality = relationship['min_cardinality'] |
23 | 78 | if not isinstance(min_cardinality, int) or min_cardinality < 0: |
@@ -48,8 +103,10 @@ def _validate_cardinality_bounds(relationship): |
48 | 103 |
|
49 | 104 |
|
50 | 105 | def _validate_relationship_structure(dayz_parameters): |
51 | | - if not isinstance(dayz_parameters.get('relationships', []), list): |
52 | | - raise SynthesizerProcessingError("The 'relationships' parameter value must be a list.") |
| 106 | + if not _is_list_of_type(dayz_parameters.get('relationships', []), dict): |
| 107 | + raise SynthesizerProcessingError( |
| 108 | + "The 'relationships' parameter value must be a list of dictionaries." |
| 109 | + ) |
53 | 110 |
|
54 | 111 | for relationship in dayz_parameters.get('relationships', []): |
55 | 112 | unknown_relationship_parameters = relationship.keys() - set(RELATIONSHIP_PARAMETER_KEYS) |
@@ -160,18 +217,18 @@ def __init__(self, metadata, locales=['en_US']): |
160 | 217 | ) |
161 | 218 |
|
162 | 219 | @classmethod |
163 | | - def create_parameters(cls, data, metadata, output_filename=None): |
| 220 | + def create_parameters(cls, data, metadata, filepath=None): |
164 | 221 | """Create parameters for the DayZSynthesizer. |
165 | 222 |
|
166 | 223 | Args: |
167 | 224 | data (dict[str, pd.DataFrame]): The input data. |
168 | 225 | metadata (Metadata): The metadata object. |
169 | | - output_filename (str, optional): The output filename for the parameters. |
| 226 | + filepath (str, optional): The output filename for the parameters. |
170 | 227 |
|
171 | 228 | Returns: |
172 | 229 | dict: The created parameters. |
173 | 230 | """ |
174 | | - return create_parameters_multi_table(data, metadata, output_filename) |
| 231 | + return create_parameters_multi_table(data, metadata, filepath) |
175 | 232 |
|
176 | 233 | @staticmethod |
177 | 234 | def validate_parameters(metadata, parameters): |
|
0 commit comments