Skip to content

Commit 39f060e

Browse files
authored
Incorrect column name ordering for Multi-Table Synthesizer (#2295)
1 parent 20c1f28 commit 39f060e

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

sdv/multi_table/base.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,18 @@ def update_transformers(self, table_name, column_name_to_transformer):
341341
def _store_and_convert_original_cols(self, data):
342342
list_of_changed_tables = []
343343
for table, dataframe in data.items():
344-
self._original_table_columns[table] = dataframe.columns
345-
for column in dataframe.columns:
344+
data_columns = dataframe.columns
345+
col_name_mapping = {str(col): col for col in data_columns}
346+
reverse_col_name_mapping = {col: str(col) for col in data_columns}
347+
self._original_table_columns[table] = col_name_mapping
348+
dataframe = dataframe.rename(columns=reverse_col_name_mapping)
349+
for column in data_columns:
346350
if isinstance(column, int):
347-
dataframe.columns = dataframe.columns.astype(str)
348351
list_of_changed_tables.append(table)
349352
break
350353

351354
data[table] = dataframe
355+
352356
return list_of_changed_tables
353357

354358
def _transform_helper(self, data):
@@ -392,7 +396,7 @@ def preprocess(self, data):
392396
raise e
393397

394398
for table in list_of_changed_tables:
395-
data[table].columns = self._original_table_columns[table]
399+
data[table] = data[table].rename(columns=self._original_table_columns[table])
396400

397401
return processed_data
398402

@@ -524,9 +528,16 @@ def sample(self, scale=1.0):
524528
total_columns += len(table.columns)
525529

526530
table_columns = getattr(self, '_original_table_columns', {})
531+
527532
for table in sampled_data:
533+
table_data = sampled_data[table][self.get_metadata().get_column_names(table)]
528534
if table in table_columns:
529-
sampled_data[table].columns = table_columns[table]
535+
if isinstance(table_columns[table], dict):
536+
table_data = table_data.rename(columns=table_columns[table])
537+
else:
538+
table_data.columns = table_columns[table]
539+
540+
sampled_data[table] = table_data
530541

531542
SYNTHESIZER_LOGGER.info({
532543
'EVENT': 'Sample',

tests/integration/multi_table/test_hma.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_hma(self):
4646
assert set(normal_sample) == {'characters', 'character_families', 'families'}
4747
assert set(increased_sample) == {'characters', 'character_families', 'families'}
4848
for table_name, table in normal_sample.items():
49-
assert all(table.columns == data[table_name].columns)
49+
assert set(table.columns) == set(data[table_name])
5050

5151
for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
5252
assert increased_table.size > normal_table.size
@@ -72,7 +72,7 @@ def test_hma_metadata(self):
7272
assert set(normal_sample) == {'characters', 'character_families', 'families'}
7373
assert set(increased_sample) == {'characters', 'character_families', 'families'}
7474
for table_name, table in normal_sample.items():
75-
assert all(table.columns == data[table_name].columns)
75+
assert set(table.columns) == set(data[table_name])
7676

7777
for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()):
7878
assert increased_table.size > normal_table.size
@@ -2172,3 +2172,52 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes():
21722172
num_table_cols -= 1
21732173

21742174
assert num_table_cols == estimated_num_columns[table_name]
2175+
2176+
2177+
def test_column_order():
2178+
"""Test that the column order of the synthetic data is the one of the metadata."""
2179+
# Setup
2180+
table_1 = pd.DataFrame({
2181+
'col_1': [1, 2, 3],
2182+
'col_3': [7, 8, 9],
2183+
'col_2': [4, 5, 6],
2184+
})
2185+
table_2 = pd.DataFrame({
2186+
'col_A': ['a', 'b', 'c'],
2187+
'col_B': ['d', 'e', 'f'],
2188+
'col_C': ['g', 'h', 'i'],
2189+
})
2190+
metadata = Metadata.load_from_dict({
2191+
'tables': {
2192+
'table_1': {
2193+
'columns': {
2194+
'col_1': {'sdtype': 'numerical'},
2195+
'col_2': {'sdtype': 'numerical'},
2196+
'col_3': {'sdtype': 'numerical'},
2197+
},
2198+
},
2199+
'table_2': {
2200+
'columns': {
2201+
'col_A': {'sdtype': 'categorical'},
2202+
'col_B': {'sdtype': 'categorical'},
2203+
'col_C': {'sdtype': 'categorical'},
2204+
},
2205+
},
2206+
}
2207+
})
2208+
data = {
2209+
'table_1': table_1,
2210+
'table_2': table_2,
2211+
}
2212+
2213+
synthesizer = HMASynthesizer(metadata)
2214+
synthesizer.fit(data)
2215+
2216+
# Run
2217+
synthetic_data = synthesizer.sample()
2218+
2219+
# Assert
2220+
table_1_column = list(synthetic_data['table_1'].columns)
2221+
assert table_1_column != list(data['table_1'].columns)
2222+
assert table_1_column == ['col_1', 'col_2', 'col_3']
2223+
assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C']

tests/unit/multi_table/test_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ def test_preprocess_int_columns(self):
846846
},
847847
},
848848
'second_table': {
849-
'columns': {'3': {'sdtype': 'id'}, 'str': {'sdtype': 'categorical'}}
849+
'columns': {'3': {'sdtype': 'id'}, 'another': {'sdtype': 'categorical'}}
850850
},
851851
},
852852
'relationships': [
@@ -889,7 +889,7 @@ def test_preprocess_int_columns(self):
889889
'another': ['John', 'Doe', 'John Doe'],
890890
}),
891891
}
892-
892+
assert set(multi_data['first_table'].columns) == set(corrected_frame['first_table'].columns)
893893
pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table'])
894894
pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table'])
895895

@@ -1234,11 +1234,11 @@ def test_sample(self, mock_datetime, caplog):
12341234
metadata = get_multi_table_metadata()
12351235
instance = BaseMultiTableSynthesizer(metadata)
12361236
instance._fitted = True
1237-
data = {
1238-
'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}),
1239-
'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}),
1240-
}
1237+
data = get_multi_table_data()
12411238
instance._sample = Mock(return_value=data)
1239+
instance._original_table_columns = {
1240+
'nesreca': ['upravna_enota', 'id_nesreca', 'nesreca_val'],
1241+
}
12421242
instance._reverse_transform_helper = Mock(return_value=data)
12431243

12441244
synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
@@ -1256,9 +1256,9 @@ def test_sample(self, mock_datetime, caplog):
12561256
'TIMESTAMP': '2024-04-19 16:20:10.037183',
12571257
'SYNTHESIZER CLASS NAME': 'BaseMultiTableSynthesizer',
12581258
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
1259-
'TOTAL NUMBER OF TABLES': 2,
1260-
'TOTAL NUMBER OF ROWS': 6,
1261-
'TOTAL NUMBER OF COLUMNS': 4,
1259+
'TOTAL NUMBER OF TABLES': 3,
1260+
'TOTAL NUMBER OF ROWS': 12,
1261+
'TOTAL NUMBER OF COLUMNS': 8,
12621262
})
12631263

12641264
def test_get_learned_distributions_raises_an_unfitted_error(self):

0 commit comments

Comments
 (0)