Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/minimum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ concurrency:
jobs:
minimum:
runs-on: ${{ matrix.os }}
timeout-minutes: 30
timeout-minutes: 45
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
Expand Down
14 changes: 7 additions & 7 deletions sdgym/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import math
import multiprocessing
import os
import pickle
import re
import textwrap
import threading
Expand Down Expand Up @@ -345,17 +344,18 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re
tracemalloc.start()
now = get_utc_now()
synthesizer_obj = get_synthesizer(data, metadata)
synthesizer_size = len(pickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB
synthesizer_size = len(cloudpickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB
train_now = get_utc_now()
synthetic_data = sample_from_synthesizer(synthesizer_obj, num_samples)
synthetic_data = sample_from_synthesizer(synthesizer_obj, n_samples=num_samples)
sample_now = get_utc_now()

peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB
tracemalloc.stop()
tracemalloc.clear_traces()
if synthesizer_path is not None and result_writer is not None:
result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data'])
result_writer.write_pickle(synthesizer_obj, synthesizer_path['synthesizer'])
internal_synthesizer = getattr(synthesizer_obj, '_internal_synthesizer', synthesizer_obj)
result_writer.write_pickle(internal_synthesizer, synthesizer_path['synthesizer'])

return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory

Expand Down Expand Up @@ -1373,7 +1373,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client):
job_args_key = f'job_args_list_{metainfo}.pkl'
job_args_key = f'{path}{job_args_key}' if path else job_args_key

serialized_data = pickle.dumps(job_args_list)
serialized_data = cloudpickle.dumps(job_args_list)
s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data)

return bucket_name, job_args_key
Expand All @@ -1384,7 +1384,7 @@ def _get_s3_script_content(
):
return f"""
import boto3
import pickle
import cloudpickle
from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file
from io import StringIO
from sdgym.result_writer import S3ResultsWriter
Expand All @@ -1396,7 +1396,7 @@ def _get_s3_script_content(
region_name='{region_name}'
)
response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}')
job_args_list = pickle.loads(response['Body'].read())
job_args_list = cloudpickle.loads(response['Body'].read())
result_writer = S3ResultsWriter(s3_client=s3_client)
_write_metainfo_file({synthesizers}, job_args_list, result_writer)
scores = _run_jobs(None, job_args_list, False, result_writer=result_writer)
Expand Down
6 changes: 3 additions & 3 deletions sdgym/result_explorer/result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import io
import operator
import os
import pickle
from abc import ABC, abstractmethod
from datetime import datetime

import cloudpickle
import pandas as pd
import yaml
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -250,7 +250,7 @@ def get_file_path(self, path_parts, end_filename):
def load_synthesizer(self, file_path):
"""Load a synthesizer from a pickle file."""
with open(os.path.join(self.base_path, file_path), 'rb') as f:
return pickle.load(f)
return cloudpickle.load(f)

def load_synthetic_data(self, file_path):
"""Load synthetic data from a CSV file."""
Expand Down Expand Up @@ -361,7 +361,7 @@ def load_synthesizer(self, file_path):
response = self.s3_client.get_object(
Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}'
)
return pickle.loads(response['Body'].read())
return cloudpickle.loads(response['Body'].read())

def load_synthetic_data(self, file_path):
"""Load synthetic data from S3."""
Expand Down
6 changes: 3 additions & 3 deletions sdgym/result_writer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Results writer for SDGym benchmark."""

import io
import pickle
from abc import ABC, abstractmethod
from pathlib import Path

import cloudpickle
import pandas as pd
import plotly.graph_objects as go
import yaml
Expand Down Expand Up @@ -82,7 +82,7 @@ def write_xlsx(self, data, file_path, index=False):
def write_pickle(self, obj, file_path):
"""Write a Python object to a pickle file."""
with open(file_path, 'wb') as f:
pickle.dump(obj, f)
cloudpickle.dump(obj, f)

def write_yaml(self, data, file_path, append=False):
"""Write data to a YAML file."""
Expand Down Expand Up @@ -126,7 +126,7 @@ def write_pickle(self, obj, file_path):
"""Write a Python object to S3 as a pickle file."""
bucket, key = parse_s3_path(file_path)
buffer = io.BytesIO()
pickle.dump(obj, buffer)
cloudpickle.dump(obj, buffer)
buffer.seek(0)
self.s3_client.put_object(Body=buffer.read(), Bucket=bucket, Key=key)

Expand Down
3 changes: 2 additions & 1 deletion sdgym/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sdgym.synthesizers.identity import DataIdentity
from sdgym.synthesizers.column import ColumnSynthesizer
from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer
from sdgym.synthesizers.uniform import UniformSynthesizer
from sdgym.synthesizers.uniform import UniformSynthesizer, MultiTableUniformSynthesizer
from sdgym.synthesizers.utils import (
get_available_single_table_synthesizers,
get_available_multi_table_synthesizers,
Expand All @@ -26,6 +26,7 @@
'create_synthesizer_variant',
'get_available_single_table_synthesizers',
'get_available_multi_table_synthesizers',
'MultiTableUniformSynthesizer',
]

for sdv_name in _get_all_sdv_synthesizers():
Expand Down
82 changes: 74 additions & 8 deletions sdgym/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,23 @@
LOGGER = logging.getLogger(__name__)


def _is_valid_modality(modality):
return modality in ('single_table', 'multi_table')


def _validate_modality(modality):
if not _is_valid_modality(modality):
raise ValueError(
f"Modality '{modality}' is not valid. Must be either 'single_table' or 'multi_table'."
)


class BaselineSynthesizer(abc.ABC):
"""Base class for all the ``SDGym`` baselines."""

_MODEL_KWARGS = {}
_NATIVELY_SUPPORTED = True
_MODALITY_FLAG = None

@classmethod
def get_subclasses(cls, include_parents=False):
Expand All @@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False):
return subclasses

@classmethod
def _get_supported_synthesizers(cls):
def _get_supported_synthesizers(cls, modality):
"""Get the natively supported synthesizer class names."""
subclasses = cls.get_subclasses(include_parents=True)
synthesizers = set()
for name, subclass in subclasses.items():
if subclass._NATIVELY_SUPPORTED:
synthesizers.add(name)

return sorted(synthesizers)
_validate_modality(modality)
return sorted({
name
for name, subclass in cls.get_subclasses(include_parents=True).items()
if (
name != 'MultiTableBaselineSynthesizer'
and subclass._NATIVELY_SUPPORTED
and subclass._MODALITY_FLAG == modality
)
})

@classmethod
def get_baselines(cls):
Expand All @@ -55,6 +70,35 @@ def get_baselines(cls):

return synthesizers

def _fit(self, data, metadata):
"""Fit the synthesizer to the data.

Args:
data (pandas.DataFrame):
The data to fit the synthesizer to.
metadata (sdv.metadata.Metadata):
The metadata describing the data.
"""
raise NotImplementedError()

@classmethod
def _get_trained_synthesizer(cls, data, metadata):
"""Train a synthesizer on the provided data and metadata.

Args:
data (pd.DataFrame or dict):
The data to train on.
metadata (sdv.metadata.Metadata):
The metadata

Returns:
A synthesizer object
"""
synthesizer = cls()
synthesizer._fit(data, metadata)

return synthesizer

def get_trained_synthesizer(self, data, metadata):
"""Get a synthesizer that has been trained on the provided data and metadata.

Expand Down Expand Up @@ -90,3 +134,25 @@ def sample_from_synthesizer(self, synthesizer, n_samples):
should be a dict mapping table name to DataFrame.
"""
return self._sample_from_synthesizer(synthesizer, n_samples)


class MultiTableBaselineSynthesizer(BaselineSynthesizer):
"""Base class for all multi-table synthesizers."""

_MODALITY_FLAG = 'multi_table'

def sample_from_synthesizer(self, synthesizer, scale=1.0):
"""Sample data from the provided synthesizer.

Args:
synthesizer (obj):
The synthesizer object to sample data from.
scale (float):
The scale of data to sample.
Defaults to 1.0.

Returns:
dict:
The sampled data. A dict mapping table name to DataFrame.
"""
return self._sample_from_synthesizer(synthesizer, scale=scale)
26 changes: 16 additions & 10 deletions sdgym/synthesizers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ class ColumnSynthesizer(BaselineSynthesizer):
Continuous columns are learned and sampled using a GMM.
"""

def _get_trained_synthesizer(self, real_data, metadata):
_MODALITY_FLAG = 'single_table'

def _fit(self, data, metadata):
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(real_data)
hyper_transformer.detect_initial_config(data)
supported_sdtypes = hyper_transformer._get_supported_sdtypes()
config = {}
if isinstance(metadata, Metadata):
Expand All @@ -46,14 +48,14 @@ def _get_trained_synthesizer(self, real_data, metadata):

# This is done to match the behavior of the synthesizer for SDGym <= 0.6.0
columns_to_remove = [
column_name for column_name, data in real_data.items() if data.dtype.kind in {'O', 'i'}
column_name for column_name, data in data.items() if data.dtype.kind in {'O', 'i'}
]
hyper_transformer.remove_transformers(columns_to_remove)

hyper_transformer.fit(real_data)
transformed = hyper_transformer.transform(real_data)
hyper_transformer.fit(data)
transformed = hyper_transformer.transform(data)

self.length = len(real_data)
self.length = len(data)
gm_models = {}
for name, column in transformed.items():
kind = column.dtype.kind
Expand All @@ -63,18 +65,22 @@ def _get_trained_synthesizer(self, real_data, metadata):
model.fit(column.to_numpy().reshape(-1, 1))
gm_models[name] = model

return (hyper_transformer, transformed, gm_models)
self.hyper_transformer = hyper_transformer
self.transformed_data = transformed
self.gm_models = gm_models

def _sample_from_synthesizer(self, synthesizer, n_samples):
hyper_transformer, transformed, gm_models = synthesizer
hyper_transformer = synthesizer.hyper_transformer
transformed = synthesizer.transformed_data
gm_models = synthesizer.gm_models
sampled = pd.DataFrame()
for name, column in transformed.items():
kind = column.dtype.kind
if kind == 'O':
values = column.sample(self.length, replace=True).to_numpy()
values = column.sample(n_samples, replace=True).to_numpy()
else:
model = gm_models.get(name)
values = model.sample(self.length)[0].ravel().clip(column.min(), column.max())
values = model.sample(n_samples)[0].ravel().clip(column.min(), column.max())

sampled[name] = values

Expand Down
Loading