99LOGGER = logging .getLogger (__name__ )
1010
1111
12+ def _is_valid_modality (modality ):
13+ return modality in ('single_table' , 'multi_table' )
14+
15+
16+ def _validate_modality (modality ):
17+ if not _is_valid_modality (modality ):
18+ raise ValueError (
19+ f"Modality '{ modality } ' is not valid. Must be either 'single_table' or 'multi_table'."
20+ )
21+
22+
1223class BaselineSynthesizer (abc .ABC ):
1324 """Base class for all the ``SDGym`` baselines."""
1425
1526 _MODEL_KWARGS = {}
1627 _NATIVELY_SUPPORTED = True
28+ _MODALITY_FLAG = None
1729
1830 @classmethod
1931 def get_subclasses (cls , include_parents = False ):
@@ -34,15 +46,18 @@ def get_subclasses(cls, include_parents=False):
3446 return subclasses
3547
3648 @classmethod
37- def _get_supported_synthesizers (cls ):
49+ def _get_supported_synthesizers (cls , modality ):
3850 """Get the natively supported synthesizer class names."""
39- subclasses = cls .get_subclasses (include_parents = True )
40- synthesizers = set ()
41- for name , subclass in subclasses .items ():
42- if subclass ._NATIVELY_SUPPORTED :
43- synthesizers .add (name )
44-
45- return sorted (synthesizers )
51+ _validate_modality (modality )
52+ return sorted ({
53+ name
54+ for name , subclass in cls .get_subclasses (include_parents = True ).items ()
55+ if (
56+ name != 'MultiTableBaselineSynthesizer'
57+ and subclass ._NATIVELY_SUPPORTED
58+ and subclass ._MODALITY_FLAG == modality
59+ )
60+ })
4661
4762 @classmethod
4863 def get_baselines (cls ):
@@ -55,6 +70,35 @@ def get_baselines(cls):
5570
5671 return synthesizers
5772
73+ def _fit (self , data , metadata ):
74+ """Fit the synthesizer to the data.
75+
76+ Args:
77+ data (pandas.DataFrame):
78+ The data to fit the synthesizer to.
79+ metadata (sdv.metadata.Metadata):
80+ The metadata describing the data.
81+ """
82+ raise NotImplementedError ()
83+
84+ @classmethod
85+ def _get_trained_synthesizer (cls , data , metadata ):
86+ """Train a synthesizer on the provided data and metadata.
87+
88+ Args:
89+ data (pd.DataFrame or dict):
90+ The data to train on.
91+ metadata (sdv.metadata.Metadata):
92+ The metadata
93+
94+ Returns:
95+ A synthesizer object
96+ """
97+ synthesizer = cls ()
98+ synthesizer ._fit (data , metadata )
99+
100+ return synthesizer
101+
58102 def get_trained_synthesizer (self , data , metadata ):
59103 """Get a synthesizer that has been trained on the provided data and metadata.
60104
@@ -90,3 +134,25 @@ def sample_from_synthesizer(self, synthesizer, n_samples):
90134 should be a dict mapping table name to DataFrame.
91135 """
92136 return self ._sample_from_synthesizer (synthesizer , n_samples )
137+
138+
139+ class MultiTableBaselineSynthesizer (BaselineSynthesizer ):
140+ """Base class for all multi-table synthesizers."""
141+
142+ _MODALITY_FLAG = 'multi_table'
143+
144+ def sample_from_synthesizer (self , synthesizer , scale = 1.0 ):
145+ """Sample data from the provided synthesizer.
146+
147+ Args:
148+ synthesizer (obj):
149+ The synthesizer object to sample data from.
150+ scale (float):
151+ The scale of data to sample.
152+ Defaults to 1.0.
153+
154+ Returns:
155+ dict:
156+ The sampled data. A dict mapping table name to DataFrame.
157+ """
158+ return self ._sample_from_synthesizer (synthesizer , scale = scale )
0 commit comments