Skip to content

Commit 2ea3213

Browse files
authored
When Copulas univariate fit fails, produce a log instead of a warning (#365)
1 parent 56a76c0 commit 2ea3213

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

copulas/multivariate/gaussian.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import sys
5-
import warnings
65

76
import numpy as np
87
import pandas as pd
@@ -112,11 +111,11 @@ def fit(self, X):
112111
try:
113112
univariate.fit(column)
114113
except BaseException:
115-
warning_message = (
114+
log_message = (
116115
f'Unable to fit to a {distribution} distribution for column {column_name}. '
117116
'Using a Gaussian distribution instead.'
118117
)
119-
warnings.warn(warning_message)
118+
LOGGER.info(log_message)
120119
univariate = GaussianUnivariate()
121120
univariate.fit(column)
122121

tests/end-to-end/multivariate/test_gaussian.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def test_save_load(self):
179179

180180

181181
@patch('copulas.univariate.truncated_gaussian.TruncatedGaussian._fit')
182-
@patch('copulas.multivariate.gaussian.warnings')
183-
def test_broken_distribution(warnings_mock, truncated_mock):
182+
@patch('copulas.multivariate.gaussian.LOGGER')
183+
def test_broken_distribution(logger_mock, truncated_mock):
184184
"""Fit should use a gaussian if the passed distribution crashes."""
185185
# Setup
186186
truncated_mock.side_effect = ValueError()
@@ -194,11 +194,14 @@ def test_broken_distribution(warnings_mock, truncated_mock):
194194
samples = model.sample()
195195

196196
# Asserts
197-
expected_warnings_msg = (
197+
expected_logging_msg = (
198198
'Unable to fit to a copulas.univariate.truncated_gaussian.TruncatedGaussian '
199199
'distribution for column y. Using a Gaussian distribution instead.'
200200
)
201-
warnings_mock.warn.assert_called_once_with(expected_warnings_msg)
201+
calls = logger_mock.info.call_args_list
202+
assert calls[0].args[0] == 'Fitting %s'
203+
assert calls[1].args[0] == expected_logging_msg
204+
assert len(calls) == 2
202205

203206
expected_model = GaussianMultivariate(
204207
distribution={'y': 'copulas.univariate.truncated_gaussian.TruncatedGaussian'}

tests/unit/multivariate/test_gaussian.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ def test_fit_numpy_array(self):
273273
assert (copula.correlation == expected_correlation).all().all()
274274

275275
@patch('copulas.univariate.truncated_gaussian.TruncatedGaussian._fit')
276-
@patch('copulas.multivariate.gaussian.warnings')
277-
def test_fit_broken_distribution(self, warnings_mock, truncated_mock):
276+
@patch('copulas.multivariate.gaussian.LOGGER')
277+
def test_fit_broken_distribution(self, logger_mock, truncated_mock):
278278
"""Fit should use a gaussian if the passed distribution crashes."""
279279
# Setup
280280
truncated_mock.side_effect = ValueError()
@@ -287,11 +287,14 @@ def test_fit_broken_distribution(self, warnings_mock, truncated_mock):
287287
copula.fit(data)
288288

289289
# Check
290-
expected_warnings_msg = (
290+
expected_logging_msg = (
291291
'Unable to fit to a copulas.univariate.truncated_gaussian.TruncatedGaussian '
292292
'distribution for column column1. Using a Gaussian distribution instead.'
293293
)
294-
warnings_mock.warn.assert_called_once_with(expected_warnings_msg)
294+
calls = logger_mock.info.call_args_list
295+
assert calls[0].args[0] == 'Fitting %s'
296+
assert calls[1].args[0] == expected_logging_msg
297+
assert len(calls) == 2
295298

296299
assert len(copula.univariates) == 1
297300
assert isinstance(copula.univariates[0], GaussianUnivariate)

0 commit comments

Comments
 (0)