2020from sdv .metadata .metadata import Metadata
2121
2222LOGGER = logging .getLogger (__name__ )
23- BUCKET = 'sdv-datasets-public'
24- BUCKET_URL = f'https://{ BUCKET } .s3.amazonaws.com'
23+ PUBLIC_BUCKET = 'sdv-datasets-public'
2524SIGNATURE_VERSION = UNSIGNED
2625METADATA_FILENAME = 'metadata.json'
2726FALLBACK_ENCODING = 'latin-1'
@@ -41,36 +40,45 @@ def _validate_output_folder(output_folder_name):
4140 )
4241
4342
44- def _create_s3_client ():
43+ def _create_s3_client (bucket , credentials = None ):
4544 """Create and return an S3 client with unsigned requests."""
45+ if bucket != PUBLIC_BUCKET :
46+ raise ValueError ('Private buckets are only supported in SDV Enterprise.' )
47+ if credentials is not None :
48+ raise ValueError (
49+ 'DataCebo credentials for private buckets are only supported in SDV Enterprise.'
50+ )
51+
4652 return boto3 .client ('s3' , config = Config (signature_version = SIGNATURE_VERSION ))
4753
4854
49- def _get_data_from_bucket (object_key ):
50- s3 = _create_s3_client ()
51- response = s3 .get_object (Bucket = BUCKET , Key = object_key )
55+ def _get_data_from_bucket (object_key , bucket , client ):
56+ response = client .get_object (Bucket = bucket , Key = object_key )
5257 return response ['Body' ].read ()
5358
5459
55- def _list_objects (prefix ):
60+ def _list_objects (prefix , bucket , client ):
5661 """List all objects under a given prefix using pagination.
5762
5863 Args:
5964 prefix (str):
6065 The S3 prefix to list.
66+ bucket (str):
67+ The name of the bucket to list objects of.
68+ client (botocore.client.S3):
69+ S3 client.
6170
6271 Returns:
6372 list[dict]:
6473 A list of object summaries.
6574 """
66- client = _create_s3_client ()
6775 contents = []
6876 paginator = client .get_paginator ('list_objects_v2' )
69- for resp in paginator .paginate (Bucket = BUCKET , Prefix = prefix ):
77+ for resp in paginator .paginate (Bucket = bucket , Prefix = prefix ):
7078 contents .extend (resp .get ('Contents' , []))
7179
7280 if not contents :
73- raise DemoResourceNotFoundError (f"No objects found under '{ prefix } ' in bucket '{ BUCKET } '." )
81+ raise DemoResourceNotFoundError (f"No objects found under '{ prefix } ' in bucket '{ bucket } '." )
7482
7583 return contents
7684
@@ -125,7 +133,7 @@ def is_data_zip(key):
125133 raise DemoResourceNotFoundError ("Could not find 'data.zip' for the requested dataset." )
126134
127135
128- def _get_first_v1_metadata_bytes (contents , dataset_prefix ):
136+ def _get_first_v1_metadata_bytes (contents , dataset_prefix , bucket , client ):
129137 """Find and return bytes of the first V1 metadata JSON under `dataset_prefix`.
130138
131139 Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix,
@@ -150,7 +158,7 @@ def is_direct_json_under_prefix(key):
150158
151159 for key in candidate_keys :
152160 try :
153- raw = _get_data_from_bucket (key )
161+ raw = _get_data_from_bucket (key , bucket = bucket , client = client )
154162 metadict = json .loads (raw )
155163 if isinstance (metadict , dict ) and metadict .get ('METADATA_SPEC_VERSION' ) == 'V1' :
156164 return raw
@@ -163,23 +171,26 @@ def is_direct_json_under_prefix(key):
163171 )
164172
165173
166- def _download (modality , dataset_name ):
174+ def _download (modality , dataset_name , bucket , credentials = None ):
167175 """Download dataset resources from a bucket.
168176
169177 Returns:
170178 tuple:
171179 (BytesIO(zip_bytes), metadata_bytes)
172180 """
181+ client = _create_s3_client (bucket = bucket , credentials = credentials )
173182 dataset_prefix = f'{ modality } /{ dataset_name } /'
183+ bucket_url = f'https://{ bucket } .s3.amazonaws.com'
174184 LOGGER .info (
175185 f"Downloading dataset '{ dataset_name } ' for modality '{ modality } ' from "
176- f'{ BUCKET_URL } /{ dataset_prefix } '
186+ f'{ bucket_url } /{ dataset_prefix } '
177187 )
178- contents = _list_objects (dataset_prefix )
179-
188+ contents = _list_objects (dataset_prefix , bucket = bucket , client = client )
180189 zip_key = _find_data_zip_key (contents , dataset_prefix )
181- zip_bytes = _get_data_from_bucket (zip_key )
182- metadata_bytes = _get_first_v1_metadata_bytes (contents , dataset_prefix )
190+ zip_bytes = _get_data_from_bucket (zip_key , bucket = bucket , client = client )
191+ metadata_bytes = _get_first_v1_metadata_bytes (
192+ contents , dataset_prefix , bucket = bucket , client = client
193+ )
183194
184195 return io .BytesIO (zip_bytes ), metadata_bytes
185196
@@ -310,7 +321,9 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
310321 return metadata
311322
312323
313- def download_demo (modality , dataset_name , output_folder_name = None ):
324+ def download_demo (
325+ modality , dataset_name , output_folder_name = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
326+ ):
314327 """Download a demo dataset.
315328
316329 Args:
@@ -322,6 +335,15 @@ def download_demo(modality, dataset_name, output_folder_name=None):
322335 The name of the local folder where the metadata and data should be stored.
323336 If ``None`` the data is not saved locally and is loaded as a Python object.
324337 Defaults to ``None``.
338+ s3_bucket_name (str):
339+ The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
340+ SDV Community. SDV Enterprise is required for other buckets.
341+ credentials (dict):
342+ Dictionary containing DataCebo license key and username. It takes the form:
343+ {
344+ 345+ 'license_key': '<MY_LICENSE_KEY>'
346+ }
325347
326348 Returns:
327349 tuple (data, metadata):
@@ -338,7 +360,7 @@ def download_demo(modality, dataset_name, output_folder_name=None):
338360 _validate_modalities (modality )
339361 _validate_output_folder (output_folder_name )
340362
341- data_io , metadata_bytes = _download (modality , dataset_name )
363+ data_io , metadata_bytes = _download (modality , dataset_name , s3_bucket_name , credentials )
342364 in_memory_directory = _extract_data (data_io , output_folder_name )
343365 data = _get_data (modality , output_folder_name , in_memory_directory )
344366 metadata = _get_metadata (metadata_bytes , dataset_name , output_folder_name )
@@ -368,9 +390,9 @@ def is_metainfo_yaml(key):
368390 yield dataset_name , key
369391
370392
371- def _get_info_from_yaml_key (yaml_key ):
393+ def _get_info_from_yaml_key (yaml_key , bucket , client ):
372394 """Load and parse YAML metadata from an S3 key."""
373- raw = _get_data_from_bucket (yaml_key )
395+ raw = _get_data_from_bucket (yaml_key , bucket = bucket , client = client )
374396 return yaml .safe_load (raw ) or {}
375397
376398
@@ -406,12 +428,21 @@ def _parse_num_tables(num_tables_val, dataset_name):
406428 return np .nan
407429
408430
409- def get_available_demos (modality ):
431+ def get_available_demos (modality , s3_bucket_name = PUBLIC_BUCKET , credentials = None ):
410432 """Get demo datasets available for a ``modality``.
411433
412434 Args:
413435 modality (str):
414436 The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
437+ s3_bucket_name (str):
438+ The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
439+ SDV Community. SDV Enterprise is required for other buckets.
440+ credentials (dict):
441+ Dictionary containing DataCebo license key and username. It takes the form:
442+ {
443+ 444+ 'license_key': '<MY_LICENSE_KEY>'
445+ }
415446
416447 Returns:
417448 pandas.DataFrame:
@@ -421,11 +452,12 @@ def get_available_demos(modality):
421452 ``num_tables``: The number of tables in the dataset.
422453 """
423454 _validate_modalities (modality )
424- contents = _list_objects (f'{ modality } /' )
455+ s3_client = _create_s3_client (bucket = s3_bucket_name , credentials = credentials )
456+ contents = _list_objects (f'{ modality } /' , bucket = s3_bucket_name , client = s3_client )
425457 tables_info = defaultdict (list )
426458 for dataset_name , yaml_key in _iter_metainfo_yaml_entries (contents , modality ):
427459 try :
428- info = _get_info_from_yaml_key (yaml_key )
460+ info = _get_info_from_yaml_key (yaml_key , bucket = s3_bucket_name , client = s3_client )
429461
430462 size_mb = _parse_size_mb (info .get ('dataset-size-mb' ), dataset_name )
431463 num_tables = _parse_num_tables (info .get ('num-tables' , np .nan ), dataset_name )
@@ -513,7 +545,9 @@ def _save_document(text, output_filepath, filename, dataset_name):
513545 LOGGER .info (f'Error saving { filename } for dataset { dataset_name } .' )
514546
515547
516- def _get_text_file_content (modality , dataset_name , filename , output_filepath = None ):
548+ def _get_text_file_content (
549+ modality , dataset_name , filename , output_filepath = None , bucket = PUBLIC_BUCKET , credentials = None
550+ ):
517551 """Fetch text file content under the dataset prefix.
518552
519553 Args:
@@ -525,6 +559,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
525559 The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``).
526560 output_filepath (str or None):
527561 If provided, save the file contents at this path.
562+ bucket (str):
563+ The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
564+ SDV Community. SDV Enterprise is required for other buckets.
565+ credentials (dict):
566+ Dictionary containing DataCebo license key and username. It takes the form:
567+ {
568+ 569+ 'license_key': '<MY_LICENSE_KEY>'
570+ }
528571
529572 Returns:
530573 str or None:
@@ -533,14 +576,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
533576 _validate_text_file_content (modality , output_filepath , filename )
534577
535578 dataset_prefix = f'{ modality } /{ dataset_name } /'
536- contents = _list_objects (dataset_prefix )
579+ s3_client = _create_s3_client (bucket = bucket , credentials = credentials )
580+ contents = _list_objects (dataset_prefix , bucket = bucket , client = s3_client )
537581 key = _find_text_key (contents , dataset_prefix , filename )
538582 if not key :
539583 _raise_warnings (filename , output_filepath )
540584 return None
541585
542586 try :
543- raw = _get_data_from_bucket (key )
587+ raw = _get_data_from_bucket (key , bucket = bucket , client = s3_client )
544588 except Exception :
545589 LOGGER .info (f'Error fetching { filename } for dataset { dataset_name } .' )
546590 return None
@@ -551,7 +595,9 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
551595 return text
552596
553597
554- def get_source (modality , dataset_name , output_filepath = None ):
598+ def get_source (
599+ modality , dataset_name , output_filepath = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
600+ ):
555601 """Get dataset source/citation text.
556602
557603 Args:
@@ -561,15 +607,33 @@ def get_source(modality, dataset_name, output_filepath=None):
561607 The name of the dataset to get the source information for.
562608 output_filepath (str or None):
563609 Optional path where to save the file.
610+ s3_bucket_name (str):
611+ The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
612+ SDV Community. SDV Enterprise is required for other buckets.
613+ credentials (dict):
614+ Dictionary containing DataCebo license key and username. It takes the form:
615+ {
616+ 617+ 'license_key': '<MY_LICENSE_KEY>'
618+ }
564619
565620 Returns:
566621 str or None:
567622 The contents of the source file if it exists; otherwise ``None``.
568623 """
569- return _get_text_file_content (modality , dataset_name , 'SOURCE.txt' , output_filepath )
624+ return _get_text_file_content (
625+ modality = modality ,
626+ dataset_name = dataset_name ,
627+ filename = 'SOURCE.txt' ,
628+ output_filepath = output_filepath ,
629+ bucket = s3_bucket_name ,
630+ credentials = credentials ,
631+ )
570632
571633
572- def get_readme (modality , dataset_name , output_filepath = None ):
634+ def get_readme (
635+ modality , dataset_name , output_filepath = None , s3_bucket_name = PUBLIC_BUCKET , credentials = None
636+ ):
573637 """Get dataset README text.
574638
575639 Args:
@@ -579,9 +643,25 @@ def get_readme(modality, dataset_name, output_filepath=None):
579643 The name of the dataset to get the README for.
580644 output_filepath (str or None):
581645 Optional path where to save the file.
646+ s3_bucket_name (str):
647+ The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
648+ SDV Community. SDV Enterprise is required for other buckets.
649+ credentials (dict):
650+ Dictionary containing DataCebo license key and username. It takes the form:
651+ {
652+ 653+ 'license_key': '<MY_LICENSE_KEY>'
654+ }
582655
583656 Returns:
584657 str or None:
585658 The contents of the README file if it exists; otherwise ``None``.
586659 """
587- return _get_text_file_content (modality , dataset_name , 'README.txt' , output_filepath )
660+ return _get_text_file_content (
661+ modality = modality ,
662+ dataset_name = dataset_name ,
663+ filename = 'README.txt' ,
664+ output_filepath = output_filepath ,
665+ bucket = s3_bucket_name ,
666+ credentials = credentials ,
667+ )
0 commit comments