Skip to content

Commit 2dc218b

Browse files
committed
fix benchmark
1 parent 180aa00 commit 2dc218b

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

sdgym/benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,15 +1191,17 @@ def _validate_aws_inputs(output_destination, aws_access_key_id, aws_secret_acces
11911191
if not bucket_name:
11921192
raise ValueError(f'Invalid S3 URL: {output_destination}')
11931193

1194+
config = Config(connect_timeout=30, read_timeout=300)
11941195
if aws_access_key_id and aws_secret_access_key:
11951196
s3_client = boto3.client(
11961197
's3',
11971198
aws_access_key_id=aws_access_key_id,
11981199
aws_secret_access_key=aws_secret_access_key,
1200+
config=config,
11991201
)
12001202
else:
12011203
# No credentials provided — rely on default session
1202-
s3_client = boto3.client('s3')
1204+
s3_client = boto3.client('s3', config=config)
12031205

12041206
s3_client.head_bucket(Bucket=bucket_name)
12051207
if not _check_write_permissions(s3_client, bucket_name):
@@ -1425,14 +1427,12 @@ def benchmark_single_table_aws(
14251427
pandas.DataFrame:
14261428
A table containing one row per synthesizer + dataset + metric.
14271429
"""
1428-
config = Config(connect_timeout=30, read_timeout=300)
14291430
s3_client = _validate_output_destination(
14301431
output_destination,
14311432
aws_keys={
14321433
'aws_access_key_id': aws_access_key_id,
14331434
'aws_secret_access_key': aws_secret_access_key,
14341435
},
1435-
config=config,
14361436
)
14371437
job_args_list = _generate_job_args_list(
14381438
limit_dataset_size=limit_dataset_size,

tests/unit/test_benchmark.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,12 @@ def test_setup_output_destination_aws(mock_get_run_id_increment):
542542

543543
@patch('sdgym.benchmark.boto3.client')
544544
@patch('sdgym.benchmark._check_write_permissions')
545-
def test_validate_aws_inputs_valid(mock_check_write_permissions, mock_boto3_client):
545+
@patch('sdgym.benchmark.Config')
546+
def test_validate_aws_inputs_valid(mock_config, mock_check_write_permissions, mock_boto3_client):
546547
"""Test `_validate_aws_inputs` with valid inputs and credentials."""
547548
# Setup
549+
config_mock = Mock()
550+
mock_config.return_value = config_mock
548551
valid_url = 's3://my-bucket/some/path'
549552
s3_client_mock = Mock()
550553
mock_boto3_client.return_value = s3_client_mock
@@ -557,7 +560,7 @@ def test_validate_aws_inputs_valid(mock_check_write_permissions, mock_boto3_clie
557560

558561
# Assert
559562
mock_boto3_client.assert_called_once_with(
560-
's3', aws_access_key_id='AKIA...', aws_secret_access_key='SECRET'
563+
's3', aws_access_key_id='AKIA...', aws_secret_access_key='SECRET', config=config_mock
561564
)
562565
s3_client_mock.head_bucket.assert_called_once_with(Bucket='my-bucket')
563566
mock_check_write_permissions.assert_called_once_with(s3_client_mock, 'my-bucket')

0 commit comments

Comments
 (0)