Skip to content
2 changes: 1 addition & 1 deletion sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re
synthesizer_obj = get_synthesizer(data, metadata)
synthesizer_size = len(pickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB
train_now = get_utc_now()
synthetic_data = sample_from_synthesizer(synthesizer_obj, num_samples)
synthetic_data = sample_from_synthesizer(synthesizer_obj, n_samples=num_samples)
sample_now = get_utc_now()

peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB
Expand Down
3 changes: 2 additions & 1 deletion sdgym/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sdgym.synthesizers.identity import DataIdentity
from sdgym.synthesizers.column import ColumnSynthesizer
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
from sdgym.synthesizers.uniform import UniformSynthesizer
from sdgym.synthesizers.uniform import UniformSynthesizer, MultiTableUniformSynthesizer
from sdgym.synthesizers.utils import (
get_available_single_table_synthesizers,
get_available_multi_table_synthesizers,
Expand All @@ -26,6 +26,7 @@
'create_synthesizer_variant',
'get_available_single_table_synthesizers',
'get_available_multi_table_synthesizers',
'MultiTableUniformSynthesizer',
]

for sdv_name in _get_all_sdv_synthesizers():
Expand Down
71 changes: 62 additions & 9 deletions sdgym/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,23 @@
LOGGER = logging.getLogger(__name__)


def _is_valid_modality(modality):
return modality in ('single_table', 'multi_table')


def _validate_modality(modality):
if not _is_valid_modality(modality):
raise ValueError(
f"Modality '{modality}' is not valid. Must be either 'single_table' or 'multi_table'."
)


class BaselineSynthesizer(abc.ABC):
"""Base class for all the ``SDGym`` baselines."""

_MODEL_KWARGS = {}
_NATIVELY_SUPPORTED = True
_MODALITY_FLAG = None

@classmethod
def get_subclasses(cls, include_parents=False):
Expand All @@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False):
return subclasses

@classmethod
def _get_supported_synthesizers(cls):
def _get_supported_synthesizers(cls, modality):
"""Get the natively supported synthesizer class names."""
subclasses = cls.get_subclasses(include_parents=True)
synthesizers = set()
for name, subclass in subclasses.items():
if subclass._NATIVELY_SUPPORTED:
synthesizers.add(name)

return sorted(synthesizers)
_validate_modality(modality)
return sorted({
name
for name, subclass in cls.get_subclasses(include_parents=True).items()
if (
name != 'MultiTableBaselineSynthesizer'
and subclass._NATIVELY_SUPPORTED
and subclass._MODALITY_FLAG == modality
)
})

@classmethod
def get_baselines(cls):
Expand All @@ -55,6 +70,13 @@ def get_baselines(cls):

return synthesizers

def _validate_modality_flag(self):
if not _is_valid_modality(self._MODALITY_FLAG):
raise ValueError(
f"The `_MODALITY_FLAG` '{self._MODALITY_FLAG}' of the synthesizer is not valid. "
"Must be either 'single_table' or 'multi_table'."
)

def get_trained_synthesizer(self, data, metadata):
"""Get a synthesizer that has been trained on the provided data and metadata.

Expand All @@ -68,14 +90,15 @@ def get_trained_synthesizer(self, data, metadata):
obj:
The synthesizer object.
"""
self._validate_modality_flag()
metadata_object = Metadata()
with warnings.catch_warnings():
warnings.simplefilter('ignore', UserWarning)
metadata = metadata_object.load_from_dict(metadata)

return self._get_trained_synthesizer(data, metadata)

def sample_from_synthesizer(self, synthesizer, n_samples):
def sample_from_synthesizer(self, synthesizer, *, n_samples):
"""Sample data from the provided synthesizer.

Args:
Expand All @@ -90,3 +113,33 @@ def sample_from_synthesizer(self, synthesizer, n_samples):
should be a dict mapping table name to DataFrame.
"""
return self._sample_from_synthesizer(synthesizer, n_samples)


class MultiTableBaselineSynthesizer(BaselineSynthesizer):
"""Base class for all multi-table synthesizers."""

_MODALITY_FLAG = 'multi_table'

def sample_from_synthesizer(self, synthesizer, *, scale=1.0, n_samples=None):
"""Sample data from the provided synthesizer.

Args:
synthesizer (obj):
The synthesizer object to sample data from.
scale (float):
The scale of data to sample.
Defaults to 1.0.
n_samples (int):
This parameter is not supported for multi-table synthesizers.
Use `scale` instead.

Returns:
dict:
The sampled data. A dict mapping table name to DataFrame.
"""
if n_samples is not None:
raise TypeError(
'Multi-table synthesizers do not support `n_samples`. Use `scale` instead.'
)

return self._sample_from_synthesizer(synthesizer, scale=scale)
2 changes: 2 additions & 0 deletions sdgym/synthesizers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ColumnSynthesizer(BaselineSynthesizer):
Continuous columns are learned and sampled using a GMM.
"""

_MODALITY_FLAG = 'single_table'

def _get_trained_synthesizer(self, real_data, metadata):
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(real_data)
Expand Down
34 changes: 20 additions & 14 deletions sdgym/synthesizers/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Helpers to create SDGym synthesizer variants."""

from sdgym.synthesizers.base import BaselineSynthesizer
from sdgym.synthesizers.base import BaselineSynthesizer, MultiTableBaselineSynthesizer
from sdgym.synthesizers.utils import _get_supported_synthesizers


Expand Down Expand Up @@ -36,7 +36,7 @@ def create_synthesizer_variant(display_name, synthesizer_class, synthesizer_para
return NewSynthesizer


def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_arg_name):
def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, modality):
"""Create a synthesizer class.

Args:
Expand All @@ -47,10 +47,8 @@ def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_ar
A function to generate and train a synthesizer, given the real data and metadata.
sample_from_synthesizer (callable):
A function to sample from the given synthesizer.
sample_arg_name (str):
The name of the argument used to specify the number of samples to generate.
Either 'num_samples' for single-table synthesizers, or 'scale' for multi-table
synthesizers.
modality (str):
The modality of the synthesizer. Either 'single_table' or 'multi_table'.

Returns:
class:
Expand All @@ -61,22 +59,30 @@ def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_ar
def get_trained_synthesizer(self, data, metadata):
return get_trained_fn(data, metadata)

if sample_arg_name == 'num_samples':
if modality == 'multi_table':

def sample_from_synthesizer(self, synthesizer, num_samples):
return sample_fn(synthesizer, num_samples)
def sample_from_synthesizer(self, synthesizer, *, scale=1.0, n_samples=None):
if n_samples is not None:
raise TypeError(
'Multi-table synthesizers do not support `n_samples`. Use `scale` instead.'
)
return sample_fn(synthesizer, scale)

base_class = MultiTableBaselineSynthesizer
else:

def sample_from_synthesizer(self, synthesizer, scale):
return sample_fn(synthesizer, scale)
def sample_from_synthesizer(self, synthesizer, *, n_samples):
return sample_fn(synthesizer, n_samples)

base_class = BaselineSynthesizer

CustomSynthesizer = type(
class_name,
(BaselineSynthesizer,),
(base_class,),
{
'__module__': __name__,
'_NATIVELY_SUPPORTED': False,
'_MODALITY_FLAG': modality,
'get_trained_synthesizer': get_trained_synthesizer,
'sample_from_synthesizer': sample_from_synthesizer,
},
Expand All @@ -94,7 +100,7 @@ def create_single_table_synthesizer(
display_name,
get_trained_synthesizer_fn,
sample_from_synthesizer_fn,
sample_arg_name='num_samples',
modality='single_table',
)


Expand All @@ -106,5 +112,5 @@ def create_multi_table_synthesizer(
display_name,
get_trained_synthesizer_fn,
sample_from_synthesizer_fn,
sample_arg_name='scale',
modality='multi_table',
)
2 changes: 2 additions & 0 deletions sdgym/synthesizers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class DataIdentity(BaselineSynthesizer):
Returns the same exact data that is used to fit it.
"""

_MODALITY_FLAG = 'single_table'

def __init__(self):
self._data = None

Expand Down
1 change: 1 addition & 0 deletions sdgym/synthesizers/realtabformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class RealTabFormerSynthesizer(BaselineSynthesizer):

LOGGER = logging.getLogger(__name__)
_MODEL_KWARGS = None
_MODALITY_FLAG = 'single_table'

def _get_trained_synthesizer(self, data, metadata):
try:
Expand Down
21 changes: 10 additions & 11 deletions sdgym/synthesizers/sdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from sdv import multi_table, single_table

from sdgym.synthesizers.base import BaselineSynthesizer
from sdgym.synthesizers.base import (
BaselineSynthesizer,
MultiTableBaselineSynthesizer,
_validate_modality,
)

LOGGER = logging.getLogger(__name__)
UNSUPPORTED_SDV_SYNTHESIZERS = ['DayZSynthesizer']
Expand All @@ -16,12 +20,6 @@
}


def _validate_modality(modality):
"""Validate that the modality is correct."""
if modality not in ['single_table', 'multi_table']:
raise ValueError("`modality` must be one of 'single_table' or 'multi_table'.")


def _get_sdv_synthesizers(modality):
_validate_modality(modality)
module = MODALITY_TO_MODULE[modality]
Expand All @@ -41,15 +39,15 @@ def _get_all_sdv_synthesizers():

def _get_trained_synthesizer(self, data, metadata):
LOGGER.info('Fitting %s', self.__class__.__name__)
sdv_class = getattr(import_module(f'sdv.{self.modality}'), self.SDV_NAME)
sdv_class = getattr(import_module(f'sdv.{self._MODALITY_FLAG}'), self.SDV_NAME)
synthesizer = sdv_class(metadata=metadata, **self._MODEL_KWARGS)
synthesizer.fit(data)
return synthesizer


def _sample_from_synthesizer(self, synthesizer, sample_arg):
LOGGER.info('Sampling %s', self.__class__.__name__)
if self.modality == 'multi_table':
if self._MODALITY_FLAG == 'multi_table':
return synthesizer.sample(scale=sample_arg)

return synthesizer.sample(num_rows=sample_arg)
Expand Down Expand Up @@ -82,13 +80,14 @@ def _create_sdv_class(sdv_name):
"""Create a SDV synthesizer class dynamically."""
current_module = sys.modules[__name__]
modality = _get_modality(sdv_name)
base_class = MultiTableBaselineSynthesizer if modality == 'multi_table' else BaselineSynthesizer
synthesizer_class = type(
sdv_name,
(BaselineSynthesizer,),
(base_class,),
{
'__module__': __name__,
'SDV_NAME': sdv_name,
'modality': modality,
'_MODALITY_FLAG': modality,
'_MODEL_KWARGS': {},
'_get_trained_synthesizer': _get_trained_synthesizer,
'_sample_from_synthesizer': _sample_from_synthesizer,
Expand Down
Loading