diff --git a/sdv/datasets/__init__.py b/sdv/datasets/__init__.py index cd82c4f8f..c0d4eecf3 100644 --- a/sdv/datasets/__init__.py +++ b/sdv/datasets/__init__.py @@ -1 +1,5 @@ """Dataset loading and managing module.""" + +from sdv.datasets import demo, local + +__all__ = ['demo', 'local'] diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 6d8c49a81..ca8f041fb 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -20,8 +20,7 @@ from sdv.metadata.metadata import Metadata LOGGER = logging.getLogger(__name__) -BUCKET = 'sdv-datasets-public' -BUCKET_URL = f'https://{BUCKET}.s3.amazonaws.com' +PUBLIC_BUCKET = 'sdv-datasets-public' SIGNATURE_VERSION = UNSIGNED METADATA_FILENAME = 'metadata.json' FALLBACK_ENCODING = 'latin-1' @@ -41,36 +40,45 @@ def _validate_output_folder(output_folder_name): ) -def _create_s3_client(): +def _create_s3_client(credentials=None, bucket=None): """Create and return an S3 client with unsigned requests.""" + if bucket is not None and bucket != PUBLIC_BUCKET: + raise ValueError('Private buckets are only supported in SDV Enterprise.') + if credentials is not None: + raise ValueError( + 'DataCebo credentials for private buckets are only supported in SDV Enterprise.' + ) + return boto3.client('s3', config=Config(signature_version=SIGNATURE_VERSION)) -def _get_data_from_bucket(object_key): - s3 = _create_s3_client() - response = s3.get_object(Bucket=BUCKET, Key=object_key) +def _get_data_from_bucket(object_key, bucket, client): + response = client.get_object(Bucket=bucket, Key=object_key) return response['Body'].read() -def _list_objects(prefix): +def _list_objects(prefix, bucket, client): """List all objects under a given prefix using pagination. Args: prefix (str): The S3 prefix to list. + bucket (str): + The name of the bucket to list objects of. + client (botocore.client.S3): + S3 client. Returns: list[dict]: A list of object summaries. """ - client = _create_s3_client() contents = [] paginator = client.get_paginator('list_objects_v2') - for resp in paginator.paginate(Bucket=BUCKET, Prefix=prefix): + for resp in paginator.paginate(Bucket=bucket, Prefix=prefix): contents.extend(resp.get('Contents', [])) if not contents: - raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{BUCKET}'.") + raise DemoResourceNotFoundError(f"No objects found under '{prefix}' in bucket '{bucket}'.") return contents @@ -125,7 +133,7 @@ def is_data_zip(key): raise DemoResourceNotFoundError("Could not find 'data.zip' for the requested dataset.") -def _get_first_v1_metadata_bytes(contents, dataset_prefix): +def _get_first_v1_metadata_bytes(contents, dataset_prefix, bucket, client): """Find and return bytes of the first V1 metadata JSON under `dataset_prefix`. Scans S3 listing `contents` and, for any JSON file directly under the dataset prefix, @@ -150,7 +158,7 @@ def is_direct_json_under_prefix(key): for key in candidate_keys: try: - raw = _get_data_from_bucket(key) + raw = _get_data_from_bucket(key, bucket=bucket, client=client) metadict = json.loads(raw) if isinstance(metadict, dict) and metadict.get('METADATA_SPEC_VERSION') == 'V1': return raw @@ -163,23 +171,26 @@ def is_direct_json_under_prefix(key): ) -def _download(modality, dataset_name): +def _download(modality, dataset_name, bucket, credentials=None): """Download dataset resources from a bucket. Returns: tuple: (BytesIO(zip_bytes), metadata_bytes) """ + client = _create_s3_client(credentials, bucket) dataset_prefix = f'{modality}/{dataset_name}/' + bucket_url = f'https://{bucket}.s3.amazonaws.com' LOGGER.info( f"Downloading dataset '{dataset_name}' for modality '{modality}' from " - f'{BUCKET_URL}/{dataset_prefix}' + f'{bucket_url}/{dataset_prefix}' ) - contents = _list_objects(dataset_prefix) - + contents = _list_objects(dataset_prefix, bucket=bucket, client=client) zip_key = _find_data_zip_key(contents, dataset_prefix) - zip_bytes = _get_data_from_bucket(zip_key) - metadata_bytes = _get_first_v1_metadata_bytes(contents, dataset_prefix) + zip_bytes = _get_data_from_bucket(zip_key, bucket=bucket, client=client) + metadata_bytes = _get_first_v1_metadata_bytes( + contents, dataset_prefix, bucket=bucket, client=client + ) return io.BytesIO(zip_bytes), metadata_bytes @@ -310,7 +321,9 @@ def _get_metadata(metadata_bytes, dataset_name, output_folder_name=None): return metadata -def download_demo(modality, dataset_name, output_folder_name=None): +def download_demo( + modality, dataset_name, output_folder_name=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None +): """Download a demo dataset. Args: @@ -322,6 +335,15 @@ def download_demo(modality, dataset_name, output_folder_name=None): The name of the local folder where the metadata and data should be stored. If ``None`` the data is not saved locally and is loaded as a Python object. Defaults to ``None``. + s3_bucket_name (str): + The name of the bucket to download from. Only 'sdv-datasets-public' is supported in + SDV Community. SDV Enterprise is required for other buckets. + credentials (str): + Dictionary containing DataCebo license key and username. It takes the form: + { + 'username': 'example@datacebo.com', + 'license_key': '' + } Returns: tuple (data, metadata): @@ -338,7 +360,7 @@ def download_demo(modality, dataset_name, output_folder_name=None): _validate_modalities(modality) _validate_output_folder(output_folder_name) - data_io, metadata_bytes = _download(modality, dataset_name) + data_io, metadata_bytes = _download(modality, dataset_name, s3_bucket_name, credentials) in_memory_directory = _extract_data(data_io, output_folder_name) data = _get_data(modality, output_folder_name, in_memory_directory) metadata = _get_metadata(metadata_bytes, dataset_name, output_folder_name) @@ -368,9 +390,9 @@ def is_metainfo_yaml(key): yield dataset_name, key -def _get_info_from_yaml_key(yaml_key): +def _get_info_from_yaml_key(yaml_key, bucket, client): """Load and parse YAML metadata from an S3 key.""" - raw = _get_data_from_bucket(yaml_key) + raw = _get_data_from_bucket(yaml_key, bucket=bucket, client=client) return yaml.safe_load(raw) or {} @@ -406,12 +428,21 @@ def _parse_num_tables(num_tables_val, dataset_name): return np.nan -def get_available_demos(modality): +def get_available_demos(modality, s3_bucket_name=PUBLIC_BUCKET, credentials=None): """Get demo datasets available for a ``modality``. Args: modality (str): The modality of the dataset: ``'single_table'``, ``'multi_table'``, ``'sequential'``. + s3_bucket_name (str): + The name of the bucket to download from. Only 'sdv-datasets-public' is supported in + SDV Community. SDV Enterprise is required for other buckets. + credentials (str): + Dictionary containing DataCebo license key and username. It takes the form: + { + 'username': 'example@datacebo.com', + 'license_key': '' + } Returns: pandas.DataFrame: @@ -421,11 +452,12 @@ def get_available_demos(modality): ``num_tables``: The number of tables in the dataset. """ _validate_modalities(modality) - contents = _list_objects(f'{modality}/') + s3_client = _create_s3_client(credentials=credentials, bucket=s3_bucket_name) + contents = _list_objects(f'{modality}/', bucket=s3_bucket_name, client=s3_client) tables_info = defaultdict(list) for dataset_name, yaml_key in _iter_metainfo_yaml_entries(contents, modality): try: - info = _get_info_from_yaml_key(yaml_key) + info = _get_info_from_yaml_key(yaml_key, bucket=s3_bucket_name, client=s3_client) size_mb = _parse_size_mb(info.get('dataset-size-mb'), dataset_name) num_tables = _parse_num_tables(info.get('num-tables', np.nan), dataset_name) @@ -513,7 +545,9 @@ def _save_document(text, output_filepath, filename, dataset_name): LOGGER.info(f'Error saving {filename} for dataset {dataset_name}.') -def _get_text_file_content(modality, dataset_name, filename, output_filepath=None): +def _get_text_file_content( + modality, dataset_name, filename, output_filepath=None, bucket=PUBLIC_BUCKET, credentials=None +): """Fetch text file content under the dataset prefix. Args: @@ -525,6 +559,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non The filename to fetch (``'README.txt'`` or ``'SOURCE.txt'``). output_filepath (str or None): If provided, save the file contents at this path. + bucket (str): + The name of the bucket to download from. Only 'sdv-datasets-public' is supported in + SDV Community. SDV Enterprise is required for other buckets. + credentials (str): + Dictionary containing DataCebo license key and username. It takes the form: + { + 'username': 'example@datacebo.com', + 'license_key': '' + } Returns: str or None: @@ -533,14 +576,15 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non _validate_text_file_content(modality, output_filepath, filename) dataset_prefix = f'{modality}/{dataset_name}/' - contents = _list_objects(dataset_prefix) + s3_client = _create_s3_client(credentials=credentials, bucket=bucket) + contents = _list_objects(dataset_prefix, bucket=bucket, client=s3_client) key = _find_text_key(contents, dataset_prefix, filename) if not key: _raise_warnings(filename, output_filepath) return None try: - raw = _get_data_from_bucket(key) + raw = _get_data_from_bucket(key, bucket=bucket, client=s3_client) except Exception: LOGGER.info(f'Error fetching {filename} for dataset {dataset_name}.') return None @@ -551,7 +595,9 @@ def _get_text_file_content(modality, dataset_name, filename, output_filepath=Non return text -def get_source(modality, dataset_name, output_filepath=None): +def get_source( + modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None +): """Get dataset source/citation text. Args: @@ -561,15 +607,33 @@ def get_source(modality, dataset_name, output_filepath=None): The name of the dataset to get the source information for. output_filepath (str or None): Optional path where to save the file. + s3_bucket_name (str): + The name of the bucket to download from. Only 'sdv-datasets-public' is supported in + SDV Community. SDV Enterprise is required for other buckets. + credentials (str): + Dictionary containing DataCebo license key and username. It takes the form: + { + 'username': 'example@datacebo.com', + 'license_key': '' + } Returns: str or None: The contents of the source file if it exists; otherwise ``None``. """ - return _get_text_file_content(modality, dataset_name, 'SOURCE.txt', output_filepath) + return _get_text_file_content( + modality=modality, + dataset_name=dataset_name, + filename='SOURCE.txt', + output_filepath=output_filepath, + bucket=s3_bucket_name, + credentials=credentials, + ) -def get_readme(modality, dataset_name, output_filepath=None): +def get_readme( + modality, dataset_name, output_filepath=None, s3_bucket_name=PUBLIC_BUCKET, credentials=None +): """Get dataset README text. Args: @@ -579,9 +643,25 @@ def get_readme(modality, dataset_name, output_filepath=None): The name of the dataset to get the README for. output_filepath (str or None): Optional path where to save the file. + s3_bucket_name (str): + The name of the bucket to download from. Only 'sdv-datasets-public' is supported in + SDV Community. SDV Enterprise is required for other buckets. + credentials (str): + Dictionary containing DataCebo license key and username. It takes the form: + { + 'username': 'example@datacebo.com', + 'license_key': '' + } Returns: str or None: The contents of the README file if it exists; otherwise ``None``. """ - return _get_text_file_content(modality, dataset_name, 'README.txt', output_filepath) + return _get_text_file_content( + modality=modality, + dataset_name=dataset_name, + filename='README.txt', + output_filepath=output_filepath, + bucket=s3_bucket_name, + credentials=credentials, + ) diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index 8dd95abde..a704d37c9 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -75,7 +75,7 @@ def test_download_demo_single_table(mock_list, mock_get, tmpdir): 'relationships': [], }).encode() - def side_effect(key): + def side_effect(key, bucket='test_bucket', client=None): if key.endswith('data.zip'): return zip_bytes if key.endswith('metadata.json'): @@ -105,22 +105,19 @@ def side_effect(key): assert metadata.to_dict() == expected_metadata_dict -@patch('sdv.datasets.demo._create_s3_client') -@patch('sdv.datasets.demo.BUCKET', 'bucket') -def test__get_data_from_bucket(create_client_mock): +def test__get_data_from_bucket(): """Test the ``_get_data_from_bucket`` method.""" # Setup mock_s3_client = Mock() - create_client_mock.return_value = mock_s3_client mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: b'data')} + bucket = 'sdv-datasets-public' # Run - result = _get_data_from_bucket('object_key') + result = _get_data_from_bucket('object_key', bucket, mock_s3_client) # Assert assert result == b'data' - create_client_mock.assert_called_once() - mock_s3_client.get_object.assert_called_once_with(Bucket='bucket', Key='object_key') + mock_s3_client.get_object.assert_called_once_with(Bucket=bucket, Key='object_key') @patch('sdv.datasets.demo._get_data_from_bucket') @@ -135,7 +132,9 @@ def test__download(mock_list, mock_get_data_from_bucket): mock_get_data_from_bucket.return_value = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() # Run - data_io, metadata_bytes = _download('single_table', 'ring') + data_io, metadata_bytes = _download( + 'single_table', 'ring', bucket='sdv-datasets-public', credentials=None + ) # Assert assert isinstance(data_io, io.BytesIO) @@ -165,7 +164,9 @@ def test_download_demo_single_table_no_output_folder(mock_list, mock_get): }, 'relationships': [], }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, bucket, client: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run table, metadata = download_demo('single_table', 'ring') @@ -220,7 +221,9 @@ def test_download_demo_timeseries(mock_list, mock_get, tmpdir): } }, }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, bucket, client: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run table, metadata = download_demo('sequential', 'Libras', tmpdir / 'test_folder') @@ -318,7 +321,9 @@ def test_download_demo_multi_table(mock_list, mock_get, tmpdir): ], 'METADATA_SPEC_VERSION': 'V1', }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, bucket, client: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run tables, metadata = download_demo('multi_table', 'got_families', tmpdir / 'test_folder') @@ -361,7 +366,7 @@ def test__get_first_v1_metadata_bytes(mock_get): bad = b'not-json' v1 = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() - def side_effect(key): + def side_effect(key, bucket, client): return { 'single_table/dataset/k1.json': v2, 'single_table/dataset/k2.json': bad, @@ -376,7 +381,9 @@ def side_effect(key): ] # Run - got = _get_first_v1_metadata_bytes(contents, 'single_table/dataset/') + got = _get_first_v1_metadata_bytes( + contents, 'single_table/dataset/', bucket='test_bucket', client=None + ) # Assert assert got == v1 @@ -413,7 +420,7 @@ def test_get_available_demos_robust_parsing(mock_list, mock_get): {'Key': 'single_table/ignore.txt'}, ] - def side_effect(key): + def side_effect(key, bucket, client): if key.endswith('d1/metainfo.yaml'): return b'dataset-name: d1\nnum-tables: 2\ndataset-size-mb: 10.5\nsource: EXTERNAL\n' if key.endswith('d2/metainfo.yaml'): @@ -445,7 +452,7 @@ def test_get_available_demos_logs_invalid_size_mb(mock_list, mock_get, caplog): {'Key': 'single_table/dsize/metainfo.yaml'}, ] - def side_effect(key): + def side_effect(key, bucket, client): return b'dataset-name: dsize\nnum-tables: 2\ndataset-size-mb: invalid\n' mock_get.side_effect = side_effect @@ -470,7 +477,7 @@ def test_get_available_demos_logs_num_tables_str_cast_fail_exact(mock_list, mock {'Key': 'single_table/dnum/metainfo.yaml'}, ] - def side_effect(key): + def side_effect(key, bucket, client): return b'dataset-name: dnum\nnum-tables: not_a_number\ndataset-size-mb: 1.1\n' mock_get.side_effect = side_effect @@ -497,7 +504,7 @@ def test_get_available_demos_logs_num_tables_int_parse_fail_exact(mock_list, moc {'Key': 'single_table/dnum/metainfo.yaml'}, ] - def side_effect(key): + def side_effect(key, bucket, client): return b'dataset-name: dnum\nnum-tables: [1, 2]\ndataset-size-mb: 1.1\n' mock_get.side_effect = side_effect @@ -524,7 +531,7 @@ def test_get_available_demos_ignores_yaml_dataset_name_mismatch(mock_list, mock_ ] # YAML uses a different name; should be ignored for dataset_name field - def side_effect(key): + def side_effect(key, bucket, client): return b'dataset-name: DIFFERENT\nnum-tables: 3\ndataset-size-mb: 2.5\n' mock_get.side_effect = side_effect @@ -539,6 +546,26 @@ def side_effect(key): assert row['size_MB'] == 2.5 +def test_get_available_demos_private_bucket_raises_error(): + """Test that an error is raised if a private bucket is given.""" + # Run and Assert + error_message = 'Private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_available_demos('single_table', 'private-bucket') + + +def test_get_available_demos_credentials_raises_error(): + """Test that an error is raised if credentials are given.""" + # Run and Assert + error_message = 'DataCebo credentials for private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_available_demos( + 'single_table', + s3_bucket_name='sdv-datasets-public', + credentials={'username': 'test@gmail.com', 'license_key': 'FakeKey123'}, + ) + + @patch('sdv.datasets.demo._get_data_from_bucket') @patch('sdv.datasets.demo._list_objects') def test_download_demo_success_single_table(mock_list, mock_get): @@ -563,7 +590,7 @@ def test_download_demo_success_single_table(mock_list, mock_get): 'relationships': [], }).encode() - def side_effect(key): + def side_effect(key, bucket, client): if key.endswith('data.ZIP'): return zip_bytes if key.endswith('metadata.json'): @@ -602,7 +629,9 @@ def test_download_demo_no_v1_metadata_raises(mock_list, mock_get): {'Key': 'single_table/word/data.zip'}, {'Key': 'single_table/word/metadata.json'}, ] - mock_get.side_effect = lambda key: json.dumps({'METADATA_SPEC_VERSION': 'V2'}).encode() + mock_get.side_effect = lambda key, bucket, client: ( + json.dumps({'METADATA_SPEC_VERSION': 'V2'}).encode() + ) # Run and Assert with pytest.raises(DemoResourceNotFoundError, match='METADATA_SPEC_VERSION'): @@ -674,7 +703,7 @@ def test_download_demo_writes_metadata_and_discovers_nested_csv(mock_list, mock_ } meta_bytes = json.dumps(meta_dict).encode() - def side_effect(key): + def side_effect(key, bucket, client): if key.endswith('data.zip'): return zip_bytes if key.endswith('metadata.json'): @@ -853,7 +882,14 @@ def test_get_readme_and_get_source_call_wrapper(monkeypatch): # Setup calls = [] - def fake(modality, dataset_name, filename, output_filepath=None): + def fake( + modality, + dataset_name, + filename, + output_filepath=None, + bucket='test_bucket', + credentials=None, + ): calls.append((modality, dataset_name, filename, output_filepath)) return 'X' @@ -888,6 +924,28 @@ def test_get_readme_raises_if_output_file_exists(mock_list, mock_get, tmp_path): get_readme('single_table', 'dataset1', str(out)) +def test_get_readme_private_bucket_raises_error(): + """Test that an error is raised if a private bucket is given.""" + # Run and Assert + error_message = 'Private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_readme('single_table', 'dataset', None, 'private-bucket') + + +def test_get_readme_credentials_raises_error(): + """Test that an error is raised if credentials are given.""" + # Run and Assert + error_message = 'DataCebo credentials for private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_readme( + 'single_table', + 'dataset', + None, + 'sdv-datasets-public', + {'username': 'test@gmail.com', 'license_key': 'FakeKey123'}, + ) + + @patch('sdv.datasets.demo._get_data_from_bucket') @patch('sdv.datasets.demo._list_objects') def test_get_source_raises_if_output_file_exists(mock_list, mock_get, tmp_path): @@ -907,6 +965,28 @@ def test_get_source_raises_if_output_file_exists(mock_list, mock_get, tmp_path): get_source('single_table', 'dataset1', str(out)) +def test_get_source_private_bucket_raises_error(): + """Test that an error is raised if a private bucket is given.""" + # Run and Assert + error_message = 'Private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_source('single_table', 'dataset', None, 'private-bucket') + + +def test_get_source_credentials_raises_error(): + """Test that an error is raised if credentials are given.""" + # Run and Assert + error_message = 'DataCebo credentials for private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + get_source( + 'single_table', + 'dataset', + None, + 'sdv-datasets-public', + {'username': 'test@gmail.com', 'license_key': 'FakeKey123'}, + ) + + def test_get_readme_raises_for_non_txt_output(): """get_readme should raise ValueError if output path is not .txt.""" err = "The README can only be saved as a txt file. Please provide a filepath ending in '.txt'" @@ -1011,7 +1091,9 @@ def test_download_demo_raises_when_no_csv_in_zip_single_table(mock_list, mock_ge zip_bytes = zip_buf.getvalue() meta_bytes = json.dumps({'METADATA_SPEC_VERSION': 'V1'}).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, client, bucket: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run and Assert msg = 'Demo data could not be downloaded because no csv files were found in data.zip' @@ -1052,14 +1134,17 @@ def test_download_demo_skips_non_csv_in_memory_no_warning(mock_list, mock_get): 'relationships': [], }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, bucket, client: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run and Assert warn_msg = 'Skipped files: empty_dir/, nested/readme.md, note.txt' with pytest.warns(UserWarning, match=warn_msg) as rec: data, _ = download_demo('single_table', 'mix') - assert len(rec) == 1 + assert any(warn_msg in str(warn_record) for warn_record in rec) + expected = pd.DataFrame({'id': [1, 2], 'name': ['a', 'b']}) pd.testing.assert_frame_equal(data, expected) @@ -1094,7 +1179,9 @@ def test_download_demo_on_disk_warns_failed_csv_only(mock_list, mock_get, tmp_pa 'relationships': [], }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, client, bucket: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Force read_csv to fail on bad.csv only orig_read_csv = pd.read_csv @@ -1113,7 +1200,7 @@ def fake_read_csv(path_or_buf, *args, **kwargs): with pytest.warns(UserWarning, match=warn_msg) as rec: data, _ = download_demo('single_table', 'mix', out_dir) - assert len(rec) == 1 + assert any(warn_msg in str(warn_record) for warn_record in rec) pd.testing.assert_frame_equal(data, good) @@ -1146,7 +1233,9 @@ def test_download_demo_handles_non_utf8_in_memory(mock_list, mock_get): 'relationships': [], }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, bucket, client: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) # Run data, _ = download_demo('single_table', 'nonutf') @@ -1185,7 +1274,9 @@ def test_download_demo_handles_non_utf8_on_disk(mock_list, mock_get, tmp_path): 'relationships': [], }).encode() - mock_get.side_effect = lambda key: zip_bytes if key.endswith('data.zip') else meta_bytes + mock_get.side_effect = lambda key, client, bucket: ( + zip_bytes if key.endswith('data.zip') else meta_bytes + ) out_dir = tmp_path / 'latin_out' @@ -1195,3 +1286,25 @@ def test_download_demo_handles_non_utf8_on_disk(mock_list, mock_get, tmp_path): # Assert expected = pd.DataFrame({'id': [1], 'name': ['café']}) pd.testing.assert_frame_equal(data, expected) + + +def test_download_demo_private_bucket_raises_error(): + """Test that an error is raised if a private bucket is given.""" + # Run and Assert + error_message = 'Private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + download_demo('single_table', 'dataset', None, 'private-bucket') + + +def test_download_demo_credentials_raises_error(): + """Test that an error is raised if credentials are given.""" + # Run and Assert + error_message = 'DataCebo credentials for private buckets are only supported in SDV Enterprise.' + with pytest.raises(ValueError, match=error_message): + download_demo( + 'single_table', + 'dataset', + None, + 'sdv-datasets-public', + {'username': 'test@gmail.com', 'license_key': 'FakeKey123'}, + )