Skip to content

Commit 51b033e

Browse files
committed
Adding bucket to s3 client creation
1 parent 67d60db commit 51b033e

File tree

2 files changed

+21
-27
lines changed

2 files changed

+21
-27
lines changed

sdv/datasets/demo.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,10 @@ def _validate_output_folder(output_folder_name):
4040
)
4141

4242

43-
def _validate_bucket(bucket):
44-
if bucket != PUBLIC_BUCKET:
45-
raise ValueError('Private buckets are only supported in SDV Enterprise.')
46-
47-
48-
def _create_s3_client(credentials=None):
43+
def _create_s3_client(credentials=None, bucket=None):
4944
"""Create and return an S3 client with unsigned requests."""
45+
if bucket is not None and bucket != PUBLIC_BUCKET:
46+
raise ValueError('Private buckets are only supported in SDV Enterprise.')
5047
if credentials is not None:
5148
raise ValueError(
5249
'DataCebo credentials for private buckets are only supported in SDV Enterprise.'
@@ -181,7 +178,7 @@ def _download(modality, dataset_name, bucket, credentials=None):
181178
tuple:
182179
(BytesIO(zip_bytes), metadata_bytes)
183180
"""
184-
client = _create_s3_client(credentials)
181+
client = _create_s3_client(credentials, bucket)
185182
dataset_prefix = f'{modality}/{dataset_name}/'
186183
bucket_url = f'https://{bucket}.s3.amazonaws.com'
187184
LOGGER.info(
@@ -325,7 +322,7 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
325322

326323

327324
def download_demo(
328-
modality, dataset_name, output_folder_name=None, bucket=PUBLIC_BUCKET, credentials=None
325+
modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
329326
):
330327
"""Download a demo dataset.
331328
@@ -338,7 +335,7 @@ def download_demo(
338335
The name of the local folder where the metadata and data should be stored.
339336
If ``None`` the data is not saved locally and is loaded as a Python object.
340337
Defaults to ``None``.
341-
bucket (str):
338+
s3_bucket_name (str):
342339
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
343340
SDV Community. SDV Enterprise is required for other buckets.
344341
credentials (str):
@@ -362,9 +359,8 @@ def download_demo(
362359
"""
363360
_validate_modalities(modality)
364361
_validate_output_folder(output_folder_name)
365-
_validate_bucket(bucket=bucket)
366362

367-
data_io, metadata_bytes = _download(modality, dataset_name, bucket, credentials)
363+
data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
368364
in_memory_directory = _extract_data(data_io, output_folder_name)
369365
data = _get_data(modality, output_folder_name, in_memory_directory)
370366
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
@@ -432,13 +428,13 @@ def _parse_num_tables(num_tables_val, dataset_name):
432428
return np.nan
433429

434430

435-
def get_available_demos(modality, bucket=PUBLIC_BUCKET, credentials=None):
431+
def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None):
436432
"""Get demo datasets available for a ``modality``.
437433
438434
Args:
439435
modality (str):
440436
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
441-
bucket (str):
437+
s3_bucket_name (str):
442438
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
443439
SDV Community. SDV Enterprise is required for other buckets.
444440
credentials (str):
@@ -456,13 +452,12 @@ def get_available_demos(modality, bucket=PUBLIC_BUCKET, credentials=None):
456452
``num_tables``: The number of tables in the dataset.
457453
"""
458454
_validate_modalities(modality)
459-
_validate_bucket(bucket=bucket)
460-
s3_client = _create_s3_client(credentials=credentials)
461-
contents = _list_objects(f'{modality}/', bucket=bucket, client=s3_client)
455+
s3_client = _create_s3_client(credentials=credentials, bucket=s3_bucket_name)
456+
contents = _list_objects(f'{modality}/', bucket=s3_bucket_name, client=s3_client)
462457
tables_info = defaultdict(list)
463458
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
464459
try:
465-
info = _get_info_from_yaml_key(yaml_key, bucket=bucket, client=s3_client)
460+
info = _get_info_from_yaml_key(yaml_key, bucket=s3_bucket_name, client=s3_client)
466461

467462
size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name)
468463
num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name)
@@ -579,10 +574,9 @@ def _get_text_file_content(
579574
The decoded text contents if the file exists, otherwise ``None``.
580575
"""
581576
_validate_text_file_content(modality, output_filepath, filename)
582-
_validate_bucket(bucket=bucket)
583577

584578
dataset_prefix = f'{modality}/{dataset_name}/'
585-
s3_client = _create_s3_client(credentials=credentials)
579+
s3_client = _create_s3_client(credentials=credentials, bucket=bucket)
586580
contents = _list_objects(dataset_prefix, bucket=bucket, client=s3_client)
587581
key = _find_text_key(contents, dataset_prefix, filename)
588582
if not key:
@@ -602,7 +596,7 @@ def _get_text_file_content(
602596

603597

604598
def get_source(
605-
modality, dataset_name, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
599+
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
606600
):
607601
"""Get dataset source/citation text.
608602
@@ -613,7 +607,7 @@ def get_source(
613607
The name of the dataset to get the source information for.
614608
output_filepath (str or None):
615609
Optional path where to save the file.
616-
bucket (str):
610+
s3_bucket_name (str):
617611
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
618612
SDV Community. SDV Enterprise is required for other buckets.
619613
credentials (str):
@@ -632,13 +626,13 @@ def get_source(
632626
dataset_name=dataset_name,
633627
filename='SOURCE.txt',
634628
output_filepath=output_filepath,
635-
bucket=bucket,
629+
bucket=s3_bucket_name,
636630
credentials=credentials,
637631
)
638632

639633

640634
def get_readme(
641-
modality, dataset_name, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
635+
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
642636
):
643637
"""Get dataset README text.
644638
@@ -649,7 +643,7 @@ def get_readme(
649643
The name of the dataset to get the README for.
650644
output_filepath (str or None):
651645
Optional path where to save the file.
652-
bucket (str):
646+
s3_bucket_name (str):
653647
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
654648
SDV Community. SDV Enterprise is required for other buckets.
655649
credentials (str):
@@ -668,6 +662,6 @@ def get_readme(
668662
dataset_name=dataset_name,
669663
filename='README.txt',
670664
output_filepath=output_filepath,
671-
bucket=bucket,
665+
bucket=s3_bucket_name,
672666
credentials=credentials,
673667
)

tests/unit/datasets/test_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test__download(mock_list, mock_get_data_from_bucket):
133133

134134
# Run
135135
data_io, metadata_bytes = _download(
136-
'single_table', 'ring', bucket='test_bucket', credentials=None
136+
'single_table', 'ring', bucket='sdv-datasets-public', credentials=None
137137
)
138138

139139
# Assert
@@ -561,7 +561,7 @@ def test_get_available_demos_credentials_raises_error():
561561
with pytest.raises(ValueError, match=error_message):
562562
get_available_demos(
563563
'single_table',
564-
bucket='sdv-datasets-public',
564+
s3_bucket_name='sdv-datasets-public',
565565
credentials={'username': '[email protected]', 'license_key': 'FakeKey123'},
566566
)
567567

0 commit comments

Comments
 (0)