Skip to content

Commit d4ae9e2

Browse files
authored
When running a benchmark locally, the additional_datasets_folder path should be the root path (#493)
1 parent ddd3dcf commit d4ae9e2

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

sdgym/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,11 @@ def _generate_job_args_list(
271271
if additional_datasets_folder is None
272272
else get_dataset_paths(
273273
modality='single_table',
274-
bucket=additional_datasets_folder,
274+
bucket=(
275+
additional_datasets_folder
276+
if is_s3_path(additional_datasets_folder)
277+
else os.path.join(additional_datasets_folder, 'single_table')
278+
),
275279
aws_access_key_id=aws_access_key_id,
276280
aws_secret_access_key=aws_secret_access_key_key,
277281
)

tests/unit/test_benchmark.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_ensure_uniform_included,
2121
_fill_adjusted_scores_with_none,
2222
_format_output,
23+
_generate_job_args_list,
2324
_get_metainfo_increment,
2425
_handle_deprecated_parameters,
2526
_setup_output_destination,
@@ -1000,6 +1001,75 @@ def test__add_adjusted_scores_missing_fallback():
10001001
assert scores.equals(expected)
10011002

10021003

1004+
@patch('sdgym.benchmark.get_dataset_paths')
1005+
def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path):
1006+
"""Local additional_datasets_folder should point to root/single_table."""
1007+
# Setup
1008+
local_root = tmp_path / 'my_root'
1009+
local_root.mkdir()
1010+
dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA'
1011+
get_dataset_paths_mock.return_value = [dataset_path]
1012+
1013+
# Run
1014+
_generate_job_args_list(
1015+
limit_dataset_size=False,
1016+
sdv_datasets=None,
1017+
additional_datasets_folder=str(local_root),
1018+
sdmetrics=None,
1019+
detailed_results_folder=None,
1020+
timeout=None,
1021+
output_destination=None,
1022+
compute_quality_score=False,
1023+
compute_diagnostic_score=False,
1024+
compute_privacy_score=False,
1025+
synthesizers=[],
1026+
custom_synthesizers=None,
1027+
s3_client=None,
1028+
)
1029+
1030+
# Assert
1031+
get_dataset_paths_mock.assert_called_once_with(
1032+
modality='single_table',
1033+
bucket=str(local_root / 'single_table'),
1034+
aws_access_key_id=None,
1035+
aws_secret_access_key=None,
1036+
)
1037+
1038+
1039+
@patch('sdgym.benchmark.get_dataset_paths')
1040+
def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_mock):
1041+
"""S3 additional_datasets_folder should point to the root path."""
1042+
# Setup
1043+
s3_root = 's3://my-bucket/custom-datasets'
1044+
dataset_path = Path('/dummy/single_table/datasetA')
1045+
get_dataset_paths_mock.return_value = [dataset_path]
1046+
1047+
# Run
1048+
_generate_job_args_list(
1049+
limit_dataset_size=False,
1050+
sdv_datasets=None,
1051+
additional_datasets_folder=s3_root,
1052+
sdmetrics=None,
1053+
detailed_results_folder=None,
1054+
timeout=None,
1055+
output_destination=None,
1056+
compute_quality_score=False,
1057+
compute_diagnostic_score=False,
1058+
compute_privacy_score=False,
1059+
synthesizers=[],
1060+
custom_synthesizers=None,
1061+
s3_client=None,
1062+
)
1063+
1064+
# Assert
1065+
get_dataset_paths_mock.assert_called_once_with(
1066+
modality='single_table',
1067+
bucket=s3_root,
1068+
aws_access_key_id=None,
1069+
aws_secret_access_key=None,
1070+
)
1071+
1072+
10031073
def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn):
10041074
"""Test that no UserWarning is raised when running `UniformSynthesizer`."""
10051075
# Setup

0 commit comments

Comments
 (0)