Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion sdgym/dataset_explorer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Dataset Explorer to summarize datasets stored in S3 buckets."""

import warnings
from collections import defaultdict
from pathlib import Path

Expand All @@ -9,6 +10,24 @@
from sdgym.datasets import BUCKET, _get_available_datasets, _validate_modality, load_dataset
from sdgym.s3 import _validate_s3_url

SUMMARY_OUTPUT_COLUMNS = [
'Dataset',
'Datasize_Size_MB',
'Num_Tables',
'Total_Num_Columns',
'Total_Num_Columns_Categorical',
'Total_Num_Columns_Numerical',
'Total_Num_Columns_Datetime',
'Total_Num_Columns_PII',
'Total_Num_Columns_ID_NonKey',
'Max_Num_Columns_Per_Table',
'Total_Num_Rows',
'Max_Num_Rows_Per_Table',
'Num_Relationships',
'Max_Schema_Depth',
'Max_Schema_Branch',
]


class DatasetExplorer:
"""``DatasetExplorer`` class.
Expand Down Expand Up @@ -277,7 +296,17 @@ def summarize_datasets(self, modality, output_filepath=None):
self._validate_output_filepath(output_filepath)
_validate_modality(modality)
results = self._load_and_summarize_datasets(modality)
dataset_summary = pd.DataFrame(results)

if not results:
warning_msg = (
f"The provided S3 URL '{self.s3_url}' does not contain any datasets "
f"of modality '{modality}'."
)
warnings.warn(warning_msg, UserWarning)
dataset_summary = pd.DataFrame(columns=SUMMARY_OUTPUT_COLUMNS)
else:
dataset_summary = pd.DataFrame(results)

if output_filepath:
dataset_summary.to_csv(output_filepath, index=False)

Expand Down
48 changes: 30 additions & 18 deletions tests/integration/test_dataset_explorer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
from pathlib import Path
from unittest.mock import patch

import pandas as pd
import pytest

from sdgym import DatasetExplorer

SUMMARY_OUTPUT_COLUMNS = [
'Dataset',
'Datasize_Size_MB',
'Num_Tables',
'Total_Num_Columns',
'Total_Num_Columns_Categorical',
'Total_Num_Columns_Numerical',
'Total_Num_Columns_Datetime',
'Total_Num_Columns_PII',
'Total_Num_Columns_ID_NonKey',
'Max_Num_Columns_Per_Table',
'Total_Num_Rows',
'Max_Num_Rows_Per_Table',
'Num_Relationships',
'Max_Schema_Depth',
'Max_Schema_Branch',
]
from sdgym.dataset_explorer import SUMMARY_OUTPUT_COLUMNS


@pytest.mark.parametrize('modality', ['single_table', 'multi_table'])
Expand All @@ -42,3 +26,31 @@ def test_end_to_end_dataset_explorer(modality, tmp_path):
assert list(dataset_summary.columns) == SUMMARY_OUTPUT_COLUMNS
loaded_summary = pd.read_csv(output_filepath)
pd.testing.assert_frame_equal(loaded_summary, dataset_summary)


@pytest.mark.parametrize('modality', ['single_table', 'multi_table'])
def test_dataset_explorer_empty_bucket_warns_and_returns_header_only(modality, tmp_path):
"""When no datasets are present, warn and return header-only table and write CSV."""
# Setup
de = DatasetExplorer(s3_url='s3://my_bucket/')
output_filepath = Path(tmp_path) / f'datasets_summary_empty_{modality}.csv'

with patch('sdgym.dataset_explorer._get_available_datasets') as mock_get:
mock_get.return_value = pd.DataFrame([])

expected_message = (
f"The provided S3 URL 's3://my_bucket/' does not contain any datasets "
f"of modality '{modality}'."
)

# Run
with pytest.warns(UserWarning, match=expected_message):
frame = de.summarize_datasets(modality=modality, output_filepath=str(output_filepath))

# Assert
assert isinstance(frame, pd.DataFrame)
assert frame.empty
assert output_filepath.exists()
assert list(frame.columns) == SUMMARY_OUTPUT_COLUMNS
loaded_summary = pd.read_csv(output_filepath)
pd.testing.assert_frame_equal(loaded_summary, frame)
20 changes: 1 addition & 19 deletions tests/unit/test_dataset_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,9 @@
import pandas as pd
import pytest

from sdgym.dataset_explorer import DatasetExplorer
from sdgym.dataset_explorer import SUMMARY_OUTPUT_COLUMNS, DatasetExplorer
from sdgym.datasets import BUCKET

SUMMARY_OUTPUT_COLUMNS = [
'Dataset',
'Datasize_Size_MB',
'Num_Tables',
'Total_Num_Columns',
'Total_Num_Columns_Categorical',
'Total_Num_Columns_Numerical',
'Total_Num_Columns_Datetime',
'Total_Num_Columns_PII',
'Total_Num_Columns_ID_NonKey',
'Max_Num_Columns_Per_Table',
'Total_Num_Rows',
'Max_Num_Rows_Per_Table',
'Num_Relationships',
'Max_Schema_Depth',
'Max_Schema_Branch',
]


class TestDatasetExplorer:
def test___init__default(self):
Expand Down