Skip to content

Commit c87b595

Browse files
committed
wip
1 parent e4f0369 commit c87b595

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

sdv/multi_table/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ def _assign_table_transformers(self, synthesizer, table_name, table_data):
471471

472472
# foreign_key_columns = self.metadata._get_all_foreign_keys(table_name)
473473
# column_name_to_transformers = {column_name: None for column_name in foreign_key_columns}
474+
# keep it in it's raw form, HMA can use it group later on,
475+
# prevent pre-process FK columns to numerical
474476

475477
column_name_to_transformers = {}
476478
primary_key = self.metadata.tables[table_name].primary_key

sdv/multi_table/hma.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
349349

350350
extension_rows = []
351351
foreign_key_columns = self.metadata._get_all_foreign_keys(child_name)
352+
353+
# foreign_key_values = child_table.index.unique()
354+
# only do if FK not the primary key
352355
foreign_key_values = child_table[foreign_key].unique()
353356
child_table = child_table.set_index(foreign_key)
354357

@@ -447,6 +450,11 @@ def _augment_table(self, table, tables, table_name):
447450
child_table = tables[child_name]
448451

449452
foreign_keys = self.metadata._get_foreign_keys(table_name, child_name)
453+
454+
primary_key = self.metadata.tables[child_name].primary_key
455+
if primary_key in foreign_keys:
456+
child_table = child_table.reset_index(drop=False)
457+
# check here
450458
for foreign_key in foreign_keys:
451459
progress_bar_desc = (
452460
f'({self._learned_relationships + 1}/{len(self.metadata.relationships)}) '
@@ -492,13 +500,15 @@ def _augment_tables(self, processed_data):
492500
processed_data (dict):
493501
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
494502
"""
503+
# data processor sets index
495504
augmented_data = deepcopy(processed_data)
496505
self._augmented_tables = []
497506
self._learned_relationships = 0
498507
parent_map = self.metadata._get_parent_map()
499508
self._print(text='Learning relationships:')
500509
for table_name in processed_data:
501510
if not parent_map.get(table_name):
511+
# only changing the child tables
502512
self._augment_table(augmented_data[table_name], augmented_data, table_name)
503513

504514
LOGGER.info('Augmentation Complete')
@@ -518,6 +528,10 @@ def _pop_foreign_keys(self, table_data, table_name):
518528
A dictionary mapping with the foreign key and it's values within the table.
519529
"""
520530
foreign_keys = self.metadata._get_all_foreign_keys(table_name)
531+
primary_key = self.metadata.tables[table_name].primary_key
532+
if primary_key in foreign_keys:
533+
table_data = table_data.reset_index(drop=False)
534+
521535
keys = {}
522536
for fk in foreign_keys:
523537
keys[fk] = table_data.pop(fk).to_numpy()

tests/integration/multi_table/test_hma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2903,7 +2903,7 @@ def data_metadata_1_to_1():
29032903
metadata = Metadata.load_from_dict(metadata_dict)
29042904
metadata.validate()
29052905
metadata.validate_data(data)
2906-
metadata.remove_primary_key('rooms')
2906+
# metadata.remove_primary_key('rooms')
29072907
return data, metadata
29082908

29092909

0 commit comments

Comments
 (0)