|
2 | 2 | from datetime import datetime, timezone |
3 | 3 | from unittest.mock import Mock, call, patch |
4 | 4 |
|
| 5 | +import pytest |
5 | 6 | from botocore.exceptions import ClientError |
6 | 7 |
|
7 | | -from sdgym.run_benchmark.run_benchmark import append_benchmark_run, main |
8 | | -from sdgym.run_benchmark.utils import OUTPUT_DESTINATION_AWS, SYNTHESIZERS_SPLIT_SINGLE_TABLE |
| 8 | +from sdgym.run_benchmark.run_benchmark import ( |
| 9 | + append_benchmark_run, |
| 10 | + main, |
| 11 | +) |
| 12 | +from sdgym.run_benchmark.utils import ( |
| 13 | + OUTPUT_DESTINATION_AWS, |
| 14 | + SYNTHESIZERS_SPLIT_MULTI_TABLE, |
| 15 | + SYNTHESIZERS_SPLIT_SINGLE_TABLE, |
| 16 | +) |
9 | 17 |
|
10 | 18 |
|
11 | 19 | @patch('sdgym.run_benchmark.run_benchmark.get_s3_client') |
@@ -51,7 +59,7 @@ def test_append_benchmark_run(mock_get_result_folder_name, mock_parse_s3_path, m |
51 | 59 | ) |
52 | 60 | mock_s3_client.put_object.assert_called_once_with( |
53 | 61 | Bucket='my-bucket', |
54 | | - Key='my-prefix/_BENCHMARK_DATES.json', |
| 62 | + Key='my-prefix/single_table/_BENCHMARK_DATES.json', |
55 | 63 | Body=json.dumps(expected_data).encode('utf-8'), |
56 | 64 | ) |
57 | 65 |
|
@@ -91,53 +99,84 @@ def test_append_benchmark_run_new_file( |
91 | 99 | mock_parse_s3_path.assert_called_once_with(OUTPUT_DESTINATION_AWS) |
92 | 100 | mock_get_result_folder_name.assert_called_once_with(date) |
93 | 101 | mock_s3_client.get_object.assert_called_once_with( |
94 | | - Bucket='my-bucket', Key='my-prefix/_BENCHMARK_DATES.json' |
| 102 | + Bucket='my-bucket', Key='my-prefix/single_table/_BENCHMARK_DATES.json' |
95 | 103 | ) |
96 | 104 | mock_s3_client.put_object.assert_called_once_with( |
97 | 105 | Bucket='my-bucket', |
98 | | - Key='my-prefix/_BENCHMARK_DATES.json', |
| 106 | + Key='my-prefix/single_table/_BENCHMARK_DATES.json', |
99 | 107 | Body=json.dumps(expected_data).encode('utf-8'), |
100 | 108 | ) |
101 | 109 |
|
102 | 110 |
|
103 | | -@patch('sdgym.run_benchmark.run_benchmark.benchmark_single_table_aws') |
104 | | -@patch('sdgym.run_benchmark.run_benchmark.os.getenv') |
105 | | -@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run') |
| 111 | +@pytest.mark.parametrize( |
| 112 | + 'modality,synthesizer_split', |
| 113 | + [ |
| 114 | + ('single_table', SYNTHESIZERS_SPLIT_SINGLE_TABLE), |
| 115 | + ('multi_table', SYNTHESIZERS_SPLIT_MULTI_TABLE), |
| 116 | + ], |
| 117 | +) |
106 | 118 | @patch('sdgym.run_benchmark.run_benchmark.post_benchmark_launch_message') |
| 119 | +@patch('sdgym.run_benchmark.run_benchmark.append_benchmark_run') |
| 120 | +@patch('sdgym.run_benchmark.run_benchmark.os.getenv') |
| 121 | +@patch('sdgym.run_benchmark.run_benchmark._parse_args') |
| 122 | +@patch.dict( |
| 123 | + 'sdgym.run_benchmark.run_benchmark.MODALITY_TO_SETUP', |
| 124 | + values={ |
| 125 | + 'single_table': { |
| 126 | + 'method': Mock(name='mock_single_method'), |
| 127 | + 'synthesizers_split': [], |
| 128 | + }, |
| 129 | + 'multi_table': { |
| 130 | + 'method': Mock(name='mock_multi_method'), |
| 131 | + 'synthesizers_split': [], |
| 132 | + }, |
| 133 | + }, |
| 134 | + clear=True, |
| 135 | +) |
107 | 136 | def test_main( |
108 | | - mock_post_benchmark_launch_message, |
109 | | - mock_append_benchmark_run, |
| 137 | + mock_parse_args, |
110 | 138 | mock_getenv, |
111 | | - mock_benchmark_single_table_aws, |
| 139 | + mock_append_benchmark_run, |
| 140 | + mock_post_benchmark_launch_message, |
| 141 | + modality, |
| 142 | + synthesizer_split, |
112 | 143 | ): |
113 | | - """Test the `main` method.""" |
| 144 | + """Test the `main` function with both single_table and multi_table modalities.""" |
114 | 145 | # Setup |
115 | | - mock_getenv.side_effect = ['my_access_key', 'my_secret_key'] |
| 146 | + from sdgym.run_benchmark.run_benchmark import MODALITY_TO_SETUP |
| 147 | + |
| 148 | + mock_parse_args.return_value = Mock(modality=modality) |
| 149 | + mock_getenv.side_effect = lambda key: { |
| 150 | + 'AWS_ACCESS_KEY_ID': 'my_access_key', |
| 151 | + 'AWS_SECRET_ACCESS_KEY': 'my_secret_key', |
| 152 | + 'CREDENTIALS_FILEPATH': '/path/to/creds.json', |
| 153 | + }.get(key) |
| 154 | + MODALITY_TO_SETUP[modality]['synthesizers_split'] = synthesizer_split |
| 155 | + mock_method = MODALITY_TO_SETUP[modality]['method'] |
116 | 156 | date = datetime.now(timezone.utc).strftime('%Y-%m-%d') |
117 | 157 |
|
118 | 158 | # Run |
119 | 159 | main() |
120 | 160 |
|
121 | 161 | # Assert |
122 | | - mock_getenv.assert_any_call('AWS_ACCESS_KEY_ID') |
123 | | - mock_getenv.assert_any_call('AWS_SECRET_ACCESS_KEY') |
124 | | - expected_calls = [] |
125 | | - for synthesizer in SYNTHESIZERS_SPLIT_SINGLE_TABLE: |
126 | | - expected_calls.append( |
127 | | - call( |
128 | | - output_destination=OUTPUT_DESTINATION_AWS, |
129 | | - aws_access_key_id='my_access_key', |
130 | | - aws_secret_access_key='my_secret_key', |
131 | | - synthesizers=synthesizer, |
132 | | - compute_privacy_score=False, |
133 | | - timeout=345600, |
134 | | - ) |
| 162 | + expected_calls = [ |
| 163 | + call( |
| 164 | + output_destination=OUTPUT_DESTINATION_AWS, |
| 165 | + credential_filepath='/path/to/creds.json', |
| 166 | + synthesizers=group, |
| 167 | + timeout=345600, |
135 | 168 | ) |
136 | | - |
137 | | - mock_benchmark_single_table_aws.assert_has_calls(expected_calls) |
| 169 | + for group in synthesizer_split |
| 170 | + ] |
| 171 | + mock_method.assert_has_calls(expected_calls) |
138 | 172 | mock_append_benchmark_run.assert_called_once_with( |
139 | 173 | 'my_access_key', |
140 | 174 | 'my_secret_key', |
141 | 175 | date, |
| 176 | + modality=modality, |
| 177 | + ) |
| 178 | + mock_post_benchmark_launch_message.assert_called_once_with( |
| 179 | + date, |
| 180 | + compute_service='GCP', |
| 181 | + modality=modality, |
142 | 182 | ) |
143 | | - mock_post_benchmark_launch_message.assert_called_once_with(date) |
|
0 commit comments