Skip to content

Commit a29ee75

Browse files
committed
Update path
1 parent 6a3b217 commit a29ee75

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

sdgym/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,11 @@ def _generate_job_args_list(
271271
if additional_datasets_folder is None
272272
else get_dataset_paths(
273273
modality='single_table',
274-
bucket=additional_datasets_folder,
274+
bucket=(
275+
additional_datasets_folder
276+
if is_s3_path(additional_datasets_folder)
277+
else os.path.join(additional_datasets_folder, 'single_table')
278+
),
275279
aws_access_key_id=aws_access_key_id,
276280
aws_secret_access_key=aws_secret_access_key_key,
277281
)

tests/unit/test_benchmark.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_ensure_uniform_included,
2121
_fill_adjusted_scores_with_none,
2222
_format_output,
23+
_generate_job_args_list,
2324
_get_metainfo_increment,
2425
_handle_deprecated_parameters,
2526
_setup_output_destination,
@@ -1001,6 +1002,75 @@ def test__add_adjusted_scores_missing_fallback():
10011002
assert scores.equals(expected)
10021003

10031004

1005+
@patch('sdgym.benchmark.get_dataset_paths')
1006+
def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path):
1007+
"""Local additional_datasets_folder should point to root/single_table."""
1008+
# Setup
1009+
local_root = tmp_path / 'my_root'
1010+
local_root.mkdir()
1011+
dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA'
1012+
get_dataset_paths_mock.return_value = [dataset_path]
1013+
1014+
# Run
1015+
_generate_job_args_list(
1016+
limit_dataset_size=False,
1017+
sdv_datasets=None,
1018+
additional_datasets_folder=str(local_root),
1019+
sdmetrics=None,
1020+
detailed_results_folder=None,
1021+
timeout=None,
1022+
output_destination=None,
1023+
compute_quality_score=False,
1024+
compute_diagnostic_score=False,
1025+
compute_privacy_score=False,
1026+
synthesizers=[],
1027+
custom_synthesizers=None,
1028+
s3_client=None,
1029+
)
1030+
1031+
# Assert
1032+
get_dataset_paths_mock.assert_called_once_with(
1033+
modality='single_table',
1034+
bucket=str(local_root / 'single_table'),
1035+
aws_access_key_id=None,
1036+
aws_secret_access_key=None,
1037+
)
1038+
1039+
1040+
@patch('sdgym.benchmark.get_dataset_paths')
1041+
def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_mock):
1042+
"""S3 additional_datasets_folder should point to the root path."""
1043+
# Setup
1044+
s3_root = 's3://my-bucket/custom-datasets'
1045+
dataset_path = Path('/dummy/single_table/datasetA')
1046+
get_dataset_paths_mock.return_value = [dataset_path]
1047+
1048+
# Run
1049+
_generate_job_args_list(
1050+
limit_dataset_size=False,
1051+
sdv_datasets=None,
1052+
additional_datasets_folder=s3_root,
1053+
sdmetrics=None,
1054+
detailed_results_folder=None,
1055+
timeout=None,
1056+
output_destination=None,
1057+
compute_quality_score=False,
1058+
compute_diagnostic_score=False,
1059+
compute_privacy_score=False,
1060+
synthesizers=[],
1061+
custom_synthesizers=None,
1062+
s3_client=None,
1063+
)
1064+
1065+
# Assert
1066+
get_dataset_paths_mock.assert_called_once_with(
1067+
modality='single_table',
1068+
bucket=s3_root,
1069+
aws_access_key_id=None,
1070+
aws_secret_access_key=None,
1071+
)
1072+
1073+
10041074
def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn):
10051075
"""Test that no UserWarning is raised when running `UniformSynthesizer`."""
10061076
# Setup

0 commit comments

Comments
 (0)