111111SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS
112112
113113
114- def _validate_inputs (output_filepath , detailed_results_folder , synthesizers , custom_synthesizers ):
114+ def _validate_output_filepath_and_detailed_results_folder (output_filepath , detailed_results_folder ):
115115 if output_filepath and os .path .exists (output_filepath ):
116116 raise ValueError (
117117 f'{ output_filepath } already exists. Please provide a file that does not already exist.'
@@ -123,15 +123,65 @@ def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, cus
123123 'Please provide a folder that does not already exist.'
124124 )
125125
126- duplicates = get_duplicates (synthesizers ) if synthesizers else set ()
127- if custom_synthesizers :
128- duplicates .update (get_duplicates (custom_synthesizers ))
129- if len (duplicates ) > 0 :
126+
127+ def _import_and_validate_synthesizers (synthesizers , custom_synthesizers , modality ):
128+ """Import user-provided synthesizer and validate modality and uniqueness.
129+
130+ This function takes lists of synthesizer, imports them as synthesizer classes,
131+ and validates two conditions:
132+ - Modality match – all synthesizers must match the expected `modality`.
133+ A `ValueError` is raised if any synthesizer has a different modality
134+ flag.
135+
136+ - Uniqueness – duplicate synthesizer across the two input lists
137+ (`synthesizers` and `custom_synthesizers`) are not allowed. A
138+ `ValueError` is raised if duplicates are found.
139+
140+ Args:
141+ synthesizers (list | None):
142+ A list of synthesizer strings or classes. May be ``None``, in which case it
143+ is treated as an empty list.
144+ custom_synthesizers (list | None):
145+ A list of custom synthesizer.
146+ modality (str):
147+ The required modality that all synthesizers must match.
148+
149+ Returns:
150+ list:
151+ A list of synthesizer classes.
152+
153+ Raises:
154+ ValueError:
155+ If any synthesizer does not match the expected modality.
156+ ValueError:
157+ If duplicate synthesizer are found across the provided lists.
158+ """
159+ # Get list of synthesizer objects
160+ synthesizers = synthesizers or []
161+ custom_synthesizers = custom_synthesizers or []
162+ resolved_synthesizers = get_synthesizers (synthesizers + custom_synthesizers )
163+ mismatched = [
164+ synth ['synthesizer' ]
165+ for synth in resolved_synthesizers
166+ if synth ['synthesizer' ]._MODALITY_FLAG != modality
167+ ]
168+ if mismatched :
169+ raise ValueError (
170+ f"Synthesizers must be of modality '{ modality } '. "
171+ "Found this synthesizers that don't match: "
172+ f'{ ", " .join ([type (synth ).__name__ for synth in mismatched ])} '
173+ )
174+
175+ # Check duplicate input values
176+ duplicates = get_duplicates (synthesizers + custom_synthesizers )
177+ if duplicates :
130178 raise ValueError (
131179 'Synthesizers must be unique. Please remove repeated values in the `synthesizers` '
132180 'and `custom_synthesizers` parameters.'
133181 )
134182
183+ return resolved_synthesizers
184+
135185
136186def _create_detailed_results_directory (detailed_results_folder ):
137187 if detailed_results_folder and not is_s3_path (detailed_results_folder ):
@@ -276,15 +326,9 @@ def _generate_job_args_list(
276326 compute_diagnostic_score ,
277327 compute_privacy_score ,
278328 synthesizers ,
279- custom_synthesizers ,
280329 s3_client ,
281330 modality ,
282331):
283- # Get list of synthesizer objects
284- synthesizers = [] if synthesizers is None else synthesizers
285- custom_synthesizers = [] if custom_synthesizers is None else custom_synthesizers
286- synthesizers = get_synthesizers (synthesizers + custom_synthesizers )
287-
288332 # Get list of dataset paths
289333 aws_access_key_id = os .getenv ('AWS_ACCESS_KEY_ID' )
290334 aws_secret_access_key_key = os .getenv ('AWS_SECRET_ACCESS_KEY' )
@@ -427,7 +471,7 @@ def _compute_scores(
427471 sdmetrics_metadata = metadata
428472
429473 if len (metrics ) > 0 :
430- metrics , metric_kwargs = get_metrics (metrics , modality = modality . replace ( '_' , '-' ) )
474+ metrics , metric_kwargs = get_metrics (metrics , modality = modality )
431475 scores = []
432476 output ['scores' ] = scores
433477 for metric_name , metric in metrics .items ():
@@ -1130,8 +1174,8 @@ def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=No
11301174 }
11311175
11321176 for synthesizer in synthesizers :
1133- if synthesizer not in SDV_SYNTHESIZERS :
1134- ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY .get (synthesizer )
1177+ if synthesizer [ 'name' ] not in SDV_SYNTHESIZERS :
1178+ ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY .get (synthesizer [ 'name' ] )
11351179 if ext_lib :
11361180 library_version = version (ext_lib )
11371181 metadata [f'{ ext_lib } _version' ] = library_version
@@ -1150,20 +1194,17 @@ def _update_metainfo_file(run_file, result_writer=None):
11501194 result_writer .write_yaml (update , run_file , append = True )
11511195
11521196
1153- def _ensure_uniform_included (synthesizers ):
1154- if UniformSynthesizer not in synthesizers and UniformSynthesizer .__name__ not in synthesizers :
1155- LOGGER .info ('Adding UniformSynthesizer to list of synthesizers.' )
1156- synthesizers .append ('UniformSynthesizer' )
1157-
1197+ def _ensure_uniform_included (synthesizers , modality ):
1198+ uniform_class = UniformSynthesizer
1199+ if modality == 'multi_table' :
1200+ uniform_class = MultiTableUniformSynthesizer
11581201
1159- def _ensure_multi_table_uniform_is_included (synthesizers ):
11601202 uniform_not_included = bool (
1161- MultiTableUniformSynthesizer not in synthesizers
1162- and MultiTableUniformSynthesizer .__name__ not in synthesizers
1203+ uniform_class not in synthesizers and uniform_class .__name__ not in synthesizers
11631204 )
11641205 if uniform_not_included :
1165- LOGGER .info ('Adding MultiTableUniformSynthesizer to the list of synthesizers.' )
1166- synthesizers .append ('MultiTableUniformSynthesizer' )
1206+ LOGGER .info (f 'Adding { uniform_class . __name__ } to the list of synthesizers.' )
1207+ synthesizers .append (uniform_class . __name__ )
11671208
11681209
11691210def _fill_adjusted_scores_with_none (scores ):
@@ -1331,7 +1372,7 @@ def benchmark_single_table(
13311372 if not synthesizers :
13321373 synthesizers = []
13331374
1334- _ensure_uniform_included (synthesizers )
1375+ _ensure_uniform_included (synthesizers , 'single_table' )
13351376 result_writer = LocalResultsWriter ()
13361377 if run_on_ec2 :
13371378 print ("This will create an instance for the current AWS user's account." ) # noqa
@@ -1343,21 +1384,25 @@ def benchmark_single_table(
13431384
13441385 return None
13451386
1346- _validate_inputs (output_filepath , detailed_results_folder , synthesizers , custom_synthesizers )
1347- _create_detailed_results_directory (detailed_results_folder )
1348- job_args_list = _generate_job_args_list (
1349- limit_dataset_size ,
1350- sdv_datasets ,
1351- additional_datasets_folder ,
1352- sdmetrics ,
1353- detailed_results_folder ,
1354- timeout ,
1355- output_destination ,
1356- compute_quality_score ,
1357- compute_diagnostic_score ,
1358- compute_privacy_score ,
1387+ _validate_output_filepath_and_detailed_results_folder (output_filepath , detailed_results_folder )
1388+ synthesizers = _import_and_validate_synthesizers (
13591389 synthesizers ,
13601390 custom_synthesizers ,
1391+ 'single_table' ,
1392+ )
1393+ _create_detailed_results_directory (detailed_results_folder )
1394+ job_args_list = _generate_job_args_list (
1395+ limit_dataset_size = limit_dataset_size ,
1396+ sdv_datasets = sdv_datasets ,
1397+ additional_datasets_folder = additional_datasets_folder ,
1398+ sdmetrics = sdmetrics ,
1399+ detailed_results_folder = detailed_results_folder ,
1400+ timeout = timeout ,
1401+ output_destination = output_destination ,
1402+ compute_quality_score = compute_quality_score ,
1403+ compute_diagnostic_score = compute_diagnostic_score ,
1404+ compute_privacy_score = compute_privacy_score ,
1405+ synthesizers = synthesizers ,
13611406 s3_client = None ,
13621407 modality = 'single_table' ,
13631408 )
@@ -1650,7 +1695,13 @@ def benchmark_single_table_aws(
16501695 if not synthesizers :
16511696 synthesizers = []
16521697
1653- _ensure_uniform_included (synthesizers )
1698+ _ensure_uniform_included (synthesizers , 'single_table' )
1699+ synthesizers = _import_and_validate_synthesizers (
1700+ synthesizers = synthesizers ,
1701+ custom_synthesizers = None ,
1702+ modality = 'single_table' ,
1703+ )
1704+
16541705 job_args_list = _generate_job_args_list (
16551706 limit_dataset_size = limit_dataset_size ,
16561707 sdv_datasets = sdv_datasets ,
@@ -1663,7 +1714,6 @@ def benchmark_single_table_aws(
16631714 compute_privacy_score = compute_privacy_score ,
16641715 synthesizers = synthesizers ,
16651716 detailed_results_folder = None ,
1666- custom_synthesizers = None ,
16671717 s3_client = s3_client ,
16681718 modality = 'single_table' ,
16691719 )
@@ -1745,20 +1795,19 @@ def benchmark_multi_table(
17451795
17461796 Returns:
17471797 pandas.DataFrame:
1748- A table containing one row per synthesizer + dataset + metric .
1798+ A table containing one row per synthesizer + dataset.
17491799 """
17501800 _validate_output_destination (output_destination )
17511801 if not synthesizers :
17521802 synthesizers = []
17531803
1754- _ensure_multi_table_uniform_is_included (synthesizers )
1804+ _ensure_uniform_included (synthesizers , 'multi_table' )
17551805 result_writer = LocalResultsWriter ()
17561806
1757- _validate_inputs (
1758- output_filepath = None ,
1759- detailed_results_folder = None ,
1760- synthesizers = synthesizers ,
1761- custom_synthesizers = custom_synthesizers ,
1807+ synthesizers = _import_and_validate_synthesizers (
1808+ synthesizers ,
1809+ custom_synthesizers ,
1810+ 'multi_table' ,
17621811 )
17631812 job_args_list = _generate_job_args_list (
17641813 limit_dataset_size = limit_dataset_size ,
@@ -1772,7 +1821,6 @@ def benchmark_multi_table(
17721821 compute_diagnostic_score = compute_diagnostic_score ,
17731822 compute_privacy_score = None ,
17741823 synthesizers = synthesizers ,
1775- custom_synthesizers = custom_synthesizers ,
17761824 s3_client = None ,
17771825 modality = 'multi_table' ,
17781826 )
0 commit comments