diff --git a/Makefile b/Makefile index 30a13ece..dce4a909 100644 --- a/Makefile +++ b/Makefile @@ -93,7 +93,11 @@ fix-lint: # TEST TARGETS .PHONY: test-unit test-unit: ## run tests quickly with the default Python - python -m pytest --cov=sdgym + invoke unit + +.PHONY: test-integration +test-integration: ## run tests quickly with the default Python + invoke integration .PHONY: test-readme test-readme: ## run the readme snippets @@ -102,7 +106,7 @@ test-readme: ## run the readme snippets rm -rf tests/readme_test .PHONY: test -test: test-unit test-readme ## test everything that needs test dependencies +test: test-unit test-integration ## test everything that needs test dependencies .PHONY: test-devel test-devel: lint ## test everything that needs development dependencies diff --git a/sdgym/__init__.py b/sdgym/__init__.py index 38698105..4104c9f8 100644 --- a/sdgym/__init__.py +++ b/sdgym/__init__.py @@ -12,7 +12,11 @@ import logging -from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws +from sdgym.benchmark import ( + benchmark_multi_table, + benchmark_single_table, + benchmark_single_table_aws, +) from sdgym.cli.collect import collect_results from sdgym.cli.summary import make_summary_spreadsheet from sdgym.dataset_explorer import DatasetExplorer @@ -31,12 +35,13 @@ __all__ = [ 'DatasetExplorer', 'ResultsExplorer', + 'benchmark_multi_table', 'benchmark_single_table', 'benchmark_single_table_aws', 'collect_results', - 'create_synthesizer_variant', - 'create_single_table_synthesizer', 'create_multi_table_synthesizer', + 'create_single_table_synthesizer', + 'create_synthesizer_variant', 'load_dataset', 'make_summary_spreadsheet', ] diff --git a/sdgym/_dataset_utils.py b/sdgym/_dataset_utils.py index 3e52998e..c9de2e21 100644 --- a/sdgym/_dataset_utils.py +++ b/sdgym/_dataset_utils.py @@ -7,9 +7,14 @@ import numpy as np import pandas as pd +from sdv.metadata import Metadata +from sdv.utils import poc LOGGER = logging.getLogger(__name__) +MAX_NUM_COLUMNS = 10 +MAX_NUM_ROWS = 1000 + def _parse_numeric_value(value, dataset_name, field_name, target_type=float): """Generic parser for numeric values with logging and NaN fallback.""" @@ -23,6 +28,65 @@ def _parse_numeric_value(value, dataset_name, field_name, target_type=float): return np.nan +def _filter_columns(columns, mandatory_columns): + """Given a dictionary of columns and a list of mandatory ones, return a filtered subset.""" + mandatory_columns = [m_col for m_col in mandatory_columns if m_col in columns] + optional_columns = [col for col in columns if col not in mandatory_columns] + keep_columns = mandatory_columns + optional_columns[:MAX_NUM_COLUMNS] + return {col: columns[col] for col in keep_columns if col in columns} + + +def _get_multi_table_dataset_subset(data, metadata_dict): + """Create a smaller, referentially consistent subset of multi-table data. + + This function limits each table to at most 10 columns by keeping all + mandatory columns and, if needed, a subset of the remaining columns, then + trims the underlying DataFrames to match the updated metadata. Finally, it + uses SDV's multi-table utility to sample up to 1,000 rows from + the main table and a consistent subset of rows from all related tables + while preserving referential integrity. + + Args: + data (dict): + A dictionary where keys are table names and values are DataFrames + representing tables. + metadata_dict (dict): + Metadata dictionary containing schema information for each table. + + Returns: + tuple: + A tuple containing: + - dict: The subset of the input data with reduced columns and rows. + - dict: The updated metadata dictionary reflecting the reduced column sets. + """ + metadata = Metadata.load_from_dict(metadata_dict) + for table_name, table in metadata.tables.items(): + table_columns = table.columns + mandatory_columns = list(metadata._get_all_keys(table_name)) + subset_column_schema = _filter_columns( + columns=table_columns, mandatory_columns=mandatory_columns + ) + metadata_dict['tables'][table_name]['columns'] = subset_column_schema + + # Re-load the metadata object that will be used with the `SDV` utility function + metadata = Metadata.load_from_dict(metadata_dict) + largest_table_name = max(data, key=lambda table_name: len(data[table_name])) + + # Trim the data to contain only the subset of columns + for table_name, table in metadata.tables.items(): + data[table_name] = data[table_name][list(table.columns)] + + # Subsample the data mantaining the referential integrity + data = poc.get_random_subset( + data=data, + metadata=metadata, + main_table_name=largest_table_name, + num_rows=MAX_NUM_ROWS, + verbose=False, + ) + return data, metadata_dict + + def _get_dataset_subset(data, metadata_dict, modality): """Limit the size of a dataset for faster evaluation or testing. @@ -31,52 +95,37 @@ def _get_dataset_subset(data, metadata_dict, modality): columns—such as sequence indices and keys in sequential datasets—are always retained. Args: - data (pd.DataFrame): + data (pd.DataFrame or dict): The dataset to be reduced. metadata_dict (dict): - A dictionary containing the dataset's metadata. + A dictionary representing the dataset's metadata. modality (str): - The dataset modality. Must be one of: ``'single_table'``, ``'sequential'``. + The dataset modality. Returns: tuple[pd.DataFrame, dict]: A tuple containing: - - The reduced dataset as a DataFrame. + - The reduced dataset as a DataFrame or Dictionary. - The updated metadata dictionary reflecting any removed columns. - - Raises: - ValueError: - If the provided modality is ``'multi_table'``. """ if modality == 'multi_table': - raise ValueError('limit_dataset_size is not supported for multi-table datasets.') + return _get_multi_table_dataset_subset(data, metadata_dict) - max_rows, max_columns = (1000, 10) tables = metadata_dict.get('tables', {}) mandatory_columns = [] table_name, table_info = next(iter(tables.items())) - columns = table_info.get('columns', {}) - keep_columns = list(columns) - if modality == 'sequential': - seq_index = table_info.get('sequence_index') - seq_key = table_info.get('sequence_key') - mandatory_columns = [col for col in (seq_index, seq_key) if col] - optional_columns = [col for col in columns if col not in mandatory_columns] + seq_index = table_info.get('sequence_index') + seq_key = table_info.get('sequence_key') + mandatory_columns = [column for column in (seq_index, seq_key) if column] + filtered = _filter_columns(columns=columns, mandatory_columns=mandatory_columns) - # If we have too many columns, drop extras but never mandatory ones - if len(columns) > max_columns: - keep_count = max_columns - len(mandatory_columns) - keep_columns = mandatory_columns + optional_columns[:keep_count] - table_info['columns'] = { - column_name: column_definition - for column_name, column_definition in columns.items() - if column_name in keep_columns - } - - data = data[list(keep_columns)] + table_info['columns'] = filtered + data = data[list(filtered)] + max_rows = min(MAX_NUM_ROWS, len(data)) data = data.sample(max_rows) + return data, metadata_dict diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index c07a684a..aa68b6d9 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -51,7 +51,7 @@ write_csv, write_file, ) -from sdgym.synthesizers import UniformSynthesizer +from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer from sdgym.synthesizers.base import BaselineSynthesizer from sdgym.utils import ( calculate_score_time, @@ -66,8 +66,13 @@ ) LOGGER = logging.getLogger(__name__) -DEFAULT_SYNTHESIZERS = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'UniformSynthesizer'] -DEFAULT_DATASETS = [ +DEFAULT_SINGLE_TABLE_SYNTHESIZERS = [ + 'GaussianCopulaSynthesizer', + 'CTGANSynthesizer', + 'UniformSynthesizer', +] +DEFAULT_MULTI_TABLE_SYNTHESIZERS = ['MultiTableUniformSynthesizer', 'HMASynthesizer'] +DEFAULT_SINGLE_TABLE_DATASETS = [ 'adult', 'alarm', 'census', @@ -78,6 +83,16 @@ 'intrusion', 'news', ] +DEFAULT_MULTI_TABLE_DATASETS = [ + 'NBA', + 'financial', + 'Student_loan', + 'Biodegradability', + 'fake_hotels', + 'restbase', + 'airbnb-simplified', +] + N_BYTES_IN_MB = 1000 * 1000 EXTERNAL_SYNTHESIZER_TO_LIBRARY = { 'RealTabFormerSynthesizer': 'realtabformer', @@ -91,9 +106,12 @@ 'CopulaGANSynthesizer', 'TVAESynthesizer', ] +SDV_MULTI_TABLE_SYNTHESIZERS = ['HMASynthesizer'] + +SDV_SYNTHESIZERS = SDV_SINGLE_TABLE_SYNTHESIZERS + SDV_MULTI_TABLE_SYNTHESIZERS -def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers): +def _validate_output_filepath_and_detailed_results_folder(output_filepath, detailed_results_folder): if output_filepath and os.path.exists(output_filepath): raise ValueError( f'{output_filepath} already exists. Please provide a file that does not already exist.' @@ -105,15 +123,65 @@ def _validate_inputs(output_filepath, detailed_results_folder, synthesizers, cus 'Please provide a folder that does not already exist.' ) - duplicates = get_duplicates(synthesizers) if synthesizers else {} - if custom_synthesizers: - duplicates.update(get_duplicates(custom_synthesizers)) - if len(duplicates) > 0: + +def _import_and_validate_synthesizers(synthesizers, custom_synthesizers, modality): + """Import user-provided synthesizer and validate modality and uniqueness. + + This function takes lists of synthesizer, imports them as synthesizer classes, + and validates two conditions: + - Modality match – all synthesizers must match the expected `modality`. + A `ValueError` is raised if any synthesizer has a different modality + flag. + + - Uniqueness – duplicate synthesizer across the two input lists + (`synthesizers` and `custom_synthesizers`) are not allowed. A + `ValueError` is raised if duplicates are found. + + Args: + synthesizers (list | None): + A list of synthesizer strings or classes. May be ``None``, in which case it + is treated as an empty list. + custom_synthesizers (list | None): + A list of custom synthesizer. + modality (str): + The required modality that all synthesizers must match. + + Returns: + list: + A list of synthesizer classes. + + Raises: + ValueError: + If any synthesizer does not match the expected modality. + ValueError: + If duplicate synthesizer are found across the provided lists. + """ + # Get list of synthesizer objects + synthesizers = synthesizers or [] + custom_synthesizers = custom_synthesizers or [] + resolved_synthesizers = get_synthesizers(synthesizers + custom_synthesizers) + mismatched = [ + synth['synthesizer'] + for synth in resolved_synthesizers + if synth['synthesizer']._MODALITY_FLAG != modality + ] + if mismatched: + raise ValueError( + f"Synthesizers must be of modality '{modality}'. " + "Found these synthesizers that don't match: " + f'{", ".join([type(synth).__name__ for synth in mismatched])}' + ) + + # Check duplicate input values + duplicates = get_duplicates(synthesizers + custom_synthesizers) + if duplicates: raise ValueError( - 'Synthesizers must be unique. Please remove repeated values in the `synthesizers` ' - 'and `custom_synthesizers` parameters.' + 'Synthesizers must be unique. Please remove repeated values in the provided ' + 'synthesizers.' ) + return resolved_synthesizers + def _create_detailed_results_directory(detailed_results_folder): if detailed_results_folder and not is_s3_path(detailed_results_folder): @@ -135,6 +203,7 @@ def _get_metainfo_increment(top_folder, s3_client=None): if match: # Extract numeric suffix (e.g. metainfo(3).yaml → 3) or 0 if plain metainfo.yaml increments.append(int(match.group(1)) if match.group(1) else 0) + except Exception: LOGGER.info(first_file_message) return 0 # start with (0) if error @@ -187,7 +256,13 @@ def _setup_output_destination_aws(output_destination, synthesizers, datasets, s3 return paths -def _setup_output_destination(output_destination, synthesizers, datasets, s3_client=None): +def _setup_output_destination( + output_destination, + synthesizers, + datasets, + modality, + s3_client=None, +): """Set up the output destination for the benchmark results. Args: @@ -198,6 +273,10 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli The list of synthesizers to benchmark. datasets (list): The list of datasets to benchmark. + modality (str): + The dataset modality to load (e.g., 'single_table' or 'multi_table'). + s3_client (boto3.session.Session.client or None): + The s3 client that can be used to read / write to s3. Defaults to ``None``. """ if s3_client: return _setup_output_destination_aws(output_destination, synthesizers, datasets, s3_client) @@ -208,11 +287,12 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli output_path = Path(output_destination) output_path.mkdir(parents=True, exist_ok=True) today = datetime.today().strftime('%m_%d_%Y') - top_folder = output_path / f'SDGym_results_{today}' + top_folder = output_path / modality / f'SDGym_results_{today}' top_folder.mkdir(parents=True, exist_ok=True) increment = _get_metainfo_increment(top_folder) suffix = f'({increment})' if increment >= 1 else '' paths = defaultdict(dict) + synthetic_data_extension = 'zip' if modality == 'multi_table' else 'csv' for dataset in datasets: dataset_folder = top_folder / f'{dataset}_{today}' dataset_folder.mkdir(parents=True, exist_ok=True) @@ -223,7 +303,9 @@ def _setup_output_destination(output_destination, synthesizers, datasets, s3_cli synth_folder.mkdir(parents=True, exist_ok=True) paths[dataset][final_synth_name] = { 'synthesizer': str(synth_folder / f'{final_synth_name}.pkl'), - 'synthetic_data': str(synth_folder / f'{final_synth_name}_synthetic_data.csv'), + 'synthetic_data': str( + synth_folder / f'{final_synth_name}_synthetic_data.{synthetic_data_extension}' + ), 'benchmark_result': str(synth_folder / f'{final_synth_name}_benchmark_result.csv'), 'metainfo': str(top_folder / f'metainfo{suffix}.yaml'), 'results': str(top_folder / f'results{suffix}.csv'), @@ -244,14 +326,9 @@ def _generate_job_args_list( compute_diagnostic_score, compute_privacy_score, synthesizers, - custom_synthesizers, s3_client, + modality, ): - # Get list of synthesizer objects - synthesizers = [] if synthesizers is None else synthesizers - custom_synthesizers = [] if custom_synthesizers is None else custom_synthesizers - synthesizers = get_synthesizers(synthesizers + custom_synthesizers) - # Get list of dataset paths aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') aws_secret_access_key_key = os.getenv('AWS_SECRET_ACCESS_KEY') @@ -259,7 +336,7 @@ def _generate_job_args_list( [] if sdv_datasets is None else get_dataset_paths( - modality='single_table', + modality=modality, datasets=sdv_datasets, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key_key, @@ -269,11 +346,11 @@ def _generate_job_args_list( [] if additional_datasets_folder is None else get_dataset_paths( - modality='single_table', + modality=modality, bucket=( additional_datasets_folder if is_s3_path(additional_datasets_folder) - else os.path.join(additional_datasets_folder, 'single_table') + else os.path.join(additional_datasets_folder, modality) ), aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key_key, @@ -283,7 +360,7 @@ def _generate_job_args_list( synthesizer_names = [synthesizer['name'] for synthesizer in synthesizers] dataset_names = [dataset.name for dataset in datasets] paths = _setup_output_destination( - output_destination, synthesizer_names, dataset_names, s3_client=s3_client + output_destination, synthesizer_names, dataset_names, modality=modality, s3_client=s3_client ) job_tuples = [] for dataset in datasets: @@ -301,9 +378,7 @@ def _generate_job_args_list( job_args_list = [] for synthesizer, dataset in job_tuples: - data, metadata_dict = load_dataset( - 'single_table', dataset, limit_dataset_size=limit_dataset_size - ) + data, metadata_dict = load_dataset(modality, dataset, limit_dataset_size=limit_dataset_size) path = paths.get(dataset.name, {}).get(synthesizer['name'], None) args = ( synthesizer, @@ -316,7 +391,7 @@ def _generate_job_args_list( compute_diagnostic_score, compute_privacy_score, dataset.name, - 'single_table', + modality, path, ) job_args_list.append(args) @@ -324,7 +399,14 @@ def _generate_job_args_list( return job_args_list -def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, result_writer=None): +def _synthesize( + synthesizer_dict, + real_data, + metadata, + synthesizer_path=None, + result_writer=None, + modality=None, +): synthesizer = synthesizer_dict['synthesizer'] if isinstance(synthesizer, type): assert issubclass(synthesizer, BaselineSynthesizer), ( @@ -339,7 +421,6 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re get_synthesizer = synthesizer.get_trained_synthesizer sample_from_synthesizer = synthesizer.sample_from_synthesizer data = real_data.copy() - num_samples = len(data) tracemalloc.start() fitted_synthesizer = None @@ -353,17 +434,28 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re synthesizer_size = len(cloudpickle.dumps(fitted_synthesizer)) / N_BYTES_IN_MB train_end = get_utc_now() train_time = train_end - start - synthetic_data = sample_from_synthesizer(fitted_synthesizer, num_samples) + + if modality == 'multi_table': + synthetic_data = sample_from_synthesizer(fitted_synthesizer, 1.0) + else: + synthetic_data = sample_from_synthesizer(fitted_synthesizer, n_samples=len(data)) + sample_end = get_utc_now() sample_time = sample_end - train_end peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB if synthesizer_path is not None and result_writer is not None: - result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data']) internal_synthesizer = getattr( fitted_synthesizer, '_internal_synthesizer', fitted_synthesizer ) result_writer.write_pickle(internal_synthesizer, synthesizer_path['synthesizer']) + if modality == 'multi_table': + result_writer.write_zipped_dataframes( + synthetic_data, synthesizer_path['synthetic_data'] + ) + + else: + result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data']) return synthetic_data, train_time, sample_time, synthesizer_size, peak_memory @@ -406,9 +498,13 @@ def _compute_scores( dataset_name, ): metrics = metrics or [] - sdmetrics_metadata = convert_metadata_to_sdmetrics(metadata) + if modality == 'single_table': + sdmetrics_metadata = convert_metadata_to_sdmetrics(metadata) + else: + sdmetrics_metadata = metadata + if len(metrics) > 0: - metrics, metric_kwargs = get_metrics(metrics, modality='single-table') + metrics, metric_kwargs = get_metrics(metrics, modality=modality) scores = [] output['scores'] = scores for metric_name, metric in metrics.items(): @@ -519,11 +615,12 @@ def _score( try: synthetic_data, train_time, sample_time, synthesizer_size, peak_memory = _synthesize( - synthesizer, - data.copy(), - metadata, + synthesizer_dict=synthesizer, + real_data=data.copy(), + metadata=metadata, synthesizer_path=synthesizer_path, result_writer=result_writer, + modality=modality, ) output['synthetic_data'] = synthetic_data @@ -849,6 +946,7 @@ def _run_jobs(multi_processing_config, job_args_list, show_progress, result_writ job_args_list = [job_args + (result_writer,) for job_args in job_args_list] if workers in (0, 1): scores = map(_run_job, job_args_list) + elif workers != 'dask': pool = concurrent.futures.ProcessPoolExecutor(workers) scores = pool.map(_run_job, job_args_list) @@ -1101,7 +1199,7 @@ def _validate_output_destination(output_destination, aws_keys=None): ) -def _write_metainfo_file(synthesizers, job_args_list, result_writer=None): +def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=None): jobs = [[job[-3], job[0]['name']] for job in job_args_list] if not job_args_list or not job_args_list[0][-1]: return @@ -1118,17 +1216,20 @@ def _write_metainfo_file(synthesizers, job_args_list, result_writer=None): date_str = date_match.group(1) metadata = { 'run_id': f'run_{date_str}_{increment}', + 'modality': modality, 'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'), 'completed_date': None, 'sdgym_version': version('sdgym'), 'jobs': jobs, } + for synthesizer in synthesizers: - if synthesizer not in SDV_SINGLE_TABLE_SYNTHESIZERS: - ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer) + if synthesizer['name'] not in SDV_SYNTHESIZERS: + ext_lib = EXTERNAL_SYNTHESIZER_TO_LIBRARY.get(synthesizer['name']) if ext_lib: library_version = version(ext_lib) metadata[f'{ext_lib}_version'] = library_version + elif 'sdv' not in metadata.keys(): metadata['sdv_version'] = version('sdv') @@ -1143,10 +1244,17 @@ def _update_metainfo_file(run_file, result_writer=None): result_writer.write_yaml(update, run_file, append=True) -def _ensure_uniform_included(synthesizers): - if UniformSynthesizer not in synthesizers and UniformSynthesizer.__name__ not in synthesizers: - LOGGER.info('Adding UniformSynthesizer to list of synthesizers.') - synthesizers.append('UniformSynthesizer') +def _ensure_uniform_included(synthesizers, modality): + uniform_class = UniformSynthesizer + if modality == 'multi_table': + uniform_class = MultiTableUniformSynthesizer + + uniform_not_included = bool( + uniform_class not in synthesizers and uniform_class.__name__ not in synthesizers + ) + if uniform_not_included: + LOGGER.info(f'Adding {uniform_class.__name__} to the list of synthesizers.') + synthesizers.append(uniform_class.__name__) def _fill_adjusted_scores_with_none(scores): @@ -1224,9 +1332,9 @@ def _add_adjusted_scores(scores, timeout): def benchmark_single_table( - synthesizers=DEFAULT_SYNTHESIZERS, + synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS, custom_synthesizers=None, - sdv_datasets=DEFAULT_DATASETS, + sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS, additional_datasets_folder=None, limit_dataset_size=False, compute_quality_score=True, @@ -1290,7 +1398,7 @@ def benchmark_single_table( SDGym_results_/ results.csv _/ - meta.yaml + metainfo.yaml / synthesizer.pkl synthetic_data.csv @@ -1331,7 +1439,7 @@ def benchmark_single_table( if not synthesizers: synthesizers = [] - _ensure_uniform_included(synthesizers) + _ensure_uniform_included(synthesizers, 'single_table') result_writer = LocalResultsWriter() if run_on_ec2: print("This will create an instance for the current AWS user's account.") # noqa @@ -1340,27 +1448,38 @@ def benchmark_single_table( _create_instance_on_ec2(script_content) else: raise ValueError('In order to run on EC2, please provide an S3 folder output.') + return None - _validate_inputs(output_filepath, detailed_results_folder, synthesizers, custom_synthesizers) - _create_detailed_results_directory(detailed_results_folder) - job_args_list = _generate_job_args_list( - limit_dataset_size, - sdv_datasets, - additional_datasets_folder, - sdmetrics, - detailed_results_folder, - timeout, - output_destination, - compute_quality_score, - compute_diagnostic_score, - compute_privacy_score, + _validate_output_filepath_and_detailed_results_folder(output_filepath, detailed_results_folder) + synthesizers = _import_and_validate_synthesizers( synthesizers, custom_synthesizers, + 'single_table', + ) + _create_detailed_results_directory(detailed_results_folder) + job_args_list = _generate_job_args_list( + limit_dataset_size=limit_dataset_size, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + sdmetrics=sdmetrics, + detailed_results_folder=detailed_results_folder, + timeout=timeout, + output_destination=output_destination, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=compute_privacy_score, + synthesizers=synthesizers, s3_client=None, + modality='single_table', ) - _write_metainfo_file(synthesizers, job_args_list, result_writer) + _write_metainfo_file( + synthesizers=synthesizers, + job_args_list=job_args_list, + modality='single_table', + result_writer=result_writer, + ) if job_args_list: scores = _run_jobs(multi_processing_config, job_args_list, show_progress, result_writer) @@ -1465,7 +1584,7 @@ def _get_s3_script_content( response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}') job_args_list = cloudpickle.loads(response['Body'].read()) result_writer = S3ResultsWriter(s3_client=s3_client) -_write_metainfo_file({synthesizers}, job_args_list, result_writer) +_write_metainfo_file({synthesizers}, job_args_list, 'single_table', result_writer) scores = _run_jobs(None, job_args_list, False, result_writer=result_writer) metainfo_filename = job_args_list[0][-1]['metainfo'] _update_metainfo_file(metainfo_filename, result_writer) @@ -1576,8 +1695,8 @@ def benchmark_single_table_aws( output_destination, aws_access_key_id=None, aws_secret_access_key=None, - synthesizers=DEFAULT_SYNTHESIZERS, - sdv_datasets=DEFAULT_DATASETS, + synthesizers=DEFAULT_SINGLE_TABLE_SYNTHESIZERS, + sdv_datasets=DEFAULT_SINGLE_TABLE_DATASETS, additional_datasets_folder=None, limit_dataset_size=False, compute_quality_score=True, @@ -1643,7 +1762,13 @@ def benchmark_single_table_aws( if not synthesizers: synthesizers = [] - _ensure_uniform_included(synthesizers) + _ensure_uniform_included(synthesizers, 'single_table') + synthesizers = _import_and_validate_synthesizers( + synthesizers=synthesizers, + custom_synthesizers=None, + modality='single_table', + ) + job_args_list = _generate_job_args_list( limit_dataset_size=limit_dataset_size, sdv_datasets=sdv_datasets, @@ -1656,8 +1781,8 @@ def benchmark_single_table_aws( compute_privacy_score=compute_privacy_score, synthesizers=synthesizers, detailed_results_folder=None, - custom_synthesizers=None, s3_client=s3_client, + modality='single_table', ) if not job_args_list: return _get_empty_dataframe( @@ -1675,3 +1800,120 @@ def benchmark_single_table_aws( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) + + +def benchmark_multi_table( + synthesizers=DEFAULT_MULTI_TABLE_SYNTHESIZERS, + custom_synthesizers=None, + sdv_datasets=DEFAULT_MULTI_TABLE_DATASETS, + additional_datasets_folder=None, + limit_dataset_size=False, + compute_quality_score=True, + compute_diagnostic_score=True, + timeout=None, + output_destination=None, + show_progress=False, +): + """Run the SDGym benchmark on single-table datasets. + + Args: + synthesizers (list[string]): + The synthesizer(s) to evaluate. Defaults to ``HMASynthesizer`` and + ``MultiTableUniformSynthesizer``. + custom_synthesizers (list[class] or ``None``): + A list of custom synthesizer classes to use. These can be completely custom or + they can be synthesizer variants (the output from ``create_single_table_synthesizer`` + or ``create_synthesizer_variant``). Defaults to ``None``. + sdv_datasets (list[str] or ``None``): + Names of the SDV demo datasets to use for the benchmark. Defaults to + ``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news, + covtype]``. Use ``None`` to disable using any sdv datasets. + additional_datasets_folder (str or ``None``): + The path to a folder (local or an S3 bucket). Datasets found in this folder are + run in addition to the SDV datasets. If ``None``, no additional datasets are used. + limit_dataset_size (bool): + Use this flag to limit the size of the datasets for faster evaluation. If ``True``, + limit the size of every table to 1,000 rows (randomly sampled) and the first 10 + columns. + compute_quality_score (bool): + Whether or not to evaluate an overall quality score. Defaults to ``True``. + compute_diagnostic_score (bool): + Whether or not to evaluate an overall diagnostic score. Defaults to ``True``. + timeout (int or ``None``): + The maximum number of seconds to wait for synthetic data creation. If ``None``, no + timeout is enforced. + output_destination (str or ``None``): + The path to the output directory where results will be saved. If ``None``, no + output is saved. The results are saved with the following structure: + output_destination/ + run_.yaml + SDGym_results_/ + results.csv + _/ + metainfo.yaml + / + synthesizer.pkl + synthetic_data.csv + show_progress (bool): + Whether to use tqdm to keep track of the progress. Defaults to ``False``. + + Returns: + pandas.DataFrame: + A table containing one row per synthesizer + dataset. + """ + _validate_output_destination(output_destination) + if not synthesizers: + synthesizers = [] + + _ensure_uniform_included(synthesizers, 'multi_table') + result_writer = LocalResultsWriter() + + synthesizers = _import_and_validate_synthesizers( + synthesizers, + custom_synthesizers, + 'multi_table', + ) + job_args_list = _generate_job_args_list( + limit_dataset_size=limit_dataset_size, + sdv_datasets=sdv_datasets, + additional_datasets_folder=additional_datasets_folder, + sdmetrics=None, + detailed_results_folder=None, + timeout=timeout, + output_destination=output_destination, + compute_quality_score=compute_quality_score, + compute_diagnostic_score=compute_diagnostic_score, + compute_privacy_score=None, + synthesizers=synthesizers, + s3_client=None, + modality='multi_table', + ) + + _write_metainfo_file( + synthesizers=synthesizers, + job_args_list=job_args_list, + modality='multi_table', + result_writer=result_writer, + ) + if job_args_list: + scores = _run_jobs( + multi_processing_config=None, + job_args_list=job_args_list, + show_progress=show_progress, + result_writer=result_writer, + ) + + # If no synthesizers/datasets are passed, return an empty dataframe + else: + scores = _get_empty_dataframe( + compute_diagnostic_score=compute_diagnostic_score, + compute_quality_score=compute_quality_score, + compute_privacy_score=None, + sdmetrics=None, + ) + + if output_destination and job_args_list: + metainfo_filename = job_args_list[0][-1]['metainfo'] + _update_metainfo_file(metainfo_filename, result_writer) + + return scores diff --git a/sdgym/dataset_explorer.py b/sdgym/dataset_explorer.py index 4d7f7410..574fc88d 100644 --- a/sdgym/dataset_explorer.py +++ b/sdgym/dataset_explorer.py @@ -205,7 +205,7 @@ def _load_and_summarize_datasets(self, modality): Args: modality (str): - The dataset modality to load (e.g., 'single-table' or 'multi-table'). + The dataset modality to load (e.g., 'single_table' or 'multi_table'). Returns: list[dict]: diff --git a/sdgym/metrics.py b/sdgym/metrics.py index f3a0e9ed..74f53d64 100644 --- a/sdgym/metrics.py +++ b/sdgym/metrics.py @@ -80,15 +80,15 @@ def normalize(self, raw_score): ], } DATA_MODALITY_METRICS = { - 'single-table': [ + 'single_table': [ 'CSTest', 'KSComplement', ], - 'multi-table': [ + 'multi_table': [ 'CSTest', 'KSComplement', ], - 'timeseries': [ + 'sequential': [ 'TSFClassifierEfficacy', 'LSTMClassifierEfficacy', 'TSFCDetection', @@ -104,17 +104,17 @@ def get_metrics(metrics, modality): metrics (list): List of strings or tuples ``(metric, metric_args)`` describing the metrics. modality (str): - It must be ``'single-table'``, ``'multi-table'`` or ``'timeseries'``. + It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``. Returns: list, kwargs: A list of metrics for the given modality, and their corresponding kwargs. """ - if modality == 'multi-table': + if modality == 'multi_table': metric_classes = sdmetrics.multi_table.MultiTableMetric.get_subclasses() - elif modality == 'single-table': + elif modality == 'single_table': metric_classes = sdmetrics.single_table.SingleTableMetric.get_subclasses() - elif modality == 'timeseries': + elif modality == 'sequential': metric_classes = sdmetrics.timeseries.TimeSeriesMetric.get_subclasses() if not metrics: diff --git a/sdgym/result_explorer/result_explorer.py b/sdgym/result_explorer/result_explorer.py index c7c18833..d3fdc7b7 100644 --- a/sdgym/result_explorer/result_explorer.py +++ b/sdgym/result_explorer/result_explorer.py @@ -2,7 +2,7 @@ import os -from sdgym.benchmark import DEFAULT_DATASETS +from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS from sdgym.datasets import load_dataset from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler from sdgym.s3 import _get_s3_client, is_s3_path @@ -62,7 +62,7 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam def load_real_data(self, dataset_name): """Load the real data for a given dataset.""" - if dataset_name not in DEFAULT_DATASETS: + if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS: raise ValueError( f"Dataset '{dataset_name}' is not a SDGym dataset. " 'Please provide a valid dataset name.' diff --git a/sdgym/result_writer.py b/sdgym/result_writer.py index 718871d7..f4d025d9 100644 --- a/sdgym/result_writer.py +++ b/sdgym/result_writer.py @@ -1,6 +1,7 @@ """Results writer for SDGym benchmark.""" import io +import zipfile from abc import ABC, abstractmethod from pathlib import Path @@ -35,6 +36,14 @@ def write_yaml(self, data, file_path, append=False): class LocalResultsWriter: """Local results writer for saving results to the local filesystem.""" + def write_zipped_dataframes(self, data, file_path, index=False): + """Write a dictoinary of dataframes to a ZIP file.""" + with zipfile.ZipFile(file_path, mode='w', compression=zipfile.ZIP_DEFLATED) as zf: + for table_name, table in data.items(): + buf = io.StringIO() + table.to_csv(buf, index=index) + zf.writestr(f'{table_name}.csv', buf.getvalue()) + def write_dataframe(self, data, file_path, append=False, index=False): """Write a DataFrame to a CSV file.""" file_path = Path(file_path) diff --git a/sdgym/synthesizers/base.py b/sdgym/synthesizers/base.py index 0aae184f..e9d7950f 100644 --- a/sdgym/synthesizers/base.py +++ b/sdgym/synthesizers/base.py @@ -148,11 +148,10 @@ def sample_from_synthesizer(self, synthesizer, scale=1.0): synthesizer (obj): The synthesizer object to sample data from. scale (float): - The scale of data to sample. - Defaults to 1.0. + The scale of data to sample. Defaults to 1.0. Returns: dict: The sampled data. A dict mapping table name to DataFrame. """ - return self._sample_from_synthesizer(synthesizer, scale=scale) + return self._sample_from_synthesizer(synthesizer, scale) diff --git a/sdgym/synthesizers/sdv.py b/sdgym/synthesizers/sdv.py index 9fd9418e..d6f90b1f 100644 --- a/sdgym/synthesizers/sdv.py +++ b/sdgym/synthesizers/sdv.py @@ -19,6 +19,8 @@ 'multi_table': multi_table, } +MODEL_KWARGS = {'HMASynthesizer': {'verbose': False}} + def _get_sdv_synthesizers(modality): _validate_modality(modality) @@ -81,6 +83,7 @@ def _create_sdv_class(sdv_name): current_module = sys.modules[__name__] modality = _get_modality(sdv_name) base_class = MultiTableBaselineSynthesizer if modality == 'multi_table' else BaselineSynthesizer + model_kwargs = MODEL_KWARGS.get(sdv_name, {}) synthesizer_class = type( sdv_name, (base_class,), @@ -88,7 +91,7 @@ def _create_sdv_class(sdv_name): '__module__': __name__, 'SDV_NAME': sdv_name, '_MODALITY_FLAG': modality, - '_MODEL_KWARGS': {}, + '_MODEL_KWARGS': model_kwargs, '_fit': _fit, '_sample_from_synthesizer': _sample_from_synthesizer, }, diff --git a/sdgym/utils.py b/sdgym/utils.py index b6ff1b47..2ff2f9f4 100644 --- a/sdgym/utils.py +++ b/sdgym/utils.py @@ -74,6 +74,7 @@ def get_synthesizers(synthesizers): synthesizer_name = getattr(synthesizer, '__name__', 'undefined') else: synthesizer_name = getattr(type(synthesizer), '__name__', 'undefined') + synthesizers_dicts.append({ 'name': synthesizer_name, 'synthesizer': synthesizer, diff --git a/tests/integration/result_explorer/test_result_explorer.py b/tests/integration/result_explorer/test_result_explorer.py index 6f148fc1..4a9e3493 100644 --- a/tests/integration/result_explorer/test_result_explorer.py +++ b/tests/integration/result_explorer/test_result_explorer.py @@ -10,16 +10,17 @@ def test_end_to_end_local(tmp_path): """Test the ResultsExplorer end-to-end with local paths.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' + result_explorer_path = output_destination / 'single_table' benchmark_single_table( - output_destination=output_destination, + output_destination=str(output_destination), synthesizers=['GaussianCopulaSynthesizer', 'TVAESynthesizer'], sdv_datasets=['expedia_hotel_logs', 'fake_companies'], ) today = time.strftime('%m_%d_%Y') # Run - result_explorer = ResultsExplorer(output_destination) + result_explorer = ResultsExplorer(str(result_explorer_path)) runs = result_explorer.list() results = result_explorer.load_results(runs[0]) metainfo = result_explorer.load_metainfo(runs[0]) @@ -42,7 +43,7 @@ def test_end_to_end_local(tmp_path): new_synthetic_data = synthesizer.sample(num_rows=10) # Assert - expected_results = pd.read_csv(f'{output_destination}/SDGym_results_{today}/results.csv') + expected_results = pd.read_csv(f'{result_explorer_path}/SDGym_results_{today}/results.csv') pd.testing.assert_frame_equal(results, expected_results) assert metainfo[f'run_{today}_0']['jobs'] == [ ['expedia_hotel_logs', 'GaussianCopulaSynthesizer'], @@ -63,7 +64,7 @@ def test_end_to_end_local(tmp_path): def test_summarize(): """Test the `summarize` method.""" # Setup - output_destination = 'tests/integration/result_explorer/_benchmark_results' + output_destination = 'tests/integration/result_explorer/_benchmark_results/' result_explorer = ResultsExplorer(output_destination) # Run diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py index b25a41a0..032a5fb9 100644 --- a/tests/integration/test_benchmark.py +++ b/tests/integration/test_benchmark.py @@ -18,6 +18,7 @@ import sdgym from sdgym import ( + benchmark_multi_table, benchmark_single_table, create_single_table_synthesizer, create_synthesizer_variant, @@ -201,8 +202,7 @@ def test_benchmark_single_table_duplicate_synthesizers(): # Run and Assert error_msg = re.escape( - 'Synthesizers must be unique. Please remove repeated values in the `synthesizers` ' - 'and `custom_synthesizers` parameters.' + 'Synthesizers must be unique. Please remove repeated values in the provided synthesizers.' ) with pytest.raises(ValueError, match=error_msg): sdgym.benchmark_single_table( @@ -508,7 +508,7 @@ def test_benchmark_single_table_no_synthesizers_with_parameters(): .all() ) assert result['Evaluate_Time'] is None - assert result['error'] == 'ValueError: Unknown single-table metric: a' + assert result['error'] == 'ValueError: Unknown single_table metric: a' def test_benchmark_single_table_custom_synthesizer(): @@ -628,73 +628,82 @@ def sample_from_synthesizer(synthesizer, n_samples): def test_benchmark_single_table_no_warnings(): """Test that the benchmark does not raise any FutureWarnings.""" # Run - with warnings.catch_warnings(record=True) as w: + with warnings.catch_warnings(record=True) as catched_warnings: benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer'], sdv_datasets=['fake_companies'] ) - future_warnings = [warning for warning in w if issubclass(warning.category, FutureWarning)] - assert len(future_warnings) == 0 + + # Assert + future_warnings = [ + warning for warning in catched_warnings if issubclass(warning.category, FutureWarning) + ] + assert len(future_warnings) == 0 def test_benchmark_single_table_with_output_destination(tmp_path): """Test it works with the ``output_destination`` argument.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' today_date = pd.Timestamp.now().strftime('%m_%d_%Y') # Run results = benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer', 'TVAESynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), # function may require str ) # Assert - directions = os.listdir(output_destination) - score_saved_separately = pd.DataFrame() - assert directions == [f'SDGym_results_{today_date}'] - subdirections = os.listdir(os.path.join(output_destination, directions[0])) - assert set(subdirections) == { + top_level = os.listdir(output_destination) + assert top_level == ['single_table'] + + second_level = os.listdir(output_destination / 'single_table') + assert second_level == [f'SDGym_results_{today_date}'] + + subdir = output_destination / 'single_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { 'results.csv', f'fake_companies_{today_date}', 'metainfo.yaml', } - with open(os.path.join(output_destination, directions[0], 'metainfo.yaml'), 'r') as f: + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: metadata = yaml.safe_load(f) - assert metadata['completed_date'] is not None - assert metadata['sdgym_version'] == sdgym.__version__ - synthesizer_directions = os.listdir( - os.path.join(output_destination, directions[0], f'fake_companies_{today_date}') - ) - assert set(synthesizer_directions) == { + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + + # Synthesizer directories + synth_dir = subdir / f'fake_companies_{today_date}' + synthesizer_dirs = os.listdir(synth_dir) + assert set(synthesizer_dirs) == { 'TVAESynthesizer', 'GaussianCopulaSynthesizer', 'UniformSynthesizer', } - for synthesizer in sorted(synthesizer_directions): - synthesizer_files = os.listdir( - os.path.join( - output_destination, directions[0], f'fake_companies_{today_date}', synthesizer - ) - ) - assert set(synthesizer_files) == { + + # Validate files in each synthesizer directory + score_saved_separately = pd.DataFrame() + for synthesizer in sorted(synthesizer_dirs): + files = os.listdir(synth_dir / synthesizer) + assert set(files) == { f'{synthesizer}.pkl', f'{synthesizer}_synthetic_data.csv', f'{synthesizer}_benchmark_result.csv', } - score = pd.read_csv( - os.path.join( - output_destination, - directions[0], - f'fake_companies_{today_date}', - synthesizer, - f'{synthesizer}_benchmark_result.csv', - ) - ) + + score_path = synth_dir / synthesizer / f'{synthesizer}_benchmark_result.csv' + score = pd.read_csv(score_path) score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) - saved_result = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results.csv') + # Load top-level results.csv + saved_results_path = ( + output_destination / 'single_table' / f'SDGym_results_{today_date}' / 'results.csv' + ) + saved_result = pd.read_csv(saved_results_path) + + # Assert Results pd.testing.assert_frame_equal(results, saved_result, check_dtype=False) results_no_adjusted = results.drop(columns=['Adjusted_Total_Time', 'Adjusted_Quality_Score']) pd.testing.assert_frame_equal(results_no_adjusted, score_saved_separately, check_dtype=False) @@ -703,83 +712,88 @@ def test_benchmark_single_table_with_output_destination(tmp_path): def test_benchmark_single_table_with_output_destination_multiple_runs(tmp_path): """Test saving in ``output_destination`` with multiple runs. - Here two benchmark runs are performed with different synthesizers - on the same dataset, and the results are saved in the same output directory. - The directory contains a `results.csv` file with the combined results - and a subdirectory for each synthesizer with its own results. + Two benchmark runs are performed with different synthesizers on the same + dataset, saving results to the same output directory. The directory contains + multiple `results.csv` files and synthesizer subdirectories. """ # Setup - output_destination = str(tmp_path / 'benchmark_output') + output_destination = tmp_path / 'benchmark_output' today_date = pd.Timestamp.now().strftime('%m_%d_%Y') # Run result_1 = benchmark_single_table( synthesizers=['GaussianCopulaSynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), ) result_2 = benchmark_single_table( synthesizers=['TVAESynthesizer'], sdv_datasets=['fake_companies'], - output_destination=output_destination, + output_destination=str(output_destination), ) # Assert score_saved_separately = pd.DataFrame() - directions = os.listdir(output_destination) - assert directions == [f'SDGym_results_{today_date}'] - subdirections = os.listdir(os.path.join(output_destination, directions[0])) - assert set(subdirections) == { + + top_level = os.listdir(output_destination) + assert top_level == ['single_table'] + + second_level = os.listdir(output_destination / 'single_table') + assert second_level == [f'SDGym_results_{today_date}'] + + subdir = output_destination / 'single_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { 'results.csv', 'results(1).csv', f'fake_companies_{today_date}', 'metainfo.yaml', 'metainfo(1).yaml', } - with open(os.path.join(output_destination, directions[0], 'metainfo.yaml'), 'r') as f: + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: metadata = yaml.safe_load(f) - assert metadata['completed_date'] is not None - assert metadata['sdgym_version'] == sdgym.__version__ - synthesizer_directions = os.listdir( - os.path.join(output_destination, directions[0], f'fake_companies_{today_date}') - ) - assert set(synthesizer_directions) == { + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + + # Synthesizer directories + synth_parent = subdir / f'fake_companies_{today_date}' + synthesizer_dirs = os.listdir(synth_parent) + + # Assert Synthesizer directories + assert set(synthesizer_dirs) == { 'TVAESynthesizer(1)', 'GaussianCopulaSynthesizer', 'UniformSynthesizer', 'UniformSynthesizer(1)', } - for synthesizer in sorted(synthesizer_directions): - synthesizer_files = os.listdir( - os.path.join( - output_destination, directions[0], f'fake_companies_{today_date}', synthesizer - ) - ) - assert set(synthesizer_files) == { + + # Validate each synthesizer directory + for synthesizer in sorted(synthesizer_dirs): + synth_path = synth_parent / synthesizer + + synth_files = os.listdir(synth_path) + assert set(synth_files) == { f'{synthesizer}.pkl', f'{synthesizer}_synthetic_data.csv', f'{synthesizer}_benchmark_result.csv', } - score = pd.read_csv( - os.path.join( - output_destination, - directions[0], - f'fake_companies_{today_date}', - synthesizer, - f'{synthesizer}_benchmark_result.csv', - ) - ) + + score = pd.read_csv(synth_path / f'{synthesizer}_benchmark_result.csv') score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) - saved_result_1 = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results.csv') - saved_result_2 = pd.read_csv(f'{output_destination}/SDGym_results_{today_date}/results(1).csv') + # Load saved results + saved_result_1 = pd.read_csv(subdir / 'results.csv') + saved_result_2 = pd.read_csv(subdir / 'results(1).csv') + + # Assert results pd.testing.assert_frame_equal(result_1, saved_result_1, check_dtype=False) pd.testing.assert_frame_equal(result_2, saved_result_2, check_dtype=False) @patch('sdv.single_table.GaussianCopulaSynthesizer.fit', autospec=True) -def test_benchmark_error_during_fit(mock_fit): +def test_benchmark_single_table_error_during_fit(mock_fit): """Test that benchmark_single_table handles errors during synthesizer fitting.""" # Setup @@ -824,7 +838,7 @@ def fit(self, data): @patch('sdv.single_table.GaussianCopulaSynthesizer.sample', autospec=True) -def test_benchmark_error_during_sample(mock_sample): +def test_benchmark_single_table_error_during_sample(mock_sample): """Test that benchmark_single_table handles errors during synthesizer sampling.""" # Setup @@ -864,3 +878,181 @@ def sample(self, num_rows): expected_time = base_time + extra assert np.isclose(row['Adjusted_Total_Time'], expected_time) + + +def test_benchmark_multi_table_basic_synthesizers(): + """Integration test that runs HMASynthesizer and MultiTableUniformSynthesizer on fake_hotels.""" + output = benchmark_multi_table( + synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'], + sdv_datasets=['fake_hotels'], + compute_quality_score=True, + compute_diagnostic_score=True, + limit_dataset_size=True, + show_progress=False, + timeout=30, + ) + + # Assert + assert isinstance(output, pd.DataFrame) + assert not output.empty + + # Required SDGym benchmark output columns + for col in [ + 'Synthesizer', + 'Train_Time', + 'Sample_Time', + 'Quality_Score', + 'Diagnostic_Score', + ]: + assert col in output.columns + + synths = sorted(output['Synthesizer'].unique()) + assert synths == [ + 'HMASynthesizer', + 'MultiTableUniformSynthesizer', + ] + + diagnostic_rank = ( + output.groupby('Synthesizer').Diagnostic_Score.mean().sort_values().index.tolist() + ) + + assert diagnostic_rank == [ + 'MultiTableUniformSynthesizer', + 'HMASynthesizer', + ] + + quality_rank = output.groupby('Synthesizer').Quality_Score.mean().sort_values().index.tolist() + + assert quality_rank == [ + 'MultiTableUniformSynthesizer', + 'HMASynthesizer', + ] + + +def test_benchmark_multi_table_with_output_destination_multiple_runs(tmp_path): + """Test saving in ``output_destination`` with multiple runs in multi-table mode. + + Two benchmark runs are performed with HMASynthesizer on the same multi-table + dataset, saving results to the same output directory. The directory contains + multiple `results*.csv` files, metainfo files, and synthesizer subdirectories. + """ + # Setup + output_destination = tmp_path / 'benchmark_output' + today_date = pd.Timestamp.now().strftime('%m_%d_%Y') + + # Run 1 + result_1 = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + sdv_datasets=['fake_hotels'], + output_destination=str(output_destination), + ) + + # Run 2 + result_2 = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + sdv_datasets=['fake_hotels'], + output_destination=str(output_destination), + ) + + # Assert + score_saved_separately = pd.DataFrame() + + # Top level directory + top_level = os.listdir(output_destination) + assert top_level == ['multi_table'] + + # Second level + second_level = os.listdir(output_destination / 'multi_table') + assert second_level == [f'SDGym_results_{today_date}'] + + # SDGym results folder + subdir = output_destination / 'multi_table' / f'SDGym_results_{today_date}' + assert set(os.listdir(subdir)) == { + 'results.csv', + 'results(1).csv', + f'fake_hotels_{today_date}', + 'metainfo.yaml', + 'metainfo(1).yaml', + } + + # Validate metadata + with open(subdir / 'metainfo.yaml', 'r') as f: + metadata = yaml.safe_load(f) + + assert metadata['completed_date'] is not None + assert metadata['sdgym_version'] == sdgym.__version__ + assert metadata['modality'] == 'multi_table' + + # Synthesizer folders + synth_parent = subdir / f'fake_hotels_{today_date}' + synthesizer_dirs = os.listdir(synth_parent) + + assert set(synthesizer_dirs) == { + 'HMASynthesizer', + 'HMASynthesizer(1)', + 'MultiTableUniformSynthesizer', + 'MultiTableUniformSynthesizer(1)', + } + + # Validate each synthesizer directory + for synthesizer in sorted(synthesizer_dirs): + synth_path = synth_parent / synthesizer + + synth_files = os.listdir(synth_path) + assert set(synth_files) == { + f'{synthesizer}.pkl', + f'{synthesizer}_synthetic_data.zip', + f'{synthesizer}_benchmark_result.csv', + } + + score = pd.read_csv(synth_path / f'{synthesizer}_benchmark_result.csv') + score_saved_separately = pd.concat([score_saved_separately, score], ignore_index=True) + + # Load results for both runs + saved_result_1 = pd.read_csv(subdir / 'results.csv') + saved_result_2 = pd.read_csv(subdir / 'results(1).csv') + + # Validate the stored results match returned results + pd.testing.assert_frame_equal(result_1, saved_result_1, check_dtype=False) + pd.testing.assert_frame_equal(result_2, saved_result_2, check_dtype=False) + + +@patch('sdv.multi_table.HMASynthesizer._augment_tables', autospec=True) +def test_benchmark_multi_table_error_during_fit(mock_augment_tables): + """Test that benchmark_multi_table handles errors during synthesizer fitting.""" + + # Setup + def _augment_tables(self, data): + raise Exception('Fitting error') + + mock_augment_tables.side_effect = _augment_tables + + # Run + result = benchmark_multi_table( + synthesizers=['HMASynthesizer', 'MultiTableUniformSynthesizer'], + sdv_datasets=['Student_loan', 'fake_hotels'], + ) + + # Assert + assert result['error'].to_list() == [ + 'Exception: Fitting error', + np.nan, + 'Exception: Fitting error', + np.nan, + ] + for dataset, data in result.groupby('Dataset'): + uniform = data.loc[data['Synthesizer'] == 'MultiTableUniformSynthesizer'].iloc[0] + uniform_train = uniform['Train_Time'] + uniform_total = uniform[['Train_Time', 'Sample_Time']].sum() + + for synth in ['HMASynthesizer', 'MultiTableUniformSynthesizer']: + row = data.loc[data['Synthesizer'] == synth] + if row.empty: + continue + + row = row.iloc[0] + base_time = row[['Train_Time', 'Sample_Time']].sum(skipna=True) + extra = uniform_total if synth == 'HMASynthesizer' else uniform_train + expected_time = base_time + extra + + assert np.isclose(row['Adjusted_Total_Time'], expected_time) diff --git a/tests/unit/synthesizers/test_base.py b/tests/unit/synthesizers/test_base.py index 0922b08c..e1f620d4 100644 --- a/tests/unit/synthesizers/test_base.py +++ b/tests/unit/synthesizers/test_base.py @@ -1,6 +1,6 @@ import re import warnings -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch import pandas as pd import pytest @@ -101,33 +101,42 @@ def test_get_trained_synthesizer(self): class TestMultiTableBaselineSynthesizer: - def test_sample_from_synthesizer(self): - """Test it calls the correct methods and returns the correct values.""" - # Setup + @pytest.mark.parametrize( + 'scale, expected_scale', + [ + (None, 1.0), + (2.0, 2.0), + ], + ) + def test_sample_from_synthesizer_valid(self, scale, expected_scale): + """Test that valid calls return correct values and call underlying method.""" synthesizer = MultiTableBaselineSynthesizer() mock_synthesizer = Mock() synthesizer._sample_from_synthesizer = Mock(return_value='sampled_data') + + # Run + if scale is None: + result = synthesizer.sample_from_synthesizer(mock_synthesizer) + else: + result = synthesizer.sample_from_synthesizer(mock_synthesizer, scale) + + # Assert call + synthesizer._sample_from_synthesizer.assert_called_with(mock_synthesizer, expected_scale) + + assert result == 'sampled_data' + assert synthesizer._MODALITY_FLAG == 'multi_table' + + def test_sample_from_synthesizer_raises_on_unexpected_kwarg(self): + """Test that passing n_samples raises a TypeError.""" + synthesizer = MultiTableBaselineSynthesizer() + mock_synthesizer = Mock() + expected_error = re.escape( "sample_from_synthesizer() got an unexpected keyword argument 'n_samples'" ) - # Run - sampled_data = synthesizer.sample_from_synthesizer(mock_synthesizer) - sampled_data_with_scale = synthesizer.sample_from_synthesizer( - mock_synthesizer, - scale=2.0, - ) with pytest.raises(TypeError, match=expected_error): synthesizer.sample_from_synthesizer( mock_synthesizer, n_samples=10, ) - - # Assert - assert synthesizer._MODALITY_FLAG == 'multi_table' - synthesizer._sample_from_synthesizer.assert_has_calls([ - call(mock_synthesizer, scale=1.0), - call(mock_synthesizer, scale=2.0), - ]) - assert sampled_data == 'sampled_data' - assert sampled_data_with_scale == 'sampled_data' diff --git a/tests/unit/test__dataset_utils.py b/tests/unit/test__dataset_utils.py new file mode 100644 index 00000000..627d5873 --- /dev/null +++ b/tests/unit/test__dataset_utils.py @@ -0,0 +1,219 @@ +import json +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +from sdgym._dataset_utils import ( + _filter_columns, + _get_dataset_subset, + _get_multi_table_dataset_subset, + _parse_numeric_value, + _read_csv_from_zip, + _read_metadata_json, + _read_zipped_data, +) + + +@pytest.mark.parametrize( + 'value,expected', + [ + ('3.14', 3.14), + ('not-a-number', np.nan), + (None, np.nan), + ], +) +def test__parse_numeric_value(value, expected): + """Test numeric parsing with fallback to NaN.""" + # Setup / Run + result = _parse_numeric_value(value, 'dataset', 'field') + + # Assert + if np.isnan(expected): + assert np.isnan(result) + else: + assert result == expected + + +@patch('sdgym._dataset_utils.poc.get_random_subset') +@patch('sdgym._dataset_utils.Metadata') +def test__get_multi_table_dataset_subset(mock_metadata, mock_subset): + """Test multi-table subset selection calls SDV and trims columns.""" + # Setup + df_main = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + df_other = pd.DataFrame({'x': [5, 6], 'y': [7, 8]}) + + data = {'main': df_main, 'other': df_other} + + metadata_dict = { + 'tables': { + 'main': {'columns': {'a': {}, 'b': {}}}, + 'other': {'columns': {'x': {}, 'y': {}}}, + } + } + + mock_meta_obj = MagicMock() + mock_meta_obj.tables = { + 'main': MagicMock(columns={'a': {}, 'b': {}}), + 'other': MagicMock(columns={'x': {}, 'y': {}}), + } + mock_meta_obj._get_all_keys.return_value = [] + mock_metadata.load_from_dict.return_value = mock_meta_obj + + mock_subset.return_value = {'main': df_main[:1], 'other': df_other[:1]} + + # Run + result_data, result_meta = _get_multi_table_dataset_subset(data, metadata_dict) + + # Assert + assert 'main' in result_data + assert 'other' in result_data + mock_subset.assert_called_once() + + +def test__get_dataset_subset_single_table(): + """Test tabular dataset subset reduces rows and columns.""" + # Setup + df = pd.DataFrame({f'c{i}': range(2000) for i in range(15)}) + metadata = {'tables': {'table': {'columns': {f'c{i}': {} for i in range(15)}}}} + + # Run + result_df, result_meta = _get_dataset_subset(df, metadata, modality='single_table') + + # Assert + assert len(result_df) <= 1000 + assert len(result_df.columns) == 10 + assert 'tables' in result_meta + + +def test__get_dataset_subset_sequential(): + """Test sequential dataset preserves mandatory columns.""" + # Setup + df = pd.DataFrame({ + 'seq_id': range(20), + 'seq_key': range(20), + **{f'c{i}': range(20) for i in range(20)}, + }) + + metadata = { + 'tables': { + 'table': { + 'columns': {col: {'sdtype': 'numerical'} for col in df.columns.to_list()}, + 'sequence_index': 'seq_id', + 'sequence_key': 'seq_key', + } + } + } + + # Run + subset_df, _ = _get_dataset_subset(df, metadata, modality='sequential') + + # Assert + assert 'seq_id' in subset_df.columns + assert 'seq_key' in subset_df.columns + assert len(subset_df.columns) <= 12 + + +@patch('sdgym._dataset_utils._get_multi_table_dataset_subset') +def test__get_dataset_subset_multi_table(mock_multi): + """Test multi-table dispatch calls the correct function.""" + # Setup + data = {'table': pd.DataFrame({'a': [1, 2]})} + metadata = {'tables': {}} + mock_multi.return_value = ('DATA', 'META') + + # Run + out_data, out_meta = _get_dataset_subset(data, metadata, modality='multi_table') + + # Assert + assert out_data == 'DATA' + assert out_meta == 'META' + mock_multi.assert_called_once() + + +@patch('sdgym._dataset_utils._read_csv_from_zip') +def test__read_zipped_data_multitable(mock_read): + """Test zipped CSV reading returns a dict for multi-table.""" + # Setup + mock_read.return_value = pd.DataFrame({'a': [1]}) + + mock_zip = MagicMock() + mock_zip.__enter__.return_value = mock_zip + mock_zip.namelist.return_value = ['table1.csv', 'table2.csv'] + + # Run + with patch('sdgym._dataset_utils.ZipFile', return_value=mock_zip): + data_multi = _read_zipped_data('fake.zip', modality='multi_table') + + # Assert + assert isinstance(data_multi, dict) + assert mock_read.call_count == 2 + + +@patch('sdgym._dataset_utils._read_csv_from_zip') +def test__read_zipped_data_single(mock_read): + """Test zipped CSV reading returns a DataFrame for single-table.""" + # Setup + mock_read.return_value = pd.DataFrame({'a': [1]}) + + mock_zip = MagicMock() + mock_zip.__enter__.return_value = mock_zip + mock_zip.namelist.return_value = ['table1.csv'] + + # Run + with patch('sdgym._dataset_utils.ZipFile', return_value=mock_zip): + data_single = _read_zipped_data('fake.zip', modality='single_table') + + # Assert + assert isinstance(data_single, pd.DataFrame) + assert mock_read.call_count == 1 + + +@patch('sdgym._dataset_utils.pd') +def test__read_csv_from_zip(mock_pd): + """Test CSV is read from zip and returned as DataFrame.""" + # Setup + csv_bytes = b'a,b\n1,2\n3,4\n' + returned_bytes = csv_bytes.decode().splitlines() + mock_zip = MagicMock() + mock_zip.open.return_value.__enter__.return_value = returned_bytes + + # Run + result = _read_csv_from_zip(mock_zip, 'fake.csv') + + # Assert + mock_pd.read_csv.assert_called_once_with(returned_bytes, low_memory=False) + assert result == mock_pd.read_csv.return_value + + +def test__read_metadata_json(tmp_path): + """Test reading metadata JSON file.""" + # Setup + meta = {'tables': {'a': {}}} + path = tmp_path / 'meta.json' + path.write_text(json.dumps(meta)) + + # Run + result = _read_metadata_json(path) + + # Assert + assert result == meta + + +def test__filter_columns(): + """Test filtering keeps mandatory columns and limits total optional columns.""" + # Setup + columns = {f'c{i}': {} for i in range(20)} + mandatory = ['c10', 'c11', 'c19'] + + # Run + filtered = _filter_columns(columns, mandatory) + + # Assert + for col in mandatory: + assert col in filtered + + assert len(filtered) == len(mandatory) + 10 + expected_optional = [f'c{i}' for i in range(20) if f'c{i}' not in mandatory][:10] + assert list(filtered.keys()) == mandatory + expected_optional diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index b73b0408..899f8824 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -11,7 +11,6 @@ import pytest import yaml -from sdgym import benchmark_single_table from sdgym.benchmark import ( _add_adjusted_scores, _check_write_permissions, @@ -23,16 +22,128 @@ _generate_job_args_list, _get_metainfo_increment, _handle_deprecated_parameters, + _import_and_validate_synthesizers, _setup_output_destination, _setup_output_destination_aws, _update_metainfo_file, _validate_aws_inputs, _validate_output_destination, _write_metainfo_file, + benchmark_multi_table, + benchmark_single_table, benchmark_single_table_aws, ) from sdgym.result_writer import LocalResultsWriter from sdgym.s3 import S3_REGION +from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer + + +class FakeSynth: + """Simple fake synthesizer with a configurable modality flag.""" + + def __init__(self, modality): + self._MODALITY_FLAG = modality + + +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark.get_duplicates', return_value=[]) +@patch('sdgym.benchmark.get_synthesizers') +def test__import_and_validate_synthesizers_valid( + mock_get_synthesizers, mock_get_duplicates, modality +): + """Test that `_import_and_validate_synthesizers` returns the `get_synthesizers` values.""" + # Setup + fake_synth = FakeSynth(modality) + + mock_get_synthesizers.return_value = [{'name': 'FakeSynth', 'synthesizer': fake_synth}] + + # Run + result = _import_and_validate_synthesizers( + synthesizers=['FakeSynth'], + custom_synthesizers=None, + modality=modality, + ) + + # Assert + assert result == mock_get_synthesizers.return_value + mock_get_synthesizers.assert_called_once_with(['FakeSynth']) + + +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark.get_duplicates', return_value=[]) +@patch('sdgym.benchmark.get_synthesizers') +def test__import_and_validate_synthesizers_mismatched_modality( + mock_get_synthesizers, mock_get_duplicates, modality +): + """Test `_import_and_validate_synthesizers` raises ValueError. + + Test to ensure that if a synthesizer's modality does not match the expected one a ValueError + is being raised. + """ + # Setup + wrong_modality = 'multi_table' if modality == 'single_table' else 'single_table' + + # Dynamically create a class named BadSynth + FakeWrong = type('BadSynth', (), {'_MODALITY_FLAG': wrong_modality}) + fake_wrong = FakeWrong() + + mock_get_synthesizers.return_value = [{'name': 'BadSynth', 'synthesizer': fake_wrong}] + + expected_message = ( + f"Synthesizers must be of modality '{modality}'. " + f"Found these synthesizers that don't match: BadSynth" + ) + + # Run and Assert + with pytest.raises(ValueError, match=expected_message): + _import_and_validate_synthesizers( + synthesizers=['BadSynth'], + custom_synthesizers=[], + modality=modality, + ) + + +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark.get_duplicates', return_value=['DupSynth']) +@patch('sdgym.benchmark.get_synthesizers') +def test__import_and_validate_synthesizers_duplicates( + mock_get_synthesizers, mock_get_duplicates, modality +): + """Test `_import_and_validate_synthesizers` raises a ValueError when duplicate values.""" + # Setup + fake_synth = FakeSynth(modality) + mock_get_synthesizers.return_value = [{'name': 'DupSynth', 'synthesizer': fake_synth}] + + expected_message = re.escape( + 'Synthesizers must be unique. Please remove repeated values in the provided synthesizers.' + ) + + # Run and Assert + with pytest.raises(ValueError, match=expected_message): + _import_and_validate_synthesizers( + synthesizers=['DupSynth'], + custom_synthesizers=['DupSynth'], + modality=modality, + ) + + +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) +@patch('sdgym.benchmark.get_duplicates', return_value=[]) +@patch('sdgym.benchmark.get_synthesizers') +def test__import_and_validate_synthesizers_none_inputs( + mock_get_synthesizers, mock_get_duplicates, modality +): + """Test `_import_and_validate_synthesizers` with empty lists.""" + # Run + result = _import_and_validate_synthesizers( + synthesizers=None, + custom_synthesizers=None, + modality=modality, + ) + + # Assert + assert result == mock_get_synthesizers.return_value + mock_get_synthesizers.assert_called_once_with([]) @patch('sdgym.benchmark.os.path') @@ -275,37 +386,49 @@ def test_run_ec2_flag(create_ec2_mock, session_mock, mock_write_permissions, moc benchmark_single_table(run_on_ec2=True, output_filepath='s3://BucketName/mock/path') -def test__ensure_uniform_included_adds_uniform(caplog): +@pytest.mark.parametrize( + 'modality,uniform_string', + [('single_table', 'UniformSynthesizer'), ('multi_table', 'MultiTableUniformSynthesizer')], +) +def test__ensure_uniform_included_adds_uniform(modality, uniform_string, caplog): """Test that UniformSynthesizer gets added to the synthesizers list.""" # Setup synthesizers = ['GaussianCopulaSynthesizer'] - expected_message = 'Adding UniformSynthesizer to list of synthesizers.' + expected_message = f'Adding {uniform_string} to the list of synthesizers.' # Run with caplog.at_level(logging.INFO): - _ensure_uniform_included(synthesizers) + _ensure_uniform_included(synthesizers, modality) # Assert - assert synthesizers == ['GaussianCopulaSynthesizer', 'UniformSynthesizer'] + assert synthesizers == ['GaussianCopulaSynthesizer', uniform_string] assert any(expected_message in record.message for record in caplog.records) -def test__ensure_uniform_included_detects_uniform_class(caplog): +@pytest.mark.parametrize( + 'modality,uniform_class', + [('single_table', UniformSynthesizer), ('multi_table', MultiTableUniformSynthesizer)], +) +def test__ensure_uniform_included_detects_uniform_class(modality, uniform_class, caplog): """Test that the synthesizers list is unchanged if UniformSynthesizer class present.""" # Setup - synthesizers = ['UniformSynthesizer', 'GaussianCopulaSynthesizer'] - expected_message = 'Adding UniformSynthesizer to list of synthesizers.' + synthesizers = [uniform_class, 'GaussianCopulaSynthesizer'] + expected_message = f'Adding {uniform_class} to the list of synthesizers.' # Run with caplog.at_level(logging.INFO): - _ensure_uniform_included(synthesizers) + _ensure_uniform_included(synthesizers, modality) # Assert - assert synthesizers == ['UniformSynthesizer', 'GaussianCopulaSynthesizer'] + assert synthesizers == [uniform_class, 'GaussianCopulaSynthesizer'] assert all(expected_message not in record.message for record in caplog.records) -def test__ensure_uniform_included_detects_uniform_string(caplog): +@pytest.mark.parametrize( + 'modality,uniform_string', + [('single_table', 'UniformSynthesizer'), ('multi_table', 'MultiTableUniformSynthesizer')], +) +def test__ensure_uniform_included_detects_uniform_string(modality, uniform_string, caplog): """Test that the synthesizers list is unchanged if UniformSynthesizer string present.""" # Setup synthesizers = ['UniformSynthesizer', 'GaussianCopulaSynthesizer'] @@ -313,7 +436,7 @@ def test__ensure_uniform_included_detects_uniform_string(caplog): # Run with caplog.at_level(logging.INFO): - _ensure_uniform_included(synthesizers) + _ensure_uniform_included(synthesizers, 'single_table') # Assert assert synthesizers == ['UniformSynthesizer', 'GaussianCopulaSynthesizer'] @@ -522,18 +645,30 @@ def test__validate_output_destination_with_aws_access_key_ids(mock_validate): ) -def test__setup_output_destination(tmp_path): - """Test the `_setup_output_destination` function.""" +def test__setup_output_destination_none(): + """If output_destination is None, the function should return an empty dict.""" + # Setup + synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer'] + datasets = ['adult', 'census'] + + # Run + result = _setup_output_destination(None, synthesizers, datasets, 'single_table') + + # Assert + assert result == {} + + +def test__setup_output_destination_single_table(tmp_path): + """Test the `_setup_output_destination` function with `single_table` modality.""" # Setup output_destination = tmp_path / 'output_destination' synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer'] datasets = ['adult', 'census'] today = datetime.today().strftime('%m_%d_%Y') - base_path = output_destination / f'SDGym_results_{today}' + base_path = output_destination / 'single_table' / f'SDGym_results_{today}' # Run - result_1 = _setup_output_destination(None, synthesizers, datasets) - result_2 = _setup_output_destination(output_destination, synthesizers, datasets) + result = _setup_output_destination(output_destination, synthesizers, datasets, 'single_table') # Assert expected = { @@ -554,8 +689,41 @@ def test__setup_output_destination(tmp_path): for dataset in datasets } - assert result_1 == {} - assert json.loads(json.dumps(result_2)) == expected + assert json.loads(json.dumps(result)) == expected + + +def test__setup_output_destination_multi_table(tmp_path): + """Test the `_setup_output_destination` function with `multi_table` modality.""" + # Setup + output_destination = tmp_path / 'output_destination' + synthesizers = ['HMASynthesizer'] + datasets = ['NBA', 'financial'] + today = datetime.today().strftime('%m_%d_%Y') + base_path = output_destination / 'multi_table' / f'SDGym_results_{today}' + + # Run + result = _setup_output_destination(output_destination, synthesizers, datasets, 'multi_table') + + # Assert + expected = { + dataset: { + synth: { + 'synthesizer': str(base_path / f'{dataset}_{today}' / synth / f'{synth}.pkl'), + 'synthetic_data': str( + base_path / f'{dataset}_{today}' / synth / f'{synth}_synthetic_data.zip' + ), + 'benchmark_result': str( + base_path / f'{dataset}_{today}' / synth / f'{synth}_benchmark_result.csv' + ), + 'metainfo': str(base_path / 'metainfo.yaml'), + 'results': str(base_path / 'results.csv'), + } + for synth in synthesizers + } + for dataset in datasets + } + + assert json.loads(json.dumps(result)) == expected @patch('sdgym.benchmark.datetime') @@ -572,22 +740,28 @@ def test__write_metainfo_file(mock_datetime, tmp_path): ({'name': 'CTGANSynthesizer'}, 'census', None, None), ] expected_jobs = [['adult', 'GaussianCopulaSynthesizer'], ['census', 'CTGANSynthesizer']] - synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer', 'RealTabFormerSynthesizer'] + synthesizers = [ + {'name': 'GaussianCopulaSynthesizer'}, + {'name': 'CTGANSynthesizer'}, + {'name': 'RealTabFormerSynthesizer'}, + ] # Run - _write_metainfo_file(synthesizers, jobs, result_writer) + _write_metainfo_file(synthesizers, jobs, 'single_table', result_writer) # Assert - assert Path(file_name['metainfo']).exists() with open(file_name['metainfo'], 'r') as file: metainfo_data = yaml.safe_load(file) - assert metainfo_data['run_id'] == 'run_06_26_2025_0' - assert metainfo_data['starting_date'] == '06_26_2025' - assert metainfo_data['jobs'] == expected_jobs - assert metainfo_data['sdgym_version'] == version('sdgym') - assert metainfo_data['sdv_version'] == version('sdv') - assert metainfo_data['realtabformer_version'] == version('realtabformer') - assert metainfo_data['completed_date'] is None + + assert Path(file_name['metainfo']).exists() + assert metainfo_data['run_id'] == 'run_06_26_2025_0' + assert metainfo_data['starting_date'] == '06_26_2025' + assert metainfo_data['jobs'] == expected_jobs + assert metainfo_data['sdgym_version'] == version('sdgym') + assert metainfo_data['sdv_version'] == version('sdv') + assert metainfo_data['realtabformer_version'] == version('realtabformer') + assert metainfo_data['completed_date'] is None + assert metainfo_data['modality'] == 'single_table' @patch('sdgym.benchmark.datetime') @@ -746,11 +920,15 @@ def test_validate_aws_inputs_permission_error(mock_check_write_permissions, mock _validate_aws_inputs(valid_url, None, None) +@patch('sdgym.benchmark._import_and_validate_synthesizers') @patch('sdgym.benchmark._validate_output_destination') @patch('sdgym.benchmark._generate_job_args_list') @patch('sdgym.benchmark._run_on_aws') def test_benchmark_single_table_aws( - mock_run_on_aws, mock_generate_job_args_list, mock_validate_output_destination + mock_run_on_aws, + mock_generate_job_args_list, + mock_validate_output_destination, + mock__import_and_validate_synthesizers, ): """Test `benchmark_single_table_aws` method.""" # Setup @@ -761,6 +939,7 @@ def test_benchmark_single_table_aws( aws_secret_access_key = '67890' mock_validate_output_destination.return_value = 's3_client_mock' mock_generate_job_args_list.return_value = 'job_args_list_mock' + mock__import_and_validate_synthesizers.return_value = synthesizers # Run benchmark_single_table_aws( @@ -792,8 +971,8 @@ def test_benchmark_single_table_aws( compute_privacy_score=True, synthesizers=synthesizers, detailed_results_folder=None, - custom_synthesizers=None, s3_client='s3_client_mock', + modality='single_table', ) mock_run_on_aws.assert_called_once_with( output_destination=output_destination, @@ -805,11 +984,15 @@ def test_benchmark_single_table_aws( ) +@patch('sdgym.benchmark._import_and_validate_synthesizers') @patch('sdgym.benchmark._validate_output_destination') @patch('sdgym.benchmark._generate_job_args_list') @patch('sdgym.benchmark._run_on_aws') def test_benchmark_single_table_aws_synthesizers_none( - mock_run_on_aws, mock_generate_job_args_list, mock_validate_output_destination + mock_run_on_aws, + mock_generate_job_args_list, + mock_validate_output_destination, + mock__import_and_validate_synthesizers, ): """Test `benchmark_single_table_aws` includes UniformSynthesizer if omitted.""" # Setup @@ -820,6 +1003,7 @@ def test_benchmark_single_table_aws_synthesizers_none( aws_secret_access_key = '67890' mock_validate_output_destination.return_value = 's3_client_mock' mock_generate_job_args_list.return_value = 'job_args_list_mock' + mock__import_and_validate_synthesizers.return_value = ['UniformSynthesizer'] # Run benchmark_single_table_aws( @@ -848,10 +1032,10 @@ def test_benchmark_single_table_aws_synthesizers_none( compute_quality_score=True, compute_diagnostic_score=True, compute_privacy_score=True, - synthesizers=['UniformSynthesizer'], detailed_results_folder=None, - custom_synthesizers=None, + synthesizers=['UniformSynthesizer'], s3_client='s3_client_mock', + modality='single_table', ) mock_run_on_aws.assert_called_once_with( output_destination=output_destination, @@ -1010,13 +1194,18 @@ def test__add_adjusted_scores_missing_fallback(): assert scores.equals(expected) +@pytest.mark.parametrize('modality', ['single_table', 'multi_table']) @patch('sdgym.benchmark.get_dataset_paths') -def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path): +def test__generate_job_args_list_local_root_additional_folder( + get_dataset_paths_mock, + tmp_path, + modality, +): """Local additional_datasets_folder should point to root/single_table.""" # Setup local_root = tmp_path / 'my_root' local_root.mkdir() - dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA' + dataset_path = tmp_path / 'my_root' / modality / 'datasetA' get_dataset_paths_mock.return_value = [dataset_path] # Run @@ -1032,14 +1221,14 @@ def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_ compute_diagnostic_score=False, compute_privacy_score=False, synthesizers=[], - custom_synthesizers=None, s3_client=None, + modality=modality, ) # Assert get_dataset_paths_mock.assert_called_once_with( - modality='single_table', - bucket=str(local_root / 'single_table'), + modality=modality, + bucket=str(local_root / modality), aws_access_key_id=None, aws_secret_access_key=None, ) @@ -1066,8 +1255,8 @@ def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_moc compute_diagnostic_score=False, compute_privacy_score=False, synthesizers=[], - custom_synthesizers=None, s3_client=None, + modality='single_table', ) # Assert @@ -1106,3 +1295,120 @@ def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): warnings_text = ' '.join(str(w.message) for w in recwarn) assert 'is incompatible with transformer' not in warnings_text pd.testing.assert_frame_equal(result[expected_result.columns], expected_result) + + +@patch('sdgym.benchmark._import_and_validate_synthesizers') +@patch('sdgym.benchmark._update_metainfo_file') +@patch('sdgym.benchmark._write_metainfo_file') +@patch('sdgym.benchmark._run_jobs') +@patch('sdgym.benchmark._generate_job_args_list') +@patch('sdgym.benchmark.LocalResultsWriter') +@patch('sdgym.benchmark._validate_output_destination') +def test_benchmark_multi_table_with_jobs( + mock__validate_output_destination, + mock_LocalResultsWriter, + mock__generate_job_args_list, + mock__run_jobs, + mock__write_metainfo_file, + mock__update_metainfo_file, + mock__import_and_validate_synthesizers, +): + """Test that `benchmark_multi_table` runs jobs and updates metainfo when there are job args.""" + # Setup + fake_scores = pd.DataFrame({'a': [1]}) + mock__run_jobs.return_value = fake_scores + job_args = ('arg1', 'arg2', {'metainfo': 'meta.yaml'}) + mock__generate_job_args_list.return_value = [job_args] + expected_valid_synthesizers = ['HMASynthesizer', 'MultiTableUniformSynthesizer', 'CustomSynth'] + + mock__import_and_validate_synthesizers.return_value = expected_valid_synthesizers + # Run + scores = benchmark_multi_table( + synthesizers=['HMASynthesizer'], + custom_synthesizers=['CustomSynth'], + sdv_datasets=['dataset1'], + additional_datasets_folder='extra', + limit_dataset_size=True, + compute_quality_score=True, + compute_diagnostic_score=True, + timeout=10, + output_destination='output_dir', + show_progress=True, + ) + + # Assert + mock__validate_output_destination.assert_called_once_with('output_dir') + mock_LocalResultsWriter.assert_called_once_with() + mock__generate_job_args_list.assert_called_once_with( + limit_dataset_size=True, + sdv_datasets=['dataset1'], + additional_datasets_folder='extra', + sdmetrics=None, + detailed_results_folder=None, + timeout=10, + output_destination='output_dir', + compute_quality_score=True, + compute_diagnostic_score=True, + compute_privacy_score=None, + synthesizers=expected_valid_synthesizers, + s3_client=None, + modality='multi_table', + ) + mock__write_metainfo_file.assert_called_once() + mock__run_jobs.assert_called_once_with( + multi_processing_config=None, + job_args_list=[job_args], + show_progress=True, + result_writer=mock_LocalResultsWriter.return_value, + ) + mock__update_metainfo_file.assert_called_once_with( + 'meta.yaml', + mock_LocalResultsWriter.return_value, + ) + pd.testing.assert_frame_equal(scores, fake_scores) + mock__import_and_validate_synthesizers.assert_called_once_with( + ['HMASynthesizer', 'MultiTableUniformSynthesizer'], ['CustomSynth'], 'multi_table' + ) + + +@patch('sdgym.benchmark._import_and_validate_synthesizers') +@patch('sdgym.benchmark._write_metainfo_file') +@patch('sdgym.benchmark._validate_output_destination') +def test_benchmark_multi_table_no_jobs( + mock__validate_output_destination, + mock__write_metainfo_file, + mock__import_and_validate_synthesizers, +): + """Test that benchmark_multi_table returns empty dataframe when there are no job args.""" + # Setup + empty_scores = pd.DataFrame({ + 'Synthesizer': [], + 'Dataset': [], + 'Dataset_Size_MB': [], + 'Train_Time': [], + 'Peak_Memory_MB': [], + 'Synthesizer_Size_MB': [], + 'Sample_Time': [], + 'Evaluate_Time': [], + 'Adjusted_Total_Time': [], + 'Diagnostic_Score': [], + }) + + # Run + scores = benchmark_multi_table( + synthesizers=[], + custom_synthesizers=None, + sdv_datasets=None, + additional_datasets_folder=None, + limit_dataset_size=False, + compute_quality_score=False, + compute_diagnostic_score=True, + timeout=None, + output_destination=None, + show_progress=False, + ) + + # Assert + mock__validate_output_destination.assert_called_once_with(None) + mock__write_metainfo_file.assert_called_once() + pd.testing.assert_frame_equal(scores, empty_scores) diff --git a/tests/unit/test_dataset_explorer.py b/tests/unit/test_dataset_explorer.py index f788db9c..86288f76 100644 --- a/tests/unit/test_dataset_explorer.py +++ b/tests/unit/test_dataset_explorer.py @@ -169,13 +169,13 @@ def test_get_data_summary_with_multiple_tables(self): assert result['Total_Num_Rows'] == 5 assert result['Max_Num_Rows_Per_Table'] == 3 - def test__validate_output_filepath_valid(self): + def test__validate_output_filepath_valid(self, tmp_path): """Test the ``_validate_output_filepath`` method with valid CSV path.""" # Setup explorer = DatasetExplorer() # Run and Assert - explorer._validate_output_filepath('output.csv') + explorer._validate_output_filepath(tmp_path / 'output.csv') def test__validate_output_filepath_invalid(self): """Test the ``_validate_output_filepath`` method with invalid file path.""" diff --git a/tests/unit/test_result_writer.py b/tests/unit/test_result_writer.py index 1d4b7be8..b2ded100 100644 --- a/tests/unit/test_result_writer.py +++ b/tests/unit/test_result_writer.py @@ -1,3 +1,4 @@ +import zipfile from unittest.mock import Mock, patch import cloudpickle @@ -100,6 +101,37 @@ def test_write_yaml_append(self, tmp_path): expected_data = {**data1, **data2} assert loaded_data == expected_data + def test_write_zipped_dataframes(self, tmp_path): + """Test the `write_zipped_dataframes` method.""" + # Setup + base_path = tmp_path / 'sdgym_results' + base_path.mkdir(parents=True, exist_ok=True) + result_writer = LocalResultsWriter() + file_path = base_path / 'data.zip' + + data = { + 'table1': pd.DataFrame({'a': [1, 2], 'b': [3, 4]}), + 'table2': pd.DataFrame({'x': [5, 6], 'y': [7, 8]}), + } + + # Run + result_writer.write_zipped_dataframes(data, file_path) + + # Assert + assert file_path.exists() + + with zipfile.ZipFile(file_path, 'r') as zf: + # Check that all tables are present + names = zf.namelist() + assert 'table1.csv' in names + assert 'table2.csv' in names + + # Check each table content matches the original + for table_name, df in data.items(): + with zf.open(f'{table_name}.csv') as f: + loaded_df = pd.read_csv(f) + pd.testing.assert_frame_equal(df, loaded_df) + class TestS3ResultsWriter: def test__init__(self):