Skip to content

Commit c51dc98

Browse files
committed
Add integration test for benchmark multi table end to end
1 parent d986dfd commit c51dc98

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/integration/test_benchmark.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,49 @@ def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path):
878878
# Validate the stored results match returned results
879879
pd.testing.assert_frame_equal(result_1, saved_result_1, check_dtype=False)
880880
pd.testing.assert_frame_equal(result_2, saved_result_2, check_dtype=False)
881+
882+
883+
def test_benchmark_multi_table_basic_synthesizers():
884+
"""Integration test: run HMASynthesizer + MultiTableUniformSynthesizer on fake_hotels."""
885+
output = benchmark_multi_table(
886+
synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'],
887+
sdv_datasets=['fake_hotels'],
888+
compute_quality_score=True,
889+
compute_diagnostic_score=True,
890+
limit_dataset_size=True,
891+
show_progress=False,
892+
timeout=30,
893+
)
894+
895+
# Assert
896+
assert isinstance(output, pd.DataFrame)
897+
assert not output.empty
898+
899+
# Required SDGym benchmark output columns
900+
for col in [
901+
'Synthesizer',
902+
'Train_Time',
903+
'Sample_Time',
904+
'Quality_Score',
905+
'Diagnostic_Score',
906+
]:
907+
assert col in output.columns
908+
909+
synths = sorted(output['Synthesizer'].unique())
910+
assert synths == [
911+
'HMASynthesizer',
912+
]
913+
914+
diagnostic_rank = (
915+
output.groupby('Synthesizer').Diagnostic_Score.mean().sort_values().index.tolist()
916+
)
917+
918+
assert diagnostic_rank == [
919+
'HMASynthesizer',
920+
]
921+
922+
quality_rank = output.groupby('Synthesizer').Quality_Score.mean().sort_values().index.tolist()
923+
924+
assert quality_rank == [
925+
'HMASynthesizer',
926+
]

0 commit comments

Comments
 (0)