Skip to content

Commit 6f15e19

Browse files
committed
wip
1 parent 848b708 commit 6f15e19

File tree

4 files changed

+37
-18
lines changed

4 files changed

+37
-18
lines changed

sdv/multi_table/hma.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,19 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
779779
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)
780780

781781
def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
782+
parent_primary_key = self.metadata.tables[parent_name].primary_key
783+
parent_id_values = parent_table[parent_primary_key].dropna().unique()
782784
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
783-
if foreign_key not in child_table:
785+
needs_assignment = True
786+
if foreign_key in child_table:
787+
child_column = child_table[foreign_key]
788+
if not child_column.dropna().empty:
789+
# check if child column (FK) contains IDs in parent table
790+
is_valid = child_column.dropna().isin(parent_id_values).all()
791+
if is_valid:
792+
needs_assignment = False
793+
794+
if needs_assignment:
784795
parent_ids = self._find_parent_ids(
785796
child_table=child_table,
786797
parent_table=parent_table,

tests/integration/multi_table/conftest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,6 @@ def data_metadata_multiple_foreign_keys():
234234
})
235235
assert data['child']['parent_1_id'].equals(data['parent']['parent_id'])
236236
assert data['child']['parent_2_id'].equals(data['second_parent']['parent_id'])
237-
metadata.validate()
238-
metadata.validate_data(data)
239237
return data, metadata
240238

241239

@@ -259,6 +257,4 @@ def data_metadata_multiple_foreign_keys_subset(data_metadata_multiple_foreign_ke
259257
}
260258
assert set(data['child']['parent_1_id']).issubset(set(data['parent']['parent_id']))
261259
assert set(data['child']['parent_2_id']).issubset(set(data['second_parent']['parent_id']))
262-
metadata.validate()
263-
metadata.validate_data(data)
264260
return data, metadata

tests/integration/multi_table/test_hma.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,14 +2961,16 @@ def test_multiple_fks(self, data_metadata_multiple_foreign_keys):
29612961
synthetic_data = synthesizer.sample(scale=1.0)
29622962

29632963
# Assert
2964-
for each_parent_id in synthetic_data['child']['parent_1_id'].tolist():
2965-
assert each_parent_id in set(synthetic_data['parent']['parent_id'])
2966-
for each_parent_id in synthetic_data['child']['parent_2_id'].tolist():
2967-
assert each_parent_id in set(synthetic_data['second_parent']['parent_id'])
2964+
assert set(synthetic_data['child']['parent_1_id']).issubset(
2965+
synthetic_data['parent']['parent_id']
2966+
)
2967+
assert set(synthetic_data['child']['parent_2_id']).issubset(
2968+
synthetic_data['second_parent']['parent_id']
2969+
)
29682970
synthesizer.validate(synthetic_data)
29692971

2970-
def test_multiple_fks_mismatched(self, data_metadata_multiple_foreign_keys_subset):
2971-
"""Test support for parent and child with multiple foreign keys."""
2972+
def test_multiple_fks_subset(self, data_metadata_multiple_foreign_keys_subset):
2973+
"""Test support for parent and child with multiple foreign keys (subset in child)."""
29722974
# Setup
29732975
data, metadata = data_metadata_multiple_foreign_keys_subset
29742976

@@ -2978,8 +2980,10 @@ def test_multiple_fks_mismatched(self, data_metadata_multiple_foreign_keys_subse
29782980
synthetic_data = synthesizer.sample(scale=1.0)
29792981

29802982
# Assert
2981-
for each_parent_id in synthetic_data['child']['parent_1_id'].tolist():
2982-
assert each_parent_id in set(synthetic_data['parent']['parent_id'])
2983-
for each_parent_id in synthetic_data['child']['parent_2_id'].tolist():
2984-
assert each_parent_id in set(synthetic_data['second_parent']['parent_id'])
2983+
assert set(synthetic_data['child']['parent_1_id']).issubset(
2984+
synthetic_data['parent']['parent_id']
2985+
)
2986+
assert set(synthetic_data['child']['parent_2_id']).issubset(
2987+
synthetic_data['second_parent']['parent_id']
2988+
)
29852989
synthesizer.validate(synthetic_data)

tests/unit/multi_table/test_hma.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,17 @@ def test__add_foreign_key_columns(self):
869869
"""Test that the ``_add_foreign_key_columns`` method adds foreign keys."""
870870
# Setup
871871
instance = Mock()
872-
metadata = Mock()
873-
metadata._get_foreign_keys.return_value = ['primary_user_id', 'secondary_user_id']
874-
instance.metadata = metadata
872+
mock_users_metadata = Mock()
873+
mock_users_metadata.primary_key = 'user_id'
874+
mock_transactions_metadata = Mock()
875+
mock_transactions_metadata.primary_key = 'transaction_id'
876+
mock_metadata = Mock()
877+
mock_metadata.tables = {
878+
'users': mock_users_metadata,
879+
'transactions': mock_transactions_metadata,
880+
}
881+
mock_metadata._get_foreign_keys.return_value = ['primary_user_id', 'secondary_user_id']
882+
instance.metadata = mock_metadata
875883

876884
instance._find_parent_ids.return_value = pd.Series([2, 1, 2], name='secondary_user_id')
877885

0 commit comments

Comments
 (0)