Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,11 @@ def _generate_job_args_list(
if additional_datasets_folder is None
else get_dataset_paths(
modality='single_table',
bucket=additional_datasets_folder,
bucket=(
additional_datasets_folder
if is_s3_path(additional_datasets_folder)
else os.path.join(additional_datasets_folder, 'single_table')
),
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key_key,
)
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_ensure_uniform_included,
_fill_adjusted_scores_with_none,
_format_output,
_generate_job_args_list,
_get_metainfo_increment,
_handle_deprecated_parameters,
_setup_output_destination,
Expand Down Expand Up @@ -1000,6 +1001,75 @@ def test__add_adjusted_scores_missing_fallback():
assert scores.equals(expected)


@patch('sdgym.benchmark.get_dataset_paths')
def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path):
"""Local additional_datasets_folder should point to root/single_table."""
# Setup
local_root = tmp_path / 'my_root'
local_root.mkdir()
dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA'
get_dataset_paths_mock.return_value = [dataset_path]

# Run
_generate_job_args_list(
limit_dataset_size=False,
sdv_datasets=None,
additional_datasets_folder=str(local_root),
sdmetrics=None,
detailed_results_folder=None,
timeout=None,
output_destination=None,
compute_quality_score=False,
compute_diagnostic_score=False,
compute_privacy_score=False,
synthesizers=[],
custom_synthesizers=None,
s3_client=None,
)

# Assert
get_dataset_paths_mock.assert_called_once_with(
modality='single_table',
bucket=str(local_root / 'single_table'),
aws_access_key_id=None,
aws_secret_access_key=None,
)


@patch('sdgym.benchmark.get_dataset_paths')
def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_mock):
"""S3 additional_datasets_folder should point to the root path."""
# Setup
s3_root = 's3://my-bucket/custom-datasets'
dataset_path = Path('/dummy/single_table/datasetA')
get_dataset_paths_mock.return_value = [dataset_path]

# Run
_generate_job_args_list(
limit_dataset_size=False,
sdv_datasets=None,
additional_datasets_folder=s3_root,
sdmetrics=None,
detailed_results_folder=None,
timeout=None,
output_destination=None,
compute_quality_score=False,
compute_diagnostic_score=False,
compute_privacy_score=False,
synthesizers=[],
custom_synthesizers=None,
s3_client=None,
)

# Assert
get_dataset_paths_mock.assert_called_once_with(
modality='single_table',
bucket=s3_root,
aws_access_key_id=None,
aws_secret_access_key=None,
)


def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn):
"""Test that no UserWarning is raised when running `UniformSynthesizer`."""
# Setup
Expand Down