|
1 | 1 | """Multi-Table DayZ parameter detection and creation.""" |
2 | 2 |
|
| 3 | +import json |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | + |
3 | 7 | from sdv.cag._utils import _is_list_of_type |
4 | 8 | from sdv.errors import SynthesizerInputError, SynthesizerProcessingError |
5 | | -from sdv.multi_table._dayz_utils import create_parameters_multi_table |
6 | | -from sdv.single_table.dayz import _validate_parameter_structure, _validate_tables_parameter |
| 9 | +from sdv.single_table.dayz import ( |
| 10 | + _validate_parameter_structure, |
| 11 | + _validate_tables_parameter, |
| 12 | + create_parameters, |
| 13 | +) |
7 | 14 |
|
8 | 15 | REQUIRED_RELATIONSHIP_KEYS = [ |
9 | 16 | 'parent_table_name', |
|
19 | 26 | DEFAULT_NUM_ROWS = 1000 |
20 | 27 |
|
21 | 28 |
|
| 29 | +def _detect_relationship_parameters(data, metadata): |
| 30 | + """Detect all relationship-level for the DayZ parameters. |
| 31 | +
|
| 32 | + The relationship-level parameters are: |
| 33 | + - The min and max cardinality |
| 34 | +
|
| 35 | + Args: |
| 36 | + data (dict[str, pd.DataFrame]): The input data. |
| 37 | + metadata (Metadata): The metadata object. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + dict: A list containing the detected parameters. |
| 41 | + """ |
| 42 | + relationship_parameters = [] |
| 43 | + for relationship in metadata.relationships: |
| 44 | + rel_tuple = ( |
| 45 | + relationship['parent_table_name'], |
| 46 | + relationship['child_table_name'], |
| 47 | + relationship['parent_primary_key'], |
| 48 | + relationship['child_foreign_key'], |
| 49 | + ) |
| 50 | + cardinality_table = pd.DataFrame(index=data[rel_tuple[0]][rel_tuple[2]].copy()) |
| 51 | + cardinality_table['cardinality'] = data[rel_tuple[1]][rel_tuple[3]].value_counts() |
| 52 | + cardinality_table = cardinality_table.fillna(0) |
| 53 | + relationship_parameters.append({ |
| 54 | + 'parent_table_name': rel_tuple[0], |
| 55 | + 'child_table_name': rel_tuple[1], |
| 56 | + 'parent_primary_key': rel_tuple[2], |
| 57 | + 'child_foreign_key': rel_tuple[3], |
| 58 | + 'min_cardinality': cardinality_table['cardinality'].min(), |
| 59 | + 'max_cardinality': cardinality_table['cardinality'].max(), |
| 60 | + }) |
| 61 | + |
| 62 | + return relationship_parameters |
| 63 | + |
| 64 | + |
| 65 | +def create_parameters_multi_table(data, metadata, output_filename): |
| 66 | + """Create parameters for the DayZSynthesizer.""" |
| 67 | + parameters = create_parameters(data, metadata, None) |
| 68 | + parameters['relationships'] = _detect_relationship_parameters(data, metadata) |
| 69 | + if output_filename: |
| 70 | + with open(output_filename, 'w') as f: |
| 71 | + json.dump(parameters, f, indent=4) |
| 72 | + |
| 73 | + return parameters |
| 74 | + |
| 75 | + |
22 | 76 | def _validate_min_cardinality(relationship): |
23 | 77 | min_cardinality = relationship['min_cardinality'] |
24 | 78 | if not isinstance(min_cardinality, int) or min_cardinality < 0: |
|
0 commit comments