@@ -878,3 +878,49 @@ def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path):
878878 # Validate the stored results match returned results
879879 pd .testing .assert_frame_equal (result_1 , saved_result_1 , check_dtype = False )
880880 pd .testing .assert_frame_equal (result_2 , saved_result_2 , check_dtype = False )
881+
882+
883+ def test_benchmark_multi_table_basic_synthesizers ():
884+ """Integration test: run HMASynthesizer + MultiTableUniformSynthesizer on fake_hotels."""
885+ output = benchmark_multi_table (
886+ synthesizers = ['HMASynthesizer' , 'MultiTableUniformSynthesizer' ],
887+ sdv_datasets = ['fake_hotels' ],
888+ compute_quality_score = True ,
889+ compute_diagnostic_score = True ,
890+ limit_dataset_size = True ,
891+ show_progress = False ,
892+ timeout = 30 ,
893+ )
894+
895+ # Assert
896+ assert isinstance (output , pd .DataFrame )
897+ assert not output .empty
898+
899+ # Required SDGym benchmark output columns
900+ for col in [
901+ 'Synthesizer' ,
902+ 'Train_Time' ,
903+ 'Sample_Time' ,
904+ 'Quality_Score' ,
905+ 'Diagnostic_Score' ,
906+ ]:
907+ assert col in output .columns
908+
909+ synths = sorted (output ['Synthesizer' ].unique ())
910+ assert synths == [
911+ 'HMASynthesizer' ,
912+ ]
913+
914+ diagnostic_rank = (
915+ output .groupby ('Synthesizer' ).Diagnostic_Score .mean ().sort_values ().index .tolist ()
916+ )
917+
918+ assert diagnostic_rank == [
919+ 'HMASynthesizer' ,
920+ ]
921+
922+ quality_rank = output .groupby ('Synthesizer' ).Quality_Score .mean ().sort_values ().index .tolist ()
923+
924+ assert quality_rank == [
925+ 'HMASynthesizer' ,
926+ ]
0 commit comments