@@ -40,13 +40,10 @@ def _validate_output_folder(output_folder_name):
4040 )
4141
4242
43- def _validate_bucket (bucket ):
44- if bucket != PUBLIC_BUCKET :
45- raise ValueError ('Private buckets are only supported in SDV Enterprise.' )
46-
47-
48- def _create_s3_client (credentials = None ):
43+ def _create_s3_client (credentials = None , bucket = None ):
4944 """Create and return an S3 client with unsigned requests."""
45+ if bucket is not None and bucket != PUBLIC_BUCKET :
46+ raise ValueError ('Private buckets are only supported in SDV Enterprise.' )
5047 if credentials is not None :
5148 raise ValueError (
5249 'DataCebo credentials for private buckets are only supported in SDV Enterprise.'
@@ -181,7 +178,7 @@ def _download(modality, dataset_name, bucket, credentials=None):
181178 tuple:
182179 (BytesIO(zip_bytes), metadata_bytes)
183180 """
184- client = _create_s3_client (credentials )
181+ client = _create_s3_client (credentials , bucket )
185182 dataset_prefix = f'{ modality } /{ dataset_name } /'
186183 bucket_url = f'https://{ bucket } .s3.amazonaws.com'
187184 LOGGER .info (
@@ -325,7 +322,7 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
325322
326323
327324def download_demo (
328- modality , dataset_name , output_folder_name = None , bucket = PUBLIC_BUCKET , credentials = None
325+ modality , dataset_name , output_folder_name = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
329326):
330327 """Download a demo dataset.
331328
@@ -338,7 +335,7 @@ def download_demo(
338335 The name of the local folder where the metadata and data should be stored.
339336 If ``None`` the data is not saved locally and is loaded as a Python object.
340337 Defaults to ``None``.
341- bucket (str):
338+ s3_bucket_name (str):
342339 The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
343340 SDV Community. SDV Enterprise is required for other buckets.
344341 credentials (str):
@@ -362,9 +359,8 @@ def download_demo(
362359 """
363360 _validate_modalities (modality )
364361 _validate_output_folder (output_folder_name )
365- _validate_bucket (bucket = bucket )
366362
367- data_io , metadata_bytes = _download (modality , dataset_name , bucket , credentials )
363+ data_io , metadata_bytes = _download (modality , dataset_name , s3_bucket_name , credentials )
368364 in_memory_directory = _extract_data (data_io , output_folder_name )
369365 data = _get_data (modality , output_folder_name , in_memory_directory )
370366 metadata = _get_metadata (metadata_bytes , dataset_name , output_folder_name )
@@ -432,13 +428,13 @@ def _parse_num_tables(num_tables_val, dataset_name):
432428 return np .nan
433429
434430
435- def get_available_demos (modality , bucket = PUBLIC_BUCKET , credentials = None ):
431+ def get_available_demos (modality , s3_bucket_name = PUBLIC_BUCKET , credentials = None ):
436432 """Get demo datasets available for a ``modality``.
437433
438434 Args:
439435 modality (str):
440436 The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
441- bucket (str):
437+ s3_bucket_name (str):
442438 The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
443439 SDV Community. SDV Enterprise is required for other buckets.
444440 credentials (str):
@@ -456,13 +452,12 @@ def get_available_demos(modality, bucket=PUBLIC_BUCKET, credentials=None):
456452 ``num_tables``: The number of tables in the dataset.
457453 """
458454 _validate_modalities (modality )
459- _validate_bucket (bucket = bucket )
460- s3_client = _create_s3_client (credentials = credentials )
461- contents = _list_objects (f'{ modality } /' , bucket = bucket , client = s3_client )
455+ s3_client = _create_s3_client (credentials = credentials , bucket = s3_bucket_name )
456+ contents = _list_objects (f'{ modality } /' , bucket = s3_bucket_name , client = s3_client )
462457 tables_info = defaultdict (list )
463458 for dataset_name , yaml_key in _iter_metainfo_yaml_entries (contents , modality ):
464459 try :
465- info = _get_info_from_yaml_key (yaml_key , bucket = bucket , client = s3_client )
460+ info = _get_info_from_yaml_key (yaml_key , bucket = s3_bucket_name , client = s3_client )
466461
467462 size_mb = _parse_size_mb (info .get ('dataset-size-mb' ), dataset_name )
468463 num_tables = _parse_num_tables (info .get ('num-tables' , np .nan ), dataset_name )
@@ -579,10 +574,9 @@ def _get_text_file_content(
579574 The decoded text contents if the file exists, otherwise ``None``.
580575 """
581576 _validate_text_file_content (modality , output_filepath , filename )
582- _validate_bucket (bucket = bucket )
583577
584578 dataset_prefix = f'{ modality } /{ dataset_name } /'
585- s3_client = _create_s3_client (credentials = credentials )
579+ s3_client = _create_s3_client (credentials = credentials , bucket = bucket )
586580 contents = _list_objects (dataset_prefix , bucket = bucket , client = s3_client )
587581 key = _find_text_key (contents , dataset_prefix , filename )
588582 if not key :
@@ -602,7 +596,7 @@ def _get_text_file_content(
602596
603597
604598def get_source (
605- modality , dataset_name , output_filepath = None , bucket = PUBLIC_BUCKET , credentials = None
599+ modality , dataset_name , output_filepath = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
606600):
607601 """Get dataset source/citation text.
608602
@@ -613,7 +607,7 @@ def get_source(
613607 The name of the dataset to get the source information for.
614608 output_filepath (str or None):
615609 Optional path where to save the file.
616- bucket (str):
610+ s3_bucket_name (str):
617611 The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
618612 SDV Community. SDV Enterprise is required for other buckets.
619613 credentials (str):
@@ -632,13 +626,13 @@ def get_source(
632626 dataset_name = dataset_name ,
633627 filename = 'SOURCE.txt' ,
634628 output_filepath = output_filepath ,
635- bucket = bucket ,
629+ bucket = s3_bucket_name ,
636630 credentials = credentials ,
637631 )
638632
639633
640634def get_readme (
641- modality , dataset_name , output_filepath = None , bucket = PUBLIC_BUCKET , credentials = None
635+ modality , dataset_name , output_filepath = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
642636):
643637 """Get dataset README text.
644638
@@ -649,7 +643,7 @@ def get_readme(
649643 The name of the dataset to get the README for.
650644 output_filepath (str or None):
651645 Optional path where to save the file.
652- bucket (str):
646+ s3_bucket_name (str):
653647 The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
654648 SDV Community. SDV Enterprise is required for other buckets.
655649 credentials (str):
@@ -668,6 +662,6 @@ def get_readme(
668662 dataset_name = dataset_name ,
669663 filename = 'README.txt' ,
670664 output_filepath = output_filepath ,
671- bucket = bucket ,
665+ bucket = s3_bucket_name ,
672666 credentials = credentials ,
673667 )
0 commit comments