Skip to content

Commit 68e9565

Browse files
committed
use existing logic
1 parent ae72623 commit 68e9565

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

sdv/multi_table/hma.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,18 +779,38 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
779779
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)
780780

781781
def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
782+
"""Add foreign key columns in the child table.
783+
784+
This function adds foreign key columns to a child table.
785+
If the foreign key column does not exist in the child table, it adds the column.
786+
If the foreign key column already exists in the child table (e.g., when it is also a PK)
787+
and it contains invalid references (FKs not found in parent table), it overwrites the
788+
foreign key values (from the parent table's PK).
789+
790+
Args:
791+
child_table (pd.DataFrame): The child table which may or may not contain the FK columns.
792+
parent_table (pd.DataFrame): The parent table.
793+
child_name (str): The name of the child table in the metadata.
794+
parent_name (str): The name of the parent table in the metadata.
795+
796+
Returns:
797+
None: The child_table is modified in-place.
798+
"""
782799
parent_primary_key = self.metadata.tables[parent_name].primary_key
783800
parent_id_values = None
784801
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
785-
needs_assignment = True
786-
if foreign_key in child_table:
787-
child_column = child_table[foreign_key]
788-
if not child_column.dropna().empty:
802+
needs_assignment = foreign_key not in child_table
803+
804+
if not needs_assignment:
805+
child_column = child_table[foreign_key].dropna()
806+
if child_column.empty:
807+
needs_assignment = True
808+
else:
789809
if parent_id_values is None:
790810
parent_id_values = parent_table[parent_primary_key].dropna().unique()
791-
is_valid = child_column.dropna().isin(parent_id_values).all()
792-
if is_valid:
793-
needs_assignment = False
811+
812+
if not child_column.isin(parent_id_values).all():
813+
needs_assignment = True
794814

795815
if needs_assignment:
796816
parent_ids = self._find_parent_ids(

tests/unit/multi_table/test_hma.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,41 @@ def test__add_foreign_key_columns(self):
912912
pd.testing.assert_frame_equal(expected_parent_table, parent_table)
913913
pd.testing.assert_frame_equal(expected_child_table, child_table)
914914

915+
def test__add_foreign_key_columns_fk_already_in_child(self):
916+
"""Test that ``_add_foreign_key_columns`` does not add a FK already in child table."""
917+
# Setup
918+
instance = Mock()
919+
mock_child_metadata = Mock()
920+
mock_child_metadata.primary_key = 'parent_id'
921+
mock_parent_metadata = Mock()
922+
mock_parent_metadata.primary_key = 'parent_id_pk'
923+
mock_metadata = Mock()
924+
mock_metadata.tables = {
925+
'child': mock_child_metadata,
926+
'parent': mock_parent_metadata,
927+
}
928+
mock_metadata._get_foreign_keys.return_value = ['parent_id']
929+
instance.metadata = mock_metadata
930+
child_table = pd.DataFrame({'parent_id': [1, 2, 3], 'col_num': [10.1, 11.2, 12.3]})
931+
parent_table = pd.DataFrame({'parent_id_pk': [1, 2, 3, 4], 'col_cat': ['A', 'B', 'B', 'C']})
932+
933+
# Run
934+
HMASynthesizer._add_foreign_key_columns(
935+
instance, child_table, parent_table, child_name='child', parent_name='parent'
936+
)
937+
938+
# Assert
939+
expected_parent_table = pd.DataFrame({
940+
'parent_id_pk': pd.Series([1, 2, 3, 4], dtype=np.int64),
941+
'col_cat': pd.Series(['A', 'B', 'B', 'C'], dtype=object),
942+
})
943+
expected_child_table = pd.DataFrame({
944+
'parent_id': pd.Series([1, 2, 3], dtype=np.int64),
945+
'col_num': pd.Series([10.1, 11.2, 12.3], dtype='float64'),
946+
})
947+
pd.testing.assert_frame_equal(expected_parent_table, parent_table)
948+
pd.testing.assert_frame_equal(expected_child_table, child_table)
949+
915950
def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self):
916951
"""Test it when there are two relationships between a parent and a child tables.
917952

0 commit comments

Comments
 (0)