Skip to content

Commit ba3482b

Browse files
committed
Feedback
1 parent 4c87a30 commit ba3482b

File tree

3 files changed

+27
-30
lines changed

3 files changed

+27
-30
lines changed

sdgym/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from sdmetrics.single_table import DCRBaselineProtection
4040

41-
from sdgym import __version__ as SDGYM_VERSION
41+
# from sdgym import __version__ as SDGYM_VERSION
4242
from sdgym.datasets import get_dataset_paths, load_dataset
4343
from sdgym.errors import BenchmarkError, SDGymError
4444
from sdgym.metrics import get_metrics
@@ -1220,7 +1220,7 @@ def _write_metainfo_file(synthesizers, job_args_list, modality, result_writer=No
12201220
'modality': modality,
12211221
'starting_date': datetime.today().strftime('%m_%d_%Y %H:%M:%S'),
12221222
'completed_date': None,
1223-
'sdgym_version': SDGYM_VERSION,
1223+
'sdgym_version': version('sdgym'),
12241224
'jobs': jobs,
12251225
}
12261226

sdgym/result_explorer/result_explorer.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ def _validate_local_path(path):
1818
raise ValueError(f"The provided path '{path}' is not a valid local directory.")
1919

2020

21-
_FOLDER_BY_MODALITY = {
22-
'single_table': 'single_table',
23-
'multi_table': 'multi_table',
24-
}
25-
2621
_BASELINE_BY_MODALITY = {
22+
'single_table': SYNTHESIZER_BASELINE,
2723
'multi_table': 'MultiTableUniformSynthesizer',
2824
}
2925

@@ -38,45 +34,47 @@ def _resolve_effective_path(path, modality):
3834
if not modality:
3935
return path
4036

41-
folder = _FOLDER_BY_MODALITY.get(modality)
42-
if folder is None:
43-
valid = ', '.join(sorted(_FOLDER_BY_MODALITY))
37+
if modality not in ('single_table', 'multi_table'):
38+
valid = ', '.join(sorted(('single_table', 'multi_table')))
4439
raise ValueError(f'Invalid modality "{modality}". Valid options are: {valid}.')
4540

4641
# Avoid double-appending if already included
47-
if str(path).rstrip('/').endswith(('/' + folder, folder)):
42+
if str(path).rstrip('/').endswith(('/' + modality, modality)):
4843
return path
4944

5045
if is_s3_path(path):
51-
path = path.rstrip('/') + '/' + folder
46+
path = path.rstrip('/') + '/' + modality
5247
return path
5348

54-
return os.path.join(path, folder)
49+
return os.path.join(path, modality)
5550

5651

5752
class ResultsExplorer:
5853
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
5954

55+
def _create_results_handler(self, original_path, effective_path, baseline_synthesizer):
56+
"""Create the appropriate results handler for local or S3 storage."""
57+
if is_s3_path(original_path):
58+
# Use original path to obtain client (keeps backwards compatibility),
59+
# but handler should operate on the modality-specific effective path.
60+
s3_client = _get_s3_client(
61+
original_path, self.aws_access_key_id, self.aws_secret_access_key
62+
)
63+
return S3ResultsHandler(
64+
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
65+
)
66+
67+
_validate_local_path(effective_path)
68+
return LocalResultsHandler(effective_path, baseline_synthesizer=baseline_synthesizer)
69+
6070
def __init__(self, path, modality, aws_access_key_id=None, aws_secret_access_key=None):
6171
self.path = path
62-
self.modality = modality
72+
self.modality = modality.lower()
6373
self.aws_access_key_id = aws_access_key_id
6474
self.aws_secret_access_key = aws_secret_access_key
65-
6675
baseline_synthesizer = _get_baseline_synthesizer(modality)
6776
effective_path = _resolve_effective_path(path, modality)
68-
if is_s3_path(path):
69-
# Use original path to obtain client (keeps backwards compatibility),
70-
# but handler should operate on the modality-specific effective path.
71-
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
72-
self._handler = S3ResultsHandler(
73-
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
74-
)
75-
else:
76-
_validate_local_path(effective_path)
77-
self._handler = LocalResultsHandler(
78-
effective_path, baseline_synthesizer=baseline_synthesizer
79-
)
77+
self._handler = self._create_results_handler(path, effective_path, baseline_synthesizer)
8078

8179
def list(self):
8280
"""List all runs available in the results directory."""
@@ -125,7 +123,7 @@ def load_real_data(self, dataset_name):
125123
)
126124

127125
data, _ = load_dataset(
128-
modality=self.modality or 'single_table',
126+
modality=self.modality,
129127
dataset=dataset_name,
130128
aws_access_key_id=self.aws_access_key_id,
131129
aws_secret_access_key=self.aws_secret_access_key,

tests/unit/test_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import pytest
1212
import yaml
1313

14-
import sdgym
1514
from sdgym.benchmark import (
1615
_add_adjusted_scores,
1716
_check_write_permissions,
@@ -759,7 +758,7 @@ def test__write_metainfo_file(mock_datetime, tmp_path):
759758
assert metainfo_data['run_id'] == 'run_06_26_2025_0'
760759
assert metainfo_data['starting_date'] == '06_26_2025'
761760
assert metainfo_data['jobs'] == expected_jobs
762-
assert metainfo_data['sdgym_version'] == sdgym.__version__
761+
assert metainfo_data['sdgym_version'] == version('sdgym')
763762
assert metainfo_data['sdv_version'] == version('sdv')
764763
assert metainfo_data['realtabformer_version'] == version('realtabformer')
765764
assert metainfo_data['completed_date'] is None

0 commit comments

Comments
 (0)