@@ -349,6 +349,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
349349
350350 extension_rows = []
351351 foreign_key_columns = self .metadata ._get_all_foreign_keys (child_name )
352+
353+ # foreign_key_values = child_table.index.unique()
354+ # only do if FK not the primary key
352355 foreign_key_values = child_table [foreign_key ].unique ()
353356 child_table = child_table .set_index (foreign_key )
354357
@@ -447,6 +450,11 @@ def _augment_table(self, table, tables, table_name):
447450 child_table = tables [child_name ]
448451
449452 foreign_keys = self .metadata ._get_foreign_keys (table_name , child_name )
453+
454+ primary_key = self .metadata .tables [child_name ].primary_key
455+ if primary_key in foreign_keys :
456+ child_table = child_table .reset_index (drop = False )
457+ # check here
450458 for foreign_key in foreign_keys :
451459 progress_bar_desc = (
452460 f'({ self ._learned_relationships + 1 } /{ len (self .metadata .relationships )} ) '
@@ -492,13 +500,15 @@ def _augment_tables(self, processed_data):
492500 processed_data (dict):
493501 Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
494502 """
503+ # data processor sets index
495504 augmented_data = deepcopy (processed_data )
496505 self ._augmented_tables = []
497506 self ._learned_relationships = 0
498507 parent_map = self .metadata ._get_parent_map ()
499508 self ._print (text = 'Learning relationships:' )
500509 for table_name in processed_data :
501510 if not parent_map .get (table_name ):
511+ # only changing the child tables
502512 self ._augment_table (augmented_data [table_name ], augmented_data , table_name )
503513
504514 LOGGER .info ('Augmentation Complete' )
@@ -518,6 +528,10 @@ def _pop_foreign_keys(self, table_data, table_name):
518528 A dictionary mapping with the foreign key and it's values within the table.
519529 """
520530 foreign_keys = self .metadata ._get_all_foreign_keys (table_name )
531+ primary_key = self .metadata .tables [table_name ].primary_key
532+ if primary_key in foreign_keys :
533+ table_data = table_data .reset_index (drop = False )
534+
521535 keys = {}
522536 for fk in foreign_keys :
523537 keys [fk ] = table_data .pop (fk ).to_numpy ()
0 commit comments