Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,18 @@ def update_transformers(self, table_name, column_name_to_transformer):
def _store_and_convert_original_cols(self, data):
list_of_changed_tables = []
for table, dataframe in data.items():
self._original_table_columns[table] = dataframe.columns
for column in dataframe.columns:
data_columns = dataframe.columns
col_name_mapping = {str(col): col for col in data_columns}
reverse_col_name_mapping = {col: str(col) for col in data_columns}
self._original_table_columns[table] = col_name_mapping
dataframe = dataframe.rename(columns=reverse_col_name_mapping)
for column in data_columns:
if isinstance(column, int):
dataframe.columns = dataframe.columns.astype(str)
list_of_changed_tables.append(table)
break

data[table] = dataframe

return list_of_changed_tables

def _transform_helper(self, data):
Expand Down Expand Up @@ -392,7 +396,7 @@ def preprocess(self, data):
raise e

for table in list_of_changed_tables:
data[table].columns = self._original_table_columns[table]
data[table] = data[table].rename(columns=self._original_table_columns[table])

return processed_data

Expand Down Expand Up @@ -524,9 +528,16 @@ def sample(self, scale=1.0):
total_columns += len(table.columns)

table_columns = getattr(self, '_original_table_columns', {})

for table in sampled_data:
table_data = sampled_data[table][self.get_metadata().get_column_names(table)]
if table in table_columns:
sampled_data[table].columns = table_columns[table]
if isinstance(table_columns[table], dict):
table_data = table_data.rename(columns=table_columns[table])
else:
table_data.columns = table_columns[table]

sampled_data[table] = table_data

SYNTHESIZER_LOGGER.info({
'EVENT': 'Sample',
Expand Down
53 changes: 51 additions & 2 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_hma(self):
assert set(normal_sample) == {'characters', 'character_families', 'families'}
assert set(increased_sample) == {'characters', 'character_families', 'families'}
for table_name, table in normal_sample.items():
assert all(table.columns == data[table_name].columns)
assert set(table.columns) == set(data[table_name])

for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
assert increased_table.size > normal_table.size
Expand All @@ -72,7 +72,7 @@ def test_hma_metadata(self):
assert set(normal_sample) == {'characters', 'character_families', 'families'}
assert set(increased_sample) == {'characters', 'character_families', 'families'}
for table_name, table in normal_sample.items():
assert all(table.columns == data[table_name].columns)
assert set(table.columns) == set(data[table_name])

for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
assert increased_table.size > normal_table.size
Expand Down Expand Up @@ -2172,3 +2172,52 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes():
num_table_cols -= 1

assert num_table_cols == estimated_num_columns[table_name]


def test_column_order():
"""Test that the column order of the synthetic data is the one of the metadata."""
# Setup
table_1 = pd.DataFrame({
'col_1': [1, 2, 3],
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
})
table_2 = pd.DataFrame({
'col_A': ['a', 'b', 'c'],
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
metadata = Metadata.load_from_dict({
'tables': {
'table_1': {
'columns': {
'col_1': {'sdtype': 'numerical'},
'col_2': {'sdtype': 'numerical'},
'col_3': {'sdtype': 'numerical'},
},
},
'table_2': {
'columns': {
'col_A': {'sdtype': 'categorical'},
'col_B': {'sdtype': 'categorical'},
'col_C': {'sdtype': 'categorical'},
},
},
}
})
data = {
'table_1': table_1,
'table_2': table_2,
}

synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)

# Run
synthetic_data = synthesizer.sample()

# Assert
table_1_column = list(synthetic_data['table_1'].columns)
assert table_1_column != list(data['table_1'].columns)
assert table_1_column == ['col_1', 'col_2', 'col_3']
assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C']
18 changes: 9 additions & 9 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def test_preprocess_int_columns(self):
},
},
'second_table': {
'columns': {'3': {'sdtype': 'id'}, 'str': {'sdtype': 'categorical'}}
'columns': {'3': {'sdtype': 'id'}, 'another': {'sdtype': 'categorical'}}
},
},
'relationships': [
Expand Down Expand Up @@ -889,7 +889,7 @@ def test_preprocess_int_columns(self):
'another': ['John', 'Doe', 'John Doe'],
}),
}

assert set(multi_data['first_table'].columns) == set(corrected_frame['first_table'].columns)
pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table'])
pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table'])

Expand Down Expand Up @@ -1234,11 +1234,11 @@ def test_sample(self, mock_datetime, caplog):
metadata = get_multi_table_metadata()
instance = BaseMultiTableSynthesizer(metadata)
instance._fitted = True
data = {
'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}),
'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}),
}
data = get_multi_table_data()
instance._sample = Mock(return_value=data)
instance._original_table_columns = {
'nesreca': ['upravna_enota', 'id_nesreca', 'nesreca_val'],
}
instance._reverse_transform_helper = Mock(return_value=data)

synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
Expand All @@ -1256,9 +1256,9 @@ def test_sample(self, mock_datetime, caplog):
'TIMESTAMP': '2024-04-19 16:20:10.037183',
'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer',
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
'TOTAL NUMBER OF TABLES': 2,
'TOTAL NUMBER OF ROWS': 6,
'TOTAL NUMBER OF COLUMNS': 4,
'TOTAL NUMBER OF TABLES': 3,
'TOTAL NUMBER OF ROWS': 12,
'TOTAL NUMBER OF COLUMNS': 8,
})

def test_get_learned_distributions_raises_an_unfitted_error(self):
Expand Down