Skip to content

Commit 9f133a5

Browse files
committed
Wrap the errors that yield from AWS client
1 parent 1676167 commit 9f133a5

File tree

2 files changed

+195
-1
lines changed

2 files changed

+195
-1
lines changed

sdv/datasets/demo.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import os
77
import warnings
88
from collections import defaultdict
9+
from functools import wraps
910
from pathlib import Path
1011
from zipfile import ZipFile
1112

1213
import boto3
14+
import botocore
1315
import numpy as np
1416
import pandas as pd
1517
import yaml
@@ -195,6 +197,63 @@ def is_direct_json_under_prefix(key):
195197
)
196198

197199

200+
def _download_error_message(
201+
modality,
202+
dataset_name,
203+
output_folder_name=None,
204+
s3_bucket_name=PUBLIC_BUCKET,
205+
credentials=None,
206+
**kwargs,
207+
):
208+
return (
209+
f"Could not download dataset '{dataset_name}' from bucket '{s3_bucket_name}'. "
210+
'Make sure the bucket name is correct. If the bucket is private '
211+
'make sure to provide your credentials.'
212+
)
213+
214+
215+
def _list_modality_error_message(modality, s3_bucket_name, **kwargs):
216+
return (
217+
f"Could not list datasets in modality '{modality}' from bucket '{s3_bucket_name}'. "
218+
'Make sure the bucket name is correct. If the bucket is private '
219+
'make sure to provide your credentials.'
220+
)
221+
222+
223+
def handle_aws_client_errors(error_message_builder):
224+
"""Decorate a function to translate AWS client errors into more descriptive errors.
225+
226+
This decorator catches ``botocore.exceptions.ClientError`` raised by the wrapped
227+
function and re-raises it as a ``DemoResourceNotFoundError`` with a custom error
228+
message. The error message is generated dynamically using the provided
229+
``error_message_builder`` function.
230+
231+
Args:
232+
error_message_builder (Callable):
233+
A callable that receives the same ``*args`` and ``**kwargs`` as the wrapped
234+
function and returns an error message.
235+
236+
Returns:
237+
func:
238+
A wrapped function.
239+
"""
240+
241+
def decorator(func):
242+
@wraps(func)
243+
def wrapper(*args, **kwargs):
244+
try:
245+
function_result = func(*args, **kwargs)
246+
except botocore.exceptions.ClientError as error:
247+
message = error_message_builder(*args, **kwargs)
248+
raise DemoResourceNotFoundError(message) from error
249+
250+
return function_result
251+
252+
return wrapper
253+
254+
return decorator
255+
256+
198257
def _download(modality, dataset_name, bucket, credentials=None):
199258
"""Download dataset resources from a bucket.
200259
@@ -349,6 +408,7 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
349408
return metadata
350409

351410

411+
@handle_aws_client_errors(_download_error_message)
352412
def download_demo(
353413
modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
354414
):
@@ -462,6 +522,7 @@ def _parse_num_tables(num_tables_val, dataset_name):
462522
return np.nan
463523

464524

525+
@handle_aws_client_errors(_list_modality_error_message)
465526
def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None):
466527
"""Get demo datasets available for a ``modality``.
467528
@@ -629,6 +690,7 @@ def _get_text_file_content(
629690
return text
630691

631692

693+
@handle_aws_client_errors(_download_error_message)
632694
def get_source(
633695
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
634696
):
@@ -665,6 +727,7 @@ def get_source(
665727
)
666728

667729

730+
@handle_aws_client_errors(_download_error_message)
668731
def get_readme(
669732
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
670733
):

tests/unit/datasets/test_demo.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010
import pytest
11+
from botocore.exceptions import ClientError
1112

1213
from sdv.datasets.demo import (
1314
_download,
@@ -1365,7 +1366,7 @@ def test__list_objects_raises_when_no_contents_and_dataset_found():
13651366
_list_objects(prefix='single_table/mydataset/', bucket='bucket', client=mock_client)
13661367

13671368

1368-
def test_list_objects_raises_when_no_contents_and_no_dataset():
1369+
def test__list_objects_raises_when_no_contents_and_no_dataset():
13691370
"""Test that `_list_objects` raise a modality-specific error when dataset name is unknown."""
13701371
# Setup
13711372
mock_client = Mock()
@@ -1380,3 +1381,133 @@ def test_list_objects_raises_when_no_contents_and_no_dataset():
13801381
)
13811382
with pytest.raises(DemoResourceNotFoundError, match=error_msg):
13821383
_list_objects(prefix='single_table/', bucket='bucket', client=mock_client)
1384+
1385+
1386+
@patch('sdv.datasets.demo._create_s3_client')
1387+
def test_download_with_client_error(mock__create_s3_client):
1388+
"""Raise DemoResourceNotFoundError when an AWS ClientError occurs during dataset download."""
1389+
# Setup
1390+
client = Mock()
1391+
client.get_paginator.side_effect = ClientError(
1392+
error_response={
1393+
'Error': {'Code': 'AccessDenied', 'Message': 'Access Denied'},
1394+
'ResponseMetadata': {'HTTPStatusCode': 403},
1395+
},
1396+
operation_name='ListObjectsV2',
1397+
)
1398+
mock__create_s3_client.return_value = client
1399+
1400+
# Run and Assert
1401+
error_msg = (
1402+
"Could not download dataset 'fake_hotels' from bucket 'private_bucket'. "
1403+
'Make sure the bucket name is correct. If the bucket is private '
1404+
'make sure to provide your credentials.'
1405+
)
1406+
with pytest.raises(DemoResourceNotFoundError, match=error_msg):
1407+
download_demo(
1408+
'single_table',
1409+
'fake_hotels',
1410+
None,
1411+
'private_bucket',
1412+
)
1413+
1414+
1415+
@patch('sdv.datasets.demo._create_s3_client')
1416+
def test_get_available_demos_with_client_error(mock__create_s3_client):
1417+
"""Raise `DemoResourceNotFoundError` when an AWS `ClientError` occurs while listing demos."""
1418+
# Setup
1419+
client = Mock()
1420+
client.get_paginator.side_effect = ClientError(
1421+
error_response={
1422+
'Error': {
1423+
'Code': 'AccessDenied',
1424+
'Message': 'Access Denied',
1425+
},
1426+
'ResponseMetadata': {
1427+
'HTTPStatusCode': 403,
1428+
},
1429+
},
1430+
operation_name='ListObjectsV2',
1431+
)
1432+
mock__create_s3_client.return_value = client
1433+
1434+
# Run and Assert
1435+
error_msg = (
1436+
"Could not list datasets in modality 'single_table' from bucket 'private_bucket'. "
1437+
'Make sure the bucket name is correct. If the bucket is private '
1438+
'make sure to provide your credentials.'
1439+
)
1440+
1441+
with pytest.raises(DemoResourceNotFoundError, match=error_msg):
1442+
get_available_demos(
1443+
modality='single_table',
1444+
s3_bucket_name='private_bucket',
1445+
)
1446+
1447+
1448+
@patch('sdv.datasets.demo._create_s3_client')
1449+
def test_get_source_with_client_error(mock__create_s3_client):
1450+
"""Raise DemoResourceNotFoundError when an AWS ClientError occurs while fetching SOURCE."""
1451+
# Setup
1452+
client = Mock()
1453+
client.get_paginator.side_effect = ClientError(
1454+
error_response={
1455+
'Error': {
1456+
'Code': 'AccessDenied',
1457+
'Message': 'Access Denied',
1458+
},
1459+
'ResponseMetadata': {
1460+
'HTTPStatusCode': 403,
1461+
},
1462+
},
1463+
operation_name='ListObjectsV2',
1464+
)
1465+
mock__create_s3_client.return_value = client
1466+
1467+
error_msg = (
1468+
"Could not download dataset 'fake_hotels' from bucket 'private_bucket'. "
1469+
'Make sure the bucket name is correct. If the bucket is private '
1470+
'make sure to provide your credentials.'
1471+
)
1472+
1473+
# Run and Assert
1474+
with pytest.raises(DemoResourceNotFoundError, match=error_msg):
1475+
get_source(
1476+
modality='single_table',
1477+
dataset_name='fake_hotels',
1478+
s3_bucket_name='private_bucket',
1479+
)
1480+
1481+
1482+
@patch('sdv.datasets.demo._create_s3_client')
1483+
def test_get_readme_with_client_error(mock__create_s3_client):
1484+
"""Raise `DemoResourceNotFoundError` when an AWS ClientError occurs while fetching README."""
1485+
# Setup
1486+
client = Mock()
1487+
client.get_paginator.side_effect = ClientError(
1488+
error_response={
1489+
'Error': {
1490+
'Code': 'AccessDenied',
1491+
'Message': 'Access Denied',
1492+
},
1493+
'ResponseMetadata': {
1494+
'HTTPStatusCode': 403,
1495+
},
1496+
},
1497+
operation_name='ListObjectsV2',
1498+
)
1499+
mock__create_s3_client.return_value = client
1500+
1501+
error_msg = (
1502+
"Could not download dataset 'fake_hotels' from bucket 'private_bucket'. "
1503+
'Make sure the bucket name is correct. If the bucket is private '
1504+
'make sure to provide your credentials.'
1505+
)
1506+
1507+
# Run and Assert
1508+
with pytest.raises(DemoResourceNotFoundError, match=error_msg):
1509+
get_readme(
1510+
modality='single_table',
1511+
dataset_name='fake_hotels',
1512+
s3_bucket_name='private_bucket',
1513+
)

0 commit comments

Comments
 (0)