Skip to content

Commit a985b08

Browse files
committed
Update handler
1 parent ced519e commit a985b08

File tree

23 files changed

+124
-68
lines changed

23 files changed

+124
-68
lines changed

sdgym/benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from sdmetrics.single_table import DCRBaselineProtection
4040

41+
from sdgym import __version__ as SDGYM_VERSION
4142
from sdgym.datasets import get_dataset_paths, load_dataset
4243
from sdgym.errors import BenchmarkError, SDGymError
4344
from sdgym.metrics import get_metrics
@@ -1219,7 +1220,7 @@ def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=No
12191220
'modality': modality,
12201221
'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'),
12211222
'completed_date': None,
1222-
'sdgym_version': version('sdgym'),
1223+
'sdgym_version': SDGYM_VERSION,
12231224
'jobs': jobs,
12241225
}
12251226

sdgym/result_explorer/result_explorer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
from sdgym.benchmark import DEFAULT_SINGLE_TABLE_DATASETS
66
from sdgym.datasets import load_dataset
7-
from sdgym.result_explorer.result_handler import LocalResultsHandler, S3ResultsHandler
7+
from sdgym.result_explorer.result_handler import (
8+
SYNTHESIZER_BASELINE,
9+
LocalResultsHandler,
10+
S3ResultsHandler,
11+
)
812
from sdgym.s3 import _get_s3_client, is_s3_path
913

1014

@@ -15,10 +19,19 @@ def _validate_local_path(path):
1519

1620

1721
_FOLDER_BY_MODALITY = {
18-
'single_table': 'single-table',
22+
'single_table': 'single_table',
1923
'multi_table': 'multi_table',
2024
}
2125

26+
_BASELINE_BY_MODALITY = {
27+
'multi_table': 'MultiTableUniformSynthesizer',
28+
}
29+
30+
31+
def _get_baseline_synthesizer(modality):
32+
"""Return the appropriate baseline synthesizer for the given modality."""
33+
return _BASELINE_BY_MODALITY.get(modality, SYNTHESIZER_BASELINE)
34+
2235

2336
def _resolve_effective_path(path, modality):
2437
"""Append the modality folder to the given base path if provided."""
@@ -50,15 +63,20 @@ def __init__(self, path, modality, aws_access_key_id=None, aws_secret_access_key
5063
self.aws_access_key_id = aws_access_key_id
5164
self.aws_secret_access_key = aws_secret_access_key
5265

66+
baseline_synthesizer = _get_baseline_synthesizer(modality)
5367
effective_path = _resolve_effective_path(path, modality)
5468
if is_s3_path(path):
5569
# Use original path to obtain client (keeps backwards compatibility),
5670
# but handler should operate on the modality-specific effective path.
5771
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
58-
self._handler = S3ResultsHandler(effective_path, s3_client)
72+
self._handler = S3ResultsHandler(
73+
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
74+
)
5975
else:
6076
_validate_local_path(effective_path)
61-
self._handler = LocalResultsHandler(effective_path)
77+
self._handler = LocalResultsHandler(
78+
effective_path, baseline_synthesizer=baseline_synthesizer
79+
)
6280

6381
def list(self):
6482
"""List all runs available in the results directory."""

sdgym/result_explorer/result_handler.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
class ResultsHandler(ABC):
2727
"""Abstract base class for handling results storage and retrieval."""
2828

29+
def __init__(self, baseline_synthesizer=SYNTHESIZER_BASELINE):
30+
# Allow overrides per modality while maintaining the historical default.
31+
self.baseline_synthesizer = baseline_synthesizer or SYNTHESIZER_BASELINE
32+
2933
@abstractmethod
3034
def list(self):
3135
"""List all runs in the results directory."""
@@ -63,7 +67,8 @@ def _compute_wins(self, result):
6367
result['Win'] = 0
6468
for dataset in datasets:
6569
score_baseline = result.loc[
66-
(result['Synthesizer'] == SYNTHESIZER_BASELINE) & (result['Dataset'] == dataset)
70+
(result['Synthesizer'] == self.baseline_synthesizer)
71+
& (result['Dataset'] == dataset)
6772
]['Quality_Score'].to_numpy()
6873
if score_baseline.size == 0:
6974
continue
@@ -88,7 +93,7 @@ def _get_summarize_table(self, folder_to_results, folder_infos):
8893
f' - # datasets: {folder_infos[folder]["# datasets"]}'
8994
f' - sdgym version: {folder_infos[folder]["sdgym_version"]}'
9095
)
91-
results = results.loc[results['Synthesizer'] != SYNTHESIZER_BASELINE]
96+
results = results.loc[results['Synthesizer'] != self.baseline_synthesizer]
9297
column_data = results.groupby(['Synthesizer'])['Win'].sum()
9398
columns.append((date_obj, column_name, column_data))
9499

@@ -111,9 +116,11 @@ def _get_column_name_infos(self, folder_to_results):
111116
continue
112117

113118
metainfo_info = self._load_yaml_file(folder, yaml_files[0])
114-
num_datasets = results.loc[
115-
results['Synthesizer'] == SYNTHESIZER_BASELINE, 'Dataset'
116-
].nunique()
119+
baseline_mask = results['Synthesizer'] == self.baseline_synthesizer
120+
if baseline_mask.any():
121+
num_datasets = results.loc[baseline_mask, 'Dataset'].nunique()
122+
else:
123+
num_datasets = results['Dataset'].nunique()
117124
folder_to_info[folder] = {
118125
'date': metainfo_info.get('starting_date')[:NUM_DIGITS_DATE],
119126
'sdgym_version': metainfo_info.get('sdgym_version'),
@@ -240,7 +247,8 @@ def all_runs_complete(self, folder_name):
240247
class LocalResultsHandler(ResultsHandler):
241248
"""Results handler for local filesystem."""
242249

243-
def __init__(self, base_path):
250+
def __init__(self, base_path, baseline_synthesizer=SYNTHESIZER_BASELINE):
251+
super().__init__(baseline_synthesizer=baseline_synthesizer)
244252
self.base_path = base_path
245253

246254
def list(self):
@@ -295,7 +303,8 @@ def _load_yaml_file(self, folder_name, file_name):
295303
class S3ResultsHandler(ResultsHandler):
296304
"""Results handler for AWS S3 storage."""
297305

298-
def __init__(self, path, s3_client):
306+
def __init__(self, path, s3_client, baseline_synthesizer=SYNTHESIZER_BASELINE):
307+
super().__init__(baseline_synthesizer=baseline_synthesizer)
299308
self.s3_client = s3_client
300309
self.bucket_name = path.split('/')[2]
301310
self.prefix = '/'.join(path.split('/')[3:]).rstrip('/') + '/'
@@ -415,8 +424,8 @@ def _get_results(self, folder_name, file_names):
415424
for file_name in file_names:
416425
s3_key = f'{self.prefix}{folder_name}/{file_name}'
417426
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=s3_key)
418-
df = pd.read_csv(io.BytesIO(response['Body'].read()))
419-
results.append(df)
427+
result_df = pd.read_csv(io.BytesIO(response['Body'].read()))
428+
results.append(result_df)
420429

421430
return results
422431

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
HMASynthesizer,fake_hotels,0.048698,22.852492,33.315142,0.988611,2.723049,0.082362,1.0,0.7353482911012336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Synthesizer,Dataset,Dataset_Size_MB,Train_Time,Peak_Memory_MB,Synthesizer_Size_MB,Sample_Time,Evaluate_Time,Diagnostic_Score,Quality_Score
2+
MultiTableUniformSynthesizer,fake_hotels,0.048698,0.201284,0.851853,0.109464,0.02749,0.081629,0.9122678149273894,0.5962941240006595
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
completed_date: 12_02_2025 08:28:47
2+
jobs:
3+
- - fake_hotels
4+
- MultiTableUniformSynthesizer
5+
- - fake_hotels
6+
- HMASynthesizer
7+
modality: multi_table
8+
run_id: run_12_02_2025_0
9+
sdgym_version: 0.11.2.dev0
10+
sdv_version: 1.28.0
11+
starting_date: 12_02_2025 08:28:21

0 commit comments

Comments
 (0)