Skip to content

Commit ced519e

Browse files
committed
Add changes
Update tests Fix init
1 parent 3aa5202 commit ced519e

File tree

5 files changed

+261
-14
lines changed

5 files changed

+261
-14
lines changed

sdgym/result_explorer/result_explorer.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,51 @@ def _validate_local_path(path):
1414
raise ValueError(f"The provided path '{path}' is not a valid local directory.")
1515

1616

17+
_FOLDER_BY_MODALITY = {
18+
'single_table': 'single-table',
19+
'multi_table': 'multi_table',
20+
}
21+
22+
23+
def _resolve_effective_path(path, modality):
24+
"""Append the modality folder to the given base path if provided."""
25+
if not modality:
26+
return path
27+
28+
folder = _FOLDER_BY_MODALITY.get(modality)
29+
if folder is None:
30+
valid = ', '.join(sorted(_FOLDER_BY_MODALITY))
31+
raise ValueError(f'Invalid modality "{modality}". Valid options are: {valid}.')
32+
33+
# Avoid double-appending if already included
34+
if str(path).rstrip('/').endswith(('/' + folder, folder)):
35+
return path
36+
37+
if is_s3_path(path):
38+
path = path.rstrip('/') + '/' + folder
39+
return path
40+
41+
return os.path.join(path, folder)
42+
43+
1744
class ResultsExplorer:
1845
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
1946

20-
def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
47+
def __init__(self, path, modality, aws_access_key_id=None, aws_secret_access_key=None):
2148
self.path = path
49+
self.modality = modality
2250
self.aws_access_key_id = aws_access_key_id
2351
self.aws_secret_access_key = aws_secret_access_key
2452

53+
effective_path = _resolve_effective_path(path, modality)
2554
if is_s3_path(path):
55+
# Use original path to obtain client (keeps backwards compatibility),
56+
# but handler should operate on the modality-specific effective path.
2657
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
27-
self._handler = S3ResultsHandler(path, s3_client)
58+
self._handler = S3ResultsHandler(effective_path, s3_client)
2859
else:
29-
_validate_local_path(path)
30-
self._handler = LocalResultsHandler(path)
60+
_validate_local_path(effective_path)
61+
self._handler = LocalResultsHandler(effective_path)
3162

3263
def list(self):
3364
"""List all runs available in the results directory."""
@@ -37,7 +68,11 @@ def _get_file_path(self, results_folder_name, dataset_name, synthesizer_name, fi
3768
"""Validate access to the synthesizer or synthetic data file."""
3869
end_filename = f'{synthesizer_name}'
3970
if file_type == 'synthetic_data':
40-
end_filename += '_synthetic_data.csv'
71+
# Multi-table synthetic data is zipped (multiple CSVs), single table is CSV
72+
if self.modality == 'multi_table':
73+
end_filename += '_synthetic_data.zip'
74+
else:
75+
end_filename += '_synthetic_data.csv'
4176
elif file_type == 'synthesizer':
4277
end_filename += '.pkl'
4378

@@ -62,14 +97,17 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam
6297

6398
def load_real_data(self, dataset_name):
6499
"""Load the real data for a given dataset."""
65-
if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS:
100+
# Keep strict validation for single_table to preserve existing behavior
101+
if (self.modality is None or self.modality == 'single_table') and (
102+
dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS
103+
):
66104
raise ValueError(
67105
f"Dataset '{dataset_name}' is not a SDGym dataset. "
68106
'Please provide a valid dataset name.'
69107
)
70108

71109
data, _ = load_dataset(
72-
modality='single_table',
110+
modality=self.modality or 'single_table',
73111
dataset=dataset_name,
74112
aws_access_key_id=self.aws_access_key_id,
75113
aws_secret_access_key=self.aws_secret_access_key,

sdgym/result_explorer/result_handler.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
import os
66
from abc import ABC, abstractmethod
77
from datetime import datetime
8+
from io import BytesIO
9+
from zipfile import ZipFile
810

911
import cloudpickle
1012
import pandas as pd
1113
import yaml
1214
from botocore.exceptions import ClientError
1315

16+
from sdgym._dataset_utils import _read_zipped_data
17+
1418
SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
1519
RESULTS_FOLDER_PREFIX = 'SDGym_results_'
1620
metainfo_PREFIX = 'metainfo'
@@ -262,8 +266,12 @@ def load_synthesizer(self, file_path):
262266
return cloudpickle.load(f)
263267

264268
def load_synthetic_data(self, file_path):
265-
"""Load synthetic data from a CSV file."""
266-
return pd.read_csv(os.path.join(self.base_path, file_path))
269+
"""Load synthetic data from a CSV or ZIP file."""
270+
full_path = os.path.join(self.base_path, file_path)
271+
if full_path.endswith('.zip'):
272+
return _read_zipped_data(full_path, modality='multi_table')
273+
274+
return pd.read_csv(full_path)
267275

268276
def _get_results_files(self, folder_name, prefix, suffix):
269277
return [
@@ -374,10 +382,21 @@ def load_synthesizer(self, file_path):
374382

375383
def load_synthetic_data(self, file_path):
376384
"""Load synthetic data from S3."""
377-
response = self.s3_client.get_object(
378-
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
379-
)
380-
return pd.read_csv(io.BytesIO(response['Body'].read()))
385+
key = f'{self.prefix}{file_path}'
386+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
387+
body = response['Body'].read()
388+
if file_path.endswith('.zip'):
389+
tables = {}
390+
with ZipFile(BytesIO(body)) as zf:
391+
for name in zf.namelist():
392+
if name.endswith('.csv'):
393+
table_name = os.path.splitext(os.path.basename(name))[0]
394+
with zf.open(name) as csv_file:
395+
tables[table_name] = pd.read_csv(csv_file, low_memory=False)
396+
397+
return tables
398+
399+
return pd.read_csv(io.BytesIO(body))
381400

382401
def _get_results_files(self, folder_name, prefix, suffix):
383402
s3_prefix = f'{self.prefix}{folder_name}/'

tests/integration/result_explorer/test_result_explorer.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,71 @@ def test_summarize():
8888
expected_results['Win'] = expected_results['Win'].astype('int64')
8989
pd.testing.assert_frame_equal(summary, expected_summary)
9090
pd.testing.assert_frame_equal(results, expected_results)
91+
92+
93+
def test_summarize_multi_table(tmp_path):
94+
"""Test summarize works under the multi_table subfolder."""
95+
# Setup: copy existing fixtures into multi_table folder
96+
import shutil
97+
98+
src_root = 'tests/integration/result_explorer/_benchmark_results'
99+
dst_root = tmp_path / 'benchmark_output' / 'multi_table'
100+
dst_root.mkdir(parents=True, exist_ok=True)
101+
for folder in [
102+
'SDGym_results_04_05_2024',
103+
'SDGym_results_05_10_2024',
104+
'SDGym_results_10_11_2024',
105+
]:
106+
shutil.copytree(f'{src_root}/{folder}', dst_root / folder)
107+
108+
result_explorer = ResultsExplorer(str(tmp_path / 'benchmark_output'), modality='multi_table')
109+
110+
# Run
111+
summary, results = result_explorer.summarize('SDGym_results_10_11_2024')
112+
113+
# Assert
114+
expected_summary = pd.DataFrame({
115+
'Synthesizer': ['CTGANSynthesizer', 'CopulaGANSynthesizer', 'TVAESynthesizer'],
116+
'10_11_2024 - # datasets: 9 - sdgym version: 0.9.1': [6, 4, 5],
117+
'05_10_2024 - # datasets: 9 - sdgym version: 0.8.0': [4, 4, 5],
118+
'04_05_2024 - # datasets: 9 - sdgym version: 0.7.0': [5, 3, 5],
119+
})
120+
expected_results = (
121+
pd.read_csv(f'{src_root}/SDGym_results_10_11_2024/results.csv')
122+
.sort_values(by=['Dataset', 'Synthesizer'])
123+
.reset_index(drop=True)
124+
)
125+
expected_results['Win'] = expected_results['Win'].astype('int64')
126+
pd.testing.assert_frame_equal(summary, expected_summary)
127+
pd.testing.assert_frame_equal(results, expected_results)
128+
129+
130+
def test_list_and_load_results_multi_table(tmp_path):
131+
"""Test listing and loading results under multi_table subfolder."""
132+
# Setup
133+
import shutil
134+
135+
src_root = 'tests/integration/result_explorer/_benchmark_results/SDGym_results_10_11_2024'
136+
dst_root = tmp_path / 'benchmark_output' / 'multi_table' / 'SDGym_results_10_11_2024'
137+
shutil.copytree(src_root, dst_root)
138+
139+
explorer = ResultsExplorer(str(tmp_path / 'benchmark_output'), modality='multi_table')
140+
141+
# Run
142+
runs = explorer.list()
143+
assert runs == ['SDGym_results_10_11_2024']
144+
loaded_results = (
145+
explorer.load_results(runs[0])
146+
.sort_values(by=['Dataset', 'Synthesizer'])
147+
.reset_index(drop=True)
148+
)
149+
metainfo = explorer.load_metainfo(runs[0])
150+
151+
# Assert
152+
expected_results = (
153+
pd.read_csv(dst_root / 'results.csv')
154+
.sort_values(by=['Dataset', 'Synthesizer'])
155+
.reset_index(drop=True)
156+
)
157+
pd.testing.assert_frame_equal(loaded_results, expected_results)
158+
assert isinstance(metainfo, dict) and len(metainfo) >= 1

tests/unit/result_explorer/test_result_explorer.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import shutil
23
from unittest.mock import Mock, patch
34

45
import pandas as pd
@@ -59,7 +60,11 @@ def test__init__s3(self, mock_is_s3_path, mock_get_s3_client):
5960
mock_get_s3_client.return_value = s3_client
6061

6162
# Run
62-
result_explorer = ResultsExplorer(path, aws_access_key_id, aws_secret_access_key)
63+
result_explorer = ResultsExplorer(
64+
path,
65+
aws_access_key_id=aws_access_key_id,
66+
aws_secret_access_key=aws_secret_access_key,
67+
)
6368

6469
# Assert
6570
mock_is_s3_path.assert_called_once_with(path)
@@ -69,6 +74,22 @@ def test__init__s3(self, mock_is_s3_path, mock_get_s3_client):
6974
assert result_explorer.aws_secret_access_key == aws_secret_access_key
7075
assert isinstance(result_explorer._handler, S3ResultsHandler)
7176

77+
def test_list_with_modality_local(self, tmp_path):
78+
"""Test the `list` method respects the modality subfolder (local)."""
79+
# Setup
80+
base = tmp_path / 'results'
81+
(base / 'unscoped_run').mkdir(parents=True)
82+
(base / 'multi_table' / 'run_mt1').mkdir(parents=True)
83+
(base / 'multi_table' / 'run_mt2').mkdir(parents=True)
84+
85+
result_explorer = ResultsExplorer(str(base), modality='multi_table')
86+
87+
# Run
88+
runs = result_explorer.list()
89+
90+
# Assert
91+
assert set(runs) == {'run_mt1', 'run_mt2'}
92+
7293
def test_list_local(self, tmp_path):
7394
"""Test the `list` method with a local path"""
7495
# Setup
@@ -129,6 +150,28 @@ def test__get_file_path(self):
129150
)
130151
assert file_path == expected_filepath
131152

153+
def test__get_file_path_multi_table_synthetic_data(self, tmp_path):
154+
"""Test `_get_file_path` returns .zip for multi_table synthetic data."""
155+
base = tmp_path / 'results'
156+
multi_table_dir = base / 'multi_table'
157+
multi_table_dir.mkdir(parents=True, exist_ok=True)
158+
explorer = ResultsExplorer(str(multi_table_dir), modality='multi_table')
159+
try:
160+
explorer._handler = Mock()
161+
explorer._handler.get_file_path.return_value = 'irrelevant'
162+
explorer._get_file_path(
163+
results_folder_name='results_folder_07_07_2025',
164+
dataset_name='my_dataset',
165+
synthesizer_name='my_synthesizer',
166+
file_type='synthetic_data',
167+
)
168+
explorer._handler.get_file_path.assert_called_once_with(
169+
['results_folder_07_07_2025', 'my_dataset_07_07_2025', 'my_synthesizer'],
170+
'my_synthesizer_synthetic_data.zip',
171+
)
172+
finally:
173+
shutil.rmtree(multi_table_dir)
174+
132175
def test_load_synthesizer(self, tmp_path):
133176
"""Test `load_synthesizer` method."""
134177
# Setup
@@ -196,6 +239,31 @@ def test_load_real_data(self, mock_load_dataset, tmp_path):
196239
)
197240
pd.testing.assert_frame_equal(real_data, expected_data)
198241

242+
@patch('sdgym.result_explorer.result_explorer.load_dataset')
243+
def test_load_real_data_multi_table(self, mock_load_dataset, tmp_path):
244+
"""Test `load_real_data` for multi_table modality calls load_dataset correctly."""
245+
dataset_name = 'synthea'
246+
expected_data = {'patients': pd.DataFrame({'id': [1]})}
247+
mock_load_dataset.return_value = (expected_data, None)
248+
multi_table_dir = tmp_path / 'multi_table'
249+
multi_table_dir.mkdir(parents=True, exist_ok=True)
250+
result_explorer = ResultsExplorer(tmp_path, modality='multi_table')
251+
252+
try:
253+
# Run
254+
real_data = result_explorer.load_real_data(dataset_name)
255+
256+
# Assert
257+
mock_load_dataset.assert_called_once_with(
258+
modality='multi_table',
259+
dataset='synthea',
260+
aws_access_key_id=None,
261+
aws_secret_access_key=None,
262+
)
263+
assert real_data == expected_data
264+
finally:
265+
shutil.rmtree(multi_table_dir)
266+
199267
def test_load_real_data_invalid_dataset(self, tmp_path):
200268
"""Test `load_real_data` method with an invalid dataset."""
201269
# Setup

tests/unit/result_explorer/test_result_handler.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import pickle
34
import re
@@ -324,6 +325,31 @@ def test_load_synthesizer(self, tmp_path):
324325
assert loaded_synthesizer is not None
325326
assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer)
326327

328+
def test_load_synthetic_data_zip(self, tmp_path):
329+
"""Test the `load_synthetic_data` method for zipped multi-table data (local)."""
330+
# Setup
331+
base = tmp_path / 'results'
332+
data_dir = base / 'SDGym_results_07_07_2025' / 'dataset_07_07_2025' / 'Synth'
333+
data_dir.mkdir(parents=True)
334+
335+
# Create a zip with two csvs
336+
import zipfile
337+
338+
zip_path = data_dir / 'Synth_synthetic_data.zip'
339+
with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
340+
zf.writestr('table1.csv', 'a,b\n1,2\n')
341+
zf.writestr('table2.csv', 'x,y\n3,4\n')
342+
343+
result_handler = LocalResultsHandler(str(base))
344+
345+
# Run
346+
tables = result_handler.load_synthetic_data(str(zip_path))
347+
348+
# Assert
349+
assert set(tables.keys()) == {'table1', 'table2'}
350+
pd.testing.assert_frame_equal(tables['table1'], pd.DataFrame({'a': [1], 'b': [2]}))
351+
pd.testing.assert_frame_equal(tables['table2'], pd.DataFrame({'x': [3], 'y': [4]}))
352+
327353
@patch('os.path.exists')
328354
@patch('os.path.isfile')
329355
def test_get_file_path_local(self, mock_isfile, mock_exists):
@@ -464,6 +490,34 @@ def test_load_synthesizer(self):
464490
Bucket='my-bucket', Key='prefix/synthesizer.pkl'
465491
)
466492

493+
def test_load_synthetic_data_zip(self):
494+
"""Test the `load_synthetic_data` method for zipped multi-table data (S3)."""
495+
# Setup
496+
import zipfile
497+
498+
buffer = io.BytesIO()
499+
with zipfile.ZipFile(buffer, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
500+
zf.writestr('customers.csv', 'id,age\n1,30\n')
501+
zf.writestr('transactions.csv', 'id,amount\n1,100\n')
502+
buffer.seek(0)
503+
504+
mock_s3_client = Mock()
505+
mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: buffer.getvalue())}
506+
result_handler = S3ResultsHandler('s3://my-bucket/prefix', mock_s3_client)
507+
508+
# Run
509+
tables = result_handler.load_synthetic_data('some/path.zip')
510+
511+
# Assert
512+
assert set(tables.keys()) == {'customers', 'transactions'}
513+
pd.testing.assert_frame_equal(tables['customers'], pd.DataFrame({'id': [1], 'age': [30]}))
514+
pd.testing.assert_frame_equal(
515+
tables['transactions'], pd.DataFrame({'id': [1], 'amount': [100]})
516+
)
517+
mock_s3_client.get_object.assert_called_once_with(
518+
Bucket='my-bucket', Key='prefix/some/path.zip'
519+
)
520+
467521
def test_get_file_path_s3(self):
468522
"""Test `get_file_path` for S3 path when folders and file exist."""
469523
# Setup

0 commit comments

Comments
 (0)