@@ -228,6 +228,83 @@ def test__fit(self, mock_multivariate, mock_warnings):
228228 mock_warnings .catch_warnings .assert_called_once ()
229229 instance ._num_rows == 10
230230
231+ def test__fit_mocked_instance (self ):
232+ """Test that the `_fit` method calls the modularized functions."""
233+ # Setup
234+ instance = Mock (numerical_distributions = {})
235+ processed_data = Mock (columns = [])
236+ numerical_distributions = Mock ()
237+ instance ._get_numerical_distributions .return_value = numerical_distributions
238+
239+ # Run
240+ GaussianCopulaSynthesizer ._fit (instance , processed_data )
241+
242+ # Assert
243+ instance ._learn_num_rows .assert_called_once_with (processed_data )
244+ instance ._get_numerical_distributions .assert_called_once_with (processed_data )
245+ instance ._initialize_model .assert_called_once_with (numerical_distributions )
246+ instance ._fit_model .assert_called_once_with (processed_data )
247+
248+ def test__learn_num_rows (self ):
249+ """Test that the `_learn_num_rows` method returns the correct number of rows."""
250+ # Setup
251+ metadata = Metadata ()
252+ instance = GaussianCopulaSynthesizer (metadata )
253+ processed_data = pd .DataFrame ({'a' : range (5 ), 'b' : range (5 )})
254+
255+ # Run
256+ result = instance ._learn_num_rows (processed_data )
257+
258+ # Assert
259+ assert result == 5
260+
261+ def test__get_numerical_distributions_with_existing_columns (self ):
262+ """Test that `_get_numerical_distributions` returns correct distributions."""
263+ # Setup
264+ metadata = Metadata ()
265+ instance = GaussianCopulaSynthesizer (metadata )
266+ instance ._numerical_distributions = {'a' : 'dist_a' , 'b' : 'dist_b' }
267+ instance ._default_distribution = 'default_dist'
268+
269+ processed_data = Mock ()
270+ processed_data .columns = ['a' , 'b' , 'c' ]
271+
272+ # Run
273+ result = instance ._get_numerical_distributions (processed_data )
274+
275+ # Assert
276+ expected_result = {'a' : 'dist_a' , 'b' : 'dist_b' , 'c' : 'default_dist' }
277+ assert result == expected_result
278+
279+ @patch ('sdv.single_table.copulas.multivariate.GaussianMultivariate' )
280+ def test__initialize_model (self , mock_gaussian_multivariate ):
281+ """Test that `_initialize_model` calls the GaussianMultivariate with correct parameters."""
282+ # Setup
283+ metadata = Metadata ()
284+ instance = GaussianCopulaSynthesizer (metadata )
285+ numerical_distributions = {'a' : 'dist_a' , 'b' : 'dist_b' }
286+
287+ # Run
288+ model = instance ._initialize_model (numerical_distributions )
289+
290+ # Assert
291+ mock_gaussian_multivariate .assert_called_once_with (distribution = numerical_distributions )
292+ assert model == mock_gaussian_multivariate .return_value
293+
294+ def test__fit_model (self ):
295+ """Test that `_fit_model` fits the model correctly."""
296+ # Setup
297+ metadata = Metadata ()
298+ instance = GaussianCopulaSynthesizer (metadata )
299+ instance ._model = Mock ()
300+ processed_data = Mock ()
301+
302+ # Run
303+ instance ._fit_model (processed_data )
304+
305+ # Assert
306+ instance ._model .fit .assert_called_once_with (processed_data )
307+
231308 def test__get_nearest_correlation_matrix_valid (self ):
232309 """Test ``_get_nearest_correlation_matrix`` with a psd input.
233310
0 commit comments