|
1 | 1 | import os |
2 | 2 | import re |
| 3 | +import shutil |
3 | 4 | from unittest.mock import Mock, patch |
4 | 5 |
|
5 | 6 | import pandas as pd |
@@ -76,6 +77,22 @@ def test__init__s3(self, mock_is_s3_path, mock_get_s3_client): |
76 | 77 | assert result_explorer.aws_secret_access_key == aws_secret_access_key |
77 | 78 | assert isinstance(result_explorer._handler, S3ResultsHandler) |
78 | 79 |
|
| 80 | + def test_list_with_modality_local(self, tmp_path): |
| 81 | + """Test the `list` method respects the modality subfolder (local).""" |
| 82 | + # Setup |
| 83 | + base = tmp_path / 'results' |
| 84 | + (base / 'unscoped_run').mkdir(parents=True) |
| 85 | + (base / 'multi_table' / 'run_mt1').mkdir(parents=True) |
| 86 | + (base / 'multi_table' / 'run_mt2').mkdir(parents=True) |
| 87 | + |
| 88 | + result_explorer = ResultsExplorer(str(base), modality='multi_table') |
| 89 | + |
| 90 | + # Run |
| 91 | + runs = result_explorer.list() |
| 92 | + |
| 93 | + # Assert |
| 94 | + assert set(runs) == {'run_mt1', 'run_mt2'} |
| 95 | + |
79 | 96 | def test_list_local(self, tmp_path): |
80 | 97 | """Test the `list` method with a local path""" |
81 | 98 | # Setup |
@@ -136,6 +153,28 @@ def test__get_file_path(self): |
136 | 153 | ) |
137 | 154 | assert file_path == expected_filepath |
138 | 155 |
|
| 156 | + def test__get_file_path_multi_table_synthetic_data(self, tmp_path): |
| 157 | + """Test `_get_file_path` returns .zip for multi_table synthetic data.""" |
| 158 | + base = tmp_path / 'results' |
| 159 | + multi_table_dir = base / 'multi_table' |
| 160 | + multi_table_dir.mkdir(parents=True, exist_ok=True) |
| 161 | + explorer = ResultsExplorer(str(multi_table_dir), modality='multi_table') |
| 162 | + try: |
| 163 | + explorer._handler = Mock() |
| 164 | + explorer._handler.get_file_path.return_value = 'irrelevant' |
| 165 | + explorer._get_file_path( |
| 166 | + results_folder_name='results_folder_07_07_2025', |
| 167 | + dataset_name='my_dataset', |
| 168 | + synthesizer_name='my_synthesizer', |
| 169 | + file_type='synthetic_data', |
| 170 | + ) |
| 171 | + explorer._handler.get_file_path.assert_called_once_with( |
| 172 | + ['results_folder_07_07_2025', 'my_dataset_07_07_2025', 'my_synthesizer'], |
| 173 | + 'my_synthesizer_synthetic_data.zip', |
| 174 | + ) |
| 175 | + finally: |
| 176 | + shutil.rmtree(multi_table_dir) |
| 177 | + |
139 | 178 | def test_load_synthesizer(self, tmp_path): |
140 | 179 | """Test `load_synthesizer` method.""" |
141 | 180 | # Setup |
@@ -209,6 +248,31 @@ def test_load_real_data(self, mock_load_dataset, tmp_path): |
209 | 248 | ) |
210 | 249 | pd.testing.assert_frame_equal(real_data, expected_data) |
211 | 250 |
|
| 251 | + @patch('sdgym.result_explorer.result_explorer.load_dataset') |
| 252 | + def test_load_real_data_multi_table(self, mock_load_dataset, tmp_path): |
| 253 | + """Test `load_real_data` for multi_table modality calls load_dataset correctly.""" |
| 254 | + dataset_name = 'synthea' |
| 255 | + expected_data = {'patients': pd.DataFrame({'id': [1]})} |
| 256 | + mock_load_dataset.return_value = (expected_data, None) |
| 257 | + multi_table_dir = tmp_path / 'multi_table' |
| 258 | + multi_table_dir.mkdir(parents=True, exist_ok=True) |
| 259 | + result_explorer = ResultsExplorer(tmp_path, modality='multi_table') |
| 260 | + |
| 261 | + try: |
| 262 | + # Run |
| 263 | + real_data = result_explorer.load_real_data(dataset_name) |
| 264 | + |
| 265 | + # Assert |
| 266 | + mock_load_dataset.assert_called_once_with( |
| 267 | + modality='multi_table', |
| 268 | + dataset='synthea', |
| 269 | + aws_access_key_id=None, |
| 270 | + aws_secret_access_key=None, |
| 271 | + ) |
| 272 | + assert real_data == expected_data |
| 273 | + finally: |
| 274 | + shutil.rmtree(multi_table_dir) |
| 275 | + |
212 | 276 | def test_load_real_data_invalid_dataset(self, tmp_path): |
213 | 277 | """Test `load_real_data` method with an invalid dataset.""" |
214 | 278 | # Setup |
|
0 commit comments