Skip to content

Commit 44ac3e3

Browse files
committed
Update handler
1 parent e14f7ff commit 44ac3e3

File tree

23 files changed

+198
-42
lines changed

23 files changed

+198
-42
lines changed

sdgym/benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from sdmetrics.single_table import DCRBaselineProtection
4040

41+
from sdgym import __version__ as SDGYM_VERSION
4142
from sdgym.datasets import get_dataset_paths, load_dataset
4243
from sdgym.errors import BenchmarkError, SDGymError
4344
from sdgym.metrics import get_metrics
@@ -1219,7 +1220,7 @@ def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=No
12191220
'modality': modality,
12201221
'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'),
12211222
'completed_date': None,
1222-
'sdgym_version': version('sdgym'),
1223+
'sdgym_version': SDGYM_VERSION,
12231224
'jobs': jobs,
12241225
}
12251226

sdgym/result_explorer/result_explorer.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS
66
from sdgym.datasets import load_dataset
7-
from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
7+
from sdgym.result_explorer.result_handler import (
8+
SYNTHESIZER_BASELINE,
9+
LocalResultsHandler,
10+
S3ResultsHandler,
11+
)
812
from sdgym.s3 import _get_s3_client, is_s3_path
913

1014

@@ -14,6 +18,42 @@ def _validate_local_path(path):
1418
raise ValueError(f"The provided path '{path}' is not a valid local directory.")
1519

1620

21+
_FOLDER_BY_MODALITY = {
22+
'single_table': 'single_table',
23+
'multi_table': 'multi_table',
24+
}
25+
26+
_BASELINE_BY_MODALITY = {
27+
'multi_table': 'MultiTableUniformSynthesizer',
28+
}
29+
30+
31+
def _get_baseline_synthesizer(modality):
32+
"""Return the appropriate baseline synthesizer for the given modality."""
33+
return _BASELINE_BY_MODALITY.get(modality, SYNTHESIZER_BASELINE)
34+
35+
36+
def _resolve_effective_path(path, modality):
37+
"""Append the modality folder to the given base path if provided."""
38+
if not modality:
39+
return path
40+
41+
folder = _FOLDER_BY_MODALITY.get(modality)
42+
if folder is None:
43+
valid = ', '.join(sorted(_FOLDER_BY_MODALITY))
44+
raise ValueError(f'Invalid modality "{modality}". Valid options are: {valid}.')
45+
46+
# Avoid double-appending if already included
47+
if str(path).rstrip('/').endswith(('/' + folder, folder)):
48+
return path
49+
50+
if is_s3_path(path):
51+
path = path.rstrip('/') + '/' + folder
52+
return path
53+
54+
return os.path.join(path, folder)
55+
56+
1757
class ResultsExplorer:
1858
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
1959

@@ -22,12 +62,18 @@ def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
2262
self.aws_access_key_id = aws_access_key_id
2363
self.aws_secret_access_key = aws_secret_access_key
2464

65+
baseline_synthesizer = _get_baseline_synthesizer(modality)
66+
effective_path = _resolve_effective_path(path, modality)
2567
if is_s3_path(path):
2668
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
27-
self._handler = S3ResultsHandler(path, s3_client)
69+
self._handler = S3ResultsHandler(
70+
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
71+
)
2872
else:
29-
_validate_local_path(path)
30-
self._handler = LocalResultsHandler(path)
73+
_validate_local_path(effective_path)
74+
self._handler = LocalResultsHandler(
75+
effective_path, baseline_synthesizer=baseline_synthesizer
76+
)
3177

3278
def list(self):
3379
"""List all runs available in the results directory."""

sdgym/result_explorer/result_handler.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
class ResultsHandler(ABC):
2323
"""Abstract base class for handling results storage and retrieval."""
2424

25+
def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE):
26+
# Allow overrides per modality while maintaining the historical default.
27+
self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
28+
2529
@abstractmethod
2630
def list(self):
2731
"""List all runs in the results directory."""
@@ -59,7 +63,8 @@ def _compute_wins(self, result):
5963
result['Win'] = 0
6064
for dataset in datasets:
6165
score_baseline = result.loc[
62-
(result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset)
66+
(result['Synthesizer'] == self.baseline_synthesizer)
67+
& (result['Dataset'] == dataset)
6368
]['Quality_Score'].to_numpy()
6469
if score_baseline.size == 0:
6570
continue
@@ -84,7 +89,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8489
f' - # datasets: {folder_infos[folder]["# datasets"]}'
8590
f' - sdgym version: {folder_infos[folder]["sdgym_version"]}'
8691
)
87-
results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE]
92+
results = results.loc[results['Synthesizer'] != self.baseline_synthesizer]
8893
column_data = results.groupby(['Synthesizer'])['Win'].sum()
8994
columns.append((date_obj, column_name, column_data))
9095

@@ -107,9 +112,11 @@ def _get_column_name_infos(self, folder_to_results):
107112
continue
108113

109114
metainfo_info = self._load_yaml_file(folder, yaml_files[0])
110-
num_datasets = results.loc[
111-
results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset'
112-
].nunique()
115+
baseline_mask = results['Synthesizer'] == self.baseline_synthesizer
116+
if baseline_mask.any():
117+
num_datasets = results.loc[baseline_mask, 'Dataset'].nunique()
118+
else:
119+
num_datasets = results['Dataset'].nunique()
113120
folder_to_info[folder] = {
114121
'date': metainfo_info.get('starting_date')[:NUM_DIGITS_DATE],
115122
'sdgym_version': metainfo_info.get('sdgym_version'),
@@ -236,7 +243,8 @@ def all_runs_complete(self, folder_name):
236243
class LocalResultsHandler(ResultsHandler):
237244
"""Results handler for local filesystem."""
238245

239-
def __init__(self, base_path):
246+
def __init__(self, base_path, baseline_synthesizer=SYNTHESIZER_BASELINE):
247+
super().__init__(baseline_synthesizer=baseline_synthesizer)
240248
self.base_path = base_path
241249

242250
def list(self):
@@ -287,7 +295,8 @@ def _load_yaml_file(self, folder_name, file_name):
287295
class S3ResultsHandler(ResultsHandler):
288296
"""Results handler for AWS S3 storage."""
289297

290-
def __init__(self, path, s3_client):
298+
def __init__(self, path, s3_client, baseline_synthesizer=SYNTHESIZER_BASELINE):
299+
super().__init__(baseline_synthesizer=baseline_synthesizer)
291300
self.s3_client = s3_client
292301
self.bucket_name = path.split('/')[2]
293302
self.prefix = '/'.join(path.split('/')[3:]).rstrip('/') + '/'
@@ -396,8 +405,8 @@ def _get_results(self, folder_name, file_names):
396405
for file_name in file_names:
397406
s3_key = f'{self.prefix}{folder_name}/{file_name}'
398407
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
399-
df = pd.read_csv(io.BytesIO(response['Body'].read()))
400-
results.append(df)
408+
result_df = pd.read_csv(io.BytesIO(response['Body'].read()))
409+
results.append(result_df)
401410

402411
return results
403412

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
completed_date: 12_02_2025 08:28:47
2+
jobs:
3+
- - fake_hotels
4+
- MultiTableUniformSynthesizer
5+
- - fake_hotels
6+
- HMASynthesizer
7+
modality: multi_table
8+
run_id: run_12_02_2025_0
9+
sdgym_version: 0.11.2.dev0
10+
sdv_version: 1.28.0
11+
starting_date: 12_02_2025 08:28:21

0 commit comments

Comments
 (0)