Skip to content

Commit 4c87a30

Browse files
committed
Add changes
Update tests Fix init
1 parent 44ac3e3 commit 4c87a30

File tree

5 files changed

+160
-13
lines changed

5 files changed

+160
-13
lines changed

sdgym/benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,7 +1815,7 @@ def benchmark_multi_table(
18151815
output_destination=None,
18161816
show_progress=False,
18171817
):
1818-
"""Run the SDGym benchmark on single-table datasets.
1818+
"""Run the SDGym benchmark on multi-table datasets.
18191819
18201820
Args:
18211821
synthesizers (list[string]):
@@ -1827,8 +1827,8 @@ def benchmark_multi_table(
18271827
or ``create_synthesizer_variant``). Defaults to ``None``.
18281828
sdv_datasets (list[str] or ``None``):
18291829
Names of the SDV demo datasets to use for the benchmark. Defaults to
1830-
``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
1831-
covtype]``. Use ``None`` to disable using any sdv datasets.
1830+
``[NBA, financial, Student_loan, Biodegradability, fake_hotels, restbase,
1831+
airbnb-simplified]``. Use ``None`` to disable using any sdv datasets.
18321832
additional_datasets_folder (str or ``None``):
18331833
The path to a folder (local or an S3 bucket). Datasets found in this folder are
18341834
run in addition to the SDV datasets. If ``None``, no additional datasets are used.

sdgym/result_explorer/result_explorer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ def _resolve_effective_path(path, modality):
5757
class ResultsExplorer:
5858
"""Explorer for SDGym benchmark results, supporting both local and S3 storage."""
5959

60-
def __init__(self, path, aws_access_key_id=None, aws_secret_access_key=None):
60+
def __init__(self, path, modality, aws_access_key_id=None, aws_secret_access_key=None):
6161
self.path = path
62+
self.modality = modality
6263
self.aws_access_key_id = aws_access_key_id
6364
self.aws_secret_access_key = aws_secret_access_key
6465

6566
baseline_synthesizer = _get_baseline_synthesizer(modality)
6667
effective_path = _resolve_effective_path(path, modality)
6768
if is_s3_path(path):
69+
# Use original path to obtain client (keeps backwards compatibility),
70+
# but handler should operate on the modality-specific effective path.
6871
s3_client = _get_s3_client(path, aws_access_key_id, aws_secret_access_key)
6972
self._handler = S3ResultsHandler(
7073
effective_path, s3_client, baseline_synthesizer=baseline_synthesizer
@@ -83,7 +86,11 @@ def _get_file_path(self, results_folder_name, dataset_name, synthesizer_name, fi
8386
"""Validate access to the synthesizer or synthetic data file."""
8487
end_filename = f'{synthesizer_name}'
8588
if file_type == 'synthetic_data':
86-
end_filename += '_synthetic_data.csv'
89+
# Multi-table synthetic data is zipped (multiple CSVs), single table is CSV
90+
if self.modality == 'multi_table':
91+
end_filename += '_synthetic_data.zip'
92+
else:
93+
end_filename += '_synthetic_data.csv'
8794
elif file_type == 'synthesizer':
8895
end_filename += '.pkl'
8996

@@ -108,14 +115,17 @@ def load_synthetic_data(self, results_folder_name, dataset_name, synthesizer_nam
108115

109116
def load_real_data(self, dataset_name):
110117
"""Load the real data for a given dataset."""
111-
if dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS:
118+
# Keep strict validation for single_table to preserve existing behavior
119+
if (self.modality is None or self.modality == 'single_table') and (
120+
dataset_name not in DEFAULT_SINGLE_TABLE_DATASETS
121+
):
112122
raise ValueError(
113123
f"Dataset '{dataset_name}' is not a SDGym dataset. "
114124
'Please provide a valid dataset name.'
115125
)
116126

117127
data, _ = load_dataset(
118-
modality='single_table',
128+
modality=self.modality or 'single_table',
119129
dataset=dataset_name,
120130
aws_access_key_id=self.aws_access_key_id,
121131
aws_secret_access_key=self.aws_secret_access_key,

sdgym/result_explorer/result_handler.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
import os
66
from abc import ABC, abstractmethod
77
from datetime import datetime
8+
from io import BytesIO
9+
from zipfile import ZipFile
810

911
import cloudpickle
1012
import pandas as pd
1113
import yaml
1214
from botocore.exceptions import ClientError
1315

16+
from sdgym._dataset_utils import _read_zipped_data
17+
1418
SYNTHESIZER_BASELINE = 'GaussianCopulaSynthesizer'
1519
RESULTS_FOLDER_PREFIX = 'SDGym_results_'
1620
metainfo_PREFIX = 'metainfo'
@@ -270,8 +274,12 @@ def load_synthesizer(self, file_path):
270274
return cloudpickle.load(f)
271275

272276
def load_synthetic_data(self, file_path):
273-
"""Load synthetic data from a CSV file."""
274-
return pd.read_csv(os.path.join(self.base_path, file_path))
277+
"""Load synthetic data from a CSV or ZIP file."""
278+
full_path = os.path.join(self.base_path, file_path)
279+
if full_path.endswith('.zip'):
280+
return _read_zipped_data(full_path, modality='multi_table')
281+
282+
return pd.read_csv(full_path)
275283

276284
def _get_results_files(self, folder_name, prefix, suffix):
277285
return [
@@ -383,10 +391,21 @@ def load_synthesizer(self, file_path):
383391

384392
def load_synthetic_data(self, file_path):
385393
"""Load synthetic data from S3."""
386-
response = self.s3_client.get_object(
387-
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
388-
)
389-
return pd.read_csv(io.BytesIO(response['Body'].read()))
394+
key = f'{self.prefix}{file_path}'
395+
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
396+
body = response['Body'].read()
397+
if file_path.endswith('.zip'):
398+
tables = {}
399+
with ZipFile(BytesIO(body)) as zf:
400+
for name in zf.namelist():
401+
if name.endswith('.csv'):
402+
table_name = os.path.splitext(os.path.basename(name))[0]
403+
with zf.open(name) as csv_file:
404+
tables[table_name] = pd.read_csv(csv_file, low_memory=False)
405+
406+
return tables
407+
408+
return pd.read_csv(io.BytesIO(body))
390409

391410
def _get_results_files(self, folder_name, prefix, suffix):
392411
s3_prefix = f'{self.prefix}{folder_name}/'

tests/unit/result_explorer/test_result_explorer.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
import shutil
34
from unittest.mock import Mock, patch
45

56
import pandas as pd
@@ -76,6 +77,22 @@ def test__init__s3(self, mock_is_s3_path, mock_get_s3_client):
7677
assert result_explorer.aws_secret_access_key == aws_secret_access_key
7778
assert isinstance(result_explorer._handler, S3ResultsHandler)
7879

80+
def test_list_with_modality_local(self, tmp_path):
81+
"""Test the `list` method respects the modality subfolder (local)."""
82+
# Setup
83+
base = tmp_path / 'results'
84+
(base / 'unscoped_run').mkdir(parents=True)
85+
(base / 'multi_table' / 'run_mt1').mkdir(parents=True)
86+
(base / 'multi_table' / 'run_mt2').mkdir(parents=True)
87+
88+
result_explorer = ResultsExplorer(str(base), modality='multi_table')
89+
90+
# Run
91+
runs = result_explorer.list()
92+
93+
# Assert
94+
assert set(runs) == {'run_mt1', 'run_mt2'}
95+
7996
def test_list_local(self, tmp_path):
8097
"""Test the `list` method with a local path"""
8198
# Setup
@@ -136,6 +153,28 @@ def test__get_file_path(self):
136153
)
137154
assert file_path == expected_filepath
138155

156+
def test__get_file_path_multi_table_synthetic_data(self, tmp_path):
157+
"""Test `_get_file_path` returns .zip for multi_table synthetic data."""
158+
base = tmp_path / 'results'
159+
multi_table_dir = base / 'multi_table'
160+
multi_table_dir.mkdir(parents=True, exist_ok=True)
161+
explorer = ResultsExplorer(str(multi_table_dir), modality='multi_table')
162+
try:
163+
explorer._handler = Mock()
164+
explorer._handler.get_file_path.return_value = 'irrelevant'
165+
explorer._get_file_path(
166+
results_folder_name='results_folder_07_07_2025',
167+
dataset_name='my_dataset',
168+
synthesizer_name='my_synthesizer',
169+
file_type='synthetic_data',
170+
)
171+
explorer._handler.get_file_path.assert_called_once_with(
172+
['results_folder_07_07_2025', 'my_dataset_07_07_2025', 'my_synthesizer'],
173+
'my_synthesizer_synthetic_data.zip',
174+
)
175+
finally:
176+
shutil.rmtree(multi_table_dir)
177+
139178
def test_load_synthesizer(self, tmp_path):
140179
"""Test `load_synthesizer` method."""
141180
# Setup
@@ -209,6 +248,31 @@ def test_load_real_data(self, mock_load_dataset, tmp_path):
209248
)
210249
pd.testing.assert_frame_equal(real_data, expected_data)
211250

251+
@patch('sdgym.result_explorer.result_explorer.load_dataset')
252+
def test_load_real_data_multi_table(self, mock_load_dataset, tmp_path):
253+
"""Test `load_real_data` for multi_table modality calls load_dataset correctly."""
254+
dataset_name = 'synthea'
255+
expected_data = {'patients': pd.DataFrame({'id': [1]})}
256+
mock_load_dataset.return_value = (expected_data, None)
257+
multi_table_dir = tmp_path / 'multi_table'
258+
multi_table_dir.mkdir(parents=True, exist_ok=True)
259+
result_explorer = ResultsExplorer(tmp_path, modality='multi_table')
260+
261+
try:
262+
# Run
263+
real_data = result_explorer.load_real_data(dataset_name)
264+
265+
# Assert
266+
mock_load_dataset.assert_called_once_with(
267+
modality='multi_table',
268+
dataset='synthea',
269+
aws_access_key_id=None,
270+
aws_secret_access_key=None,
271+
)
272+
assert real_data == expected_data
273+
finally:
274+
shutil.rmtree(multi_table_dir)
275+
212276
def test_load_real_data_invalid_dataset(self, tmp_path):
213277
"""Test `load_real_data` method with an invalid dataset."""
214278
# Setup

tests/unit/result_explorer/test_result_handler.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import pickle
34
import re
@@ -326,6 +327,31 @@ def test_load_synthesizer(self, tmp_path):
326327
assert loaded_synthesizer is not None
327328
assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer)
328329

330+
def test_load_synthetic_data_zip(self, tmp_path):
331+
"""Test the `load_synthetic_data` method for zipped multi-table data (local)."""
332+
# Setup
333+
base = tmp_path / 'results'
334+
data_dir = base / 'SDGym_results_07_07_2025' / 'dataset_07_07_2025' / 'Synth'
335+
data_dir.mkdir(parents=True)
336+
337+
# Create a zip with two csvs
338+
import zipfile
339+
340+
zip_path = data_dir / 'Synth_synthetic_data.zip'
341+
with zipfile.ZipFile(zip_path, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
342+
zf.writestr('table1.csv', 'a,b\n1,2\n')
343+
zf.writestr('table2.csv', 'x,y\n3,4\n')
344+
345+
result_handler = LocalResultsHandler(str(base))
346+
347+
# Run
348+
tables = result_handler.load_synthetic_data(str(zip_path))
349+
350+
# Assert
351+
assert set(tables.keys()) == {'table1', 'table2'}
352+
pd.testing.assert_frame_equal(tables['table1'], pd.DataFrame({'a': [1], 'b': [2]}))
353+
pd.testing.assert_frame_equal(tables['table2'], pd.DataFrame({'x': [3], 'y': [4]}))
354+
329355
@patch('os.path.exists')
330356
@patch('os.path.isfile')
331357
def test_get_file_path_local(self, mock_isfile, mock_exists):
@@ -466,6 +492,34 @@ def test_load_synthesizer(self):
466492
Bucket='my-bucket', Key='prefix/synthesizer.pkl'
467493
)
468494

495+
def test_load_synthetic_data_zip(self):
496+
"""Test the `load_synthetic_data` method for zipped multi-table data (S3)."""
497+
# Setup
498+
import zipfile
499+
500+
buffer = io.BytesIO()
501+
with zipfile.ZipFile(buffer, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
502+
zf.writestr('customers.csv', 'id,age\n1,30\n')
503+
zf.writestr('transactions.csv', 'id,amount\n1,100\n')
504+
buffer.seek(0)
505+
506+
mock_s3_client = Mock()
507+
mock_s3_client.get_object.return_value = {'Body': Mock(read=lambda: buffer.getvalue())}
508+
result_handler = S3ResultsHandler('s3://my-bucket/prefix', mock_s3_client)
509+
510+
# Run
511+
tables = result_handler.load_synthetic_data('some/path.zip')
512+
513+
# Assert
514+
assert set(tables.keys()) == {'customers', 'transactions'}
515+
pd.testing.assert_frame_equal(tables['customers'], pd.DataFrame({'id': [1], 'age': [30]}))
516+
pd.testing.assert_frame_equal(
517+
tables['transactions'], pd.DataFrame({'id': [1], 'amount': [100]})
518+
)
519+
mock_s3_client.get_object.assert_called_once_with(
520+
Bucket='my-bucket', Key='prefix/some/path.zip'
521+
)
522+
469523
def test_get_file_path_s3(self):
470524
"""Test `get_file_path` for S3 path when folders and file exist."""
471525
# Setup

0 commit comments

Comments
 (0)