diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 1171d63b..b7f6d6b4 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -12,7 +12,7 @@ concurrency: jobs: minimum: runs-on: ${{ matrix.os }} - timeout-minutes: 30 + timeout-minutes: 45 strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py index 64bbbbe2..42d78fe8 100644 --- a/sdgym/benchmark.py +++ b/sdgym/benchmark.py @@ -5,7 +5,6 @@ import math import multiprocessing import os -import pickle import re import textwrap import threading @@ -345,9 +344,9 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re tracemalloc.start() now = get_utc_now() synthesizer_obj = get_synthesizer(data, metadata) - synthesizer_size = len(pickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB + synthesizer_size = len(cloudpickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB train_now = get_utc_now() - synthetic_data = sample_from_synthesizer(synthesizer_obj, num_samples) + synthetic_data = sample_from_synthesizer(synthesizer_obj, n_samples=num_samples) sample_now = get_utc_now() peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB @@ -355,7 +354,8 @@ def _synthesize(synthesizer_dict, real_data, metadata, synthesizer_path=None, re tracemalloc.clear_traces() if synthesizer_path is not None and result_writer is not None: result_writer.write_dataframe(synthetic_data, synthesizer_path['synthetic_data']) - result_writer.write_pickle(synthesizer_obj, synthesizer_path['synthesizer']) + internal_synthesizer = getattr(synthesizer_obj, '_internal_synthesizer', synthesizer_obj) + result_writer.write_pickle(internal_synthesizer, synthesizer_path['synthesizer']) return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory @@ -1373,7 +1373,7 @@ def _store_job_args_in_s3(output_destination, job_args_list, s3_client): job_args_key = f'job_args_list_{metainfo}.pkl' job_args_key = f'{path}{job_args_key}' if path else job_args_key - serialized_data = pickle.dumps(job_args_list) + serialized_data = cloudpickle.dumps(job_args_list) s3_client.put_object(Bucket=bucket_name, Key=job_args_key, Body=serialized_data) return bucket_name, job_args_key @@ -1384,7 +1384,7 @@ def _get_s3_script_content( ): return f""" import boto3 -import pickle +import cloudpickle from sdgym.benchmark import _run_jobs, _write_metainfo_file, _update_metainfo_file from io import StringIO from sdgym.result_writer import S3ResultsWriter @@ -1396,7 +1396,7 @@ def _get_s3_script_content( region_name='{region_name}' ) response = s3_client.get_object(Bucket='{bucket_name}', Key='{job_args_key}') -job_args_list = pickle.loads(response['Body'].read()) +job_args_list = cloudpickle.loads(response['Body'].read()) result_writer = S3ResultsWriter(s3_client=s3_client) _write_metainfo_file({synthesizers}, job_args_list, result_writer) scores = _run_jobs(None, job_args_list, False, result_writer=result_writer) diff --git a/sdgym/result_explorer/result_handler.py b/sdgym/result_explorer/result_handler.py index b0a3ea6f..84f22f3d 100644 --- a/sdgym/result_explorer/result_handler.py +++ b/sdgym/result_explorer/result_handler.py @@ -3,10 +3,10 @@ import io import operator import os -import pickle from abc import ABC, abstractmethod from datetime import datetime +import cloudpickle import pandas as pd import yaml from botocore.exceptions import ClientError @@ -250,7 +250,7 @@ def get_file_path(self, path_parts, end_filename): def load_synthesizer(self, file_path): """Load a synthesizer from a pickle file.""" with open(os.path.join(self.base_path, file_path), 'rb') as f: - return pickle.load(f) + return cloudpickle.load(f) def load_synthetic_data(self, file_path): """Load synthetic data from a CSV file.""" @@ -361,7 +361,7 @@ def load_synthesizer(self, file_path): response = self.s3_client.get_object( Bucket=self.bucket_name, Key=f'{self.prefix}{file_path}' ) - return pickle.loads(response['Body'].read()) + return cloudpickle.loads(response['Body'].read()) def load_synthetic_data(self, file_path): """Load synthetic data from S3.""" diff --git a/sdgym/result_writer.py b/sdgym/result_writer.py index 3101384a..718871d7 100644 --- a/sdgym/result_writer.py +++ b/sdgym/result_writer.py @@ -1,10 +1,10 @@ """Results writer for SDGym benchmark.""" import io -import pickle from abc import ABC, abstractmethod from pathlib import Path +import cloudpickle import pandas as pd import plotly.graph_objects as go import yaml @@ -82,7 +82,7 @@ def write_xlsx(self, data, file_path, index=False): def write_pickle(self, obj, file_path): """Write a Python object to a pickle file.""" with open(file_path, 'wb') as f: - pickle.dump(obj, f) + cloudpickle.dump(obj, f) def write_yaml(self, data, file_path, append=False): """Write data to a YAML file.""" @@ -126,7 +126,7 @@ def write_pickle(self, obj, file_path): """Write a Python object to S3 as a pickle file.""" bucket, key = parse_s3_path(file_path) buffer = io.BytesIO() - pickle.dump(obj, buffer) + cloudpickle.dump(obj, buffer) buffer.seek(0) self.s3_client.put_object(Body=buffer.read(), Bucket=bucket, Key=key) diff --git a/sdgym/synthesizers/__init__.py b/sdgym/synthesizers/__init__.py index c7f44b8b..67368dc7 100644 --- a/sdgym/synthesizers/__init__.py +++ b/sdgym/synthesizers/__init__.py @@ -8,7 +8,7 @@ from sdgym.synthesizers.identity import DataIdentity from sdgym.synthesizers.column import ColumnSynthesizer from sdgym.synthesizers.realtabformer import RealTabFormerSynthesizer -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers.uniform import UniformSynthesizer, MultiTableUniformSynthesizer from sdgym.synthesizers.utils import ( get_available_single_table_synthesizers, get_available_multi_table_synthesizers, @@ -26,6 +26,7 @@ 'create_synthesizer_variant', 'get_available_single_table_synthesizers', 'get_available_multi_table_synthesizers', + 'MultiTableUniformSynthesizer', ] for sdv_name in _get_all_sdv_synthesizers(): diff --git a/sdgym/synthesizers/base.py b/sdgym/synthesizers/base.py index 01c1c19a..0aae184f 100644 --- a/sdgym/synthesizers/base.py +++ b/sdgym/synthesizers/base.py @@ -9,11 +9,23 @@ LOGGER = logging.getLogger(__name__) +def _is_valid_modality(modality): + return modality in ('single_table', 'multi_table') + + +def _validate_modality(modality): + if not _is_valid_modality(modality): + raise ValueError( + f"Modality '{modality}' is not valid. Must be either 'single_table' or 'multi_table'." + ) + + class BaselineSynthesizer(abc.ABC): """Base class for all the ``SDGym`` baselines.""" _MODEL_KWARGS = {} _NATIVELY_SUPPORTED = True + _MODALITY_FLAG = None @classmethod def get_subclasses(cls, include_parents=False): @@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False): return subclasses @classmethod - def _get_supported_synthesizers(cls): + def _get_supported_synthesizers(cls, modality): """Get the natively supported synthesizer class names.""" - subclasses = cls.get_subclasses(include_parents=True) - synthesizers = set() - for name, subclass in subclasses.items(): - if subclass._NATIVELY_SUPPORTED: - synthesizers.add(name) - - return sorted(synthesizers) + _validate_modality(modality) + return sorted({ + name + for name, subclass in cls.get_subclasses(include_parents=True).items() + if ( + name != 'MultiTableBaselineSynthesizer' + and subclass._NATIVELY_SUPPORTED + and subclass._MODALITY_FLAG == modality + ) + }) @classmethod def get_baselines(cls): @@ -55,6 +70,35 @@ def get_baselines(cls): return synthesizers + def _fit(self, data, metadata): + """Fit the synthesizer to the data. + + Args: + data (pandas.DataFrame): + The data to fit the synthesizer to. + metadata (sdv.metadata.Metadata): + The metadata describing the data. + """ + raise NotImplementedError() + + @classmethod + def _get_trained_synthesizer(cls, data, metadata): + """Train a synthesizer on the provided data and metadata. + + Args: + data (pd.DataFrame or dict): + The data to train on. + metadata (sdv.metadata.Metadata): + The metadata + + Returns: + A synthesizer object + """ + synthesizer = cls() + synthesizer._fit(data, metadata) + + return synthesizer + def get_trained_synthesizer(self, data, metadata): """Get a synthesizer that has been trained on the provided data and metadata. @@ -90,3 +134,25 @@ def sample_from_synthesizer(self, synthesizer, n_samples): should be a dict mapping table name to DataFrame. """ return self._sample_from_synthesizer(synthesizer, n_samples) + + +class MultiTableBaselineSynthesizer(BaselineSynthesizer): + """Base class for all multi-table synthesizers.""" + + _MODALITY_FLAG = 'multi_table' + + def sample_from_synthesizer(self, synthesizer, scale=1.0): + """Sample data from the provided synthesizer. + + Args: + synthesizer (obj): + The synthesizer object to sample data from. + scale (float): + The scale of data to sample. + Defaults to 1.0. + + Returns: + dict: + The sampled data. A dict mapping table name to DataFrame. + """ + return self._sample_from_synthesizer(synthesizer, scale=scale) diff --git a/sdgym/synthesizers/column.py b/sdgym/synthesizers/column.py index 69233283..94107f69 100644 --- a/sdgym/synthesizers/column.py +++ b/sdgym/synthesizers/column.py @@ -19,9 +19,11 @@ class ColumnSynthesizer(BaselineSynthesizer): Continuous columns are learned and sampled using a GMM. """ - def _get_trained_synthesizer(self, real_data, metadata): + _MODALITY_FLAG = 'single_table' + + def _fit(self, data, metadata): hyper_transformer = HyperTransformer() - hyper_transformer.detect_initial_config(real_data) + hyper_transformer.detect_initial_config(data) supported_sdtypes = hyper_transformer._get_supported_sdtypes() config = {} if isinstance(metadata, Metadata): @@ -46,14 +48,14 @@ def _get_trained_synthesizer(self, real_data, metadata): # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [ - column_name for column_name, data in real_data.items() if data.dtype.kind in {'O', 'i'} + column_name for column_name, data in data.items() if data.dtype.kind in {'O', 'i'} ] hyper_transformer.remove_transformers(columns_to_remove) - hyper_transformer.fit(real_data) - transformed = hyper_transformer.transform(real_data) + hyper_transformer.fit(data) + transformed = hyper_transformer.transform(data) - self.length = len(real_data) + self.length = len(data) gm_models = {} for name, column in transformed.items(): kind = column.dtype.kind @@ -63,18 +65,22 @@ def _get_trained_synthesizer(self, real_data, metadata): model.fit(column.to_numpy().reshape(-1, 1)) gm_models[name] = model - return (hyper_transformer, transformed, gm_models) + self.hyper_transformer = hyper_transformer + self.transformed_data = transformed + self.gm_models = gm_models def _sample_from_synthesizer(self, synthesizer, n_samples): - hyper_transformer, transformed, gm_models = synthesizer + hyper_transformer = synthesizer.hyper_transformer + transformed = synthesizer.transformed_data + gm_models = synthesizer.gm_models sampled = pd.DataFrame() for name, column in transformed.items(): kind = column.dtype.kind if kind == 'O': - values = column.sample(self.length, replace=True).to_numpy() + values = column.sample(n_samples, replace=True).to_numpy() else: model = gm_models.get(name) - values = model.sample(self.length)[0].ravel().clip(column.min(), column.max()) + values = model.sample(n_samples)[0].ravel().clip(column.min(), column.max()) sampled[name] = values diff --git a/sdgym/synthesizers/generate.py b/sdgym/synthesizers/generate.py index 5535dd01..d341d03b 100644 --- a/sdgym/synthesizers/generate.py +++ b/sdgym/synthesizers/generate.py @@ -1,6 +1,10 @@ """Helpers to create SDGym synthesizer variants.""" -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _validate_modality, +) from sdgym.synthesizers.utils import _get_supported_synthesizers @@ -36,7 +40,7 @@ def create_synthesizer_variant(display_name, synthesizer_class, synthesizer_para return NewSynthesizer -def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_arg_name): +def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, modality): """Create a synthesizer class. Args: @@ -47,36 +51,39 @@ def _create_synthesizer_class(display_name, get_trained_fn, sample_fn, sample_ar A function to generate and train a synthesizer, given the real data and metadata. sample_from_synthesizer (callable): A function to sample from the given synthesizer. - sample_arg_name (str): - The name of the argument used to specify the number of samples to generate. - Either 'num_samples' for single-table synthesizers, or 'scale' for multi-table - synthesizers. + modality (str): + The modality of the synthesizer. Either 'single_table' or 'multi_table'. Returns: class: The synthesizer class. """ + _validate_modality(modality) class_name = f'Custom:{display_name}' def get_trained_synthesizer(self, data, metadata): return get_trained_fn(data, metadata) - if sample_arg_name == 'num_samples': + if modality == 'multi_table': - def sample_from_synthesizer(self, synthesizer, num_samples): - return sample_fn(synthesizer, num_samples) + def sample_from_synthesizer(self, synthesizer, scale=1.0): + return sample_fn(synthesizer, scale) + base_class = MultiTableBaselineSynthesizer else: - def sample_from_synthesizer(self, synthesizer, scale): - return sample_fn(synthesizer, scale) + def sample_from_synthesizer(self, synthesizer, n_samples): + return sample_fn(synthesizer, n_samples) + + base_class = BaselineSynthesizer CustomSynthesizer = type( class_name, - (BaselineSynthesizer,), + (base_class,), { '__module__': __name__, '_NATIVELY_SUPPORTED': False, + '_MODALITY_FLAG': modality, 'get_trained_synthesizer': get_trained_synthesizer, 'sample_from_synthesizer': sample_from_synthesizer, }, @@ -94,7 +101,7 @@ def create_single_table_synthesizer( display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn, - sample_arg_name='num_samples', + modality='single_table', ) @@ -106,5 +113,5 @@ def create_multi_table_synthesizer( display_name, get_trained_synthesizer_fn, sample_from_synthesizer_fn, - sample_arg_name='scale', + modality='multi_table', ) diff --git a/sdgym/synthesizers/identity.py b/sdgym/synthesizers/identity.py index d63b5956..1c41031a 100644 --- a/sdgym/synthesizers/identity.py +++ b/sdgym/synthesizers/identity.py @@ -11,24 +11,21 @@ class DataIdentity(BaselineSynthesizer): Returns the same exact data that is used to fit it. """ + _MODALITY_FLAG = 'single_table' + def __init__(self): self._data = None - def get_trained_synthesizer(self, data, metadata): - """Get a synthesizer that has been trained on the provided data and metadata. + def _fit(self, data, metadata): + """Fit the synthesizer to the data. Args: data (pandas.DataFrame): - The data to train on. + The data to fit the synthesizer to. metadata (dict): The metadata dictionary. - - Returns: - obj: - The synthesizer object. """ self._data = data - return None def sample_from_synthesizer(self, synthesizer, n_samples): """Sample data from the provided synthesizer. @@ -44,4 +41,4 @@ def sample_from_synthesizer(self, synthesizer, n_samples): The sampled data. If single-table, should be a DataFrame. If multi-table, should be a dict mapping table name to DataFrame. """ - return copy.deepcopy(self._data) + return copy.deepcopy(synthesizer._data) diff --git a/sdgym/synthesizers/realtabformer.py b/sdgym/synthesizers/realtabformer.py index 8b46bda5..92d8dc68 100644 --- a/sdgym/synthesizers/realtabformer.py +++ b/sdgym/synthesizers/realtabformer.py @@ -24,8 +24,10 @@ class RealTabFormerSynthesizer(BaselineSynthesizer): LOGGER = logging.getLogger(__name__) _MODEL_KWARGS = None + _MODALITY_FLAG = 'single_table' - def _get_trained_synthesizer(self, data, metadata): + def _fit(self, data, metadata): + """Fit the REaLTabFormer model to the data.""" try: from realtabformer import REaLTabFormer except Exception as exception: @@ -39,8 +41,8 @@ def _get_trained_synthesizer(self, data, metadata): model = REaLTabFormer(model_type='tabular', **model_kwargs) model.fit(data) - return model + self._internal_synthesizer = model def _sample_from_synthesizer(self, synthesizer, n_sample): """Sample synthetic data with specified sample count.""" - return synthesizer.sample(n_sample) + return synthesizer._internal_synthesizer.sample(n_sample) diff --git a/sdgym/synthesizers/sdv.py b/sdgym/synthesizers/sdv.py index 90e4e3a6..9fd9418e 100644 --- a/sdgym/synthesizers/sdv.py +++ b/sdgym/synthesizers/sdv.py @@ -6,7 +6,11 @@ from sdv import multi_table, single_table -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _validate_modality, +) LOGGER = logging.getLogger(__name__) UNSUPPORTED_SDV_SYNTHESIZERS = ['DayZSynthesizer'] @@ -16,12 +20,6 @@ } -def _validate_modality(modality): - """Validate that the modality is correct.""" - if modality not in ['single_table', 'multi_table']: - raise ValueError("`modality` must be one of 'single_table' or 'multi_table'.") - - def _get_sdv_synthesizers(modality): _validate_modality(modality) module = MODALITY_TO_MODULE[modality] @@ -39,20 +37,20 @@ def _get_all_sdv_synthesizers(): return sorted(synthesizers) -def _get_trained_synthesizer(self, data, metadata): +def _fit(self, data, metadata): LOGGER.info('Fitting %s', self.__class__.__name__) - sdv_class = getattr(import_module(f'sdv.{self.modality}'), self.SDV_NAME) + sdv_class = getattr(import_module(f'sdv.{self._MODALITY_FLAG}'), self.SDV_NAME) synthesizer = sdv_class(metadata=metadata, **self._MODEL_KWARGS) synthesizer.fit(data) - return synthesizer + self._internal_synthesizer = synthesizer def _sample_from_synthesizer(self, synthesizer, sample_arg): LOGGER.info('Sampling %s', self.__class__.__name__) - if self.modality == 'multi_table': - return synthesizer.sample(scale=sample_arg) + if self._MODALITY_FLAG == 'multi_table': + return synthesizer._internal_synthesizer.sample(scale=sample_arg) - return synthesizer.sample(num_rows=sample_arg) + return synthesizer._internal_synthesizer.sample(num_rows=sample_arg) def _retrieve_sdv_class(sdv_name): @@ -82,15 +80,16 @@ def _create_sdv_class(sdv_name): """Create a SDV synthesizer class dynamically.""" current_module = sys.modules[__name__] modality = _get_modality(sdv_name) + base_class = MultiTableBaselineSynthesizer if modality == 'multi_table' else BaselineSynthesizer synthesizer_class = type( sdv_name, - (BaselineSynthesizer,), + (base_class,), { '__module__': __name__, 'SDV_NAME': sdv_name, - 'modality': modality, + '_MODALITY_FLAG': modality, '_MODEL_KWARGS': {}, - '_get_trained_synthesizer': _get_trained_synthesizer, + '_fit': _fit, '_sample_from_synthesizer': _sample_from_synthesizer, }, ) diff --git a/sdgym/synthesizers/uniform.py b/sdgym/synthesizers/uniform.py index 57713839..f562c1bf 100644 --- a/sdgym/synthesizers/uniform.py +++ b/sdgym/synthesizers/uniform.py @@ -7,7 +7,7 @@ import pandas as pd from rdt.hyper_transformer import HyperTransformer -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import BaselineSynthesizer, MultiTableBaselineSynthesizer LOGGER = logging.getLogger(__name__) @@ -15,9 +15,24 @@ class UniformSynthesizer(BaselineSynthesizer): """Synthesizer that samples each column using a Uniform distribution.""" - def _get_trained_synthesizer(self, real_data, metadata): + _MODALITY_FLAG = 'single_table' + + def __init__(self): + super().__init__() + self.hyper_transformer = None + self.transformed_data = None + + def _fit(self, data, metadata): + """Fit the synthesizer to the data. + + Args: + data (pd.DataFrame): + The data to fit the synthesizer to. + metadata (sdv.metadata.Metadata): + The metadata describing the data. + """ hyper_transformer = HyperTransformer() - hyper_transformer.detect_initial_config(real_data) + hyper_transformer.detect_initial_config(data) supported_sdtypes = hyper_transformer._get_supported_sdtypes() config = {} table = next(iter(metadata.tables.values())) @@ -44,29 +59,82 @@ def _get_trained_synthesizer(self, real_data, metadata): # This is done to match the behavior of the synthesizer for SDGym <= 0.6.0 columns_to_remove = [ column_name - for column_name, data in real_data.items() - if data.dtype.kind in {'O', 'i', 'b'} + for column_name, column_data in data.items() + if column_data.dtype.kind in {'O', 'i', 'b'} ] hyper_transformer.remove_transformers(columns_to_remove) - hyper_transformer.fit(real_data) - transformed = hyper_transformer.transform(real_data) - - self.length = len(real_data) - return (hyper_transformer, transformed) + hyper_transformer.fit(data) + transformed = hyper_transformer.transform(data) + self.hyper_transformer = hyper_transformer + self.transformed_data = transformed def _sample_from_synthesizer(self, synthesizer, n_samples): - hyper_transformer, transformed = synthesizer + hyper_transformer = synthesizer.hyper_transformer + transformed = synthesizer.transformed_data sampled = pd.DataFrame() for name, column in transformed.items(): kind = column.dtype.kind if kind == 'i': - values = np.random.randint(column.min(), column.max() + 1, size=self.length) + values = np.random.randint( + int(column.min()), int(column.max()) + 1, size=n_samples, dtype=np.int64 + ) elif kind in ['O', 'b']: - values = np.random.choice(column.unique(), size=self.length) + values = np.random.choice(column.unique(), size=n_samples) else: - values = np.random.uniform(column.min(), column.max(), size=self.length) - + values = np.random.uniform(column.min(), column.max(), size=n_samples) sampled[name] = values return hyper_transformer.reverse_transform(sampled) + + +class MultiTableUniformSynthesizer(MultiTableBaselineSynthesizer): + """Multi-table Uniform Synthesizer. + + This synthesizer trains a UniformSynthesizer on each table in the multi-table dataset. + It samples data from each table independently using the corresponding trained synthesizer. + """ + + def __init__(self): + super().__init__() + self.num_rows_per_table = {} + self.table_synthesizers = {} + + def _fit(self, data, metadata): + """Fit the synthesizer to the multi-table data. + + Args: + data (dict): + A dict mapping table name to table data. + metadata (sdv.metadata.MultiTableMetadata): + The multi-table metadata describing the data. + """ + for table_name, table_data in data.items(): + table_metadata = metadata.get_table_metadata(table_name) + synthesizer = UniformSynthesizer() + synthesizer._fit(table_data, table_metadata) + self.num_rows_per_table[table_name] = len(table_data) + self.table_synthesizers[table_name] = synthesizer + + def _sample_from_synthesizer(self, synthesizer, scale): + """Sample data from the provided synthesizer. + + Args: + synthesizer (SDGym synthesizer): + The synthesizer object to sample data from. + scale (float): + The scale of data to sample. + Defaults to 1.0. + + Returns: + dict: A dict mapping table name to the sampled data. + """ + sampled_data = {} + for table_name, table_synthesizer in synthesizer.table_synthesizers.items(): + n_samples = int(synthesizer.num_rows_per_table[table_name] * scale) + sampled_table = UniformSynthesizer().sample_from_synthesizer( + table_synthesizer, n_samples=n_samples + ) + sampled_data[table_name] = sampled_table + + return sampled_data diff --git a/sdgym/synthesizers/utils.py b/sdgym/synthesizers/utils.py index 18cdd9f0..c30f752a 100644 --- a/sdgym/synthesizers/utils.py +++ b/sdgym/synthesizers/utils.py @@ -1,22 +1,6 @@ """Utility functions for synthesizers in SDGym.""" from sdgym.synthesizers.base import BaselineSynthesizer -from sdgym.synthesizers.sdv import _get_all_sdv_synthesizers, _get_sdv_synthesizers - - -def _get_sdgym_synthesizers(): - """Get SDGym synthesizers. - - Returns: - list: - A list of available SDGym synthesizer names. - """ - synthesizers = BaselineSynthesizer._get_supported_synthesizers() - sdv_synthesizer = _get_all_sdv_synthesizers() - sdgym_synthesizer = [ - synthesizer for synthesizer in synthesizers if synthesizer not in sdv_synthesizer - ] - return sorted(sdgym_synthesizer) def get_available_single_table_synthesizers(): @@ -26,9 +10,7 @@ def get_available_single_table_synthesizers(): list: A sorted list of available single-table synthesizer names. """ - sdv_synthesizers = _get_sdv_synthesizers('single_table') - sdgym_synthesizers = _get_sdgym_synthesizers() - return sorted(sdv_synthesizers + sdgym_synthesizers) + return sorted(BaselineSynthesizer._get_supported_synthesizers('single_table')) def get_available_multi_table_synthesizers(): @@ -38,7 +20,7 @@ def get_available_multi_table_synthesizers(): list: A sorted list of available multi-table synthesizer names. """ - return sorted(_get_sdv_synthesizers('multi_table')) + return sorted(BaselineSynthesizer._get_supported_synthesizers('multi_table')) def _get_supported_synthesizers(): @@ -48,4 +30,8 @@ def _get_supported_synthesizers(): list: A list of available SDGym supported synthesizer names. """ - return BaselineSynthesizer._get_supported_synthesizers() + synthesizers = [] + for modality in ['single_table', 'multi_table']: + synthesizers.extend(BaselineSynthesizer._get_supported_synthesizers(modality)) + + return sorted(synthesizers) diff --git a/tests/integration/result_explorer/test_result_explorer.py b/tests/integration/result_explorer/test_result_explorer.py index 188053fd..6f148fc1 100644 --- a/tests/integration/result_explorer/test_result_explorer.py +++ b/tests/integration/result_explorer/test_result_explorer.py @@ -38,6 +38,7 @@ def test_end_to_end_local(tmp_path): dataset_name='fake_companies', synthesizer_name='TVAESynthesizer', ) + assert isinstance(synthesizer, TVAESynthesizer) new_synthetic_data = synthesizer.sample(num_rows=10) # Assert diff --git a/tests/integration/synthesizers/test_column.py b/tests/integration/synthesizers/test_column.py index e22c1196..45c29eba 100644 --- a/tests/integration/synthesizers/test_column.py +++ b/tests/integration/synthesizers/test_column.py @@ -29,7 +29,9 @@ def test_column_synthesizer(self): # Run trained_synthesizer = column_synthesizer.get_trained_synthesizer(data, {}) - samples = column_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) + samples = column_synthesizer.sample_from_synthesizer( + trained_synthesizer, n_samples=n_samples + ) # Assert assert samples['num'].between(-10, 10).all() @@ -105,7 +107,7 @@ def test_column_synthesizer_sdtypes(self): # Run real_data = pd.DataFrame(data) synthesizer = ColumnSynthesizer().get_trained_synthesizer(real_data, metadata) - hyper_transformer_config = synthesizer[0].get_config() + hyper_transformer_config = synthesizer.hyper_transformer.get_config() # Assert config_sdtypes = hyper_transformer_config['sdtypes'] diff --git a/tests/integration/synthesizers/test_uniform.py b/tests/integration/synthesizers/test_uniform.py index e807d375..366409a2 100644 --- a/tests/integration/synthesizers/test_uniform.py +++ b/tests/integration/synthesizers/test_uniform.py @@ -2,8 +2,10 @@ import numpy as np import pandas as pd +from pandas.api.types import is_numeric_dtype +from sdv.datasets.demo import download_demo -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers import MultiTableUniformSynthesizer, UniformSynthesizer def test_uniform_synthesizer(): @@ -30,7 +32,7 @@ def test_uniform_synthesizer(): # Run trained_synthesizer = uniform_synthesizer.get_trained_synthesizer(data, {}) - samples = uniform_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples) + samples = uniform_synthesizer.sample_from_synthesizer(trained_synthesizer, n_samples=n_samples) # Assert numerical values are uniform min_val = samples['num'].min() @@ -69,3 +71,25 @@ def test_uniform_synthesizer(): assert n_values_interval2 * 0.9 < n_values_interval1 < n_values_interval2 * 1.1 assert n_values_interval3 * 0.9 < n_values_interval1 < n_values_interval3 * 1.1 + + +def test_multitable_uniform_synthesizer_end_to_end(): + """Test the MultiTableUniformSynthesizer end to end.""" + # Setup + data, metadata = download_demo(dataset_name='fake_hotels', modality='multi_table') + synthesizer = MultiTableUniformSynthesizer() + + # Run + trained_synthesizer = synthesizer.get_trained_synthesizer(data, metadata.to_dict()) + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=2) + + # Assert + for table_name, table_data in data.items(): + sampled_table = sampled_data[table_name] + assert len(sampled_table) == len(table_data) * 2 + for column_name in table_data.columns: + original_column = table_data[column_name] + sampled_column = sampled_table[column_name] + if is_numeric_dtype(original_column): + assert sampled_column.min() >= original_column.min() + assert sampled_column.max() <= original_column.max() diff --git a/tests/integration/synthesizers/test_utils.py b/tests/integration/synthesizers/test_utils.py index 60232de0..ad78c519 100644 --- a/tests/integration/synthesizers/test_utils.py +++ b/tests/integration/synthesizers/test_utils.py @@ -28,7 +28,7 @@ def test_get_available_single_table_synthesizers(): def test_get_available_multi_table_synthesizers(): """Test the `get_available_multi_table_synthesizers` method""" # Setup - expected_synthesizers = ['HMASynthesizer'] + expected_synthesizers = ['HMASynthesizer', 'MultiTableUniformSynthesizer'] # Run synthesizers = get_available_multi_table_synthesizers() diff --git a/tests/unit/synthesizers/test_base.py b/tests/unit/synthesizers/test_base.py index 73f1be2b..0922b08c 100644 --- a/tests/unit/synthesizers/test_base.py +++ b/tests/unit/synthesizers/test_base.py @@ -1,23 +1,61 @@ +import re import warnings -from unittest.mock import Mock, patch +from unittest.mock import Mock, call, patch import pandas as pd +import pytest from sdv.metadata import Metadata -from sdgym.synthesizers.base import BaselineSynthesizer +from sdgym.synthesizers.base import ( + BaselineSynthesizer, + MultiTableBaselineSynthesizer, + _is_valid_modality, + _validate_modality, +) + + +@pytest.mark.parametrize( + 'modality, result', + [ + ('single_table', True), + ('multi_table', True), + ('invalid_modality', False), + ], +) +def test__is_valid_modality(modality, result): + """Test the `_is_valid_modality` method.""" + assert _is_valid_modality(modality) == result + + +def test__validate_modality(): + """Test the `_validate_modality` method.""" + # Setup + valid_modality = 'single_table' + invalid_modality = 'invalid_modality' + expected_error = re.escape( + f"Modality '{invalid_modality}' is not valid. Must be either " + "'single_table' or 'multi_table'." + ) + + # Run and Assert + _validate_modality(valid_modality) + with pytest.raises(ValueError, match=expected_error): + _validate_modality(invalid_modality) class TestBaselineSynthesizer: - @patch('sdgym.synthesizers.utils.BaselineSynthesizer.get_subclasses') - def test__get_supported_synthesizers_mock(self, mock_get_subclasses): + @patch('sdgym.synthesizers.base.BaselineSynthesizer.get_subclasses') + @patch('sdgym.synthesizers.base._validate_modality') + def test__get_supported_synthesizers_mock(self, mock_validate_modality, mock_get_subclasses): """Test the `_get_supported_synthesizers` method with mocks.""" # Setup mock_get_subclasses.return_value = { - 'Variant:ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=False), - 'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False), - 'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True), - 'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True), - 'DataIdentity': Mock(_NATIVELY_SUPPORTED=True), + 'Variant:Synthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'), + 'Custom:MySynthesizer': Mock(_NATIVELY_SUPPORTED=False, _MODALITY_FLAG='single_table'), + 'ColumnSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), + 'UniformSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), + 'MultiTableSynthesizer': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='multi_table'), + 'DataIdentity': Mock(_NATIVELY_SUPPORTED=True, _MODALITY_FLAG='single_table'), } expected_synthesizers = [ 'ColumnSynthesizer', @@ -26,9 +64,11 @@ def test__get_supported_synthesizers_mock(self, mock_get_subclasses): ] # Run - synthesizers = BaselineSynthesizer._get_supported_synthesizers() + synthesizers = BaselineSynthesizer._get_supported_synthesizers('single_table') # Assert + mock_validate_modality.assert_called_once_with('single_table') + mock_get_subclasses.assert_called_once_with(include_parents=True) assert synthesizers == expected_synthesizers def test_get_trained_synthesizer(self): @@ -58,3 +98,36 @@ def test_get_trained_synthesizer(self): assert args[1].to_dict() == metadata.to_dict() assert isinstance(args[1], Metadata) assert instance._get_trained_synthesizer.return_value == mock_synthesizer + + +class TestMultiTableBaselineSynthesizer: + def test_sample_from_synthesizer(self): + """Test it calls the correct methods and returns the correct values.""" + # Setup + synthesizer = MultiTableBaselineSynthesizer() + mock_synthesizer = Mock() + synthesizer._sample_from_synthesizer = Mock(return_value='sampled_data') + expected_error = re.escape( + "sample_from_synthesizer() got an unexpected keyword argument 'n_samples'" + ) + + # Run + sampled_data = synthesizer.sample_from_synthesizer(mock_synthesizer) + sampled_data_with_scale = synthesizer.sample_from_synthesizer( + mock_synthesizer, + scale=2.0, + ) + with pytest.raises(TypeError, match=expected_error): + synthesizer.sample_from_synthesizer( + mock_synthesizer, + n_samples=10, + ) + + # Assert + assert synthesizer._MODALITY_FLAG == 'multi_table' + synthesizer._sample_from_synthesizer.assert_has_calls([ + call(mock_synthesizer, scale=1.0), + call(mock_synthesizer, scale=2.0), + ]) + assert sampled_data == 'sampled_data' + assert sampled_data_with_scale == 'sampled_data' diff --git a/tests/unit/synthesizers/test_generate.py b/tests/unit/synthesizers/test_generate.py index 45e25a6d..72ad0c81 100644 --- a/tests/unit/synthesizers/test_generate.py +++ b/tests/unit/synthesizers/test_generate.py @@ -55,7 +55,7 @@ def test_create_sdv_variant_synthesizer(): # Assert assert out.__name__ == 'Variant:test_synth' - assert out.modality == 'single_table' + assert out._MODALITY_FLAG == 'single_table' assert out._MODEL_KWARGS == synthesizer_parameters assert out.SDV_NAME == synthesizer_class assert out._NATIVELY_SUPPORTED is False @@ -85,7 +85,7 @@ def test_create_sdv_variant_synthesizer_multi_table(): # Assert assert out.__name__ == 'Variant:test_synth' - assert out.modality == 'multi_table' + assert out._MODALITY_FLAG == 'multi_table' assert out._MODEL_KWARGS == synthesizer_parameters assert out.SDV_NAME == synthesizer_class assert out._NATIVELY_SUPPORTED is False @@ -103,13 +103,15 @@ def test__create_synthesizer_class(): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='num_samples', + modality='single_table', ) # Assert assert out.__name__ == 'Custom:test_synth' assert hasattr(out, 'get_trained_synthesizer') assert hasattr(out, 'sample_from_synthesizer') + assert out._NATIVELY_SUPPORTED is False + assert out._MODALITY_FLAG == 'single_table' @patch('sdgym.synthesizers.generate._create_synthesizer_class') @@ -128,7 +130,7 @@ def test_create_single_table_synthesizer_mock(mock_create_class): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='num_samples', + modality='single_table', ) assert out == 'synthesizer_class' @@ -149,6 +151,6 @@ def test_create_multi_table_synthesizer_mock(mock_create_class): 'test_synth', get_trained_synthesizer_fn, sample_fn, - sample_arg_name='scale', + modality='multi_table', ) assert out == 'synthesizer_class' diff --git a/tests/unit/synthesizers/test_realtabformer.py b/tests/unit/synthesizers/test_realtabformer.py index 89146c1a..a66a8cb8 100644 --- a/tests/unit/synthesizers/test_realtabformer.py +++ b/tests/unit/synthesizers/test_realtabformer.py @@ -43,13 +43,17 @@ def test__get_trained_synthesizer(self, mock_real_tab_former): # Assert mock_real_tab_former.assert_called_once_with(model_type='tabular') mock_model.fit.assert_called_once_with(data) - assert result == mock_model, 'Expected the trained model to be returned.' + assert result._internal_synthesizer == mock_model + assert isinstance(result, RealTabFormerSynthesizer) def test__sample_from_synthesizer(self): """Test _sample_from_synthesizer generates data with the specified sample size.""" # Setup trained_model = MagicMock() - trained_model.sample.return_value = MagicMock(shape=(10, 5)) # Mock sample data shape + trained_model._internal_synthesizer = MagicMock() + trained_model._internal_synthesizer.sample.return_value = MagicMock( + shape=(10, 5) + ) # Mock sample data shape n_sample = 10 synthesizer = RealTabFormerSynthesizer() @@ -57,7 +61,7 @@ def test__sample_from_synthesizer(self): synthetic_data = synthesizer._sample_from_synthesizer(trained_model, n_sample) # Assert - trained_model.sample.assert_called_once_with(n_sample) + trained_model._internal_synthesizer.sample.assert_called_once_with(n_sample) assert synthetic_data.shape[0] == n_sample, ( f'Expected {n_sample} rows, but got {synthetic_data.shape[0]}' ) diff --git a/tests/unit/synthesizers/test_sdv.py b/tests/unit/synthesizers/test_sdv.py index c0153948..cd6cb65a 100644 --- a/tests/unit/synthesizers/test_sdv.py +++ b/tests/unit/synthesizers/test_sdv.py @@ -9,37 +9,16 @@ from sdgym.synthesizers.base import BaselineSynthesizer from sdgym.synthesizers.sdv import ( _create_sdv_class, + _fit, _get_all_sdv_synthesizers, _get_modality, _get_sdv_synthesizers, - _get_trained_synthesizer, _retrieve_sdv_class, _sample_from_synthesizer, - _validate_modality, create_sdv_synthesizer_class, ) -def test__validate_modality(): - """Test the `_validate_modality` method.""" - # Setup - valid_modalities = ['single_table', 'multi_table'] - - # Run and Assert - for modality in valid_modalities: - _validate_modality(modality) - - -def test__validate_modality_invalid(): - """Test the `_validate_modality` method with invalid modality.""" - # Setup - expected_error = re.escape("`modality` must be one of 'single_table' or 'multi_table'.") - - # Run and Assert - with pytest.raises(ValueError, match=expected_error): - _validate_modality('invalid_modality') - - def test__get_sdv_synthesizers(): """Test the `_get_sdv_synthesizers` method.""" # Setup @@ -79,8 +58,8 @@ def test__get_all_sdv_synthesizers(): @patch('sdgym.synthesizers.sdv.LOGGER') -def test__get_trained_synthesizer(mock_logger): - """Test the `_get_trained_synthesizer` method.""" +def test__fit(mock_logger): + """Test the `_fit` method.""" # Setup data = pd.DataFrame({ 'column1': [1, 2, 3, 4, 5], @@ -99,16 +78,17 @@ def test__get_trained_synthesizer(mock_logger): synthesizer = Mock() synthesizer.__class__.__name__ = 'GaussianCopulaClass' synthesizer._MODEL_KWARGS = {'enforce_min_max_values': False} - synthesizer.modality = 'single_table' + synthesizer._MODALITY_FLAG = 'single_table' synthesizer.SDV_NAME = 'GaussianCopulaSynthesizer' # Run - valid_model = _get_trained_synthesizer(synthesizer, data, metadata) + _fit(synthesizer, data, metadata) # Assert mock_logger.info.assert_called_with('Fitting %s', 'GaussianCopulaClass') - assert isinstance(valid_model, GaussianCopulaSynthesizer) - assert valid_model.enforce_min_max_values is False + assert isinstance(synthesizer._internal_synthesizer, GaussianCopulaSynthesizer) + assert synthesizer._internal_synthesizer.enforce_min_max_values is False + assert synthesizer._internal_synthesizer._fitted is True @patch('sdgym.synthesizers.sdv.LOGGER') @@ -121,9 +101,10 @@ def test__sample_from_synthesizer(mock_logger): }) base_synthesizer = Mock() base_synthesizer.__class__.__name__ = 'GaussianCopulaSynthesizer' - base_synthesizer.modality = 'single_table' + base_synthesizer._MODALITY_FLAG = 'single_table' synthesizer = Mock() - synthesizer.sample.return_value = data + synthesizer._internal_synthesizer = Mock() + synthesizer._internal_synthesizer.sample.return_value = data n_samples = 3 # Run @@ -132,7 +113,7 @@ def test__sample_from_synthesizer(mock_logger): # Assert mock_logger.info.assert_called_with('Sampling %s', 'GaussianCopulaSynthesizer') pd.testing.assert_frame_equal(sampled_data, data) - synthesizer.sample.assert_called_once_with(num_rows=n_samples) + synthesizer._internal_synthesizer.sample.assert_called_once_with(num_rows=n_samples) @patch('sdgym.synthesizers.sdv.sys.modules') @@ -187,15 +168,15 @@ def test__create_sdv_class_mock(mock_get_modality, mock_sys_modules): # Assert assert synt_class.__name__ == sdv_name - assert synt_class.modality == 'single_table' + assert synt_class._MODALITY_FLAG == 'single_table' assert synt_class._MODEL_KWARGS == {} assert synt_class.SDV_NAME == sdv_name assert issubclass(synt_class, BaselineSynthesizer) - assert getattr(synt_class, '_get_trained_synthesizer') is _get_trained_synthesizer + assert getattr(synt_class, '_fit') is _fit assert getattr(synt_class, '_sample_from_synthesizer') is _sample_from_synthesizer assert getattr(fake_module, sdv_name) is synt_class - assert instance._get_trained_synthesizer.__self__ is instance - assert instance._get_trained_synthesizer.__func__ is _get_trained_synthesizer + assert instance._fit.__self__ is instance + assert instance._fit.__func__ is _fit assert instance._sample_from_synthesizer.__self__ is instance assert instance._sample_from_synthesizer.__func__ is _sample_from_synthesizer assert instance.SDV_NAME == sdv_name @@ -212,7 +193,7 @@ def test__create_sdv_class(): # Assert assert synthesizer_class.__name__ == sdv_name - assert synthesizer_class.modality == 'single_table' + assert synthesizer_class._MODALITY_FLAG == 'single_table' assert synthesizer_class._MODEL_KWARGS == {} assert issubclass(synthesizer_class, BaselineSynthesizer) diff --git a/tests/unit/synthesizers/test_uniform.py b/tests/unit/synthesizers/test_uniform.py index 8269938c..d7258cd3 100644 --- a/tests/unit/synthesizers/test_uniform.py +++ b/tests/unit/synthesizers/test_uniform.py @@ -1,12 +1,17 @@ +from unittest.mock import Mock, call, patch + import numpy as np import pandas as pd +from rdt import HyperTransformer +from sdv.metadata import Metadata -from sdgym.synthesizers.uniform import UniformSynthesizer +from sdgym.synthesizers.uniform import MultiTableUniformSynthesizer, UniformSynthesizer class TestUniformSynthesizer: def test_uniform_synthesizer_sdtypes(self): """Ensure that sdtypes uniform are taken from metadata instead of inferred.""" + # Setup uniform_synthesizer = UniformSynthesizer() metadata = { 'primary_key': 'guest_email', @@ -67,8 +72,11 @@ def test_uniform_synthesizer_sdtypes(self): } real_data = pd.DataFrame(data) + + # Run synthesizer = uniform_synthesizer.get_trained_synthesizer(real_data, metadata) - hyper_transformer_config = synthesizer[0].get_config() + + hyper_transformer_config = synthesizer.hyper_transformer.get_config() config_sdtypes = hyper_transformer_config['sdtypes'] unknown_sdtypes = ['email', 'credit_card_number', 'address'] for column in metadata['columns']: @@ -78,3 +86,221 @@ def test_uniform_synthesizer_sdtypes(self): assert metadata_sdtype == config_sdtypes[column] else: assert config_sdtypes[column] == 'pii' + + +class TestMultiTableUniformSynthesizer: + @patch('sdgym.synthesizers.uniform.MultiTableBaselineSynthesizer.__init__') + def test__init__(self, mock_baseline_init): + """Test the `__init__` method.""" + # Run + synthesizer = MultiTableUniformSynthesizer() + + # Assert + mock_baseline_init.assert_called_once() + assert synthesizer.num_rows_per_table == {} + + @patch('sdgym.synthesizers.uniform.UniformSynthesizer._fit') + def test__fit_mock(self, mock_uniform_fit): + """Test the `fit` method with mocking.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3], + 'col2': ['A', 'B', 'C'], + }), + 'table2': pd.DataFrame({ + 'col3': [10.0, 20.0, 30.0], + 'col4': [True, False, True], + }), + } + metadata = Mock() + st_metadatas = [ + { + 'primary_key': 'col1', + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + }, + { + 'primary_key': 'col3', + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + }, + ] + metadata.get_table_metadata.side_effect = st_metadatas + + # Run + synthesizer._fit(data, metadata) + + # Assert + metadata.get_table_metadata.assert_has_calls([ + call('table1'), + call('table2'), + ]) + mock_uniform_fit.assert_has_calls([ + call(data['table1'], st_metadatas[0]), + call(data['table2'], st_metadatas[1]), + ]) + assert synthesizer.num_rows_per_table == { + 'table1': 3, + 'table2': 3, + } + for table_name, table_synthesizer in synthesizer.table_synthesizers.items(): + assert table_name in ('table1', 'table2') + assert isinstance(table_synthesizer, UniformSynthesizer) + + def test__get_trained_synthesizer(self): + """Test the `_get_trained_synthesizer` method.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'D', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [10.0, 20.0, 30.0], + 'col4': [True, False, True], + }), + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + 'primary_key': 'col1', + }, + 'table2': { + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + 'primary_key': 'col3', + }, + }, + 'relationships': [], + }) + + # Run + trained_synthesizer = synthesizer._get_trained_synthesizer(data, metadata) + + # Assert + assert trained_synthesizer.num_rows_per_table == { + 'table1': 5, + 'table2': 3, + } + assert set(trained_synthesizer.table_synthesizers.keys()) == {'table1', 'table2'} + for table_name, table_synthesizer in trained_synthesizer.table_synthesizers.items(): + hyper_transformer = table_synthesizer.hyper_transformer + transformed = table_synthesizer.transformed_data + assert isinstance(hyper_transformer, HyperTransformer) + assert isinstance(transformed, pd.DataFrame) + assert set(transformed.columns) == set(data[table_name].columns) + + @patch('sdgym.synthesizers.uniform.UniformSynthesizer.sample_from_synthesizer') + def test_sample_from_synthesizer_mock(self, mock_sample_from_synthesizer): + """Test the `sample_from_synthesizer` method with mocking.""" + # Setup + synthesizer = MultiTableUniformSynthesizer() + trained_synthesizer = MultiTableUniformSynthesizer() + trained_synthesizer.num_rows_per_table = { + 'table1': 3, + 'table2': 2, + } + synthesizer_table1 = Mock() + synthesizer_table2 = Mock() + trained_synthesizer.table_synthesizers = { + 'table1': synthesizer_table1, + 'table2': synthesizer_table2, + } + mock_sample_from_synthesizer.side_effect = [ + 'sampled_data_table1', + 'sampled_data_table2', + ] + scale = 2 + + # Run + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=scale) + + # Assert + assert sampled_data == { + 'table1': 'sampled_data_table1', + 'table2': 'sampled_data_table2', + } + mock_sample_from_synthesizer.assert_has_calls([ + call(synthesizer_table1, n_samples=6), + call(synthesizer_table2, n_samples=4), + ]) + + def test_sample_from_synthesizer(self): + """Test the `sample_from_synthesizer` method.""" + # Setup + np.random.seed(0) + synthesizer = MultiTableUniformSynthesizer() + data = { + 'table1': pd.DataFrame({ + 'col1': [1, 2, 3, 4, 5], + 'col2': ['A', 'B', 'C', 'D', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [10, 20, 30], + 'col4': [True, False, True], + }), + } + table_1 = UniformSynthesizer() + table_1._fit( + data['table1'], + Metadata.load_from_dict({ + 'columns': { + 'col1': {'sdtype': 'numerical'}, + 'col2': {'sdtype': 'categorical'}, + }, + }), + ) + table_2 = UniformSynthesizer() + table_2._fit( + data['table2'], + Metadata.load_from_dict({ + 'columns': { + 'col3': {'sdtype': 'numerical'}, + 'col4': {'sdtype': 'boolean'}, + }, + }), + ) + trained_synthesizer = MultiTableUniformSynthesizer() + + trained_synthesizer.table_synthesizers = { + 'table1': table_1, + 'table2': table_2, + } + trained_synthesizer.num_rows_per_table = { + 'table1': 5, + 'table2': 3, + } + scale = 2 + expected_data = { + 'table1': pd.DataFrame({ + 'col1': [5, 1, 4, 4, 4, 2, 4, 3, 5, 1], + 'col2': ['A', 'E', 'C', 'B', 'A', 'B', 'B', 'A', 'B', 'E'], + }), + 'table2': pd.DataFrame({ + 'col3': [29, 26, 29, 15, 25, 25], + 'col4': [True, True, False, True, False, False], + }), + } + + # Run + sampled_data = synthesizer.sample_from_synthesizer(trained_synthesizer, scale=scale) + + # Assert + for table_name, table_data in sampled_data.items(): + pd.testing.assert_frame_equal( + table_data, + expected_data[table_name], + ) diff --git a/tests/unit/synthesizers/test_utils.py b/tests/unit/synthesizers/test_utils.py index 77d19dd1..0881994a 100644 --- a/tests/unit/synthesizers/test_utils.py +++ b/tests/unit/synthesizers/test_utils.py @@ -1,35 +1,4 @@ -from unittest.mock import patch - -from sdgym.synthesizers.utils import _get_sdgym_synthesizers, _get_supported_synthesizers - - -@patch('sdgym.synthesizers.utils.BaselineSynthesizer._get_supported_synthesizers') -def test__get_sdgym_synthesizers(mock_get_supported_synthesizers): - """Test the `_get_sdgym_synthesizers` method.""" - # Setup - mock_get_supported_synthesizers.return_value = [ - 'ColumnSynthesizer', - 'UniformSynthesizer', - 'DataIdentity', - 'RealTabFormerSynthesizer', - 'CTGANSynthesizer', - 'CopulaGANSynthesizer', - 'GaussianCopulaSynthesizer', - 'HMASynthesizer', - 'TVAESynthesizer', - ] - expected_synthesizers = [ - 'ColumnSynthesizer', - 'DataIdentity', - 'RealTabFormerSynthesizer', - 'UniformSynthesizer', - ] - - # Run - synthesizers = _get_sdgym_synthesizers() - - # Assert - assert synthesizers == expected_synthesizers +from sdgym.synthesizers.utils import _get_supported_synthesizers def test__get_supported_synthesizers(): @@ -42,6 +11,7 @@ def test__get_supported_synthesizers(): 'DataIdentity', 'GaussianCopulaSynthesizer', 'HMASynthesizer', + 'MultiTableUniformSynthesizer', 'RealTabFormerSynthesizer', 'TVAESynthesizer', 'UniformSynthesizer', diff --git a/tests/unit/test_result_writer.py b/tests/unit/test_result_writer.py index 5f338d44..1d4b7be8 100644 --- a/tests/unit/test_result_writer.py +++ b/tests/unit/test_result_writer.py @@ -1,6 +1,6 @@ -import pickle from unittest.mock import Mock, patch +import cloudpickle import pandas as pd import yaml @@ -57,7 +57,7 @@ def test_write_pickle(self, tmp_path): # Assert with open(file_path, 'rb') as f: - loaded_obj = pickle.load(f) + loaded_obj = cloudpickle.load(f) assert loaded_obj == obj @@ -172,7 +172,7 @@ def test_write_pickle(self, mockparse_s3_path): # Assert mockparse_s3_path.assert_called_once_with('test_object.pkl') mock_s3_client.put_object.assert_called_once_with( - Body=pickle.dumps(obj), + Body=cloudpickle.dumps(obj), Bucket='bucket_name', Key='key_prefix/test_object.pkl', )