|
| 1 | +import json |
1 | 2 | import re |
2 | 3 | from unittest.mock import call, patch |
3 | 4 |
|
|
9 | 10 | from sdv.metadata import Metadata |
10 | 11 | from sdv.multi_table.dayz import ( |
11 | 12 | DayZSynthesizer, |
| 13 | + _detect_relationship_parameters, |
12 | 14 | _validate_cardinality, |
13 | 15 | _validate_parameters, |
14 | 16 | _validate_relationship_parameters, |
15 | 17 | _validate_relationship_structure, |
| 18 | + create_parameters_multi_table, |
16 | 19 | ) |
17 | 20 |
|
18 | 21 |
|
@@ -52,6 +55,109 @@ def metadata(): |
52 | 55 | }) |
53 | 56 |
|
54 | 57 |
|
| 58 | +def test__detect_relationship_parameters(): |
| 59 | + """Test the `_detect_relationship_parameters` method.""" |
| 60 | + # Setup |
| 61 | + parent_data = pd.DataFrame({'parent_id': [1, 2, 3, 4, 5]}) |
| 62 | + child_data = pd.DataFrame({ |
| 63 | + 'child_id': [10, 11, 12, 13, 14, 15, 16], |
| 64 | + 'parent_id': [1, 1, 2, 2, 2, 3, None], |
| 65 | + }) |
| 66 | + data = {'parent': parent_data, 'child': child_data} |
| 67 | + metadata_dict = { |
| 68 | + 'tables': { |
| 69 | + 'parent': {'columns': {'parent_id': {'sdtype': 'id'}}, 'primary_key': 'parent_id'}, |
| 70 | + 'child': { |
| 71 | + 'columns': {'child_id': {'sdtype': 'id'}, 'parent_id': {'sdtype': 'id'}}, |
| 72 | + 'primary_key': 'child_id', |
| 73 | + }, |
| 74 | + }, |
| 75 | + 'relationships': [ |
| 76 | + { |
| 77 | + 'parent_table_name': 'parent', |
| 78 | + 'child_table_name': 'child', |
| 79 | + 'parent_primary_key': 'parent_id', |
| 80 | + 'child_foreign_key': 'parent_id', |
| 81 | + } |
| 82 | + ], |
| 83 | + } |
| 84 | + metadata = Metadata.load_from_dict(metadata_dict) |
| 85 | + |
| 86 | + # Run |
| 87 | + result = _detect_relationship_parameters(data, metadata) |
| 88 | + |
| 89 | + # Assert |
| 90 | + expected = [ |
| 91 | + { |
| 92 | + 'parent_table_name': 'parent', |
| 93 | + 'child_table_name': 'child', |
| 94 | + 'parent_primary_key': 'parent_id', |
| 95 | + 'child_foreign_key': 'parent_id', |
| 96 | + 'min_cardinality': 0, |
| 97 | + 'max_cardinality': 3, |
| 98 | + } |
| 99 | + ] |
| 100 | + assert result == expected |
| 101 | + |
| 102 | + |
| 103 | +@patch('sdv.multi_table.dayz._detect_relationship_parameters') |
| 104 | +@patch('sdv.multi_table.dayz.create_parameters') |
| 105 | +def test_create_parameters_multi_table(mock_create_parameters, mock_detect_relationship, tmp_path): |
| 106 | + """Test the `create_parameters_multi_table` method.""" |
| 107 | + # Setup |
| 108 | + data = pd.DataFrame() |
| 109 | + metadata = Metadata() |
| 110 | + output_filename = str(tmp_path / 'output.json') |
| 111 | + mock_detect_relationship.return_value = { |
| 112 | + '["parent_table", "child_table", "parent_pk", "child_fk"]': { |
| 113 | + 'min_cardinality': 0, |
| 114 | + 'max_cardinality': 10, |
| 115 | + } |
| 116 | + } |
| 117 | + mock_create_parameters.return_value = { |
| 118 | + 'DAYZ_SPEC_VERSION': 'V1', |
| 119 | + 'tables': { |
| 120 | + 'table_name': { |
| 121 | + 'num_rows': 100, |
| 122 | + 'columns': { |
| 123 | + 'col1': {'missing_values_proportion': 0.1}, |
| 124 | + 'col2': {'missing_values_proportion': 0.2}, |
| 125 | + }, |
| 126 | + } |
| 127 | + }, |
| 128 | + } |
| 129 | + |
| 130 | + # Run |
| 131 | + result = create_parameters_multi_table(data, metadata, output_filename) |
| 132 | + |
| 133 | + # Assert |
| 134 | + mock_create_parameters.assert_called_once_with(data, metadata, None) |
| 135 | + mock_detect_relationship.assert_called_once_with(data, metadata) |
| 136 | + assert result == { |
| 137 | + 'DAYZ_SPEC_VERSION': 'V1', |
| 138 | + 'tables': { |
| 139 | + 'table_name': { |
| 140 | + 'num_rows': 100, |
| 141 | + 'columns': { |
| 142 | + 'col1': {'missing_values_proportion': 0.1}, |
| 143 | + 'col2': {'missing_values_proportion': 0.2}, |
| 144 | + }, |
| 145 | + } |
| 146 | + }, |
| 147 | + 'relationships': { |
| 148 | + '["parent_table", "child_table", "parent_pk", "child_fk"]': { |
| 149 | + 'min_cardinality': 0, |
| 150 | + 'max_cardinality': 10, |
| 151 | + } |
| 152 | + }, |
| 153 | + } |
| 154 | + assert result == mock_create_parameters.return_value |
| 155 | + with open(output_filename) as f: |
| 156 | + output = json.load(f) |
| 157 | + |
| 158 | + assert output == result |
| 159 | + |
| 160 | + |
55 | 161 | def test__validate_relationship_structure(): |
56 | 162 | """Test validating the relationship parameters structure.""" |
57 | 163 | # Setup |
|
0 commit comments