Skip to content

Commit 1821a16

Browse files
Add bucket and credentials parameter to download_demo (#2748)
Co-authored-by: Plamen Valentinov Kolev <[email protected]>
1 parent c3f97e5 commit 1821a16

File tree

3 files changed

+259
-62
lines changed

3 files changed

+259
-62
lines changed

sdv/datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
"""Dataset loading and managing module."""
2+
3+
from sdv.datasets import demo, local
4+
5+
__all__ = ['demo', 'local']

sdv/datasets/demo.py

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from sdv.metadata.metadata import Metadata
2121

2222
LOGGER = logging.getLogger(__name__)
23-
BUCKET = 'sdv-datasets-public'
24-
BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com'
23+
PUBLIC_BUCKET = 'sdv-datasets-public'
2524
SIGNATURE_VERSION = UNSIGNED
2625
METADATA_FILENAME = 'metadata.json'
2726
FALLBACK_ENCODING = 'latin-1'
@@ -41,36 +40,45 @@ def _validate_output_folder(output_folder_name):
4140
)
4241

4342

44-
def _create_s3_client():
43+
def _create_s3_client(bucket, credentials=None):
4544
"""Create and return an S3 client with unsigned requests."""
45+
if bucket != PUBLIC_BUCKET:
46+
raise ValueError('Private buckets are only supported in SDV Enterprise.')
47+
if credentials is not None:
48+
raise ValueError(
49+
'DataCebo credentials for private buckets are only supported in SDV Enterprise.'
50+
)
51+
4652
return boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION))
4753

4854

49-
def _get_data_from_bucket(object_key):
50-
s3 = _create_s3_client()
51-
response = s3.get_object(Bucket=BUCKET, Key=object_key)
55+
def _get_data_from_bucket(object_key, bucket, client):
56+
response = client.get_object(Bucket=bucket, Key=object_key)
5257
return response['Body'].read()
5358

5459

55-
def _list_objects(prefix):
60+
def _list_objects(prefix, bucket, client):
5661
"""List all objects under a given prefix using pagination.
5762
5863
Args:
5964
prefix (str):
6065
The S3 prefix to list.
66+
bucket (str):
67+
The name of the bucket to list objects of.
68+
client (botocore.client.S3):
69+
S3 client.
6170
6271
Returns:
6372
list[dict]:
6473
A list of object summaries.
6574
"""
66-
client = _create_s3_client()
6775
contents = []
6876
paginator = client.get_paginator('list_objects_v2')
69-
for resp in paginator.paginate(Bucket=BUCKET, Prefix=prefix):
77+
for resp in paginator.paginate(Bucket=bucket, Prefix=prefix):
7078
contents.extend(resp.get('Contents', []))
7179

7280
if not contents:
73-
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{BUCKET}'.")
81+
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{bucket}'.")
7482

7583
return contents
7684

@@ -125,7 +133,7 @@ def is_data_zip(key):
125133
raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.")
126134

127135

128-
def _get_first_v1_metadata_bytes(contents, dataset_prefix):
136+
def _get_first_v1_metadata_bytes(contents, dataset_prefix, bucket, client):
129137
"""Find and return bytes of the first V1 metadata JSON under `dataset_prefix`.
130138
131139
Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix,
@@ -150,7 +158,7 @@ def is_direct_json_under_prefix(key):
150158

151159
for key in candidate_keys:
152160
try:
153-
raw = _get_data_from_bucket(key)
161+
raw = _get_data_from_bucket(key, bucket=bucket, client=client)
154162
metadict = json.loads(raw)
155163
if isinstance(metadict, dict) and metadict.get('METADATA_SPEC_VERSION') == 'V1':
156164
return raw
@@ -163,23 +171,26 @@ def is_direct_json_under_prefix(key):
163171
)
164172

165173

166-
def _download(modality, dataset_name):
174+
def _download(modality, dataset_name, bucket, credentials=None):
167175
"""Download dataset resources from a bucket.
168176
169177
Returns:
170178
tuple:
171179
(BytesIO(zip_bytes), metadata_bytes)
172180
"""
181+
client = _create_s3_client(bucket=bucket, credentials=credentials)
173182
dataset_prefix = f'{modality}/{dataset_name}/'
183+
bucket_url = f'https://{bucket}.s3.amazonaws.com'
174184
LOGGER.info(
175185
f"Downloading dataset '{dataset_name}' for modality '{modality}' from "
176-
f'{BUCKET_URL}/{dataset_prefix}'
186+
f'{bucket_url}/{dataset_prefix}'
177187
)
178-
contents = _list_objects(dataset_prefix)
179-
188+
contents = _list_objects(dataset_prefix, bucket=bucket, client=client)
180189
zip_key = _find_data_zip_key(contents, dataset_prefix)
181-
zip_bytes = _get_data_from_bucket(zip_key)
182-
metadata_bytes = _get_first_v1_metadata_bytes(contents, dataset_prefix)
190+
zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client)
191+
metadata_bytes = _get_first_v1_metadata_bytes(
192+
contents, dataset_prefix, bucket=bucket, client=client
193+
)
183194

184195
return io.BytesIO(zip_bytes), metadata_bytes
185196

@@ -310,7 +321,9 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
310321
return metadata
311322

312323

313-
def download_demo(modality, dataset_name, output_folder_name=None):
324+
def download_demo(
325+
modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
326+
):
314327
"""Download a demo dataset.
315328
316329
Args:
@@ -322,6 +335,15 @@ def download_demo(modality, dataset_name, output_folder_name=None):
322335
The name of the local folder where the metadata and data should be stored.
323336
If ``None`` the data is not saved locally and is loaded as a Python object.
324337
Defaults to ``None``.
338+
s3_bucket_name (str):
339+
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
340+
SDV Community. SDV Enterprise is required for other buckets.
341+
credentials (dict):
342+
Dictionary containing DataCebo license key and username. It takes the form:
343+
{
344+
'username': '[email protected]',
345+
'license_key': '<MY_LICENSE_KEY>'
346+
}
325347
326348
Returns:
327349
tuple (data, metadata):
@@ -338,7 +360,7 @@ def download_demo(modality, dataset_name, output_folder_name=None):
338360
_validate_modalities(modality)
339361
_validate_output_folder(output_folder_name)
340362

341-
data_io, metadata_bytes = _download(modality, dataset_name)
363+
data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
342364
in_memory_directory = _extract_data(data_io, output_folder_name)
343365
data = _get_data(modality, output_folder_name, in_memory_directory)
344366
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
@@ -368,9 +390,9 @@ def is_metainfo_yaml(key):
368390
yield dataset_name, key
369391

370392

371-
def _get_info_from_yaml_key(yaml_key):
393+
def _get_info_from_yaml_key(yaml_key, bucket, client):
372394
"""Load and parse YAML metadata from an S3 key."""
373-
raw = _get_data_from_bucket(yaml_key)
395+
raw = _get_data_from_bucket(yaml_key, bucket=bucket, client=client)
374396
return yaml.safe_load(raw) or {}
375397

376398

@@ -406,12 +428,21 @@ def _parse_num_tables(num_tables_val, dataset_name):
406428
return np.nan
407429

408430

409-
def get_available_demos(modality):
431+
def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None):
410432
"""Get demo datasets available for a ``modality``.
411433
412434
Args:
413435
modality (str):
414436
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
437+
s3_bucket_name (str):
438+
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
439+
SDV Community. SDV Enterprise is required for other buckets.
440+
credentials (dict):
441+
Dictionary containing DataCebo license key and username. It takes the form:
442+
{
443+
'username': '[email protected]',
444+
'license_key': '<MY_LICENSE_KEY>'
445+
}
415446
416447
Returns:
417448
pandas.DataFrame:
@@ -421,11 +452,12 @@ def get_available_demos(modality):
421452
``num_tables``: The number of tables in the dataset.
422453
"""
423454
_validate_modalities(modality)
424-
contents = _list_objects(f'{modality}/')
455+
s3_client = _create_s3_client(bucket=s3_bucket_name, credentials=credentials)
456+
contents = _list_objects(f'{modality}/', bucket=s3_bucket_name, client=s3_client)
425457
tables_info = defaultdict(list)
426458
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
427459
try:
428-
info = _get_info_from_yaml_key(yaml_key)
460+
info = _get_info_from_yaml_key(yaml_key, bucket=s3_bucket_name, client=s3_client)
429461

430462
size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name)
431463
num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name)
@@ -513,7 +545,9 @@ def _save_document(text, output_filepath, filename, dataset_name):
513545
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')
514546

515547

516-
def _get_text_file_content(modality, dataset_name, filename, output_filepath=None):
548+
def _get_text_file_content(
549+
modality, dataset_name, filename, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
550+
):
517551
"""Fetch text file content under the dataset prefix.
518552
519553
Args:
@@ -525,6 +559,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
525559
The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``).
526560
output_filepath (str or None):
527561
If provided, save the file contents at this path.
562+
bucket (str):
563+
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
564+
SDV Community. SDV Enterprise is required for other buckets.
565+
credentials (dict):
566+
Dictionary containing DataCebo license key and username. It takes the form:
567+
{
568+
'username': '[email protected]',
569+
'license_key': '<MY_LICENSE_KEY>'
570+
}
528571
529572
Returns:
530573
str or None:
@@ -533,14 +576,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
533576
_validate_text_file_content(modality, output_filepath, filename)
534577

535578
dataset_prefix = f'{modality}/{dataset_name}/'
536-
contents = _list_objects(dataset_prefix)
579+
s3_client = _create_s3_client(bucket=bucket, credentials=credentials)
580+
contents = _list_objects(dataset_prefix, bucket=bucket, client=s3_client)
537581
key = _find_text_key(contents, dataset_prefix, filename)
538582
if not key:
539583
_raise_warnings(filename, output_filepath)
540584
return None
541585

542586
try:
543-
raw = _get_data_from_bucket(key)
587+
raw = _get_data_from_bucket(key, bucket=bucket, client=s3_client)
544588
except Exception:
545589
LOGGER.info(f'Error fetching {filename} for dataset {dataset_name}.')
546590
return None
@@ -551,7 +595,9 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
551595
return text
552596

553597

554-
def get_source(modality, dataset_name, output_filepath=None):
598+
def get_source(
599+
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
600+
):
555601
"""Get dataset source/citation text.
556602
557603
Args:
@@ -561,15 +607,33 @@ def get_source(modality, dataset_name, output_filepath=None):
561607
The name of the dataset to get the source information for.
562608
output_filepath (str or None):
563609
Optional path where to save the file.
610+
s3_bucket_name (str):
611+
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
612+
SDV Community. SDV Enterprise is required for other buckets.
613+
credentials (dict):
614+
Dictionary containing DataCebo license key and username. It takes the form:
615+
{
616+
'username': '[email protected]',
617+
'license_key': '<MY_LICENSE_KEY>'
618+
}
564619
565620
Returns:
566621
str or None:
567622
The contents of the source file if it exists; otherwise ``None``.
568623
"""
569-
return _get_text_file_content(modality, dataset_name, 'SOURCE.txt', output_filepath)
624+
return _get_text_file_content(
625+
modality=modality,
626+
dataset_name=dataset_name,
627+
filename='SOURCE.txt',
628+
output_filepath=output_filepath,
629+
bucket=s3_bucket_name,
630+
credentials=credentials,
631+
)
570632

571633

572-
def get_readme(modality, dataset_name, output_filepath=None):
634+
def get_readme(
635+
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
636+
):
573637
"""Get dataset README text.
574638
575639
Args:
@@ -579,9 +643,25 @@ def get_readme(modality, dataset_name, output_filepath=None):
579643
The name of the dataset to get the README for.
580644
output_filepath (str or None):
581645
Optional path where to save the file.
646+
s3_bucket_name (str):
647+
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
648+
SDV Community. SDV Enterprise is required for other buckets.
649+
credentials (dict):
650+
Dictionary containing DataCebo license key and username. It takes the form:
651+
{
652+
'username': '[email protected]',
653+
'license_key': '<MY_LICENSE_KEY>'
654+
}
582655
583656
Returns:
584657
str or None:
585658
The contents of the README file if it exists; otherwise ``None``.
586659
"""
587-
return _get_text_file_content(modality, dataset_name, 'README.txt', output_filepath)
660+
return _get_text_file_content(
661+
modality=modality,
662+
dataset_name=dataset_name,
663+
filename='README.txt',
664+
output_filepath=output_filepath,
665+
bucket=s3_bucket_name,
666+
credentials=credentials,
667+
)

0 commit comments

Comments
 (0)