1+ import json
12from datetime import datetime , timezone
2- from unittest .mock import call , patch
3+ from unittest .mock import Mock , call , patch
34
4- from sdgym ._run_benchmark import OUTPUT_DESTINATION_AWS
5- from sdgym ._run_benchmark .run_benchmark import main
5+ from botocore .exceptions import ClientError
6+
7+ from sdgym ._run_benchmark import OUTPUT_DESTINATION_AWS , SYNTHESIZERS
8+ from sdgym ._run_benchmark .run_benchmark import append_benchmark_run , main
9+
10+
11+ @patch ('sdgym._run_benchmark.run_benchmark.get_s3_client' )
12+ @patch ('sdgym._run_benchmark.run_benchmark.parse_s3_path' )
13+ @patch ('sdgym._run_benchmark.run_benchmark.get_run_name' )
14+ def test_append_benchmark_run (mock_get_run_name , mock_parse_s3_path , mock_get_s3_client ):
15+ """Test the `append_benchmark_run` method."""
16+ # Setup
17+ aws_access_key_id = 'my_access_key'
18+ aws_secret_access_key = 'my_secret_key'
19+ date = '2023-10-01'
20+ mock_get_run_name .return_value = 'SDGym_results_10_01_2023'
21+ mock_parse_s3_path .return_value = ('my-bucket' , 'my-prefix/' )
22+ mock_s3_client = Mock ()
23+ benchmark_date = {
24+ 'runs' : [
25+ {'date' : '2023-09-30' , 'run_name' : 'SDGym_results_09_30_2023' },
26+ ]
27+ }
28+ mock_get_s3_client .return_value = mock_s3_client
29+ mock_s3_client .get_object .return_value = {
30+ 'Body' : Mock (read = lambda : json .dumps (benchmark_date ).encode ('utf-8' ))
31+ }
32+ expected_data = {
33+ 'runs' : [
34+ {'date' : '2023-09-30' , 'run_name' : 'SDGym_results_09_30_2023' },
35+ {'date' : date , 'run_name' : 'SDGym_results_10_01_2023' },
36+ ]
37+ }
38+
39+ # Run
40+ append_benchmark_run (aws_access_key_id , aws_secret_access_key , date )
41+
42+ # Assert
43+ mock_get_s3_client .assert_called_once_with (
44+ aws_access_key_id = aws_access_key_id ,
45+ aws_secret_access_key = aws_secret_access_key ,
46+ )
47+ mock_parse_s3_path .assert_called_once_with (OUTPUT_DESTINATION_AWS )
48+ mock_get_run_name .assert_called_once_with (date )
49+ mock_s3_client .get_object .assert_called_once_with (
50+ Bucket = 'my-bucket' , Key = 'my-prefix/_BENCHMARK_DATES.json'
51+ )
52+ mock_s3_client .put_object .assert_called_once_with (
53+ Bucket = 'my-bucket' ,
54+ Key = 'my-prefix/_BENCHMARK_DATES.json' ,
55+ Body = json .dumps (expected_data ).encode ('utf-8' ),
56+ )
57+
58+
59+ @patch ('sdgym._run_benchmark.run_benchmark.get_s3_client' )
60+ @patch ('sdgym._run_benchmark.run_benchmark.parse_s3_path' )
61+ @patch ('sdgym._run_benchmark.run_benchmark.get_run_name' )
62+ def test_append_benchmark_run_new_file (mock_get_run_name , mock_parse_s3_path , mock_get_s3_client ):
63+ """Test the `append_benchmark_run` with a new file."""
64+ # Setup
65+ aws_access_key_id = 'my_access_key'
66+ aws_secret_access_key = 'my_secret_key'
67+ date = '2023-10-01'
68+ mock_get_run_name .return_value = 'SDGym_results_10_01_2023'
69+ mock_parse_s3_path .return_value = ('my-bucket' , 'my-prefix/' )
70+ mock_s3_client = Mock ()
71+ mock_get_s3_client .return_value = mock_s3_client
72+ mock_s3_client .get_object .side_effect = ClientError (
73+ {'Error' : {'Code' : 'NoSuchKey' }}, 'GetObject'
74+ )
75+ expected_data = {
76+ 'runs' : [
77+ {'date' : date , 'run_name' : 'SDGym_results_10_01_2023' },
78+ ]
79+ }
80+
81+ # Run
82+ append_benchmark_run (aws_access_key_id , aws_secret_access_key , date )
83+
84+ # Assert
85+ mock_get_s3_client .assert_called_once_with (
86+ aws_access_key_id = aws_access_key_id ,
87+ aws_secret_access_key = aws_secret_access_key ,
88+ )
89+ mock_parse_s3_path .assert_called_once_with (OUTPUT_DESTINATION_AWS )
90+ mock_get_run_name .assert_called_once_with (date )
91+ mock_s3_client .get_object .assert_called_once_with (
92+ Bucket = 'my-bucket' , Key = 'my-prefix/_BENCHMARK_DATES.json'
93+ )
94+ mock_s3_client .put_object .assert_called_once_with (
95+ Bucket = 'my-bucket' ,
96+ Key = 'my-prefix/_BENCHMARK_DATES.json' ,
97+ Body = json .dumps (expected_data ).encode ('utf-8' ),
98+ )
699
7100
8101@patch ('sdgym._run_benchmark.run_benchmark.benchmark_single_table_aws' )
@@ -21,14 +114,13 @@ def test_main(mock_append_benchmark_run, mock_getenv, mock_benchmark_single_tabl
21114 mock_getenv .assert_any_call ('AWS_ACCESS_KEY_ID' )
22115 mock_getenv .assert_any_call ('AWS_SECRET_ACCESS_KEY' )
23116 expected_calls = []
24- for synthesizer in [ 'GaussianCopulaSynthesizer' , 'TVAESynthesizer' ] :
117+ for synthesizer in SYNTHESIZERS :
25118 expected_calls .append (
26119 call (
27120 output_destination = OUTPUT_DESTINATION_AWS ,
28121 aws_access_key_id = 'my_access_key' ,
29122 aws_secret_access_key = 'my_secret_key' ,
30123 synthesizers = [synthesizer ],
31- sdv_datasets = ['expedia_hotel_logs' , 'fake_companies' ],
32124 compute_privacy_score = False ,
33125 )
34126 )
0 commit comments