Skip to content

Commit f91f373

Browse files
committed
add int tests with multiple FKs
1 parent 51fcdc8 commit f91f373

File tree

3 files changed

+197
-74
lines changed

3 files changed

+197
-74
lines changed

sdv/multi_table/hma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
350350
extension_rows = []
351351
primary_key = self.metadata.tables[child_name].primary_key
352352
foreign_key_columns = self.metadata._get_all_foreign_keys(child_name)
353-
if primary_key and primary_key in foreign_key_columns:
353+
primary_key_is_a_foreign_key = primary_key and primary_key in foreign_key_columns
354+
if primary_key_is_a_foreign_key and foreign_key == primary_key:
354355
# data processor will set index of each table to the PK for table
355356
foreign_key_values = child_table.index.unique()
356357
else:

tests/integration/multi_table/conftest.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from copy import deepcopy
23

34
import pandas as pd
@@ -173,3 +174,91 @@ def data_metadata_1_to_1_to_1_subset_to_subset(data_metadata_1_to_1_subset_diamo
173174
]
174175
metadata = Metadata.load_from_dict(metadata_dict)
175176
return data, metadata
177+
178+
179+
@pytest.fixture
180+
def data_metadata_multiple_foreign_keys():
181+
parent_1_ids = range(0, 10)
182+
parent_2_ids = range(10, 20)
183+
parent = pd.DataFrame({
184+
'parent_id': parent_1_ids,
185+
'col_categorical': random.choices(['A', 'B', 'C', 'D', 'E'], k=10),
186+
})
187+
child = pd.DataFrame({
188+
'parent_1_id': parent_1_ids,
189+
'parent_2_id': parent_2_ids,
190+
'col_numerical': [10.2, 20.3] * 5,
191+
})
192+
second_parent = pd.DataFrame({'parent_id': parent_2_ids, 'col_boolean': [True, False] * 5})
193+
data = {
194+
'parent': parent,
195+
'child': child,
196+
'second_parent': second_parent,
197+
}
198+
metadata = Metadata.load_from_dict({
199+
'tables': {
200+
'parent': {
201+
'columns': {
202+
'parent_id': {'sdtype': 'id'},
203+
'col_categorical': {'sdtype': 'categorical'},
204+
},
205+
'primary_key': 'parent_id',
206+
},
207+
'child': {
208+
'columns': {
209+
'parent_1_id': {'sdtype': 'id'},
210+
'parent_2_id': {'sdtype': 'id'},
211+
'col_numerical': {'sdtype': 'numerical'},
212+
},
213+
'primary_key': 'parent_1_id',
214+
},
215+
'second_parent': {
216+
'columns': {'parent_id': {'sdtype': 'id'}, 'col_boolean': {'sdtype': 'boolean'}},
217+
'primary_key': 'parent_id',
218+
},
219+
},
220+
'relationships': [
221+
{
222+
'parent_table_name': 'parent',
223+
'child_table_name': 'child',
224+
'parent_primary_key': 'parent_id',
225+
'child_foreign_key': 'parent_1_id',
226+
},
227+
{
228+
'parent_table_name': 'second_parent',
229+
'child_table_name': 'child',
230+
'parent_primary_key': 'parent_id',
231+
'child_foreign_key': 'parent_2_id',
232+
},
233+
],
234+
})
235+
assert data['child']['parent_1_id'].equals(data['parent']['parent_id'])
236+
assert data['child']['parent_2_id'].equals(data['second_parent']['parent_id'])
237+
metadata.validate()
238+
metadata.validate_data(data)
239+
return data, metadata
240+
241+
242+
@pytest.fixture
243+
def data_metadata_multiple_foreign_keys_subset(data_metadata_multiple_foreign_keys):
244+
_, metadata = data_metadata_multiple_foreign_keys
245+
parent = pd.DataFrame({'parent_id': [1, 2, 3], 'col_categorical': ['A', 'B', 'C']})
246+
child = pd.DataFrame({
247+
'parent_1_id': [1, 2],
248+
'parent_2_id': [1, 1],
249+
'col_numerical': [100.5, 101.6],
250+
})
251+
second_parent = pd.DataFrame({
252+
'parent_id': [1, 2],
253+
'col_boolean': [True, False],
254+
})
255+
data = {
256+
'parent': parent,
257+
'child': child,
258+
'second_parent': second_parent,
259+
}
260+
assert set(data['child']['parent_1_id']).issubset(set(data['parent']['parent_id']))
261+
assert set(data['child']['parent_2_id']).issubset(set(data['second_parent']['parent_id']))
262+
metadata.validate()
263+
metadata.validate_data(data)
264+
return data, metadata

tests/integration/multi_table/test_hma.py

Lines changed: 106 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,95 +2858,128 @@ def test_datetime_warning_doesnt_repeat():
28582858
assert len(matching_warnings) == 1
28592859

28602860

2861-
def test_hma_1_to_1(data_metadata_1_to_1):
2862-
"""Test HMA handles PK to PK relationship (1 to 1) and synthetic data matching cardinality."""
2863-
# Setup
2864-
data, metadata = data_metadata_1_to_1
2861+
class TestPrimaryKeyToPrimaryKey:
2862+
def test_1_to_1(self, data_metadata_1_to_1):
2863+
"""Test HMA handles PK to PK relationship (1 to 1) and synthetic data matching cardinality."""
2864+
# Setup
2865+
data, metadata = data_metadata_1_to_1
28652866

2866-
# Run
2867-
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2868-
with warnings.catch_warnings(record=True) as caught_warnings:
2867+
# Run
2868+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2869+
with warnings.catch_warnings(record=True) as caught_warnings:
2870+
synthesizer.fit(data)
2871+
synthetic_data = synthesizer.sample(scale=1)
2872+
2873+
# Assert
2874+
assert synthetic_data['guests']['guest_email'].equals(
2875+
synthetic_data['rooms']['guest_email']
2876+
)
2877+
synthesizer.validate(synthetic_data)
2878+
for msg in caught_warnings:
2879+
assert 'ChainedAssignmentError' not in str(msg.message)
2880+
2881+
def test_1_to_1_or_0(self, data_metadata_1_to_1_or_0):
2882+
"""Test HMA handles PK to PK relationship (1 to 1/0) and synthetic data matching cardinality."""
2883+
# Setup
2884+
data, metadata = data_metadata_1_to_1_or_0
2885+
2886+
# Run
2887+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
28692888
synthesizer.fit(data)
28702889
synthetic_data = synthesizer.sample(scale=1)
28712890

2872-
# Assert
2873-
assert synthetic_data['guests']['guest_email'].equals(synthetic_data['rooms']['guest_email'])
2874-
synthesizer.validate(synthetic_data)
2875-
for msg in caught_warnings:
2876-
assert 'ChainedAssignmentError' not in str(msg.message)
2891+
# Assert
2892+
assert set(synthetic_data['users']['user_id']).issuperset(
2893+
set(synthetic_data['survey_response']['user_id'])
2894+
)
2895+
synthesizer.validate(synthetic_data)
28772896

2897+
def test_1_to_1_or_0_not_superset(self, data_metadata_1_to_1_or_0):
2898+
"""Test error is raised if PK to PK relationship but parent is not a superset."""
2899+
# Setup
2900+
data, metadata = data_metadata_1_to_1_or_0
2901+
metadata.remove_relationship(parent_table_name='users', child_table_name='survey_response')
2902+
metadata.add_relationship(
2903+
parent_table_name='survey_response',
2904+
parent_primary_key='user_id',
2905+
child_table_name='users',
2906+
child_foreign_key='user_id',
2907+
)
2908+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2909+
match_ = re.escape("Error: foreign key column 'user_id' contains unknown references: (9).")
28782910

2879-
def test_hma_1_to_1_or_0(data_metadata_1_to_1_or_0):
2880-
"""Test HMA handles PK to PK relationship (1 to 1/0) and synthetic data matching cardinality."""
2881-
# Setup
2882-
data, metadata = data_metadata_1_to_1_or_0
2911+
# Run and Assert
2912+
with pytest.raises(InvalidDataError, match=match_):
2913+
synthesizer.fit(data)
28832914

2884-
# Run
2885-
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2886-
synthesizer.fit(data)
2887-
synthetic_data = synthesizer.sample(scale=1)
2915+
def test_1_to_1_to_1_subset_to_subset(self, data_metadata_1_to_1_to_1_subset_to_subset):
2916+
"""Test PK to PK to PK, with the 2nd and 3rd table having a subset."""
2917+
# Setup
2918+
data, metadata = data_metadata_1_to_1_to_1_subset_to_subset
28882919

2889-
# Assert
2890-
assert set(synthetic_data['users']['user_id']).issuperset(
2891-
set(synthetic_data['survey_response']['user_id'])
2892-
)
2893-
synthesizer.validate(synthetic_data)
2920+
# Run
2921+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2922+
synthesizer.fit(data)
2923+
synthetic_data = synthesizer.sample(scale=1.0)
28942924

2925+
# Assert
2926+
assert set(synthetic_data['guests']['guest_email']).issuperset(
2927+
set(synthetic_data['guests_pii']['guest_email'])
2928+
)
2929+
assert set(synthetic_data['guests_pii']['guest_email']).issuperset(
2930+
set(synthetic_data['rooms']['guest_email'])
2931+
)
2932+
synthesizer.validate(synthetic_data)
28952933

2896-
def test_hma_1_to_1_or_0_not_superset(data_metadata_1_to_1_or_0):
2897-
"""Test error is raised if PK to PK relationship but parent is not a superset."""
2898-
# Setup
2899-
data, metadata = data_metadata_1_to_1_or_0
2900-
metadata.remove_relationship(parent_table_name='users', child_table_name='survey_response')
2901-
metadata.add_relationship(
2902-
parent_table_name='survey_response',
2903-
parent_primary_key='user_id',
2904-
child_table_name='users',
2905-
child_foreign_key='user_id',
2906-
)
2907-
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2908-
match_ = re.escape("Error: foreign key column 'user_id' contains unknown references: (9).")
2934+
def test_1_to_1_to_1_diamond(self, data_metadata_1_to_1_subset_diamond):
2935+
"""Test PK to PK to PK in a diamond relationship."""
2936+
# Setup
2937+
data, metadata = data_metadata_1_to_1_subset_diamond
29092938

2910-
# Run and Assert
2911-
with pytest.raises(InvalidDataError, match=match_):
2939+
# Run
2940+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
29122941
synthesizer.fit(data)
2942+
synthetic_data = synthesizer.sample(scale=1.0)
29132943

2944+
# Assert
2945+
assert set(synthetic_data['guests']['guest_email']).issuperset(
2946+
set(synthetic_data['guests_pii']['guest_email'])
2947+
)
2948+
assert set(synthetic_data['guests']['guest_email']).issuperset(
2949+
set(synthetic_data['rooms']['guest_email'])
2950+
)
2951+
synthesizer.validate(synthetic_data)
29142952

2915-
def test_1_to_1_to_1_subset_to_subset(data_metadata_1_to_1_to_1_subset_to_subset):
2916-
"""Test PK to PK to PK, with the 2nd and 3rd table having a subset."""
2917-
# Setup
2918-
data, metadata = data_metadata_1_to_1_to_1_subset_to_subset
2919-
2920-
# Run
2921-
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2922-
synthesizer.fit(data)
2923-
synthetic_data = synthesizer.sample(scale=1.0)
2953+
def test_multiple_fks(self, data_metadata_multiple_foreign_keys):
2954+
"""Test support for parent and child with multiple foreign keys."""
2955+
# Setup
2956+
data, metadata = data_metadata_multiple_foreign_keys
29242957

2925-
# Assert
2926-
assert set(synthetic_data['guests']['guest_email']).issuperset(
2927-
set(synthetic_data['guests_pii']['guest_email'])
2928-
)
2929-
assert set(synthetic_data['guests_pii']['guest_email']).issuperset(
2930-
set(synthetic_data['rooms']['guest_email'])
2931-
)
2932-
synthesizer.validate(synthetic_data)
2958+
# Run
2959+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2960+
synthesizer.fit(data)
2961+
synthetic_data = synthesizer.sample(scale=1.0)
29332962

2963+
# Assert
2964+
for each_parent_id in synthetic_data['child']['parent_1_id'].tolist():
2965+
assert each_parent_id in set(synthetic_data['parent']['parent_id'])
2966+
for each_parent_id in synthetic_data['child']['parent_2_id'].tolist():
2967+
assert each_parent_id in set(synthetic_data['second_parent']['parent_id'])
2968+
synthesizer.validate(synthetic_data)
29342969

2935-
def test_1_to_1_to_1_diamond(data_metadata_1_to_1_subset_diamond):
2936-
"""Test PK to PK to PK in a diamond relationship."""
2937-
# Setup
2938-
data, metadata = data_metadata_1_to_1_subset_diamond
2970+
def test_multiple_fks_mismatched(self, data_metadata_multiple_foreign_keys_mismatched):
2971+
"""Test support for parent and child with multiple foreign keys."""
2972+
# Setup
2973+
data, metadata = data_metadata_multiple_foreign_keys_mismatched
29392974

2940-
# Run
2941-
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2942-
synthesizer.fit(data)
2943-
synthetic_data = synthesizer.sample(scale=1.0)
2975+
# Run
2976+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2977+
synthesizer.fit(data)
2978+
synthetic_data = synthesizer.sample(scale=1.0)
29442979

2945-
# Assert
2946-
assert set(synthetic_data['guests']['guest_email']).issuperset(
2947-
set(synthetic_data['guests_pii']['guest_email'])
2948-
)
2949-
assert set(synthetic_data['guests']['guest_email']).issuperset(
2950-
set(synthetic_data['rooms']['guest_email'])
2951-
)
2952-
synthesizer.validate(synthetic_data)
2980+
# Assert
2981+
for each_parent_id in synthetic_data['child']['parent_1_id'].tolist():
2982+
assert each_parent_id in set(synthetic_data['parent']['parent_id'])
2983+
for each_parent_id in synthetic_data['child']['parent_2_id'].tolist():
2984+
assert each_parent_id in set(synthetic_data['second_parent']['parent_id'])
2985+
synthesizer.validate(synthetic_data)

0 commit comments

Comments
 (0)