Skip to content

Commit a02032c

Browse files
Modularize GaussianMultivariate fit (#436)
1 parent 0cbd0e9 commit a02032c

File tree

2 files changed

+174
-43
lines changed

2 files changed

+174
-43
lines changed

copulas/multivariate/gaussian.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,6 @@ def _transform_to_normal(self, X):
7070

7171
return stats.norm.ppf(np.column_stack(U))
7272

73-
def _get_correlation(self, X):
74-
"""Compute correlation matrix with transformed data.
75-
76-
Args:
77-
X (numpy.ndarray):
78-
Data for which the correlation needs to be computed.
79-
80-
Returns:
81-
numpy.ndarray:
82-
computed correlation matrix.
83-
"""
84-
result = self._transform_to_normal(X)
85-
correlation = pd.DataFrame(data=result).corr().to_numpy()
86-
correlation = np.nan_to_num(correlation, nan=0.0)
87-
# If singular, add some noise to the diagonal
88-
if np.linalg.cond(correlation) > 1.0 / sys.float_info.epsilon:
89-
correlation = correlation + np.identity(correlation.shape[0]) * EPSILON
90-
91-
return pd.DataFrame(correlation, index=self.columns, columns=self.columns)
92-
9373
@check_valid_values
9474
def fit(self, X):
9575
"""Compute the distribution for each variable and then its correlation matrix.
@@ -100,42 +80,88 @@ def fit(self, X):
10080
"""
10181
LOGGER.info('Fitting %s', self)
10282

83+
# Validate the input data
84+
X = self._validate_input(X)
85+
columns, univariates = self._fit_columns(X)
86+
87+
self.columns = columns
88+
self.univariates = univariates
89+
90+
LOGGER.debug('Computing correlation.')
91+
self.correlation = self._get_correlation(X)
92+
self.fitted = True
93+
LOGGER.debug('GaussianMultivariate fitted successfully')
94+
95+
def _validate_input(self, X):
96+
"""Validate the input data."""
10397
if not isinstance(X, pd.DataFrame):
10498
X = pd.DataFrame(X)
10599

100+
return X
101+
102+
def _fit_columns(self, X):
103+
"""Fit each column to its distribution."""
106104
columns = []
107105
univariates = []
108106
for column_name, column in X.items():
109-
if isinstance(self.distribution, dict):
110-
distribution = self.distribution.get(column_name, DEFAULT_DISTRIBUTION)
111-
else:
112-
distribution = self.distribution
113-
107+
distribution = self._get_distribution_for_column(column_name)
114108
LOGGER.debug('Fitting column %s to %s', column_name, distribution)
115109

116-
univariate = get_instance(distribution)
117-
try:
118-
univariate.fit(column)
119-
except BaseException:
120-
log_message = (
121-
f'Unable to fit to a {distribution} distribution for column {column_name}. '
122-
'Using a Gaussian distribution instead.'
123-
)
124-
LOGGER.info(log_message)
125-
univariate = GaussianUnivariate()
126-
univariate.fit(column)
127-
110+
univariate = self._fit_column(column, distribution, column_name)
128111
columns.append(column_name)
129112
univariates.append(univariate)
130113

131-
self.columns = columns
132-
self.univariates = univariates
114+
return columns, univariates
115+
116+
def _get_distribution_for_column(self, column_name):
117+
"""Retrieve the distribution for a given column name."""
118+
if isinstance(self.distribution, dict):
119+
return self.distribution.get(column_name, DEFAULT_DISTRIBUTION)
120+
121+
return self.distribution
122+
123+
def _fit_column(self, column, distribution, column_name):
124+
"""Fit a single column to its distribution with exception handling."""
125+
univariate = get_instance(distribution)
126+
try:
127+
univariate.fit(column)
128+
except Exception as error:
129+
univariate = self._fit_with_fallback_distribution(
130+
column, distribution, column_name, error
131+
)
132+
133+
return univariate
134+
135+
def _fit_with_fallback_distribution(self, column, distribution, column_name, error):
136+
"""Fall back to fitting a Gaussian distribution and log the error."""
137+
log_message = (
138+
f'Unable to fit to a {distribution} distribution for column {column_name}. '
139+
'Using a Gaussian distribution instead.'
140+
)
141+
LOGGER.info(log_message)
142+
univariate = GaussianUnivariate()
143+
univariate.fit(column)
144+
return univariate
133145

134-
LOGGER.debug('Computing correlation')
135-
self.correlation = self._get_correlation(X)
136-
self.fitted = True
146+
def _get_correlation(self, X):
147+
"""Compute correlation matrix with transformed data.
137148
138-
LOGGER.debug('GaussianMultivariate fitted successfully')
149+
Args:
150+
X (numpy.ndarray):
151+
Data for which the correlation needs to be computed.
152+
153+
Returns:
154+
numpy.ndarray:
155+
computed correlation matrix.
156+
"""
157+
result = self._transform_to_normal(X)
158+
correlation = pd.DataFrame(data=result).corr().to_numpy()
159+
correlation = np.nan_to_num(correlation, nan=0.0)
160+
# If singular, add some noise to the diagonal
161+
if np.linalg.cond(correlation) > 1.0 / sys.float_info.epsilon:
162+
correlation = correlation + np.identity(correlation.shape[0]) * EPSILON
163+
164+
return pd.DataFrame(correlation, index=self.columns, columns=self.columns)
139165

140166
def probability_density(self, X):
141167
"""Compute the probability density for each point in X.

tests/unit/multivariate/test_gaussian.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,111 @@ def test_fit_broken_distribution(self, logger_mock, truncated_mock):
350350
assert isinstance(copula.univariates[0], GaussianUnivariate)
351351
assert copula.univariates[0]._params == {'loc': np.mean(data), 'scale': np.std(data)}
352352

353+
def test__validate_input_with_dataframe(self):
354+
"""Test that `_validate_input` returns the same DataFrame."""
355+
# Setup
356+
instance = GaussianMultivariate()
357+
input_df = pd.DataFrame({'A': [1, 2, 3]})
358+
359+
# Run
360+
result = instance._validate_input(input_df)
361+
362+
# Assert
363+
pd.testing.assert_frame_equal(result, input_df)
364+
365+
def test__validate_input_with_non_dataframe(self):
366+
"""Test that `_validate_input` converts non-DataFrame input into a DataFrame."""
367+
# Setup
368+
instance = GaussianMultivariate()
369+
input_data = [[1, 2, 3], [4, 5, 6]]
370+
371+
# Run
372+
result = instance._validate_input(input_data)
373+
374+
# Assert
375+
expected_df = pd.DataFrame(input_data)
376+
pd.testing.assert_frame_equal(result, expected_df)
377+
378+
@patch('copulas.multivariate.gaussian.LOGGER')
379+
def test__fit_columns(self, mock_logger):
380+
"""Test that `_fit_columns` fits each column to its distribution."""
381+
# Setup
382+
instance = GaussianMultivariate()
383+
instance._get_distribution_for_column = Mock(return_value='normal')
384+
instance._fit_column = Mock(return_value='fitted_univariate')
385+
386+
X = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
387+
388+
# Run
389+
columns, univariates = instance._fit_columns(X)
390+
391+
# Assert
392+
assert columns == ['A', 'B']
393+
assert univariates == ['fitted_univariate', 'fitted_univariate']
394+
instance._get_distribution_for_column.assert_any_call('A')
395+
instance._get_distribution_for_column.assert_any_call('B')
396+
mock_logger.debug.assert_any_call('Fitting column %s to %s', 'A', 'normal')
397+
mock_logger.debug.assert_any_call('Fitting column %s to %s', 'B', 'normal')
398+
399+
@patch('copulas.multivariate.gaussian.DEFAULT_DISTRIBUTION', new='default_distribution')
400+
def test__get_distribution_for_column_with_dict(self):
401+
"""Test that `_get_distribution_for_column` retrieves correct distribution from dict."""
402+
# Setup
403+
instance = GaussianMultivariate()
404+
instance.distribution = {'A': 'normal', 'B': 'uniform'}
405+
406+
# Run
407+
result_A = instance._get_distribution_for_column('A')
408+
result_B = instance._get_distribution_for_column('B')
409+
result_C = instance._get_distribution_for_column('C')
410+
411+
# Assert
412+
assert result_A == 'normal'
413+
assert result_B == 'uniform'
414+
assert result_C == 'default_distribution'
415+
416+
@patch('copulas.multivariate.gaussian.get_instance')
417+
@patch('copulas.multivariate.gaussian.GaussianUnivariate')
418+
def test__fit_column_with_exception(self, mock_gaussian_univariate, mock_get_instance):
419+
"""Test that `_fit_column` falls back to a Gaussian distribution on exception."""
420+
# Setup
421+
instance = GaussianMultivariate()
422+
column = pd.Series([1, 2, 3])
423+
distribution = 'normal'
424+
column_name = 'A'
425+
instance._fit_with_fallback_distribution = Mock(return_value='fallback_univariate')
426+
427+
mock_univariate = Mock()
428+
mock_univariate.fit.side_effect = Exception('Fit error')
429+
mock_get_instance.return_value = mock_univariate
430+
431+
# Run
432+
result = instance._fit_column(column, distribution, column_name)
433+
434+
# Assert
435+
instance._fit_with_fallback_distribution.assert_called_once_with(
436+
column, distribution, column_name, mock_univariate.fit.side_effect
437+
)
438+
assert result == 'fallback_univariate'
439+
440+
@patch('copulas.multivariate.gaussian.GaussianUnivariate')
441+
def test__fit_with_fallback_distribution(self, mock_gaussian_univariate):
442+
"""Test that `_fit_with_fallback_distribution` fits a Gaussian distribution."""
443+
# Setup
444+
instance = GaussianMultivariate()
445+
column = pd.Series([1, 2, 3])
446+
distribution = 'normal'
447+
column_name = 'A'
448+
error = Exception('Test error')
449+
mock_gaussian_univariate.return_value = Mock(fit=Mock())
450+
451+
# Run
452+
result = instance._fit_with_fallback_distribution(column, distribution, column_name, error)
453+
454+
# Assert
455+
mock_gaussian_univariate.return_value.fit.assert_called_once_with(column)
456+
assert result == mock_gaussian_univariate.return_value
457+
353458
def test_probability_density(self):
354459
"""Probability_density computes probability for the given values."""
355460
# Setup

0 commit comments

Comments
 (0)