Skip to content

Commit af14a5d

Browse files
committed
add additional test
1 parent 684eace commit af14a5d

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

sdv/multi_table/hma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def _augment_tables(self, processed_data):
499499
parent_map = self.metadata._get_parent_map()
500500
self._print(text='Learning relationships:')
501501

502+
# Reset index for tables where foreign key is also a primary key
502503
for table_name in processed_data:
503504
foreign_keys = self.metadata._get_all_foreign_keys(table_name)
504505
primary_key = self.metadata.tables[table_name].primary_key

tests/integration/multi_table/test_hma.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2903,7 +2903,6 @@ def data_metadata_1_to_1():
29032903
metadata = Metadata.load_from_dict(metadata_dict)
29042904
metadata.validate()
29052905
metadata.validate_data(data)
2906-
# metadata.remove_primary_key('rooms')
29072906
return data, metadata
29082907

29092908

@@ -2920,5 +2919,60 @@ def test_hma_1_to_1(data_metadata_1_to_1):
29202919
assert synthetic_data['guests']['guest_email'].equals(synthetic_data['guests']['guest_email'])
29212920

29222921

2923-
def test_hma_1_to_1_or_0(data_metadata_1_to_1):
2924-
pass
2922+
def test_hma_1_to_1_or_0():
2923+
# Setup
2924+
data = {
2925+
'users': pd.DataFrame({
2926+
'user_id': range(10),
2927+
'date_joined': [
2928+
'2024-01-01',
2929+
'2024-02-01',
2930+
'2024-03-01',
2931+
'2024-04-01',
2932+
'2024-05-01',
2933+
]
2934+
* 2,
2935+
}),
2936+
'survey_response': pd.DataFrame({
2937+
'user_id': range(9),
2938+
'age': [11, 22, 33, 44, 55, 66, 77, 88, 99],
2939+
}),
2940+
}
2941+
metadata = Metadata.load_from_dict({
2942+
'tables': {
2943+
'users': {
2944+
'columns': {
2945+
'user_id': {'sdtype': 'id'},
2946+
'date_joined': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
2947+
},
2948+
'primary_key': 'user_id',
2949+
},
2950+
'survey_response': {
2951+
'columns': {
2952+
'user_id': {'sdtype': 'id'},
2953+
'age': {'sdtype': 'numerical'},
2954+
},
2955+
'primary_key': 'user_id',
2956+
},
2957+
},
2958+
'relationships': [
2959+
{
2960+
'parent_table_name': 'users',
2961+
'parent_primary_key': 'user_id',
2962+
'child_table_name': 'survey_response',
2963+
'child_foreign_key': 'user_id',
2964+
}
2965+
],
2966+
})
2967+
metadata.validate()
2968+
metadata.validate_data(data)
2969+
2970+
# Run
2971+
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
2972+
synthesizer.fit(data)
2973+
synthetic_data = synthesizer.sample(scale=1)
2974+
2975+
# Assert
2976+
assert set(synthetic_data['users']['user_id']).issuperset(
2977+
set(synthetic_data['survey_response']['user_id'])
2978+
)

0 commit comments

Comments
 (0)