Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,7 @@ def benchmark_multi_table(
output_destination=None,
show_progress=False,
):
"""Run the SDGym benchmark on single-table datasets.
"""Run the SDGym benchmark on multi-table datasets.

Args:
synthesizers (list[string]):
Expand All @@ -1844,8 +1844,8 @@ def benchmark_multi_table(
or ``create_synthesizer_variant``). Defaults to ``None``.
sdv_datasets (list[str] or ``None``):
Names of the SDV demo datasets to use for the benchmark. Defaults to
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
covtype]``. Use ``None`` to disable using any sdv datasets.
``[NBA, financial, Student_loan, Biodegradability, fake_hotels, restbase,
airbnb-simplified]``. Use ``None`` to disable using any sdv datasets.
additional_datasets_folder (str or ``None``):
The path to a folder (local or an S3 bucket). Datasets found in this folder are
run in addition to the SDV datasets. If ``None``, no additional datasets are used.
Expand Down
42 changes: 36 additions & 6 deletions sdgym/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,40 @@ def _get_bucket_name(bucket):
return bucket[len(S3_PREFIX) :] if bucket.startswith(S3_PREFIX) else bucket


def _raise_dataset_not_found_error(
s3_client,
bucket_name,
dataset_name,
current_modality,
bucket,
modality,
):
display_name = dataset_name
if isinstance(dataset_name, Path):
display_name = dataset_name.name

available_modalities = []
for other_modality in MODALITIES:
if other_modality == current_modality:
continue

other_prefix = f'{other_modality.lower()}/{dataset_name}/'
other_contents = _list_s3_bucket_contents(s3_client, bucket_name, other_prefix)
if other_contents:
available_modalities.append(other_modality)

if available_modalities:
available_list = ', '.join(sorted(available_modalities))
raise ValueError(
f"Dataset '{display_name}' not found in bucket '{bucket}' "
f"for modality '{modality}'. It is available under modality: '{available_list}'."
)
else:
raise ValueError(
f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'."
)


def _download_dataset(
modality,
dataset_name,
Expand All @@ -53,12 +87,8 @@ def _download_dataset(

contents = _list_s3_bucket_contents(s3_client, bucket_name, prefix)
if not contents:
display_name = dataset_name
if isinstance(dataset_name, Path):
display_name = dataset_name.name

raise ValueError(
f"Dataset '{display_name}' not found in bucket '{bucket}' for modality '{modality}'."
_raise_dataset_not_found_error(
s3_client, bucket_name, dataset_name, modality, bucket, modality
)

for obj in contents:
Expand Down
69 changes: 51 additions & 18 deletions sdgym/result_explorer/result_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import os

from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS
from sdgym.datasets import load_dataset
from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
from sdgym.result_explorer.result_handler import (
SYNTHESIZER_BASELINE,
LocalResultsHandler,
S3ResultsHandler,
)
from sdgym.s3 import _get_s3_client, is_s3_path
from sdgym.synthesizers.base import _validate_modality


def _validate_local_path(path):
Expand All @@ -14,20 +18,51 @@ def _validate_local_path(path):
raise ValueError(f"The provided path '{path}' is not a valid local directory.")


_BASELINE_BY_MODALITY = {
'single_table': SYNTHESIZER_BASELINE,
'multi_table': 'MultiTableUniformSynthesizer',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fealho can you ask in the Team-Engineering channel which synthesizer we should use as the baseline for multi-table here?

Just to make sure everything is as expected. With this implementation it will be easy to update it in the future 👍

}


def _resolve_effective_path(path, modality):
"""Append the modality folder to the given base path if provided."""
# Avoid double-appending if already included
if str(path).rstrip('/').endswith(('/' + modality, modality)):
return path

if is_s3_path(path):
return path.rstrip('/') + '/' + modality

return os.path.join(path, modality)


class ResultsExplorer:
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""

def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
def _create_results_handler(self, original_path, effective_path):
"""Create the appropriate results handler for local or S3 storage."""
baseline_synthesizer = _BASELINE_BY_MODALITY.get(self.modality, SYNTHESIZER_BASELINE)
if is_s3_path(original_path):
s3_client = _get_s3_client(
original_path, self.aws_access_key_id, self.aws_secret_access_key
)
return S3ResultsHandler(
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
)

_validate_local_path(effective_path)
return LocalResultsHandler(effective_path, baseline_synthesizer=baseline_synthesizer)

def __init__(
self, path, modality='single_table', aws_access_key_id=None, aws_secret_access_key=None
):
self.path = path
_validate_modality(modality)
self.modality = modality.lower()
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key

if is_s3_path(path):
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
self._handler = S3ResultsHandler(path, s3_client)
else:
_validate_local_path(path)
self._handler = LocalResultsHandler(path)
effective_path = _resolve_effective_path(path, self.modality)
self._handler = self._create_results_handler(path, effective_path)

def list(self):
"""List all runs available in the results directory."""
Expand All @@ -37,7 +72,11 @@ def _get_file_path(self, results_folder_name, dataset_name, synthesizer_name, fi
"""Validate access to the synthesizer or synthetic data file."""
end_filename = f'{synthesizer_name}'
if file_type == 'synthetic_data':
end_filename += '_synthetic_data.csv'
# Multi-table synthetic data is zipped (multiple CSVs), single table is CSV
if self.modality == 'multi_table':
end_filename += '_synthetic_data.zip'
else:
end_filename += '_synthetic_data.csv'
elif file_type == 'synthesizer':
end_filename += '.pkl'

Expand All @@ -62,14 +101,8 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam

def load_real_data(self, dataset_name):
"""Load the real data for a given dataset."""
if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS:
raise ValueError(
f"Dataset '{dataset_name}' is not a SDGym dataset. "
'Please provide a valid dataset name.'
)

data, _ = load_dataset(
modality='single_table',
modality=self.modality,
dataset=dataset_name,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
Expand Down
47 changes: 32 additions & 15 deletions sdgym/result_explorer/result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import yaml
from botocore.exceptions import ClientError

from sdgym._dataset_utils import _read_zipped_data

SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
RESULTS_FOLDER_PREFIX = 'SDGym_results_'
metainfo_PREFIX = 'metainfo'
Expand All @@ -22,6 +24,9 @@
class ResultsHandler(ABC):
"""Abstract base class for handling results storage and retrieval."""

def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE):
self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE

@abstractmethod
def list(self):
"""List all runs in the results directory."""
Expand Down Expand Up @@ -59,7 +64,8 @@ def _compute_wins(self, result):
result['Win'] = 0
for dataset in datasets:
score_baseline = result.loc[
(result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset)
(result['Synthesizer'] == self.baseline_synthesizer)
& (result['Dataset'] == dataset)
]['Quality_Score'].to_numpy()
if score_baseline.size == 0:
continue
Expand All @@ -84,7 +90,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
f' - # datasets: {folder_infos[folder]["# datasets"]}'
f' - sdgym version: {folder_infos[folder]["sdgym_version"]}'
)
results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE]
results = results.loc[results['Synthesizer'] != self.baseline_synthesizer]
column_data = results.groupby(['Synthesizer'])['Win'].sum()
columns.append((date_obj, column_name, column_data))

Expand All @@ -107,9 +113,11 @@ def _get_column_name_infos(self, folder_to_results):
continue

metainfo_info = self._load_yaml_file(folder, yaml_files[0])
num_datasets = results.loc[
results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset'
].nunique()
baseline_mask = results['Synthesizer'] == self.baseline_synthesizer
if baseline_mask.any():
num_datasets = results.loc[baseline_mask, 'Dataset'].nunique()
else:
num_datasets = results['Dataset'].nunique()
folder_to_info[folder] = {
'date': metainfo_info.get('starting_date')[:NUM_DIGITS_DATE],
'sdgym_version': metainfo_info.get('sdgym_version'),
Expand Down Expand Up @@ -236,7 +244,8 @@ def all_runs_complete(self, folder_name):
class LocalResultsHandler(ResultsHandler):
"""Results handler for local filesystem."""

def __init__(self, base_path):
def __init__(self, base_path, baseline_synthesizer=SYNTHESIZER_BASELINE):
super().__init__(baseline_synthesizer=baseline_synthesizer)
self.base_path = base_path

def list(self):
Expand All @@ -262,8 +271,12 @@ def load_synthesizer(self, file_path):
return cloudpickle.load(f)

def load_synthetic_data(self, file_path):
"""Load synthetic data from a CSV file."""
return pd.read_csv(os.path.join(self.base_path, file_path))
"""Load synthetic data from a CSV or ZIP file."""
full_path = os.path.join(self.base_path, file_path)
if full_path.endswith('.zip'):
return _read_zipped_data(full_path, modality='multi_table')

return pd.read_csv(full_path)

def _get_results_files(self, folder_name, prefix, suffix):
return [
Expand All @@ -287,7 +300,8 @@ def _load_yaml_file(self, folder_name, file_name):
class S3ResultsHandler(ResultsHandler):
"""Results handler for AWS S3 storage."""

def __init__(self, path, s3_client):
def __init__(self, path, s3_client, baseline_synthesizer=SYNTHESIZER_BASELINE):
super().__init__(baseline_synthesizer=baseline_synthesizer)
self.s3_client = s3_client
self.bucket_name = path.split('/')[2]
self.prefix = '/'.join(path.split('/')[3:]).rstrip('/') + '/'
Expand Down Expand Up @@ -374,10 +388,13 @@ def load_synthesizer(self, file_path):

def load_synthetic_data(self, file_path):
"""Load synthetic data from S3."""
response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
)
return pd.read_csv(io.BytesIO(response['Body'].read()))
key = f'{self.prefix}{file_path}'
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
body = response['Body'].read()
if file_path.endswith('.zip'):
return _read_zipped_data(io.BytesIO(body), modality='multi_table')

return pd.read_csv(io.BytesIO(body))

def _get_results_files(self, folder_name, prefix, suffix):
s3_prefix = f'{self.prefix}{folder_name}/'
Expand All @@ -396,8 +413,8 @@ def _get_results(self, folder_name, file_names):
for file_name in file_names:
s3_key = f'{self.prefix}{folder_name}/{file_name}'
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
df = pd.read_csv(io.BytesIO(response['Body'].read()))
results.append(df)
result_df = pd.read_csv(io.BytesIO(response['Body'].read()))
results.append(result_df)

return results

Expand Down
1 change: 1 addition & 0 deletions sdgym/run_benchmark/upload_benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def upload_results(
run_date = folder_infos['date']
result_explorer = ResultsExplorer(
OUTPUT_DESTINATION_AWS,
modality='single_table',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
completed_date: 12_02_2025 08:28:47
jobs:
- - fake_hotels
- MultiTableUniformSynthesizer
- - fake_hotels
- HMASynthesizer
modality: multi_table
run_id: run_12_02_2025_0
sdgym_version: 0.11.2.dev0
sdv_version: 1.28.0
starting_date: 12_02_2025 08:28:21
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score,Adjusted_Total_Time,Adjusted_Quality_Score
MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595,0.430058,0.5962941240006595
HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336,25.776825000000002,0.7353482911012336
Loading