diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index cae221f8..42f17502 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -1832,7 +1832,7 @@ def benchmark_multi_table( output_destination=None, show_progress=False, ): - """Run the SDGym benchmark on single-table datasets. + """Run the SDGym benchmark on multi-table datasets. Args: synthesizers (list[string]): @@ -1844,8 +1844,8 @@ def benchmark_multi_table( or ``create_synthesizer_variant``). Defaults to ``None``. sdv_datasets (list[str] or ``None``): Names of the SDV demo datasets to use for the benchmark. Defaults to - ``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news, - covtype]``. Use ``None`` to disable using any sdv datasets. + ``[NBA, financial, Student_loan, Biodegradability, fake_hotels, restbase, + airbnb-simplified]``. Use ``None`` to disable using any sdv datasets. additional_datasets_folder (str or ``None``): The path to a folder (local or an S3 bucket). Datasets found in this folder are run in addition to the SDV datasets. If ``None``, no additional datasets are used. diff --git a/sdgym/datasets.py b/sdgym/datasets.py index 93ebea9b..87804191 100644 --- a/sdgym/datasets.py +++ b/sdgym/datasets.py @@ -34,6 +34,40 @@ def _get_bucket_name(bucket): return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket +def _raise_dataset_not_found_error( + s3_client, + bucket_name, + dataset_name, + current_modality, + bucket, + modality, +): + display_name = dataset_name + if isinstance(dataset_name, Path): + display_name = dataset_name.name + + available_modalities = [] + for other_modality in MODALITIES: + if other_modality == current_modality: + continue + + other_prefix = f'{other_modality.lower()}/{dataset_name}/' + other_contents = _list_s3_bucket_contents(s3_client, bucket_name, other_prefix) + if other_contents: + available_modalities.append(other_modality) + + if available_modalities: + available_list = ', '.join(sorted(available_modalities)) + raise ValueError( + f"Dataset '{display_name}' not found in bucket '{bucket}' " + f"for modality '{modality}'. It is available under modality: '{available_list}'." + ) + else: + raise ValueError( + f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'." + ) + + def _download_dataset( modality, dataset_name, @@ -53,12 +87,8 @@ def _download_dataset( contents = _list_s3_bucket_contents(s3_client, bucket_name, prefix) if not contents: - display_name = dataset_name - if isinstance(dataset_name, Path): - display_name = dataset_name.name - - raise ValueError( - f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'." + _raise_dataset_not_found_error( + s3_client, bucket_name, dataset_name, modality, bucket, modality ) for obj in contents: diff --git a/sdgym/result_explorer/result_explorer.py b/sdgym/result_explorer/result_explorer.py index d3fdc7b7..46f3f46c 100644 --- a/sdgym/result_explorer/result_explorer.py +++ b/sdgym/result_explorer/result_explorer.py @@ -2,10 +2,14 @@ import os -from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS from sdgym.datasets import load_dataset -from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler +from sdgym.result_explorer.result_handler import ( + SYNTHESIZER_BASELINE, + LocalResultsHandler, + S3ResultsHandler, +) from sdgym.s3 import _get_s3_client, is_s3_path +from sdgym.synthesizers.base import _validate_modality def _validate_local_path(path): @@ -14,20 +18,51 @@ def _validate_local_path(path): raise ValueError(f"The provided path '{path}' is not a valid local directory.") +_BASELINE_BY_MODALITY = { + 'single_table': SYNTHESIZER_BASELINE, + 'multi_table': 'MultiTableUniformSynthesizer', +} + + +def _resolve_effective_path(path, modality): + """Append the modality folder to the given base path if provided.""" + # Avoid double-appending if already included + if str(path).rstrip('/').endswith(('/' + modality, modality)): + return path + + if is_s3_path(path): + return path.rstrip('/') + '/' + modality + + return os.path.join(path, modality) + + class ResultsExplorer: """Explorer for SDGym benchmark results, supporting both local and S3 storage.""" - def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None): + def _create_results_handler(self, original_path, effective_path): + """Create the appropriate results handler for local or S3 storage.""" + baseline_synthesizer = _BASELINE_BY_MODALITY.get(self.modality, SYNTHESIZER_BASELINE) + if is_s3_path(original_path): + s3_client = _get_s3_client( + original_path, self.aws_access_key_id, self.aws_secret_access_key + ) + return S3ResultsHandler( + effective_path, s3_client, baseline_synthesizer=baseline_synthesizer + ) + + _validate_local_path(effective_path) + return LocalResultsHandler(effective_path, baseline_synthesizer=baseline_synthesizer) + + def __init__( + self, path, modality='single_table', aws_access_key_id=None, aws_secret_access_key=None + ): self.path = path + _validate_modality(modality) + self.modality = modality.lower() self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key - - if is_s3_path(path): - s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key) - self._handler = S3ResultsHandler(path, s3_client) - else: - _validate_local_path(path) - self._handler = LocalResultsHandler(path) + effective_path = _resolve_effective_path(path, self.modality) + self._handler = self._create_results_handler(path, effective_path) def list(self): """List all runs available in the results directory.""" @@ -37,7 +72,11 @@ def _get_file_path(self, results_folder_name, dataset_name, synthesizer_name, fi """Validate access to the synthesizer or synthetic data file.""" end_filename = f'{synthesizer_name}' if file_type == 'synthetic_data': - end_filename += '_synthetic_data.csv' + # Multi-table synthetic data is zipped (multiple CSVs), single table is CSV + if self.modality == 'multi_table': + end_filename += '_synthetic_data.zip' + else: + end_filename += '_synthetic_data.csv' elif file_type == 'synthesizer': end_filename += '.pkl' @@ -62,14 +101,8 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam def load_real_data(self, dataset_name): """Load the real data for a given dataset.""" - if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS: - raise ValueError( - f"Dataset '{dataset_name}' is not a SDGym dataset. " - 'Please provide a valid dataset name.' - ) - data, _ = load_dataset( - modality='single_table', + modality=self.modality, dataset=dataset_name, aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, diff --git a/sdgym/result_explorer/result_handler.py b/sdgym/result_explorer/result_handler.py index 72f48da0..ca1c4007 100644 --- a/sdgym/result_explorer/result_handler.py +++ b/sdgym/result_explorer/result_handler.py @@ -11,6 +11,8 @@ import yaml from botocore.exceptions import ClientError +from sdgym._dataset_utils import _read_zipped_data + SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer' RESULTS_FOLDER_PREFIX = 'SDGym_results_' metainfo_PREFIX = 'metainfo' @@ -22,6 +24,9 @@ class ResultsHandler(ABC): """Abstract base class for handling results storage and retrieval.""" + def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE): + self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE + @abstractmethod def list(self): """List all runs in the results directory.""" @@ -59,7 +64,8 @@ def _compute_wins(self, result): result['Win'] = 0 for dataset in datasets: score_baseline = result.loc[ - (result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset) + (result['Synthesizer'] == self.baseline_synthesizer) + & (result['Dataset'] == dataset) ]['Quality_Score'].to_numpy() if score_baseline.size == 0: continue @@ -84,7 +90,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos): f' - # datasets: {folder_infos[folder]["# datasets"]}' f' - sdgym version: {folder_infos[folder]["sdgym_version"]}' ) - results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE] + results = results.loc[results['Synthesizer'] != self.baseline_synthesizer] column_data = results.groupby(['Synthesizer'])['Win'].sum() columns.append((date_obj, column_name, column_data)) @@ -107,9 +113,11 @@ def _get_column_name_infos(self, folder_to_results): continue metainfo_info = self._load_yaml_file(folder, yaml_files[0]) - num_datasets = results.loc[ - results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset' - ].nunique() + baseline_mask = results['Synthesizer'] == self.baseline_synthesizer + if baseline_mask.any(): + num_datasets = results.loc[baseline_mask, 'Dataset'].nunique() + else: + num_datasets = results['Dataset'].nunique() folder_to_info[folder] = { 'date': metainfo_info.get('starting_date')[:NUM_DIGITS_DATE], 'sdgym_version': metainfo_info.get('sdgym_version'), @@ -236,7 +244,8 @@ def all_runs_complete(self, folder_name): class LocalResultsHandler(ResultsHandler): """Results handler for local filesystem.""" - def __init__(self, base_path): + def __init__(self, base_path, baseline_synthesizer=SYNTHESIZER_BASELINE): + super().__init__(baseline_synthesizer=baseline_synthesizer) self.base_path = base_path def list(self): @@ -262,8 +271,12 @@ def load_synthesizer(self, file_path): return cloudpickle.load(f) def load_synthetic_data(self, file_path): - """Load synthetic data from a CSV file.""" - return pd.read_csv(os.path.join(self.base_path, file_path)) + """Load synthetic data from a CSV or ZIP file.""" + full_path = os.path.join(self.base_path, file_path) + if full_path.endswith('.zip'): + return _read_zipped_data(full_path, modality='multi_table') + + return pd.read_csv(full_path) def _get_results_files(self, folder_name, prefix, suffix): return [ @@ -287,7 +300,8 @@ def _load_yaml_file(self, folder_name, file_name): class S3ResultsHandler(ResultsHandler): """Results handler for AWS S3 storage.""" - def __init__(self, path, s3_client): + def __init__(self, path, s3_client, baseline_synthesizer=SYNTHESIZER_BASELINE): + super().__init__(baseline_synthesizer=baseline_synthesizer) self.s3_client = s3_client self.bucket_name = path.split('/')[2] self.prefix = '/'.join(path.split('/')[3:]).rstrip('/') + '/' @@ -374,10 +388,13 @@ def load_synthesizer(self, file_path): def load_synthetic_data(self, file_path): """Load synthetic data from S3.""" - response = self.s3_client.get_object( - Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}' - ) - return pd.read_csv(io.BytesIO(response['Body'].read())) + key = f'{self.prefix}{file_path}' + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) + body = response['Body'].read() + if file_path.endswith('.zip'): + return _read_zipped_data(io.BytesIO(body), modality='multi_table') + + return pd.read_csv(io.BytesIO(body)) def _get_results_files(self, folder_name, prefix, suffix): s3_prefix = f'{self.prefix}{folder_name}/' @@ -396,8 +413,8 @@ def _get_results(self, folder_name, file_names): for file_name in file_names: s3_key = f'{self.prefix}{folder_name}/{file_name}' response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key) - df = pd.read_csv(io.BytesIO(response['Body'].read())) - results.append(df) + result_df = pd.read_csv(io.BytesIO(response['Body'].read())) + results.append(result_df) return results diff --git a/sdgym/run_benchmark/upload_benchmark_results.py b/sdgym/run_benchmark/upload_benchmark_results.py index 8d1c1ea6..29d29343 100644 --- a/sdgym/run_benchmark/upload_benchmark_results.py +++ b/sdgym/run_benchmark/upload_benchmark_results.py @@ -116,6 +116,7 @@ def upload_results( run_date = folder_infos['date'] result_explorer = ResultsExplorer( OUTPUT_DESTINATION_AWS, + modality='single_table', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer.pkl b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer.pkl new file mode 100644 index 00000000..41be87a9 Binary files /dev/null and b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer.pkl differ diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_benchmark_result.csv b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_benchmark_result.csv new file mode 100644 index 00000000..755a61fe --- /dev/null +++ b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_benchmark_result.csv @@ -0,0 +1,2 @@ +Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score +HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336 diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_synthetic_data.zip b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_synthetic_data.zip new file mode 100644 index 00000000..780d3538 Binary files /dev/null and b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/HMASynthesizer/HMASynthesizer_synthetic_data.zip differ diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer.pkl b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer.pkl new file mode 100644 index 00000000..7ad1d477 Binary files /dev/null and b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer.pkl differ diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_benchmark_result.csv b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_benchmark_result.csv new file mode 100644 index 00000000..47479e65 --- /dev/null +++ b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_benchmark_result.csv @@ -0,0 +1,2 @@ +Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score +MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595 diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_synthetic_data.zip b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_synthetic_data.zip new file mode 100644 index 00000000..e438845d Binary files /dev/null and b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/fake_hotels_12_02_2025/MultiTableUniformSynthesizer/MultiTableUniformSynthesizer_synthetic_data.zip differ diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/metainfo.yaml b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/metainfo.yaml new file mode 100644 index 00000000..742ae309 --- /dev/null +++ b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/metainfo.yaml @@ -0,0 +1,11 @@ +completed_date: 12_02_2025 08:28:47 +jobs: +- - fake_hotels + - MultiTableUniformSynthesizer +- - fake_hotels + - HMASynthesizer +modality: multi_table +run_id: run_12_02_2025_0 +sdgym_version: 0.11.2.dev0 +sdv_version: 1.28.0 +starting_date: 12_02_2025 08:28:21 diff --git a/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/results.csv b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/results.csv new file mode 100644 index 00000000..b1308f82 --- /dev/null +++ b/tests/integration/result_explorer/_benchmark_results/multi_table/SDGym_results_12_02_2025/results.csv @@ -0,0 +1,3 @@ +Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score,Adjusted_Total_Time,Adjusted_Quality_Score +MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595,0.430058,0.5962941240006595 +HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336,25.776825000000002,0.7353482911012336 diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_04_05_2024/metainfo.yaml b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_04_05_2024/metainfo.yaml similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_04_05_2024/metainfo.yaml rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_04_05_2024/metainfo.yaml diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_04_05_2024/results.csv b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_04_05_2024/results.csv similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_04_05_2024/results.csv rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_04_05_2024/results.csv diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_05_10_2024/metainfo.yaml b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_05_10_2024/metainfo.yaml similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_05_10_2024/metainfo.yaml rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_05_10_2024/metainfo.yaml diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_05_10_2024/results.csv b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_05_10_2024/results.csv similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_05_10_2024/results.csv rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_05_10_2024/results.csv diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_10_11_2024/metainfo.yaml b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_10_11_2024/metainfo.yaml similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_10_11_2024/metainfo.yaml rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_10_11_2024/metainfo.yaml diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_10_11_2024/results.csv b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_10_11_2024/results.csv similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_10_11_2024/results.csv rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_10_11_2024/results.csv diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_12_17_2024/metainfo.yaml b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_12_17_2024/metainfo.yaml similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_12_17_2024/metainfo.yaml rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_12_17_2024/metainfo.yaml diff --git a/tests/integration/result_explorer/_benchmark_results/SDGym_results_12_17_2024/results.csv b/tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_12_17_2024/results.csv similarity index 100% rename from tests/integration/result_explorer/_benchmark_results/SDGym_results_12_17_2024/results.csv rename to tests/integration/result_explorer/_benchmark_results/single_table/SDGym_results_12_17_2024/results.csv diff --git a/tests/integration/result_explorer/test_result_explorer.py b/tests/integration/result_explorer/test_result_explorer.py index 4a9e3493..609a7329 100644 --- a/tests/integration/result_explorer/test_result_explorer.py +++ b/tests/integration/result_explorer/test_result_explorer.py @@ -1,3 +1,4 @@ +import shutil import time import pandas as pd @@ -20,7 +21,7 @@ def test_end_to_end_local(tmp_path): today = time.strftime('%m_%d_%Y') # Run - result_explorer = ResultsExplorer(str(result_explorer_path)) + result_explorer = ResultsExplorer(str(result_explorer_path), modality='single_table') runs = result_explorer.list() results = result_explorer.load_results(runs[0]) metainfo = result_explorer.load_metainfo(runs[0]) @@ -65,7 +66,7 @@ def test_summarize(): """Test the `summarize` method.""" # Setup output_destination = 'tests/integration/result_explorer/_benchmark_results/' - result_explorer = ResultsExplorer(output_destination) + result_explorer = ResultsExplorer(output_destination, modality='single_table') # Run summary, results = result_explorer.summarize('SDGym_results_10_11_2024') @@ -79,7 +80,7 @@ def test_summarize(): }) expected_results = ( pd.read_csv( - 'tests/integration/result_explorer/_benchmark_results/' + 'tests/integration/result_explorer/_benchmark_results/single_table/' 'SDGym_results_10_11_2024/results.csv', ) .sort_values(by=['Dataset', 'Synthesizer']) @@ -88,3 +89,62 @@ def test_summarize(): expected_results['Win'] = expected_results['Win'].astype('int64') pd.testing.assert_frame_equal(summary, expected_summary) pd.testing.assert_frame_equal(results, expected_results) + + +def test_summarize_multi_table(): + """Test summarize works under the multi_table subfolder.""" + # Setup + output_destination = 'tests/integration/result_explorer/_benchmark_results/' + result_explorer = ResultsExplorer(output_destination, modality='multi_table') + + # Run + summary, results = result_explorer.summarize('SDGym_results_12_02_2025') + + # Assert + expected_summary = pd.DataFrame({ + 'Synthesizer': ['HMASynthesizer'], + '12_02_2025 - # datasets: 1 - sdgym version: 0.11.2.dev0': [1], + }) + expected_results = ( + pd.read_csv( + 'tests/integration/result_explorer/_benchmark_results/multi_table/' + 'SDGym_results_12_02_2025/results.csv', + ) + .sort_values(by=['Dataset', 'Synthesizer']) + .reset_index(drop=True) + ) + expected_results['Win'] = ( + expected_results['Synthesizer'] != 'MultiTableUniformSynthesizer' + ).astype('int64') + pd.testing.assert_frame_equal(summary, expected_summary) + pd.testing.assert_frame_equal(results, expected_results) + + +def test_list_and_load_results_multi_table(tmp_path): + """Test listing and loading results under multi_table subfolder.""" + # Setup + run_folder = 'SDGym_results_12_02_2025' + src_root = 'tests/integration/result_explorer/_benchmark_results/multi_table/' + run_folder + dst_root = tmp_path / 'benchmark_output' / 'multi_table' / run_folder + shutil.copytree(src_root, dst_root) + + explorer = ResultsExplorer(str(tmp_path / 'benchmark_output'), modality='multi_table') + + # Run + runs = explorer.list() + assert runs == [run_folder] + loaded_results = ( + explorer.load_results(runs[0]) + .sort_values(by=['Dataset', 'Synthesizer']) + .reset_index(drop=True) + ) + metainfo = explorer.load_metainfo(runs[0]) + + # Assert + expected_results = ( + pd.read_csv(dst_root / 'results.csv') + .sort_values(by=['Dataset', 'Synthesizer']) + .reset_index(drop=True) + ) + pd.testing.assert_frame_equal(loaded_results, expected_results) + assert isinstance(metainfo, dict) and len(metainfo) >= 1 diff --git a/tests/unit/result_explorer/test_result_explorer.py b/tests/unit/result_explorer/test_result_explorer.py index ded60ae6..e703248d 100644 --- a/tests/unit/result_explorer/test_result_explorer.py +++ b/tests/unit/result_explorer/test_result_explorer.py @@ -1,4 +1,6 @@ +import os import re +import shutil from unittest.mock import Mock, patch import pandas as pd @@ -36,11 +38,12 @@ def test__init__local(self, mock_validate_local_path, mock_is_s3_path): path = 'local_results_folder' # Run - result_explorer = ResultsExplorer(path) + result_explorer = ResultsExplorer(path, modality='single_table') # Assert - mock_validate_local_path.assert_called_once_with(path) - mock_is_s3_path.assert_called_once_with(path) + expected_path = os.path.join(path, 'single_table') + mock_validate_local_path.assert_called_once_with(expected_path) + mock_is_s3_path.assert_called_with(path) assert isinstance(result_explorer._handler, LocalResultsHandler) assert result_explorer.path == path assert result_explorer.aws_access_key_id is None @@ -59,24 +62,45 @@ def test__init__s3(self, mock_is_s3_path, mock_get_s3_client): mock_get_s3_client.return_value = s3_client # Run - result_explorer = ResultsExplorer(path, aws_access_key_id, aws_secret_access_key) + result_explorer = ResultsExplorer( + path, + modality='single_table', + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) # Assert - mock_is_s3_path.assert_called_once_with(path) + mock_is_s3_path.assert_called_with(path) mock_get_s3_client.assert_called_once_with(path, aws_access_key_id, aws_secret_access_key) assert result_explorer.path == path assert result_explorer.aws_access_key_id == aws_access_key_id assert result_explorer.aws_secret_access_key == aws_secret_access_key assert isinstance(result_explorer._handler, S3ResultsHandler) + def test_list_with_modality_local(self, tmp_path): + """Test the `list` method respects the modality subfolder (local).""" + # Setup + base = tmp_path / 'results' + (base / 'unscoped_run').mkdir(parents=True) + (base / 'multi_table' / 'run_mt1').mkdir(parents=True) + (base / 'multi_table' / 'run_mt2').mkdir(parents=True) + + result_explorer = ResultsExplorer(str(base), modality='multi_table') + + # Run + runs = result_explorer.list() + + # Assert + assert set(runs) == {'run_mt1', 'run_mt2'} + def test_list_local(self, tmp_path): """Test the `list` method with a local path""" # Setup - path = tmp_path / 'results' - path.mkdir() + path = tmp_path / 'results' / 'single_table' + path.mkdir(parents=True) (path / 'run1').mkdir() (path / 'run2').mkdir() - result_explorer = ResultsExplorer(str(path)) + result_explorer = ResultsExplorer(str(path), modality='single_table') # Run runs = result_explorer.list() @@ -87,11 +111,11 @@ def test_list_local(self, tmp_path): def test_list(self, tmp_path): """Test the `list` method with an S3 path""" # Setup - path = tmp_path / 'results' - path.mkdir() + path = tmp_path / 'results' / 'single_table' + path.mkdir(parents=True) (path / 'run1').mkdir() (path / 'run2').mkdir() - result_explorer = ResultsExplorer(str(path)) + result_explorer = ResultsExplorer(str(path), modality='single_table') result_explorer._handler = Mock() result_explorer._handler.list.return_value = ['run1', 'run2'] @@ -129,10 +153,34 @@ def test__get_file_path(self): ) assert file_path == expected_filepath + def test__get_file_path_multi_table_synthetic_data(self, tmp_path): + """Test `_get_file_path` returns .zip for multi_table synthetic data.""" + base = tmp_path / 'results' + multi_table_dir = base / 'multi_table' + multi_table_dir.mkdir(parents=True, exist_ok=True) + explorer = ResultsExplorer(str(multi_table_dir), modality='multi_table') + try: + explorer._handler = Mock() + explorer._handler.get_file_path.return_value = 'irrelevant' + explorer._get_file_path( + results_folder_name='results_folder_07_07_2025', + dataset_name='my_dataset', + synthesizer_name='my_synthesizer', + file_type='synthetic_data', + ) + explorer._handler.get_file_path.assert_called_once_with( + ['results_folder_07_07_2025', 'my_dataset_07_07_2025', 'my_synthesizer'], + 'my_synthesizer_synthetic_data.zip', + ) + finally: + shutil.rmtree(multi_table_dir) + def test_load_synthesizer(self, tmp_path): """Test `load_synthesizer` method.""" # Setup - explorer = ResultsExplorer(str(tmp_path)) + path = tmp_path / 'single_table' + path.mkdir(parents=True) + explorer = ResultsExplorer(str(path), modality='single_table') explorer._handler = Mock() explorer._handler.load_synthesizer = Mock( return_value=GaussianCopulaSynthesizer(Metadata()) @@ -155,7 +203,9 @@ def test_load_synthesizer(self, tmp_path): def test_load_synthetic_data(self, tmp_path): # Setup - explorer = ResultsExplorer(str(tmp_path)) + path = tmp_path / 'single_table' + path.mkdir(parents=True) + explorer = ResultsExplorer(str(path), modality='single_table') explorer._handler = Mock() data = pd.DataFrame({'column1': [1, 2], 'column2': [3, 4]}) explorer._handler.load_synthetic_data = Mock(return_value=data) @@ -182,7 +232,9 @@ def test_load_real_data(self, mock_load_dataset, tmp_path): dataset_name = 'adult' expected_data = pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}) mock_load_dataset.return_value = (expected_data, None) - result_explorer = ResultsExplorer(tmp_path) + path = tmp_path / 'single_table' + path.mkdir(parents=True) + result_explorer = ResultsExplorer(str(path), modality='single_table') # Run real_data = result_explorer.load_real_data(dataset_name) @@ -196,13 +248,41 @@ def test_load_real_data(self, mock_load_dataset, tmp_path): ) pd.testing.assert_frame_equal(real_data, expected_data) + @patch('sdgym.result_explorer.result_explorer.load_dataset') + def test_load_real_data_multi_table(self, mock_load_dataset, tmp_path): + """Test `load_real_data` for multi_table modality calls load_dataset correctly.""" + dataset_name = 'synthea' + expected_data = {'patients': pd.DataFrame({'id': [1]})} + mock_load_dataset.return_value = (expected_data, None) + multi_table_dir = tmp_path / 'multi_table' + multi_table_dir.mkdir(parents=True, exist_ok=True) + result_explorer = ResultsExplorer(tmp_path, modality='multi_table') + + try: + # Run + real_data = result_explorer.load_real_data(dataset_name) + + # Assert + mock_load_dataset.assert_called_once_with( + modality='multi_table', + dataset='synthea', + aws_access_key_id=None, + aws_secret_access_key=None, + ) + assert real_data == expected_data + finally: + shutil.rmtree(multi_table_dir) + def test_load_real_data_invalid_dataset(self, tmp_path): """Test `load_real_data` method with an invalid dataset.""" # Setup dataset_name = 'invalid_dataset' - result_explorer = ResultsExplorer(tmp_path) - expected_error_message = re.escape( - f"Dataset '{dataset_name}' is not a SDGym dataset. Please provide a valid dataset name." + path = tmp_path / 'single_table' + path.mkdir(parents=True) + result_explorer = ResultsExplorer(str(path), modality='single_table') + expected_error_message = ( + "Dataset 'invalid_dataset' not found in bucket 's3://sdv-datasets-public' " + "for modality 'single_table'." ) # Run and Assert @@ -212,9 +292,10 @@ def test_load_real_data_invalid_dataset(self, tmp_path): def test_summarize(self, tmp_path): """Test the `summarize` method.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') - (tmp_path / 'benchmark_output' / 'SDGym_results_07_07_2025').mkdir(parents=True) - result_explorer = ResultsExplorer(output_destination) + output_dir = tmp_path / 'benchmark_output' / 'single_table' + output_dir.mkdir(parents=True) + (output_dir / 'SDGym_results_07_07_2025').mkdir(parents=True) + result_explorer = ResultsExplorer(str(output_dir), modality='single_table') result_explorer._handler = Mock() results = pd.DataFrame({ 'Synthesizer': ['CTGANSynthesizer', 'CopulaGANSynthesizer', 'TVAESynthesizer'], @@ -235,9 +316,10 @@ def test_summarize(self, tmp_path): def test_load_results(self, tmp_path): """Test the `load_results` method.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') - (tmp_path / 'benchmark_output' / 'SDGym_results_07_07_2025').mkdir(parents=True) - result_explorer = ResultsExplorer(output_destination) + output_dir = tmp_path / 'benchmark_output' / 'single_table' + output_dir.mkdir(parents=True) + (output_dir / 'SDGym_results_07_07_2025').mkdir(parents=True) + result_explorer = ResultsExplorer(str(output_dir), modality='single_table') result_explorer._handler = Mock() results = pd.DataFrame({ 'Dataset': ['A', 'B'], @@ -256,9 +338,10 @@ def test_load_results(self, tmp_path): def test_load_metainfo(self, tmp_path): """Test the `load_metainfo` method.""" # Setup - output_destination = str(tmp_path / 'benchmark_output') - (tmp_path / 'benchmark_output' / 'SDGym_results_07_07_2025').mkdir(parents=True) - result_explorer = ResultsExplorer(output_destination) + output_dir = tmp_path / 'benchmark_output' / 'single_table' + output_dir.mkdir(parents=True) + (output_dir / 'SDGym_results_07_07_2025').mkdir(parents=True) + result_explorer = ResultsExplorer(str(output_dir), modality='single_table') result_explorer._handler = Mock() metainfo = {'synthesizer_versions': {'Synth1': '1.0.0', 'Synth2': '2.0.0'}} result_explorer._handler.load_metainfo = Mock(return_value=metainfo) diff --git a/tests/unit/result_explorer/test_result_handler.py b/tests/unit/result_explorer/test_result_handler.py index 2a7c195d..5d0b0590 100644 --- a/tests/unit/result_explorer/test_result_handler.py +++ b/tests/unit/result_explorer/test_result_handler.py @@ -1,3 +1,4 @@ +import io import os import pickle import re @@ -70,6 +71,7 @@ def test__get_summarize_table(self): } } handler = Mock() + handler.baseline_synthesizer = 'GaussianCopulaSynthesizer' # Run result = ResultsHandler._get_summarize_table(handler, folder_to_results, folder_infos) @@ -97,6 +99,7 @@ def test_get_column_name_infos(self): handler = Mock() handler._get_results_files = Mock(return_value=['run_config.yaml']) handler._load_yaml_file = Mock(return_value=yaml_content) + handler.baseline_synthesizer = 'GaussianCopulaSynthesizer' # Run info = ResultsHandler._get_column_name_infos(handler, folder_to_results) @@ -258,6 +261,24 @@ def test_load_metainfo(self): class TestLocalResultsHandler: """Unit tests for the LocalResultsHandler class.""" + def test__init__sets_base_path_and_default_baseline(self, tmp_path): + """Test it initializes base_path and default baseline.""" + # Run + handler = LocalResultsHandler(str(tmp_path)) + + # Assert + assert handler.base_path == str(tmp_path) + assert handler.baseline_synthesizer == 'GaussianCopulaSynthesizer' + + def test__init__supports_baseline_override(self, tmp_path): + """Test it allows overriding baseline synthesizer.""" + # Run + handler = LocalResultsHandler(str(tmp_path), baseline_synthesizer='CustomBaseline') + + # Assert + assert handler.base_path == str(tmp_path) + assert handler.baseline_synthesizer == 'CustomBaseline' + def test_list(self, tmp_path): """Test the `list` method""" # Setup @@ -324,6 +345,31 @@ def test_load_synthesizer(self, tmp_path): assert loaded_synthesizer is not None assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer) + def test_load_synthetic_data_zip(self, tmp_path): + """Test the `load_synthetic_data` method for zipped multi-table data (local).""" + # Setup + base = tmp_path / 'results' + data_dir = base / 'SDGym_results_07_07_2025' / 'dataset_07_07_2025' / 'Synth' + data_dir.mkdir(parents=True) + + # Create a zip with two csvs + import zipfile + + zip_path = data_dir / 'Synth_synthetic_data.zip' + with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('table1.csv', 'a,b\n1,2\n') + zf.writestr('table2.csv', 'x,y\n3,4\n') + + result_handler = LocalResultsHandler(str(base)) + + # Run + tables = result_handler.load_synthetic_data(str(zip_path)) + + # Assert + assert set(tables.keys()) == {'table1', 'table2'} + pd.testing.assert_frame_equal(tables['table1'], pd.DataFrame({'a': [1], 'b': [2]})) + pd.testing.assert_frame_equal(tables['table2'], pd.DataFrame({'x': [3], 'y': [4]})) + @patch('os.path.exists') @patch('os.path.isfile') def test_get_file_path_local(self, mock_isfile, mock_exists): @@ -390,9 +436,7 @@ def test_get_file_path_local_error(self, mock_isfile, mock_exists): class TestS3ResultsHandler: """Unit tests for the S3ResultsHandler class.""" - def test__init__( - self, - ): + def test__init__(self): """Test the `__init__` method.""" # Setup path = 's3://my-bucket/prefix' @@ -404,6 +448,21 @@ def test__init__( assert result_handler.s3_client == 's3_client' assert result_handler.bucket_name == 'my-bucket' assert result_handler.prefix == 'prefix/' + assert result_handler.baseline_synthesizer == 'GaussianCopulaSynthesizer' + + def test__init__supports_baseline_override(self): + """Test it allows overriding baseline synthesizer.""" + # Run + s3_client = Mock() + handler = S3ResultsHandler( + 's3://bkt/prefix', s3_client, baseline_synthesizer='CustomBaseline' + ) + + # Assert + assert handler.baseline_synthesizer == 'CustomBaseline' + assert handler.s3_client == s3_client + assert handler.bucket_name == 'bkt' + assert handler.prefix == 'prefix/' def test_list(self): """Test the `list` method.""" @@ -464,6 +523,34 @@ def test_load_synthesizer(self): Bucket='my-bucket', Key='prefix/synthesizer.pkl' ) + def test_load_synthetic_data_zip(self): + """Test the `load_synthetic_data` method for zipped multi-table data (S3).""" + # Setup + import zipfile + + buffer = io.BytesIO() + with zipfile.ZipFile(buffer, 'w', compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr('customers.csv', 'id,age\n1,30\n') + zf.writestr('transactions.csv', 'id,amount\n1,100\n') + buffer.seek(0) + + mock_s3_client = Mock() + mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: buffer.getvalue())} + result_handler = S3ResultsHandler('s3://my-bucket/prefix', mock_s3_client) + + # Run + tables = result_handler.load_synthetic_data('some/path.zip') + + # Assert + assert set(tables.keys()) == {'customers', 'transactions'} + pd.testing.assert_frame_equal(tables['customers'], pd.DataFrame({'id': [1], 'age': [30]})) + pd.testing.assert_frame_equal( + tables['transactions'], pd.DataFrame({'id': [1], 'amount': [100]}) + ) + mock_s3_client.get_object.assert_called_once_with( + Bucket='my-bucket', Key='prefix/some/path.zip' + ) + def test_get_file_path_s3(self): """Test `get_file_path` for S3 path when folders and file exist.""" # Setup diff --git a/tests/unit/run_benchmark/test_upload_benchmark_result.py b/tests/unit/run_benchmark/test_upload_benchmark_result.py index f9982968..3c776ebd 100644 --- a/tests/unit/run_benchmark/test_upload_benchmark_result.py +++ b/tests/unit/run_benchmark/test_upload_benchmark_result.py @@ -211,6 +211,7 @@ def test_upload_results( ) mock_sdgym_results_explorer.assert_called_once_with( mock_output_destination_aws, + modality='single_table', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) @@ -260,6 +261,7 @@ def test_upload_results_not_all_runs_complete( mock_logger.warning.assert_called_once_with(f'Run {run_name} is not complete yet. Exiting.') mock_sdgym_results_explorer.assert_called_once_with( mock_output_destination_aws, + modality='single_table', aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py index 331b4f7a..85fbd3ed 100644 --- a/tests/unit/test_datasets.py +++ b/tests/unit/test_datasets.py @@ -80,10 +80,13 @@ def test__download_dataset_not_found(list_mock, bucket_name_mock, s3_client_mock # Assert s3_client_mock.assert_called_once() - bucket_name_mock.assert_called_once_with(bucket) - list_mock.assert_called_once_with( - s3_client_mock.return_value, 'fake-bucket', 'single_table/missing_dataset/' - ) + bucket_name_mock.assert_called_with(bucket) + expected_calls = [ + call(s3_client_mock.return_value, 'fake-bucket', 'single_table/missing_dataset/'), + call(s3_client_mock.return_value, 'fake-bucket', 'sequential/missing_dataset/'), + call(s3_client_mock.return_value, 'fake-bucket', 'multi_table/missing_dataset/'), + ] + list_mock.assert_has_calls(expected_calls, any_order=True) @patch('sdgym.datasets.get_s3_client')