Skip to content

Commit e6576b0

Browse files
committed
wip
1 parent c87b595 commit e6576b0

File tree

3 files changed

+15
-22
lines changed

3 files changed

+15
-22
lines changed

sdv/multi_table/base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,18 +466,17 @@ def _validate_table_name(self, table_name):
466466
)
467467

468468
def _assign_table_transformers(self, synthesizer, table_name, table_data):
469-
"""Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data."""
470-
synthesizer.auto_assign_transformers(table_data)
469+
"""Update the ``synthesizer`` to ignore the foreign keys while preprocessing the data.
471470
472-
# foreign_key_columns = self.metadata._get_all_foreign_keys(table_name)
473-
# column_name_to_transformers = {column_name: None for column_name in foreign_key_columns}
474-
# keep it in it's raw form, HMA can use it group later on,
475-
# prevent pre-process FK columns to numerical
471+
If the foreign key is also the primary key for the table, then no transformer assigned
472+
for that column.
476473
474+
"""
475+
synthesizer.auto_assign_transformers(table_data)
477476
column_name_to_transformers = {}
478477
primary_key = self.metadata.tables[table_name].primary_key
479478
for relation in self.metadata.relationships:
480-
column_name = deepcopy(relation['child_foreign_key'])
479+
column_name = relation['child_foreign_key']
481480
if (
482481
relation['child_table_name'] == table_name
483482
and relation['child_foreign_key'] != primary_key

sdv/multi_table/hma.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,6 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
349349

350350
extension_rows = []
351351
foreign_key_columns = self.metadata._get_all_foreign_keys(child_name)
352-
353-
# foreign_key_values = child_table.index.unique()
354-
# only do if FK not the primary key
355352
foreign_key_values = child_table[foreign_key].unique()
356353
child_table = child_table.set_index(foreign_key)
357354

@@ -451,10 +448,6 @@ def _augment_table(self, table, tables, table_name):
451448

452449
foreign_keys = self.metadata._get_foreign_keys(table_name, child_name)
453450

454-
primary_key = self.metadata.tables[child_name].primary_key
455-
if primary_key in foreign_keys:
456-
child_table = child_table.reset_index(drop=False)
457-
# check here
458451
for foreign_key in foreign_keys:
459452
progress_bar_desc = (
460453
f'({self._learned_relationships + 1}/{len(self.metadata.relationships)}) '
@@ -500,15 +493,20 @@ def _augment_tables(self, processed_data):
500493
processed_data (dict):
501494
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
502495
"""
503-
# data processor sets index
504496
augmented_data = deepcopy(processed_data)
505497
self._augmented_tables = []
506498
self._learned_relationships = 0
507499
parent_map = self.metadata._get_parent_map()
508500
self._print(text='Learning relationships:')
501+
502+
for table_name in processed_data:
503+
foreign_keys = self.metadata._get_all_foreign_keys(table_name)
504+
primary_key = self.metadata.tables[table_name].primary_key
505+
if primary_key in foreign_keys:
506+
augmented_data[table_name] = augmented_data[table_name].reset_index(drop=False)
507+
509508
for table_name in processed_data:
510509
if not parent_map.get(table_name):
511-
# only changing the child tables
512510
self._augment_table(augmented_data[table_name], augmented_data, table_name)
513511

514512
LOGGER.info('Augmentation Complete')
@@ -528,10 +526,6 @@ def _pop_foreign_keys(self, table_data, table_name):
528526
A dictionary mapping with the foreign key and it's values within the table.
529527
"""
530528
foreign_keys = self.metadata._get_all_foreign_keys(table_name)
531-
primary_key = self.metadata.tables[table_name].primary_key
532-
if primary_key in foreign_keys:
533-
table_data = table_data.reset_index(drop=False)
534-
535529
keys = {}
536530
for fk in foreign_keys:
537531
keys[fk] = table_data.pop(fk).to_numpy()

sdv/sampling/hierarchical_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
102102
if len(sampled_rows):
103103
parent_key = self.metadata.tables[parent_name].primary_key
104104
if foreign_key in sampled_rows:
105-
# If foreign key is in sampeld rows raises `SettingWithCopyWarning`
105+
# If foreign key is in sampled rows raises `SettingWithCopyWarning`
106106
row_indices = sampled_rows.index
107-
sampled_rows[foreign_key].iloc[row_indices] = parent_row[parent_key]
107+
sampled_rows.loc[row_indices, foreign_key] = parent_row[parent_key]
108108
else:
109109
sampled_rows[foreign_key] = (
110110
parent_row[parent_key] if parent_row is not None else np.nan

0 commit comments

Comments
 (0)