diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index e3c099402..60d2b11f8 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -404,6 +404,7 @@ def set_table_parameters(self, table_name, table_parameters): self._table_synthesizers[table_name] = self._synthesizer( metadata=table_metadata, **table_parameters ) + self._table_synthesizers[table_name]._data_processor.table_name = table_name self._table_parameters[table_name].update(deepcopy(table_parameters)) def _validate_all_tables(self, data): diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 117e276e6..a779a347a 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -53,7 +53,7 @@ def _get_num_data_columns(metadata): columns_per_table = {} for table_name, table in metadata.tables.items(): key_columns = metadata._get_all_keys(table_name) - columns_per_table[table_name] = sum([ + num_data_columns = sum([ 1 for col_name, col_meta in table.columns.items() if ( @@ -61,6 +61,8 @@ def _get_num_data_columns(metadata): or (col_name not in key_columns and col_meta.get('pii', False) is False) ) ]) + num_extended_columns = 0 + columns_per_table[table_name] = [num_data_columns, num_extended_columns] return columns_per_table @@ -85,18 +87,29 @@ def _get_num_extended_columns( table_name, cls.DEFAULT_SYNTHESIZER_KWARGS['default_distribution'] ) - num_parameters = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution] - + num_params_data = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[distribution] + num_params_extended = cls.DISTRIBUTIONS_TO_NUM_PARAMETER_COLUMNS[ + DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION + ] num_rows_columns = len(metadata._get_foreign_keys(parent_table, table_name)) - # no parameter columns are generated if there are no data columns - num_data_columns = columns_per_table[table_name] - if num_data_columns == 0: + # no parameter columns are generated if there are no data or extended columns + num_data_columns = columns_per_table[table_name][0] + num_extended_columns = columns_per_table[table_name][1] + + if (num_data_columns + num_extended_columns) == 0: return num_rows_columns - num_parameters_columns = num_rows_columns * num_data_columns * num_parameters + num_parameters_columns = (num_rows_columns * num_data_columns * num_params_data) + ( + num_rows_columns * num_extended_columns * num_params_extended + ) - num_correlation_columns = num_rows_columns * (num_data_columns - 1) * num_data_columns // 2 + num_correlation_columns = ( + num_rows_columns + * (num_data_columns + num_extended_columns - 1) + * (num_data_columns + num_extended_columns) + // 2 + ) return num_correlation_columns + num_rows_columns + num_parameters_columns @@ -118,9 +131,11 @@ def _estimate_columns_traversal( """ for child_name in metadata._get_child_map()[table_name]: if child_name not in visited: - cls._estimate_columns_traversal(metadata, child_name, columns_per_table, visited) + cls._estimate_columns_traversal( + metadata, child_name, columns_per_table, visited, distributions + ) - columns_per_table[table_name] += cls._get_num_extended_columns( + columns_per_table[table_name][1] += cls._get_num_extended_columns( metadata, child_name, table_name, columns_per_table, distributions ) @@ -157,7 +172,9 @@ def _estimate_num_columns(cls, metadata, distributions=None): metadata, table_name, columns_per_table, visited, distributions ) - return columns_per_table + return { + table_name: sum(columns_list) for table_name, columns_list in columns_per_table.items() + } def __init__(self, metadata, locales=['en_US'], verbose=True): BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales) @@ -173,6 +190,11 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): BaseHierarchicalSampler.__init__( self, self.metadata, self._table_synthesizers, self._table_sizes ) + child_tables = set() + for relationship in metadata.relationships: + child_tables.add(relationship['child_table_name']) + for child_table_name in child_tables: + self.set_table_parameters(child_table_name, {'default_distribution': 'norm'}) self._print_estimate_warning() def set_table_parameters(self, table_name, table_parameters): @@ -238,7 +260,7 @@ def _print_estimate_warning(self): for table, est_cols in self._estimate_num_columns(self.metadata, distributions).items(): entry = [] entry.append(table) - entry.append(metadata_columns[table]) + entry.append(sum(metadata_columns[table])) total_est_cols += est_cols entry.append(est_cols) print_table.append(entry) @@ -679,6 +701,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name]) + extended_columns = getattr(self, '_parent_extended_columns', {}).get(table_name, []) + if extended_columns: + self._set_extended_columns_distributions(synthesizer, table_name, extended_columns) synthesizer._set_parameters(parameters) try: likelihoods[parent_id] = synthesizer._get_likelihood(table_rows) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 33e38aafd..e900f1c43 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2610,9 +2610,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(): }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) + distributions = synthesizer._get_distributions() # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) + estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions) # Run actual modeling synthesizer.fit(data) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index 19ba309aa..b58f1c8ed 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -70,12 +70,12 @@ def test_simplify_schema(capsys): # Assert expected_message_before = re.compile( r'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended\.' - r' To model this data, HMA will generate a large number of columns\. \(173818 columns\)\s+' + r' To model this data, HMA will generate a large number of columns\. \(135934 columns\)\s+' r'Table Name\s*#\s*Columns in Metadata\s*Est # Columns\s*' r'match_stats\s*24\s*24\s*' - r'matches\s*39\s*412\s*' - r'players\s*5\s*378\s*' - r'teams\s*1\s*173004\s*' + r'matches\s*39\s*364\s*' + r'players\s*5\s*330\s*' + r'teams\s*1\s*135216\s*' r'We recommend simplifying your metadata schema using ' r"'sdv.utils.poc.simplify_schema'\.\s*" r'If this is not possible, please visit ' diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index a26222e06..f95fbc286 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -26,8 +26,8 @@ def test___init__(self): assert isinstance(instance._table_synthesizers['oseba'], GaussianCopulaSynthesizer) assert isinstance(instance._table_synthesizers['upravna_enota'], GaussianCopulaSynthesizer) assert instance._table_parameters == { - 'nesreca': {'default_distribution': 'beta'}, - 'oseba': {'default_distribution': 'beta'}, + 'nesreca': {'default_distribution': 'norm'}, + 'oseba': {'default_distribution': 'norm'}, 'upravna_enota': {'default_distribution': 'beta'}, } instance.metadata.validate.assert_called_once_with() @@ -70,8 +70,6 @@ def test__get_extension(self): # Assert expected = pd.DataFrame({ - '__nesreca__upravna_enota__univariates__id_nesreca__a': [1.0, 1.0, 1.0, 1.0], - '__nesreca__upravna_enota__univariates__id_nesreca__b': [1.0, 1.0, 1.0, 1.0], '__nesreca__upravna_enota__univariates__id_nesreca__loc': [0.0, 1.0, 2.0, 3.0], '__nesreca__upravna_enota__univariates__id_nesreca__scale': [np.nan] * 4, '__nesreca__upravna_enota__num_rows': [1.0, 1.0, 1.0, 1.0], @@ -187,12 +185,8 @@ def test__augment_table(self): 'nesreca_val': [0, 1, 2, 3], 'value': [0, 1, 2, 3], '__oseba__id_nesreca__correlation__0__0': [0.0] * 4, - '__oseba__id_nesreca__univariates__oseba_val__a': [1.0] * 4, - '__oseba__id_nesreca__univariates__oseba_val__b': [1.0] * 4, '__oseba__id_nesreca__univariates__oseba_val__loc': [0.0, 1.0, 2.0, 3.0], '__oseba__id_nesreca__univariates__oseba_val__scale': [1e-6] * 4, - '__oseba__id_nesreca__univariates__oseba_value__a': [1.0] * 4, - '__oseba__id_nesreca__univariates__oseba_value__b': [1.0] * 4, '__oseba__id_nesreca__univariates__oseba_value__loc': [0.0, 1.0, 2.0, 3.0], '__oseba__id_nesreca__univariates__oseba_value__scale': [1e-6] * 4, '__oseba__id_nesreca__num_rows': [1.0] * 4, @@ -877,9 +871,10 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) + distributions = synthesizer._get_distributions() # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) + estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions) # Run actual modeling synthesizer.fit(data) @@ -1152,9 +1147,10 @@ def test__estimate_num_columns_to_be_modeled(self): }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) + distributions = synthesizer._get_distributions() # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) + estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions) # Run actual modeling synthesizer.fit(data) @@ -1264,9 +1260,10 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): }) synthesizer = HMASynthesizer(metadata) synthesizer._finalize = Mock(return_value=data) + distributions = synthesizer._get_distributions() # Run estimation - estimated_num_columns = synthesizer._estimate_num_columns(metadata) + estimated_num_columns = synthesizer._estimate_num_columns(metadata, distributions) # Run actual modeling synthesizer.fit(data)