Skip to content

Commit 315266f

Browse files
Modularize GaussianCopulaSynthesizer._fit (#2273)
1 parent 05f7494 commit 315266f

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

sdv/single_table/copulas.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,28 @@ def _fit(self, processed_data):
135135
log_numerical_distributions_error(
136136
self.numerical_distributions, processed_data.columns, LOGGER
137137
)
138-
self._num_rows = len(processed_data)
138+
self._num_rows = self._learn_num_rows(processed_data)
139+
numerical_distributions = self._get_numerical_distributions(processed_data)
140+
self._model = self._initialize_model(numerical_distributions)
141+
self._fit_model(processed_data)
139142

143+
def _learn_num_rows(self, processed_data):
144+
return len(processed_data)
145+
146+
def _get_numerical_distributions(self, processed_data):
140147
numerical_distributions = deepcopy(self._numerical_distributions)
141148
for column in processed_data.columns:
142149
if column not in numerical_distributions:
143150
numerical_distributions[column] = self._numerical_distributions.get(
144151
column, self._default_distribution
145152
)
146-
self._model = multivariate.GaussianMultivariate(distribution=numerical_distributions)
147153

154+
return numerical_distributions
155+
156+
def _initialize_model(self, numerical_distributions):
157+
return multivariate.GaussianMultivariate(distribution=numerical_distributions)
158+
159+
def _fit_model(self, processed_data):
148160
with warnings.catch_warnings():
149161
warnings.filterwarnings('ignore', module='scipy')
150162
self._model.fit(processed_data)

tests/unit/single_table/test_copulas.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,83 @@ def test__fit(self, mock_multivariate, mock_warnings):
228228
mock_warnings.catch_warnings.assert_called_once()
229229
instance._num_rows == 10
230230

231+
def test__fit_mocked_instance(self):
232+
"""Test that the `_fit` method calls the modularized functions."""
233+
# Setup
234+
instance = Mock(numerical_distributions={})
235+
processed_data = Mock(columns=[])
236+
numerical_distributions = Mock()
237+
instance._get_numerical_distributions.return_value = numerical_distributions
238+
239+
# Run
240+
GaussianCopulaSynthesizer._fit(instance, processed_data)
241+
242+
# Assert
243+
instance._learn_num_rows.assert_called_once_with(processed_data)
244+
instance._get_numerical_distributions.assert_called_once_with(processed_data)
245+
instance._initialize_model.assert_called_once_with(numerical_distributions)
246+
instance._fit_model.assert_called_once_with(processed_data)
247+
248+
def test__learn_num_rows(self):
249+
"""Test that the `_learn_num_rows` method returns the correct number of rows."""
250+
# Setup
251+
metadata = Metadata()
252+
instance = GaussianCopulaSynthesizer(metadata)
253+
processed_data = pd.DataFrame({'a': range(5), 'b': range(5)})
254+
255+
# Run
256+
result = instance._learn_num_rows(processed_data)
257+
258+
# Assert
259+
assert result == 5
260+
261+
def test__get_numerical_distributions_with_existing_columns(self):
262+
"""Test that `_get_numerical_distributions` returns correct distributions."""
263+
# Setup
264+
metadata = Metadata()
265+
instance = GaussianCopulaSynthesizer(metadata)
266+
instance._numerical_distributions = {'a': 'dist_a', 'b': 'dist_b'}
267+
instance._default_distribution = 'default_dist'
268+
269+
processed_data = Mock()
270+
processed_data.columns = ['a', 'b', 'c']
271+
272+
# Run
273+
result = instance._get_numerical_distributions(processed_data)
274+
275+
# Assert
276+
expected_result = {'a': 'dist_a', 'b': 'dist_b', 'c': 'default_dist'}
277+
assert result == expected_result
278+
279+
@patch('sdv.single_table.copulas.multivariate.GaussianMultivariate')
280+
def test__initialize_model(self, mock_gaussian_multivariate):
281+
"""Test that `_initialize_model` calls the GaussianMultivariate with correct parameters."""
282+
# Setup
283+
metadata = Metadata()
284+
instance = GaussianCopulaSynthesizer(metadata)
285+
numerical_distributions = {'a': 'dist_a', 'b': 'dist_b'}
286+
287+
# Run
288+
model = instance._initialize_model(numerical_distributions)
289+
290+
# Assert
291+
mock_gaussian_multivariate.assert_called_once_with(distribution=numerical_distributions)
292+
assert model == mock_gaussian_multivariate.return_value
293+
294+
def test__fit_model(self):
295+
"""Test that `_fit_model` fits the model correctly."""
296+
# Setup
297+
metadata = Metadata()
298+
instance = GaussianCopulaSynthesizer(metadata)
299+
instance._model = Mock()
300+
processed_data = Mock()
301+
302+
# Run
303+
instance._fit_model(processed_data)
304+
305+
# Assert
306+
instance._model.fit.assert_called_once_with(processed_data)
307+
231308
def test__get_nearest_correlation_matrix_valid(self):
232309
"""Test ``_get_nearest_correlation_matrix`` with a psd input.
233310

0 commit comments

Comments
 (0)