Skip to content

Commit bbaee25

Browse files
committed
Add tests
1 parent 48e9dd0 commit bbaee25

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/unit/test_dataset_explorer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,57 @@ def test_summarize_datasets_with_output(
277277
assert output_filepath.exists()
278278
assert isinstance(df, pd.DataFrame)
279279
assert df.columns.to_list() == SUMMARY_OUTPUT_COLUMNS
280+
281+
@patch('sdgym.dataset_explorer._validate_modality')
282+
@patch('sdgym.dataset_explorer._get_available_datasets')
283+
def test_list_datasets_without_output(self, mock_get_available, mock_validate_modality):
284+
"""Test that `list_datasets` returns the expected dataframe."""
285+
# Setup
286+
explorer = DatasetExplorer()
287+
expected_df = pd.DataFrame([
288+
{'dataset_name': 'ds1', 'size_MB': 12.5, 'num_tables': 1},
289+
{'dataset_name': 'ds2', 'size_MB': 3.0, 'num_tables': 2},
290+
])
291+
mock_get_available.return_value = expected_df
292+
293+
# Run
294+
result = explorer.list_datasets('single_table')
295+
296+
# Assert
297+
mock_validate_modality.assert_called_once_with('single_table')
298+
mock_get_available.assert_called_once_with(
299+
modality='single_table',
300+
bucket='sdv-datasets-public',
301+
aws_access_key_id=None,
302+
aws_secret_access_key=None,
303+
)
304+
pd.testing.assert_frame_equal(result, expected_df)
305+
306+
@patch('sdgym.dataset_explorer._validate_modality')
307+
@patch('sdgym.dataset_explorer._get_available_datasets')
308+
def test_list_datasets_with_output(self, mock_get_available, mock_validate_modality, tmp_path):
309+
"""Test that `list_datasets` writes CSV when output path is provided."""
310+
# Setup
311+
explorer = DatasetExplorer()
312+
expected_df = pd.DataFrame([
313+
{'dataset_name': 'alpha', 'size_MB': 1.5, 'num_tables': 1},
314+
{'dataset_name': 'beta', 'size_MB': 2.0, 'num_tables': 3},
315+
])
316+
mock_get_available.return_value = expected_df
317+
output_filepath = tmp_path / 'datasets_list.csv'
318+
319+
# Run
320+
result = explorer.list_datasets('multi_table', output_filepath=str(output_filepath))
321+
322+
# Assert
323+
mock_validate_modality.assert_called_once_with('multi_table')
324+
mock_get_available.assert_called_once_with(
325+
modality='multi_table',
326+
bucket='sdv-datasets-public',
327+
aws_access_key_id=None,
328+
aws_secret_access_key=None,
329+
)
330+
assert output_filepath.exists()
331+
loaded = pd.read_csv(output_filepath)
332+
pd.testing.assert_frame_equal(loaded, expected_df)
333+
pd.testing.assert_frame_equal(result, expected_df)

0 commit comments

Comments
 (0)