Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdv/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Dataset loading and managing module."""

from sdv.datasets import demo, local

__all__ = ['demo', 'local']
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was needed so that I could load the add-ons on Enterprise

144 changes: 112 additions & 32 deletions sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from sdv.metadata.metadata import Metadata

LOGGER = logging.getLogger(__name__)
BUCKET = 'sdv-datasets-public'
BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com'
PUBLIC_BUCKET = 'sdv-datasets-public'
SIGNATURE_VERSION = UNSIGNED
METADATA_FILENAME = 'metadata.json'
FALLBACK_ENCODING = 'latin-1'
Expand All @@ -41,36 +40,45 @@ def _validate_output_folder(output_folder_name):
)


def _create_s3_client():
def _create_s3_client(credentials=None, bucket=None):
"""Create and return an S3 client with unsigned requests."""
if bucket is not None and bucket != PUBLIC_BUCKET:
raise ValueError('Private buckets are only supported in SDV Enterprise.')
if credentials is not None:
raise ValueError(
'DataCebo credentials for private buckets are only supported in SDV Enterprise.'
)

return boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION))


def _get_data_from_bucket(object_key):
s3 = _create_s3_client()
response = s3.get_object(Bucket=BUCKET, Key=object_key)
def _get_data_from_bucket(object_key, bucket, client):
response = client.get_object(Bucket=bucket, Key=object_key)
return response['Body'].read()


def _list_objects(prefix):
def _list_objects(prefix, bucket, client):
"""List all objects under a given prefix using pagination.

Args:
prefix (str):
The S3 prefix to list.
bucket (str):
The name of the bucket to list objects of.
client (botocore.client.S3):
S3 client.

Returns:
list[dict]:
A list of object summaries.
"""
client = _create_s3_client()
contents = []
paginator = client.get_paginator('list_objects_v2')
for resp in paginator.paginate(Bucket=BUCKET, Prefix=prefix):
for resp in paginator.paginate(Bucket=bucket, Prefix=prefix):
contents.extend(resp.get('Contents', []))

if not contents:
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{BUCKET}'.")
raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{bucket}'.")

return contents

Expand Down Expand Up @@ -125,7 +133,7 @@ def is_data_zip(key):
raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.")


def _get_first_v1_metadata_bytes(contents, dataset_prefix):
def _get_first_v1_metadata_bytes(contents, dataset_prefix, bucket, client):
"""Find and return bytes of the first V1 metadata JSON under `dataset_prefix`.

Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix,
Expand All @@ -150,7 +158,7 @@ def is_direct_json_under_prefix(key):

for key in candidate_keys:
try:
raw = _get_data_from_bucket(key)
raw = _get_data_from_bucket(key, bucket=bucket, client=client)
metadict = json.loads(raw)
if isinstance(metadict, dict) and metadict.get('METADATA_SPEC_VERSION') == 'V1':
return raw
Expand All @@ -163,23 +171,26 @@ def is_direct_json_under_prefix(key):
)


def _download(modality, dataset_name):
def _download(modality, dataset_name, bucket, credentials=None):
"""Download dataset resources from a bucket.

Returns:
tuple:
(BytesIO(zip_bytes), metadata_bytes)
"""
client = _create_s3_client(credentials, bucket)
dataset_prefix = f'{modality}/{dataset_name}/'
bucket_url = f'https://{bucket}.s3.amazonaws.com'
LOGGER.info(
f"Downloading dataset '{dataset_name}' for modality '{modality}' from "
f'{BUCKET_URL}/{dataset_prefix}'
f'{bucket_url}/{dataset_prefix}'
)
contents = _list_objects(dataset_prefix)

contents = _list_objects(dataset_prefix, bucket=bucket, client=client)
zip_key = _find_data_zip_key(contents, dataset_prefix)
zip_bytes = _get_data_from_bucket(zip_key)
metadata_bytes = _get_first_v1_metadata_bytes(contents, dataset_prefix)
zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client)
metadata_bytes = _get_first_v1_metadata_bytes(
contents, dataset_prefix, bucket=bucket, client=client
)

return io.BytesIO(zip_bytes), metadata_bytes

Expand Down Expand Up @@ -310,7 +321,9 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None):
return metadata


def download_demo(modality, dataset_name, output_folder_name=None):
def download_demo(
modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
):
"""Download a demo dataset.

Args:
Expand All @@ -322,6 +335,15 @@ def download_demo(modality, dataset_name, output_folder_name=None):
The name of the local folder where the metadata and data should be stored.
If ``None`` the data is not saved locally and is loaded as a Python object.
Defaults to ``None``.
s3_bucket_name (str):
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
SDV Community. SDV Enterprise is required for other buckets.
credentials (str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict, not str. This happens in all docstrings.

Dictionary containing DataCebo license key and username. It takes the form:
{
'username': '[email protected]',
'license_key': '<MY_LICENSE_KEY>'
}

Returns:
tuple (data, metadata):
Expand All @@ -338,7 +360,7 @@ def download_demo(modality, dataset_name, output_folder_name=None):
_validate_modalities(modality)
_validate_output_folder(output_folder_name)

data_io, metadata_bytes = _download(modality, dataset_name)
data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials)
in_memory_directory = _extract_data(data_io, output_folder_name)
data = _get_data(modality, output_folder_name, in_memory_directory)
metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name)
Expand Down Expand Up @@ -368,9 +390,9 @@ def is_metainfo_yaml(key):
yield dataset_name, key


def _get_info_from_yaml_key(yaml_key):
def _get_info_from_yaml_key(yaml_key, bucket, client):
"""Load and parse YAML metadata from an S3 key."""
raw = _get_data_from_bucket(yaml_key)
raw = _get_data_from_bucket(yaml_key, bucket=bucket, client=client)
return yaml.safe_load(raw) or {}


Expand Down Expand Up @@ -406,12 +428,21 @@ def _parse_num_tables(num_tables_val, dataset_name):
return np.nan


def get_available_demos(modality):
def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None):
"""Get demo datasets available for a ``modality``.

Args:
modality (str):
The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``.
s3_bucket_name (str):
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
SDV Community. SDV Enterprise is required for other buckets.
credentials (str):
Dictionary containing DataCebo license key and username. It takes the form:
{
'username': '[email protected]',
'license_key': '<MY_LICENSE_KEY>'
}

Returns:
pandas.DataFrame:
Expand All @@ -421,11 +452,12 @@ def get_available_demos(modality):
``num_tables``: The number of tables in the dataset.
"""
_validate_modalities(modality)
contents = _list_objects(f'{modality}/')
s3_client = _create_s3_client(credentials=credentials, bucket=s3_bucket_name)
contents = _list_objects(f'{modality}/', bucket=s3_bucket_name, client=s3_client)
tables_info = defaultdict(list)
for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality):
try:
info = _get_info_from_yaml_key(yaml_key)
info = _get_info_from_yaml_key(yaml_key, bucket=s3_bucket_name, client=s3_client)

size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name)
num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name)
Expand Down Expand Up @@ -513,7 +545,9 @@ def _save_document(text, output_filepath, filename, dataset_name):
LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.')


def _get_text_file_content(modality, dataset_name, filename, output_filepath=None):
def _get_text_file_content(
modality, dataset_name, filename, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None
):
"""Fetch text file content under the dataset prefix.

Args:
Expand All @@ -525,6 +559,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``).
output_filepath (str or None):
If provided, save the file contents at this path.
bucket (str):
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
SDV Community. SDV Enterprise is required for other buckets.
credentials (str):
Dictionary containing DataCebo license key and username. It takes the form:
{
'username': '[email protected]',
'license_key': '<MY_LICENSE_KEY>'
}

Returns:
str or None:
Expand All @@ -533,14 +576,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
_validate_text_file_content(modality, output_filepath, filename)

dataset_prefix = f'{modality}/{dataset_name}/'
contents = _list_objects(dataset_prefix)
s3_client = _create_s3_client(credentials=credentials, bucket=bucket)
contents = _list_objects(dataset_prefix, bucket=bucket, client=s3_client)
key = _find_text_key(contents, dataset_prefix, filename)
if not key:
_raise_warnings(filename, output_filepath)
return None

try:
raw = _get_data_from_bucket(key)
raw = _get_data_from_bucket(key, bucket=bucket, client=s3_client)
except Exception:
LOGGER.info(f'Error fetching {filename} for dataset {dataset_name}.')
return None
Expand All @@ -551,7 +595,9 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non
return text


def get_source(modality, dataset_name, output_filepath=None):
def get_source(
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
):
"""Get dataset source/citation text.

Args:
Expand All @@ -561,15 +607,33 @@ def get_source(modality, dataset_name, output_filepath=None):
The name of the dataset to get the source information for.
output_filepath (str or None):
Optional path where to save the file.
s3_bucket_name (str):
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
SDV Community. SDV Enterprise is required for other buckets.
credentials (str):
Dictionary containing DataCebo license key and username. It takes the form:
{
'username': '[email protected]',
'license_key': '<MY_LICENSE_KEY>'
}

Returns:
str or None:
The contents of the source file if it exists; otherwise ``None``.
"""
return _get_text_file_content(modality, dataset_name, 'SOURCE.txt', output_filepath)
return _get_text_file_content(
modality=modality,
dataset_name=dataset_name,
filename='SOURCE.txt',
output_filepath=output_filepath,
bucket=s3_bucket_name,
credentials=credentials,
)


def get_readme(modality, dataset_name, output_filepath=None):
def get_readme(
modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None
):
"""Get dataset README text.

Args:
Expand All @@ -579,9 +643,25 @@ def get_readme(modality, dataset_name, output_filepath=None):
The name of the dataset to get the README for.
output_filepath (str or None):
Optional path where to save the file.
s3_bucket_name (str):
The name of the bucket to download from. Only 'sdv-datasets-public' is supported in
SDV Community. SDV Enterprise is required for other buckets.
credentials (str):
Dictionary containing DataCebo license key and username. It takes the form:
{
'username': '[email protected]',
'license_key': '<MY_LICENSE_KEY>'
}

Returns:
str or None:
The contents of the README file if it exists; otherwise ``None``.
"""
return _get_text_file_content(modality, dataset_name, 'README.txt', output_filepath)
return _get_text_file_content(
modality=modality,
dataset_name=dataset_name,
filename='README.txt',
output_filepath=output_filepath,
bucket=s3_bucket_name,
credentials=credentials,
)
Loading