Skip to content

Commit daf4511

Browse files
authored
Merge branch 'main' into issue-2758-metadata-visualize-color
2 parents 8c7a2b5 + 7b10f34 commit daf4511

File tree

5 files changed

+381
-21
lines changed

5 files changed

+381
-21
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ ignore = [
219219
# pydocstyle
220220
"D107", # Missing docstring in __init__
221221
"D417", # Missing argument descriptions in the docstring, this is a bug from pydocstyle: https://github.com/PyCQA/pydocstyle/issues/449
222-
"PD901",
223222
"PD101",
224223
]
225224

sdv/datasets/demo.py

Lines changed: 122 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import os
77
import warnings
88
from collections import defaultdict
9+
from functools import wraps
910
from pathlib import Path
1011
from zipfile import ZipFile
1112

1213
import boto3
14+
import botocore
1315
import numpy as np
1416
import pandas as pd
1517
import yaml
@@ -57,6 +59,10 @@ def _get_data_from_bucket(object_key, bucket, client):
5759
return response['Body'].read()
5860

5961

62+
def _get_dataset_name_from_prefix(prefix):
63+
return prefix.split('/')[1]
64+
65+
6066
def _list_objects(prefix, bucket, client):
6167
"""List all objects under a given prefix using pagination.
6268
@@ -78,7 +84,21 @@ def _list_objects(prefix, bucket, client):
7884
contents.extend(resp.get('Contents', []))
7985

8086
if not contents:
81-
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{bucket}'.")
87+
prefix_parts = prefix.split('/')
88+
modality = prefix_parts[0]
89+
dataset_name = _get_dataset_name_from_prefix(prefix)
90+
if dataset_name:
91+
raise DemoResourceNotFoundError(
92+
f"Could not download dataset '{dataset_name}' from bucket '{bucket}'. "
93+
'Make sure the bucket name is correct. If the bucket is private '
94+
'make sure to provide your credentials.'
95+
)
96+
else:
97+
raise DemoResourceNotFoundError(
98+
f"Could not list datasets in modality '{modality}' from bucket '{bucket}'. "
99+
'Make sure the bucket name is correct. If the bucket is private '
100+
'make sure to provide your credentials.'
101+
)
82102

83103
return contents
84104

@@ -108,7 +128,7 @@ def _search_contents_keys(contents, match_fn):
108128
return matches
109129

110130

111-
def _find_data_zip_key(contents, dataset_prefix):
131+
def _find_data_zip_key(contents, dataset_prefix, bucket):
112132
"""Find the 'data.zip' object key under dataset prefix, case-insensitive.
113133
114134
Args:
@@ -130,7 +150,11 @@ def is_data_zip(key):
130150
if matches:
131151
return matches[0]
132152

133-
raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.")
153+
dataset_name = _get_dataset_name_from_prefix(dataset_prefix)
154+
raise DemoResourceNotFoundError(
155+
f"Could not download dataset '{dataset_name}' from bucket '{bucket}'. "
156+
"The dataset is missing 'data.zip' file."
157+
)
134158

135159

136160
def _get_first_v1_metadata_bytes(contents, dataset_prefix, bucket, client):
@@ -166,11 +190,86 @@ def is_direct_json_under_prefix(key):
166190
except Exception:
167191
continue
168192

193+
dataset_name = _get_dataset_name_from_prefix(dataset_prefix)
169194
raise DemoResourceNotFoundError(
170-
'Could not find a valid metadata JSON with METADATA_SPEC_VERSION "V1".'
195+
f"Could not download dataset '{dataset_name}' from bucket '{bucket}'. "
196+
'The dataset is missing a valid metadata.'
171197
)
172198

173199

200+
def _download_text_file_error_message(
201+
modality,
202+
dataset_name,
203+
output_filepath=None,
204+
bucket=PUBLIC_BUCKET,
205+
filename=None,
206+
**kwargs,
207+
):
208+
return (
209+
f"Could not retrieve '{filename}' for dataset '{dataset_name}' "
210+
f"from bucket '{bucket}'. "
211+
'Make sure the bucket name is correct. If the bucket is private '
212+
'make sure to provide your credentials.'
213+
)
214+
215+
216+
def _download_error_message(
217+
modality,
218+
dataset_name,
219+
output_folder_name=None,
220+
s3_bucket_name=PUBLIC_BUCKET,
221+
credentials=None,
222+
**kwargs,
223+
):
224+
return (
225+
f"Could not download dataset '{dataset_name}' from bucket '{s3_bucket_name}'. "
226+
'Make sure the bucket name is correct. If the bucket is private '
227+
'make sure to provide your credentials.'
228+
)
229+
230+
231+
def _list_modality_error_message(modality, s3_bucket_name, **kwargs):
232+
return (
233+
f"Could not list datasets in modality '{modality}' from bucket '{s3_bucket_name}'. "
234+
'Make sure the bucket name is correct. If the bucket is private '
235+
'make sure to provide your credentials.'
236+
)
237+
238+
239+
def handle_aws_client_errors(error_message_builder):
240+
"""Decorate a function to translate AWS client errors into more descriptive errors.
241+
242+
This decorator catches ``botocore.exceptions.ClientError`` raised by the wrapped
243+
function and re-raises it as a ``DemoResourceNotFoundError`` with a custom error
244+
message. The error message is generated dynamically using the provided
245+
``error_message_builder`` function.
246+
247+
Args:
248+
error_message_builder (Callable):
249+
A callable that receives the same ``*args`` and ``**kwargs`` as the wrapped
250+
function and returns an error message.
251+
252+
Returns:
253+
func:
254+
A wrapped function.
255+
"""
256+
257+
def decorator(func):
258+
@wraps(func)
259+
def wrapper(*args, **kwargs):
260+
try:
261+
function_result = func(*args, **kwargs)
262+
except botocore.exceptions.ClientError as error:
263+
message = error_message_builder(*args, **kwargs)
264+
raise DemoResourceNotFoundError(message) from error
265+
266+
return function_result
267+
268+
return wrapper
269+
270+
return decorator
271+
272+
174273
def _download(modality, dataset_name, bucket, credentials=None):
175274
"""Download dataset resources from a bucket.
176275
@@ -186,7 +285,7 @@ def _download(modality, dataset_name, bucket, credentials=None):
186285
f'{bucket_url}/{dataset_prefix}'
187286
)
188287
contents = _list_objects(dataset_prefix, bucket=bucket, client=client)
189-
zip_key = _find_data_zip_key(contents, dataset_prefix)
288+
zip_key = _find_data_zip_key(contents, dataset_prefix, bucket)
190289
zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client)
191290
metadata_bytes = _get_first_v1_metadata_bytes(
192291
contents, dataset_prefix, bucket=bucket, client=client
@@ -262,7 +361,7 @@ def _get_data_without_output_folder(in_memory_directory):
262361
return data, skipped_files
263362

264363

265-
def _get_data(modality, output_folder_name, in_memory_directory):
364+
def _get_data(modality, output_folder_name, in_memory_directory, bucket, dataset_name):
266365
if output_folder_name:
267366
data, skipped_files = _get_data_with_output_folder(output_folder_name)
268367
else:
@@ -273,7 +372,8 @@ def _get_data(modality, output_folder_name, in_memory_directory):
273372

274373
if not data:
275374
raise DemoResourceNotFoundError(
276-
'Demo data could not be downloaded because no csv files were found in data.zip'
375+
f"Could not download dataset '{dataset_name}' from bucket '{bucket}'. "
376+
'The dataset is missing `csv` file/s.'
277377
)
278378

279379
if modality != 'multi_table':
@@ -301,7 +401,10 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
301401
metadict = json.loads(metadata_bytes)
302402
metadata = Metadata().load_from_dict(metadict, dataset_name)
303403
except Exception as e:
304-
raise DemoResourceNotFoundError('Failed to parse metadata JSON for the dataset.') from e
404+
raise DemoResourceNotFoundError(
405+
f"Could not parse the metadata for dataset '{dataset_name}'. "
406+
'The dataset is missing a valid metadata file.'
407+
) from e
305408

306409
if output_folder_name:
307410
try:
@@ -321,6 +424,7 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
321424
return metadata
322425

323426

427+
@handle_aws_client_errors(_download_error_message)
324428
def download_demo(
325429
modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
326430
):
@@ -362,7 +466,13 @@ def download_demo(
362466

363467
data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
364468
in_memory_directory = _extract_data(data_io, output_folder_name)
365-
data = _get_data(modality, output_folder_name, in_memory_directory)
469+
data = _get_data(
470+
modality,
471+
output_folder_name,
472+
in_memory_directory,
473+
s3_bucket_name,
474+
dataset_name,
475+
)
366476
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
367477

368478
return data, metadata
@@ -428,6 +538,7 @@ def _parse_num_tables(num_tables_val, dataset_name):
428538
return np.nan
429539

430540

541+
@handle_aws_client_errors(_list_modality_error_message)
431542
def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None):
432543
"""Get demo datasets available for a ``modality``.
433544
@@ -545,6 +656,7 @@ def _save_document(text, output_filepath, filename, dataset_name):
545656
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
546657

547658

659+
@handle_aws_client_errors(_download_text_file_error_message)
548660
def _get_text_file_content(
549661
modality, dataset_name, filename, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
550662
):
@@ -595,6 +707,7 @@ def _get_text_file_content(
595707
return text
596708

597709

710+
@handle_aws_client_errors(_download_text_file_error_message)
598711
def get_source(
599712
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
600713
):

sdv/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ class DemoResourceNotFoundError(Exception):
104104
metadata, license, README, or other auxiliary files in the demo bucket.
105105
"""
106106

107+
def __init__(self, message):
108+
self.message = message
109+
super().__init__(self.message)
110+
107111

108112
class DemoResourceNotFoundWarning(UserWarning):
109113
"""Warning raised when an optional demo resource is not available.

tests/integration/datasets/test_demo.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
2+
import pytest
23

3-
from sdv.datasets.demo import get_available_demos
4+
from sdv.datasets.demo import download_demo, get_available_demos
5+
from sdv.metadata import Metadata
46

57

68
def test_get_available_demos_single_table():
@@ -79,3 +81,40 @@ def test_get_available_demos_multi_table():
7981
assert not tables_info.empty
8082
assert (tables_info['num_tables'] > 1).all()
8183
assert (tables_info['size_MB'] >= 0).all()
84+
85+
86+
@pytest.mark.parametrize('output_path', [None, 'tmp_path'])
87+
def test_download_demo_single_table(output_path, tmp_path):
88+
"""Test that the `download_demo` function works as intended for single-table."""
89+
# Run
90+
output_folder_name = tmp_path / 'sdv' if output_path else None
91+
data, metadata = download_demo(
92+
modality='single_table',
93+
dataset_name='fake_hotel_guests',
94+
output_folder_name=output_folder_name,
95+
)
96+
97+
# Assert
98+
metadata.validate_data({'fake_hotel_guests': data})
99+
assert len(data) > 1
100+
assert isinstance(metadata, Metadata)
101+
102+
103+
@pytest.mark.parametrize('output_path', [None, 'tmp_path'])
104+
def test_download_demo_multi_table(output_path, tmp_path):
105+
"""Test that the `download_demo` function works as intended for multi-table."""
106+
# Run
107+
output_folder_name = tmp_path / 'sdv' if output_path else None
108+
data, metadata = download_demo(
109+
modality='multi_table',
110+
dataset_name='fake_hotels',
111+
output_folder_name=output_folder_name,
112+
)
113+
114+
# Assert
115+
metadata.validate_data(data)
116+
expected_tables = ['hotels', 'guests']
117+
assert set(expected_tables) == set(data)
118+
assert isinstance(metadata, Metadata)
119+
assert len(data['hotels']) > 1
120+
assert len(data['guests']) > 1

0 commit comments

Comments
 (0)