Skip to content

Commit e14f7ff

Browse files
Add working multi table benchmark (#504)
1 parent cf085f5 commit e14f7ff

File tree

18 files changed

+1325
-254
lines changed

18 files changed

+1325
-254
lines changed

Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,11 @@ fix-lint:
9393
# TEST TARGETS
9494
.PHONY: test-unit
9595
test-unit: ## run tests quickly with the default Python
96-
python -m pytest --cov=sdgym
96+
invoke unit
97+
98+
.PHONY: test-integration
99+
test-integration: ## run tests quickly with the default Python
100+
invoke integration
97101

98102
.PHONY: test-readme
99103
test-readme: ## run the readme snippets
@@ -102,7 +106,7 @@ test-readme: ## run the readme snippets
102106
rm -rf tests/readme_test
103107

104108
.PHONY: test
105-
test: test-unit test-readme ## test everything that needs test dependencies
109+
test: test-unit test-integration ## test everything that needs test dependencies
106110

107111
.PHONY: test-devel
108112
test-devel: lint ## test everything that needs development dependencies

sdgym/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import logging
1414

15-
from sdgym.benchmark import benchmark_single_table, benchmark_single_table_aws
15+
from sdgym.benchmark import (
16+
benchmark_multi_table,
17+
benchmark_single_table,
18+
benchmark_single_table_aws,
19+
)
1620
from sdgym.cli.collect import collect_results
1721
from sdgym.cli.summary import make_summary_spreadsheet
1822
from sdgym.dataset_explorer import DatasetExplorer
@@ -31,12 +35,13 @@
3135
__all__ = [
3236
'DatasetExplorer',
3337
'ResultsExplorer',
38+
'benchmark_multi_table',
3439
'benchmark_single_table',
3540
'benchmark_single_table_aws',
3641
'collect_results',
37-
'create_synthesizer_variant',
38-
'create_single_table_synthesizer',
3942
'create_multi_table_synthesizer',
43+
'create_single_table_synthesizer',
44+
'create_synthesizer_variant',
4045
'load_dataset',
4146
'make_summary_spreadsheet',
4247
]

sdgym/_dataset_utils.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77

88
import numpy as np
99
import pandas as pd
10+
from sdv.metadata import Metadata
11+
from sdv.utils import poc
1012

1113
LOGGER = logging.getLogger(__name__)
1214

15+
MAX_NUM_COLUMNS = 10
16+
MAX_NUM_ROWS = 1000
17+
1318

1419
def _parse_numeric_value(value, dataset_name, field_name, target_type=float):
1520
"""Generic parser for numeric values with logging and NaN fallback."""
@@ -23,6 +28,65 @@ def _parse_numeric_value(value, dataset_name, field_name, target_type=float):
2328
return np.nan
2429

2530

31+
def _filter_columns(columns, mandatory_columns):
32+
"""Given a dictionary of columns and a list of mandatory ones, return a filtered subset."""
33+
mandatory_columns = [m_col for m_col in mandatory_columns if m_col in columns]
34+
optional_columns = [col for col in columns if col not in mandatory_columns]
35+
keep_columns = mandatory_columns + optional_columns[:MAX_NUM_COLUMNS]
36+
return {col: columns[col] for col in keep_columns if col in columns}
37+
38+
39+
def _get_multi_table_dataset_subset(data, metadata_dict):
40+
"""Create a smaller, referentially consistent subset of multi-table data.
41+
42+
This function limits each table to at most 10 columns by keeping all
43+
mandatory columns and, if needed, a subset of the remaining columns, then
44+
trims the underlying DataFrames to match the updated metadata. Finally, it
45+
uses SDV's multi-table utility to sample up to 1,000 rows from
46+
the main table and a consistent subset of rows from all related tables
47+
while preserving referential integrity.
48+
49+
Args:
50+
data (dict):
51+
A dictionary where keys are table names and values are DataFrames
52+
representing tables.
53+
metadata_dict (dict):
54+
Metadata dictionary containing schema information for each table.
55+
56+
Returns:
57+
tuple:
58+
A tuple containing:
59+
- dict: The subset of the input data with reduced columns and rows.
60+
- dict: The updated metadata dictionary reflecting the reduced column sets.
61+
"""
62+
metadata = Metadata.load_from_dict(metadata_dict)
63+
for table_name, table in metadata.tables.items():
64+
table_columns = table.columns
65+
mandatory_columns = list(metadata._get_all_keys(table_name))
66+
subset_column_schema = _filter_columns(
67+
columns=table_columns, mandatory_columns=mandatory_columns
68+
)
69+
metadata_dict['tables'][table_name]['columns'] = subset_column_schema
70+
71+
# Re-load the metadata object that will be used with the `SDV` utility function
72+
metadata = Metadata.load_from_dict(metadata_dict)
73+
largest_table_name = max(data, key=lambda table_name: len(data[table_name]))
74+
75+
# Trim the data to contain only the subset of columns
76+
for table_name, table in metadata.tables.items():
77+
data[table_name] = data[table_name][list(table.columns)]
78+
79+
# Subsample the data mantaining the referential integrity
80+
data = poc.get_random_subset(
81+
data=data,
82+
metadata=metadata,
83+
main_table_name=largest_table_name,
84+
num_rows=MAX_NUM_ROWS,
85+
verbose=False,
86+
)
87+
return data, metadata_dict
88+
89+
2690
def _get_dataset_subset(data, metadata_dict, modality):
2791
"""Limit the size of a dataset for faster evaluation or testing.
2892
@@ -31,52 +95,37 @@ def _get_dataset_subset(data, metadata_dict, modality):
3195
columns—such as sequence indices and keys in sequential datasets—are always retained.
3296
3397
Args:
34-
data (pd.DataFrame):
98+
data (pd.DataFrame or dict):
3599
The dataset to be reduced.
36100
metadata_dict (dict):
37-
A dictionary containing the dataset's metadata.
101+
A dictionary representing the dataset's metadata.
38102
modality (str):
39-
The dataset modality. Must be one of: ``'single_table'``, ``'sequential'``.
103+
The dataset modality.
40104
41105
Returns:
42106
tuple[pd.DataFrame, dict]:
43107
A tuple containing:
44-
- The reduced dataset as a DataFrame.
108+
- The reduced dataset as a DataFrame or Dictionary.
45109
- The updated metadata dictionary reflecting any removed columns.
46-
47-
Raises:
48-
ValueError:
49-
If the provided modality is ``'multi_table'``.
50110
"""
51111
if modality == 'multi_table':
52-
raise ValueError('limit_dataset_size is not supported for multi-table datasets.')
112+
return _get_multi_table_dataset_subset(data, metadata_dict)
53113

54-
max_rows, max_columns = (1000, 10)
55114
tables = metadata_dict.get('tables', {})
56115
mandatory_columns = []
57116
table_name, table_info = next(iter(tables.items()))
58-
59117
columns = table_info.get('columns', {})
60-
keep_columns = list(columns)
61-
if modality == 'sequential':
62-
seq_index = table_info.get('sequence_index')
63-
seq_key = table_info.get('sequence_key')
64-
mandatory_columns = [col for col in (seq_index, seq_key) if col]
65118

66-
optional_columns = [col for col in columns if col not in mandatory_columns]
119+
seq_index = table_info.get('sequence_index')
120+
seq_key = table_info.get('sequence_key')
121+
mandatory_columns = [column for column in (seq_index, seq_key) if column]
122+
filtered = _filter_columns(columns=columns, mandatory_columns=mandatory_columns)
67123

68-
# If we have too many columns, drop extras but never mandatory ones
69-
if len(columns) > max_columns:
70-
keep_count = max_columns - len(mandatory_columns)
71-
keep_columns = mandatory_columns + optional_columns[:keep_count]
72-
table_info['columns'] = {
73-
column_name: column_definition
74-
for column_name, column_definition in columns.items()
75-
if column_name in keep_columns
76-
}
77-
78-
data = data[list(keep_columns)]
124+
table_info['columns'] = filtered
125+
data = data[list(filtered)]
126+
max_rows = min(MAX_NUM_ROWS, len(data))
79127
data = data.sample(max_rows)
128+
80129
return data, metadata_dict
81130

82131

0 commit comments

Comments
 (0)