Skip to content

Commit c72a368

Browse files
authored
Incorporate the get_available_datasets functionality into the DatasetExplorer (#492)
1 parent 38acb1b commit c72a368

File tree

9 files changed

+111
-86
lines changed

9 files changed

+111
-86
lines changed

DATASETS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ Out[6]:
6666
## Getting the list of all the datasets
6767

6868
If you want to obtain the list of all the available datasets you can use the
69-
`sdgym.get_available_datasets` function:
69+
`list_datasets` function:
7070

7171
```python
72-
In [7]: from sdgym import get_available_datasets
72+
In [7]: from sdgym.dataset_explorer import DatasetExplorer
7373

74-
In [8]: get_available_datasets()
74+
In [8]: DatasetExplorer().list_datasets()
7575
Out[8]:
7676
dataset_name size_MB num_tables
7777
0 KRK_v1 0.072128 1

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customi
103103
## Customizing your datasets
104104

105105
The SDGym library includes many publicly available datasets that you can include right away.
106-
List these using the ``get_available_datasets`` feature.
106+
List these using the ``list_datasets`` feature.
107107

108108
```python
109-
sdgym.get_available_datasets()
109+
sdgym.dataset_explorer.DatasetExplorer().list_datasets()
110110
```
111111

112112
```

sdgym/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sdgym.cli.collect import collect_results
1717
from sdgym.cli.summary import make_summary_spreadsheet
1818
from sdgym.dataset_explorer import DatasetExplorer
19-
from sdgym.datasets import get_available_datasets, load_dataset
19+
from sdgym.datasets import load_dataset
2020
from sdgym.synthesizers import (
2121
create_synthesizer_variant,
2222
create_single_table_synthesizer,
@@ -37,7 +37,6 @@
3737
'create_synthesizer_variant',
3838
'create_single_table_synthesizer',
3939
'create_multi_table_synthesizer',
40-
'get_available_datasets',
4140
'load_dataset',
4241
'make_summary_spreadsheet',
4342
]

sdgym/cli/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _download_datasets(args):
9797
_env_setup(args.logfile, args.verbose)
9898
datasets = args.datasets
9999
if not datasets:
100-
datasets = sdgym.datasets.get_available_datasets(
100+
datasets = sdgym.datasets._get_available_datasets(
101101
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
102102
)['name']
103103

@@ -118,7 +118,7 @@ def _list_downloaded(args):
118118

119119

120120
def _list_available(args):
121-
datasets = sdgym.datasets.get_available_datasets(
121+
datasets = sdgym.datasets._get_available_datasets(
122122
args.bucket, args.aws_access_key_id, args.aws_secret_access_key
123123
)
124124
_print_table(datasets, args.sort, args.reverse, {'size': humanfriendly.format_size})

sdgym/dataset_explorer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,36 @@ def summarize_datasets(self, modality, output_filepath=None):
275275
dataset_summary.to_csv(output_filepath, index=False)
276276

277277
return dataset_summary
278+
279+
def list_datasets(self, modality, output_filepath=None):
280+
"""List available datasets for a modality using metainfo only.
281+
282+
This is a lightweight alternative to ``summarize_datasets`` that does not load
283+
the actual data. It reads dataset information from the ``metainfo.yaml`` files
284+
in the bucket and returns a table equivalent to the legacy
285+
``get_available_datasets`` output.
286+
287+
Args:
288+
modality (str):
289+
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
290+
output_filepath (str, optional):
291+
Full path to a ``.csv`` file where the resulting table will be written.
292+
If not provided, the table is only returned.
293+
294+
Returns:
295+
pd.DataFrame:
296+
A DataFrame with columns: ``['dataset_name', 'size_MB', 'num_tables']``.
297+
"""
298+
self._validate_output_filepath(output_filepath)
299+
_validate_modality(modality)
300+
301+
dataframe = _get_available_datasets(
302+
modality=modality,
303+
bucket=self._bucket_name,
304+
aws_access_key_id=self.aws_access_key_id,
305+
aws_secret_access_key=self.aws_secret_access_key,
306+
)
307+
if output_filepath:
308+
dataframe.to_csv(output_filepath, index=False)
309+
310+
return dataframe

sdgym/datasets.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,6 @@ def load_dataset(
254254
return data, metadata_dict
255255

256256

257-
def get_available_datasets(modality='single_table'):
258-
"""Get available single_table datasets.
259-
260-
Args:
261-
modality (str):
262-
It must be ``'single_table'``, ``'multi_table'`` or ``'sequential'``.
263-
264-
Return:
265-
pd.DataFrame:
266-
Table of available datasets and their sizes.
267-
"""
268-
_validate_modality(modality)
269-
return _get_available_datasets(modality)
270-
271-
272257
def get_dataset_paths(
273258
modality,
274259
datasets=None,

tests/integration/test_datasets.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
1-
from sdgym import get_available_datasets
1+
from sdgym import DatasetExplorer
22

33

4-
def test_get_available_datasets_single_table():
5-
"""Test that `get_available_datasets` returns single table datasets with expected properties."""
4+
def test_list_datasets_single_table():
5+
"""Test that it lists single table datasets with expected properties."""
66
# Run
7-
df = get_available_datasets('single_table')
7+
dataframe = DatasetExplorer().list_datasets('single_table')
88

99
# Assert
10-
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
11-
assert all(df['num_tables'] == 1)
10+
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
11+
assert all(dataframe['num_tables'] == 1)
1212

1313

14-
def test_get_available_datasets_multi_table():
15-
"""Test that `get_available_datasets` returns multi table datasets with expected properties."""
14+
def test_list_datasets_multi_table():
15+
"""Test that it lists multi table datasets with expected properties."""
1616
# Run
17-
df = get_available_datasets('multi_table')
17+
dataframe = DatasetExplorer().list_datasets('multi_table')
1818

1919
# Assert
20-
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
21-
assert all(df['num_tables'] > 1)
20+
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
21+
assert all(dataframe['num_tables'] > 1)
2222

2323

24-
def test_get_available_datasets_sequential():
25-
"""Test that `get_available_datasets` returns sequential datasets with expected properties."""
24+
def test_list_datasets_sequential():
25+
"""Test that it lists sequential datasets with expected properties."""
2626
# Run
27-
df = get_available_datasets('sequential')
27+
dataframe = DatasetExplorer().list_datasets('sequential')
2828

2929
# Assert
30-
assert df.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
31-
assert all(df['num_tables'] == 1)
30+
assert dataframe.columns.tolist() == ['dataset_name', 'size_MB', 'num_tables']
31+
assert all(dataframe['num_tables'] == 1)

tests/unit/test_dataset_explorer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,57 @@ def test_summarize_datasets_with_output(
277277
assert output_filepath.exists()
278278
assert isinstance(df, pd.DataFrame)
279279
assert df.columns.to_list() == SUMMARY_OUTPUT_COLUMNS
280+
281+
@patch('sdgym.dataset_explorer._validate_modality')
282+
@patch('sdgym.dataset_explorer._get_available_datasets')
283+
def test_list_datasets_without_output(self, mock_get_available, mock_validate_modality):
284+
"""Test that `list_datasets` returns the expected dataframe."""
285+
# Setup
286+
explorer = DatasetExplorer()
287+
expected_df = pd.DataFrame([
288+
{'dataset_name': 'ds1', 'size_MB': 12.5, 'num_tables': 1},
289+
{'dataset_name': 'ds2', 'size_MB': 3.0, 'num_tables': 2},
290+
])
291+
mock_get_available.return_value = expected_df
292+
293+
# Run
294+
result = explorer.list_datasets('single_table')
295+
296+
# Assert
297+
mock_validate_modality.assert_called_once_with('single_table')
298+
mock_get_available.assert_called_once_with(
299+
modality='single_table',
300+
bucket='sdv-datasets-public',
301+
aws_access_key_id=None,
302+
aws_secret_access_key=None,
303+
)
304+
pd.testing.assert_frame_equal(result, expected_df)
305+
306+
@patch('sdgym.dataset_explorer._validate_modality')
307+
@patch('sdgym.dataset_explorer._get_available_datasets')
308+
def test_list_datasets_with_output(self, mock_get_available, mock_validate_modality, tmp_path):
309+
"""Test that `list_datasets` writes CSV when output path is provided."""
310+
# Setup
311+
explorer = DatasetExplorer()
312+
expected_df = pd.DataFrame([
313+
{'dataset_name': 'alpha', 'size_MB': 1.5, 'num_tables': 1},
314+
{'dataset_name': 'beta', 'size_MB': 2.0, 'num_tables': 3},
315+
])
316+
mock_get_available.return_value = expected_df
317+
output_filepath = tmp_path / 'datasets_list.csv'
318+
319+
# Run
320+
result = explorer.list_datasets('multi_table', output_filepath=str(output_filepath))
321+
322+
# Assert
323+
mock_validate_modality.assert_called_once_with('multi_table')
324+
mock_get_available.assert_called_once_with(
325+
modality='multi_table',
326+
bucket='sdv-datasets-public',
327+
aws_access_key_id=None,
328+
aws_secret_access_key=None,
329+
)
330+
assert output_filepath.exists()
331+
loaded = pd.read_csv(output_filepath)
332+
pd.testing.assert_frame_equal(loaded, expected_df)
333+
pd.testing.assert_frame_equal(result, expected_df)

tests/unit/test_datasets.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
from unittest.mock import Mock, call, patch
33

44
import numpy as np
5-
import pandas as pd
65
import pytest
76

8-
from sdgym import get_available_datasets
97
from sdgym.datasets import (
108
DATASETS_PATH,
119
_download_dataset,
@@ -361,50 +359,6 @@ def test_get_bucket_name_local_folder():
361359
assert bucket_name == 'bucket-name'
362360

363361

364-
@patch('sdgym.datasets._get_available_datasets')
365-
def test_get_available_datasets(helper_mock):
366-
"""Test that the modality is set to single-table."""
367-
# Run
368-
get_available_datasets()
369-
370-
# Assert
371-
helper_mock.assert_called_once_with('single_table')
372-
373-
374-
def test_get_available_datasets_results():
375-
# Run
376-
tables_info = get_available_datasets()
377-
378-
# Assert
379-
expected_table = pd.DataFrame({
380-
'dataset_name': [
381-
'adult',
382-
'alarm',
383-
'census',
384-
'child',
385-
'covtype',
386-
'expedia_hotel_logs',
387-
'insurance',
388-
'intrusion',
389-
'news',
390-
],
391-
'size_MB': [
392-
'3.907448',
393-
'4.520128',
394-
'98.165608',
395-
'3.200128',
396-
'255.645408',
397-
'0.200128',
398-
'3.340128',
399-
'162.039016',
400-
'18.712096',
401-
],
402-
'num_tables': [1] * 9,
403-
})
404-
expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2)
405-
assert len(expected_table.merge(tables_info.round(2))) == len(expected_table)
406-
407-
408362
@patch('sdgym.datasets._get_dataset_path_and_download')
409363
@patch('sdgym.datasets._path_contains_data_and_metadata', return_value=True)
410364
@patch('sdgym.datasets.Path')

0 commit comments

Comments
 (0)