Skip to content

Commit cf085f5

Browse files
committed
Add multi-table UniformSynthesizer (#497)
1 parent a18e9f1 commit cf085f5

File tree

24 files changed

+623
-206
lines changed

24 files changed

+623
-206
lines changed

.github/workflows/minimum.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ concurrency:
1212
jobs:
1313
minimum:
1414
runs-on: ${{ matrix.os }}
15-
timeout-minutes: 30
15+
timeout-minutes: 45
1616
strategy:
1717
matrix:
1818
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']

sdgym/benchmark.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import math
66
import multiprocessing
77
import os
8-
import pickle
98
import re
109
import textwrap
1110
import threading
@@ -351,7 +350,7 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re
351350
train_end = None
352351
try:
353352
fitted_synthesizer = get_synthesizer(data, metadata)
354-
synthesizer_size = len(pickle.dumps(fitted_synthesizer)) / N_BYTES_IN_MB
353+
synthesizer_size = len(cloudpickle.dumps(fitted_synthesizer)) / N_BYTES_IN_MB
355354
train_end = get_utc_now()
356355
train_time = train_end - start
357356
synthetic_data = sample_from_synthesizer(fitted_synthesizer, num_samples)
@@ -361,7 +360,10 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re
361360

362361
if synthesizer_path is not None and result_writer is not None:
363362
result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data'])
364-
result_writer.write_pickle(fitted_synthesizer, synthesizer_path['synthesizer'])
363+
internal_synthesizer = getattr(
364+
fitted_synthesizer, '_internal_synthesizer', fitted_synthesizer
365+
)
366+
result_writer.write_pickle(internal_synthesizer, synthesizer_path['synthesizer'])
365367

366368
return synthetic_data, train_time, sample_time, synthesizer_size, peak_memory
367369

@@ -1438,7 +1440,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
14381440
job_args_key = f'job_args_list_{metainfo}.pkl'
14391441
job_args_key = f'{path}{job_args_key}' if path else job_args_key
14401442

1441-
serialized_data = pickle.dumps(job_args_list)
1443+
serialized_data = cloudpickle.dumps(job_args_list)
14421444
s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data)
14431445

14441446
return bucket_name, job_args_key
@@ -1449,7 +1451,7 @@ def _get_s3_script_content(
14491451
):
14501452
return f"""
14511453
import boto3
1452-
import pickle
1454+
import cloudpickle
14531455
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
14541456
from io import StringIO
14551457
from sdgym.result_writer import S3ResultsWriter
@@ -1461,7 +1463,7 @@ def _get_s3_script_content(
14611463
region_name='{region_name}'
14621464
)
14631465
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
1464-
job_args_list = pickle.loads(response['Body'].read())
1466+
job_args_list = cloudpickle.loads(response['Body'].read())
14651467
result_writer = S3ResultsWriter(s3_client=s3_client)
14661468
_write_metainfo_file({synthesizers}, job_args_list, result_writer)
14671469
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)

sdgym/result_explorer/result_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import io
44
import operator
55
import os
6-
import pickle
76
from abc import ABC, abstractmethod
87
from datetime import datetime
98

9+
import cloudpickle
1010
import pandas as pd
1111
import yaml
1212
from botocore.exceptions import ClientError
@@ -259,7 +259,7 @@ def get_file_path(self, path_parts, end_filename):
259259
def load_synthesizer(self, file_path):
260260
"""Load a synthesizer from a pickle file."""
261261
with open(os.path.join(self.base_path, file_path), 'rb') as f:
262-
return pickle.load(f)
262+
return cloudpickle.load(f)
263263

264264
def load_synthetic_data(self, file_path):
265265
"""Load synthetic data from a CSV file."""
@@ -370,7 +370,7 @@ def load_synthesizer(self, file_path):
370370
response = self.s3_client.get_object(
371371
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
372372
)
373-
return pickle.loads(response['Body'].read())
373+
return cloudpickle.loads(response['Body'].read())
374374

375375
def load_synthetic_data(self, file_path):
376376
"""Load synthetic data from S3."""

sdgym/result_writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Results writer for SDGym benchmark."""
22

33
import io
4-
import pickle
54
from abc import ABC, abstractmethod
65
from pathlib import Path
76

7+
import cloudpickle
88
import pandas as pd
99
import plotly.graph_objects as go
1010
import yaml
@@ -82,7 +82,7 @@ def write_xlsx(self, data, file_path, index=False):
8282
def write_pickle(self, obj, file_path):
8383
"""Write a Python object to a pickle file."""
8484
with open(file_path, 'wb') as f:
85-
pickle.dump(obj, f)
85+
cloudpickle.dump(obj, f)
8686

8787
def write_yaml(self, data, file_path, append=False):
8888
"""Write data to a YAML file."""
@@ -126,7 +126,7 @@ def write_pickle(self, obj, file_path):
126126
"""Write a Python object to S3 as a pickle file."""
127127
bucket, key = parse_s3_path(file_path)
128128
buffer = io.BytesIO()
129-
pickle.dump(obj, buffer)
129+
cloudpickle.dump(obj, buffer)
130130
buffer.seek(0)
131131
self.s3_client.put_object(Body=buffer.read(), Bucket=bucket, Key=key)
132132

sdgym/synthesizers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sdgym.synthesizers.identity import DataIdentity
99
from sdgym.synthesizers.column import ColumnSynthesizer
1010
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
11-
from sdgym.synthesizers.uniform import UniformSynthesizer
11+
from sdgym.synthesizers.uniform import UniformSynthesizer, MultiTableUniformSynthesizer
1212
from sdgym.synthesizers.utils import (
1313
get_available_single_table_synthesizers,
1414
get_available_multi_table_synthesizers,
@@ -26,6 +26,7 @@
2626
'create_synthesizer_variant',
2727
'get_available_single_table_synthesizers',
2828
'get_available_multi_table_synthesizers',
29+
'MultiTableUniformSynthesizer',
2930
]
3031

3132
for sdv_name in _get_all_sdv_synthesizers():

sdgym/synthesizers/base.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,23 @@
99
LOGGER = logging.getLogger(__name__)
1010

1111

12+
def _is_valid_modality(modality):
13+
return modality in ('single_table', 'multi_table')
14+
15+
16+
def _validate_modality(modality):
17+
if not _is_valid_modality(modality):
18+
raise ValueError(
19+
f"Modality '{modality}' is not valid. Must be either 'single_table' or 'multi_table'."
20+
)
21+
22+
1223
class BaselineSynthesizer(abc.ABC):
1324
"""Base class for all the ``SDGym`` baselines."""
1425

1526
_MODEL_KWARGS = {}
1627
_NATIVELY_SUPPORTED = True
28+
_MODALITY_FLAG = None
1729

1830
@classmethod
1931
def get_subclasses(cls, include_parents=False):
@@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False):
3446
return subclasses
3547

3648
@classmethod
37-
def _get_supported_synthesizers(cls):
49+
def _get_supported_synthesizers(cls, modality):
3850
"""Get the natively supported synthesizer class names."""
39-
subclasses = cls.get_subclasses(include_parents=True)
40-
synthesizers = set()
41-
for name, subclass in subclasses.items():
42-
if subclass._NATIVELY_SUPPORTED:
43-
synthesizers.add(name)
44-
45-
return sorted(synthesizers)
51+
_validate_modality(modality)
52+
return sorted({
53+
name
54+
for name, subclass in cls.get_subclasses(include_parents=True).items()
55+
if (
56+
name != 'MultiTableBaselineSynthesizer'
57+
and subclass._NATIVELY_SUPPORTED
58+
and subclass._MODALITY_FLAG == modality
59+
)
60+
})
4661

4762
@classmethod
4863
def get_baselines(cls):
@@ -55,6 +70,35 @@ def get_baselines(cls):
5570

5671
return synthesizers
5772

73+
def _fit(self, data, metadata):
74+
"""Fit the synthesizer to the data.
75+
76+
Args:
77+
data (pandas.DataFrame):
78+
The data to fit the synthesizer to.
79+
metadata (sdv.metadata.Metadata):
80+
The metadata describing the data.
81+
"""
82+
raise NotImplementedError()
83+
84+
@classmethod
85+
def _get_trained_synthesizer(cls, data, metadata):
86+
"""Train a synthesizer on the provided data and metadata.
87+
88+
Args:
89+
data (pd.DataFrame or dict):
90+
The data to train on.
91+
metadata (sdv.metadata.Metadata):
92+
The metadata
93+
94+
Returns:
95+
A synthesizer object
96+
"""
97+
synthesizer = cls()
98+
synthesizer._fit(data, metadata)
99+
100+
return synthesizer
101+
58102
def get_trained_synthesizer(self, data, metadata):
59103
"""Get a synthesizer that has been trained on the provided data and metadata.
60104
@@ -90,3 +134,25 @@ def sample_from_synthesizer(self, synthesizer, n_samples):
90134
should be a dict mapping table name to DataFrame.
91135
"""
92136
return self._sample_from_synthesizer(synthesizer, n_samples)
137+
138+
139+
class MultiTableBaselineSynthesizer(BaselineSynthesizer):
140+
"""Base class for all multi-table synthesizers."""
141+
142+
_MODALITY_FLAG = 'multi_table'
143+
144+
def sample_from_synthesizer(self, synthesizer, scale=1.0):
145+
"""Sample data from the provided synthesizer.
146+
147+
Args:
148+
synthesizer (obj):
149+
The synthesizer object to sample data from.
150+
scale (float):
151+
The scale of data to sample.
152+
Defaults to 1.0.
153+
154+
Returns:
155+
dict:
156+
The sampled data. A dict mapping table name to DataFrame.
157+
"""
158+
return self._sample_from_synthesizer(synthesizer, scale=scale)

sdgym/synthesizers/column.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ class ColumnSynthesizer(BaselineSynthesizer):
1919
Continuous columns are learned and sampled using a GMM.
2020
"""
2121

22-
def _get_trained_synthesizer(self, real_data, metadata):
22+
_MODALITY_FLAG = 'single_table'
23+
24+
def _fit(self, data, metadata):
2325
hyper_transformer = HyperTransformer()
24-
hyper_transformer.detect_initial_config(real_data)
26+
hyper_transformer.detect_initial_config(data)
2527
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
2628
config = {}
2729
if isinstance(metadata, Metadata):
@@ -46,14 +48,14 @@ def _get_trained_synthesizer(self, real_data, metadata):
4648

4749
# This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
4850
columns_to_remove = [
49-
column_name for column_name, data in real_data.items() if data.dtype.kind in {'O', 'i'}
51+
column_name for column_name, data in data.items() if data.dtype.kind in {'O', 'i'}
5052
]
5153
hyper_transformer.remove_transformers(columns_to_remove)
5254

53-
hyper_transformer.fit(real_data)
54-
transformed = hyper_transformer.transform(real_data)
55+
hyper_transformer.fit(data)
56+
transformed = hyper_transformer.transform(data)
5557

56-
self.length = len(real_data)
58+
self.length = len(data)
5759
gm_models = {}
5860
for name, column in transformed.items():
5961
kind = column.dtype.kind
@@ -63,18 +65,22 @@ def _get_trained_synthesizer(self, real_data, metadata):
6365
model.fit(column.to_numpy().reshape(-1, 1))
6466
gm_models[name] = model
6567

66-
return (hyper_transformer, transformed, gm_models)
68+
self.hyper_transformer = hyper_transformer
69+
self.transformed_data = transformed
70+
self.gm_models = gm_models
6771

6872
def _sample_from_synthesizer(self, synthesizer, n_samples):
69-
hyper_transformer, transformed, gm_models = synthesizer
73+
hyper_transformer = synthesizer.hyper_transformer
74+
transformed = synthesizer.transformed_data
75+
gm_models = synthesizer.gm_models
7076
sampled = pd.DataFrame()
7177
for name, column in transformed.items():
7278
kind = column.dtype.kind
7379
if kind == 'O':
74-
values = column.sample(self.length, replace=True).to_numpy()
80+
values = column.sample(n_samples, replace=True).to_numpy()
7581
else:
7682
model = gm_models.get(name)
77-
values = model.sample(self.length)[0].ravel().clip(column.min(), column.max())
83+
values = model.sample(n_samples)[0].ravel().clip(column.min(), column.max())
7884

7985
sampled[name] = values
8086

0 commit comments

Comments
 (0)