Skip to content

Commit a9f3296

Browse files
committed
Feedback
1 parent ba3482b commit a9f3296

File tree

5 files changed

+37
-9
lines changed

5 files changed

+37
-9
lines changed

sdgym/benchmark.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
)
3939
from sdmetrics.single_table import DCRBaselineProtection
4040

41-
# from sdgym import __version__ as SDGYM_VERSION
4241
from sdgym.datasets import get_dataset_paths, load_dataset
4342
from sdgym.errors import BenchmarkError, SDGymError
4443
from sdgym.metrics import get_metrics

sdgym/result_explorer/result_explorer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
S3ResultsHandler,
1111
)
1212
from sdgym.s3 import _get_s3_client, is_s3_path
13+
from sdgym.synthesizers.base import _validate_modality
1314

1415

1516
def _validate_local_path(path):
@@ -34,9 +35,7 @@ def _resolve_effective_path(path, modality):
3435
if not modality:
3536
return path
3637

37-
if modality not in ('single_table', 'multi_table'):
38-
valid = ', '.join(sorted(('single_table', 'multi_table')))
39-
raise ValueError(f'Invalid modality "{modality}". Valid options are: {valid}.')
38+
_validate_modality(modality)
4039

4140
# Avoid double-appending if already included
4241
if str(path).rstrip('/').endswith(('/' + modality, modality)):

sdgym/result_explorer/result_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def load_synthetic_data(self, file_path):
401401
if name.endswith('.csv'):
402402
table_name = os.path.splitext(os.path.basename(name))[0]
403403
with zf.open(name) as csv_file:
404-
tables[table_name] = pd.read_csv(csv_file, low_memory=False)
404+
tables[table_name] = pd.read_csv(csv_file)
405405

406406
return tables
407407

tests/unit/result_explorer/test_result_handler.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,24 @@ def test_load_metainfo(self):
261261
class TestLocalResultsHandler:
262262
"""Unit tests for the LocalResultsHandler class."""
263263

264+
def test__init__sets_base_path_and_default_baseline(self, tmp_path):
265+
"""Test it initializes base_path and default baseline."""
266+
# Run
267+
handler = LocalResultsHandler(str(tmp_path))
268+
269+
# Assert
270+
assert handler.base_path == str(tmp_path)
271+
assert handler.baseline_synthesizer == 'GaussianCopulaSynthesizer'
272+
273+
def test__init__supports_baseline_override(self, tmp_path):
274+
"""Test it allows overriding baseline synthesizer."""
275+
# Run
276+
handler = LocalResultsHandler(str(tmp_path), baseline_synthesizer='CustomBaseline')
277+
278+
# Assert
279+
assert handler.base_path == str(tmp_path)
280+
assert handler.baseline_synthesizer == 'CustomBaseline'
281+
264282
def test_list(self, tmp_path):
265283
"""Test the `list` method"""
266284
# Setup
@@ -418,9 +436,7 @@ def test_get_file_path_local_error(self, mock_isfile, mock_exists):
418436
class TestS3ResultsHandler:
419437
"""Unit tests for the S3ResultsHandler class."""
420438

421-
def test__init__(
422-
self,
423-
):
439+
def test__init__(self):
424440
"""Test the `__init__` method."""
425441
# Setup
426442
path = 's3://my-bucket/prefix'
@@ -432,6 +448,21 @@ def test__init__(
432448
assert result_handler.s3_client == 's3_client'
433449
assert result_handler.bucket_name == 'my-bucket'
434450
assert result_handler.prefix == 'prefix/'
451+
assert result_handler.baseline_synthesizer == 'GaussianCopulaSynthesizer'
452+
453+
def test__init__supports_baseline_override(self):
454+
"""Test it allows overriding baseline synthesizer."""
455+
# Run
456+
s3_client = Mock()
457+
handler = S3ResultsHandler(
458+
's3://bkt/prefix', s3_client, baseline_synthesizer='CustomBaseline'
459+
)
460+
461+
# Assert
462+
assert handler.baseline_synthesizer == 'CustomBaseline'
463+
assert handler.s3_client == s3_client
464+
assert handler.bucket_name == 'bkt'
465+
assert handler.prefix == 'prefix/'
435466

436467
def test_list(self):
437468
"""Test the `list` method."""

tests/unit/test_benchmark.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,6 @@ def test__setup_output_destination_multi_table(tmp_path):
729729
@patch('sdgym.benchmark.datetime')
730730
def test__write_metainfo_file(mock_datetime, tmp_path):
731731
"""Test the `_write_metainfo_file` method."""
732-
pytest.importorskip('realtabformer')
733732
# Setup
734733
output_destination = tmp_path / 'SDGym_results_06_26_2025'
735734
output_destination.mkdir()

0 commit comments

Comments
 (0)