diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 897189483..ab4d74507 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -341,14 +341,18 @@ def update_transformers(self, table_name, column_name_to_transformer): def _store_and_convert_original_cols(self, data): list_of_changed_tables = [] for table, dataframe in data.items(): - self._original_table_columns[table] = dataframe.columns - for column in dataframe.columns: + data_columns = dataframe.columns + col_name_mapping = {str(col): col for col in data_columns} + reverse_col_name_mapping = {col: str(col) for col in data_columns} + self._original_table_columns[table] = col_name_mapping + dataframe = dataframe.rename(columns=reverse_col_name_mapping) + for column in data_columns: if isinstance(column, int): - dataframe.columns = dataframe.columns.astype(str) list_of_changed_tables.append(table) break data[table] = dataframe + return list_of_changed_tables def _transform_helper(self, data): @@ -392,7 +396,7 @@ def preprocess(self, data): raise e for table in list_of_changed_tables: - data[table].columns = self._original_table_columns[table] + data[table] = data[table].rename(columns=self._original_table_columns[table]) return processed_data @@ -524,9 +528,16 @@ def sample(self, scale=1.0): total_columns += len(table.columns) table_columns = getattr(self, '_original_table_columns', {}) + for table in sampled_data: + table_data = sampled_data[table][self.get_metadata().get_column_names(table)] if table in table_columns: - sampled_data[table].columns = table_columns[table] + if isinstance(table_columns[table], dict): + table_data = table_data.rename(columns=table_columns[table]) + else: + table_data.columns = table_columns[table] + + sampled_data[table] = table_data SYNTHESIZER_LOGGER.info({ 'EVENT': 'Sample', diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 288e470ad..97c981ced 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -46,7 +46,7 @@ def test_hma(self): assert set(normal_sample) == {'characters', 'character_families', 'families'} assert set(increased_sample) == {'characters', 'character_families', 'families'} for table_name, table in normal_sample.items(): - assert all(table.columns == data[table_name].columns) + assert set(table.columns) == set(data[table_name]) for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()): assert increased_table.size > normal_table.size @@ -72,7 +72,7 @@ def test_hma_metadata(self): assert set(normal_sample) == {'characters', 'character_families', 'families'} assert set(increased_sample) == {'characters', 'character_families', 'families'} for table_name, table in normal_sample.items(): - assert all(table.columns == data[table_name].columns) + assert set(table.columns) == set(data[table_name]) for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()): assert increased_table.size > normal_table.size @@ -2172,3 +2172,52 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(): num_table_cols -= 1 assert num_table_cols == estimated_num_columns[table_name] + + +def test_column_order(): + """Test that the column order of the synthetic data is the one of the metadata.""" + # Setup + table_1 = pd.DataFrame({ + 'col_1': [1, 2, 3], + 'col_3': [7, 8, 9], + 'col_2': [4, 5, 6], + }) + table_2 = pd.DataFrame({ + 'col_A': ['a', 'b', 'c'], + 'col_B': ['d', 'e', 'f'], + 'col_C': ['g', 'h', 'i'], + }) + metadata = Metadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'numerical'}, + 'col_3': {'sdtype': 'numerical'}, + }, + }, + 'table_2': { + 'columns': { + 'col_A': {'sdtype': 'categorical'}, + 'col_B': {'sdtype': 'categorical'}, + 'col_C': {'sdtype': 'categorical'}, + }, + }, + } + }) + data = { + 'table_1': table_1, + 'table_2': table_2, + } + + synthesizer = HMASynthesizer(metadata) + synthesizer.fit(data) + + # Run + synthetic_data = synthesizer.sample() + + # Assert + table_1_column = list(synthetic_data['table_1'].columns) + assert table_1_column != list(data['table_1'].columns) + assert table_1_column == ['col_1', 'col_2', 'col_3'] + assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C'] diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 55491440e..050d90cdd 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -846,7 +846,7 @@ def test_preprocess_int_columns(self): }, }, 'second_table': { - 'columns': {'3': {'sdtype': 'id'}, 'str': {'sdtype': 'categorical'}} + 'columns': {'3': {'sdtype': 'id'}, 'another': {'sdtype': 'categorical'}} }, }, 'relationships': [ @@ -889,7 +889,7 @@ def test_preprocess_int_columns(self): 'another': ['John', 'Doe', 'John Doe'], }), } - + assert set(multi_data['first_table'].columns) == set(corrected_frame['first_table'].columns) pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table']) pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table']) @@ -1234,11 +1234,11 @@ def test_sample(self, mock_datetime, caplog): metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) instance._fitted = True - data = { - 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), - } + data = get_multi_table_data() instance._sample = Mock(return_value=data) + instance._original_table_columns = { + 'nesreca': ['upravna_enota', 'id_nesreca', 'nesreca_val'], + } instance._reverse_transform_helper = Mock(return_value=data) synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' @@ -1256,9 +1256,9 @@ def test_sample(self, mock_datetime, caplog): 'TIMESTAMP': '2024-04-19 16:20:10.037183', 'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer', 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - 'TOTAL NUMBER OF TABLES': 2, - 'TOTAL NUMBER OF ROWS': 6, - 'TOTAL NUMBER OF COLUMNS': 4, + 'TOTAL NUMBER OF TABLES': 3, + 'TOTAL NUMBER OF ROWS': 12, + 'TOTAL NUMBER OF COLUMNS': 8, }) def test_get_learned_distributions_raises_an_unfitted_error(self):