|
20 | 20 | _ensure_uniform_included, |
21 | 21 | _fill_adjusted_scores_with_none, |
22 | 22 | _format_output, |
| 23 | + _generate_job_args_list, |
23 | 24 | _get_metainfo_increment, |
24 | 25 | _handle_deprecated_parameters, |
25 | 26 | _setup_output_destination, |
@@ -1000,6 +1001,75 @@ def test__add_adjusted_scores_missing_fallback(): |
1000 | 1001 | assert scores.equals(expected) |
1001 | 1002 |
|
1002 | 1003 |
|
| 1004 | +@patch('sdgym.benchmark.get_dataset_paths') |
| 1005 | +def test__generate_job_args_list_local_root_additional_folder(get_dataset_paths_mock, tmp_path): |
| 1006 | + """Local additional_datasets_folder should point to root/single_table.""" |
| 1007 | + # Setup |
| 1008 | + local_root = tmp_path / 'my_root' |
| 1009 | + local_root.mkdir() |
| 1010 | + dataset_path = tmp_path / 'my_root' / 'single_table' / 'datasetA' |
| 1011 | + get_dataset_paths_mock.return_value = [dataset_path] |
| 1012 | + |
| 1013 | + # Run |
| 1014 | + _generate_job_args_list( |
| 1015 | + limit_dataset_size=False, |
| 1016 | + sdv_datasets=None, |
| 1017 | + additional_datasets_folder=str(local_root), |
| 1018 | + sdmetrics=None, |
| 1019 | + detailed_results_folder=None, |
| 1020 | + timeout=None, |
| 1021 | + output_destination=None, |
| 1022 | + compute_quality_score=False, |
| 1023 | + compute_diagnostic_score=False, |
| 1024 | + compute_privacy_score=False, |
| 1025 | + synthesizers=[], |
| 1026 | + custom_synthesizers=None, |
| 1027 | + s3_client=None, |
| 1028 | + ) |
| 1029 | + |
| 1030 | + # Assert |
| 1031 | + get_dataset_paths_mock.assert_called_once_with( |
| 1032 | + modality='single_table', |
| 1033 | + bucket=str(local_root / 'single_table'), |
| 1034 | + aws_access_key_id=None, |
| 1035 | + aws_secret_access_key=None, |
| 1036 | + ) |
| 1037 | + |
| 1038 | + |
| 1039 | +@patch('sdgym.benchmark.get_dataset_paths') |
| 1040 | +def test__generate_job_args_list_s3_root_additional_folder(get_dataset_paths_mock): |
| 1041 | + """S3 additional_datasets_folder should point to the root path.""" |
| 1042 | + # Setup |
| 1043 | + s3_root = 's3://my-bucket/custom-datasets' |
| 1044 | + dataset_path = Path('/dummy/single_table/datasetA') |
| 1045 | + get_dataset_paths_mock.return_value = [dataset_path] |
| 1046 | + |
| 1047 | + # Run |
| 1048 | + _generate_job_args_list( |
| 1049 | + limit_dataset_size=False, |
| 1050 | + sdv_datasets=None, |
| 1051 | + additional_datasets_folder=s3_root, |
| 1052 | + sdmetrics=None, |
| 1053 | + detailed_results_folder=None, |
| 1054 | + timeout=None, |
| 1055 | + output_destination=None, |
| 1056 | + compute_quality_score=False, |
| 1057 | + compute_diagnostic_score=False, |
| 1058 | + compute_privacy_score=False, |
| 1059 | + synthesizers=[], |
| 1060 | + custom_synthesizers=None, |
| 1061 | + s3_client=None, |
| 1062 | + ) |
| 1063 | + |
| 1064 | + # Assert |
| 1065 | + get_dataset_paths_mock.assert_called_once_with( |
| 1066 | + modality='single_table', |
| 1067 | + bucket=s3_root, |
| 1068 | + aws_access_key_id=None, |
| 1069 | + aws_secret_access_key=None, |
| 1070 | + ) |
| 1071 | + |
| 1072 | + |
1003 | 1073 | def test_benchmark_single_table_no_warning_uniform_synthesizer(recwarn): |
1004 | 1074 | """Test that no UserWarning is raised when running `UniformSynthesizer`.""" |
1005 | 1075 | # Setup |
|
0 commit comments