@@ -70,26 +70,6 @@ def _transform_to_normal(self, X):
7070
7171 return stats .norm .ppf (np .column_stack (U ))
7272
73- def _get_correlation (self , X ):
74- """Compute correlation matrix with transformed data.
75-
76- Args:
77- X (numpy.ndarray):
78- Data for which the correlation needs to be computed.
79-
80- Returns:
81- numpy.ndarray:
82- computed correlation matrix.
83- """
84- result = self ._transform_to_normal (X )
85- correlation = pd .DataFrame (data = result ).corr ().to_numpy ()
86- correlation = np .nan_to_num (correlation , nan = 0.0 )
87- # If singular, add some noise to the diagonal
88- if np .linalg .cond (correlation ) > 1.0 / sys .float_info .epsilon :
89- correlation = correlation + np .identity (correlation .shape [0 ]) * EPSILON
90-
91- return pd .DataFrame (correlation , index = self .columns , columns = self .columns )
92-
9373 @check_valid_values
9474 def fit (self , X ):
9575 """Compute the distribution for each variable and then its correlation matrix.
@@ -100,42 +80,88 @@ def fit(self, X):
10080 """
10181 LOGGER .info ('Fitting %s' , self )
10282
83+ # Validate the input data
84+ X = self ._validate_input (X )
85+ columns , univariates = self ._fit_columns (X )
86+
87+ self .columns = columns
88+ self .univariates = univariates
89+
90+ LOGGER .debug ('Computing correlation.' )
91+ self .correlation = self ._get_correlation (X )
92+ self .fitted = True
93+ LOGGER .debug ('GaussianMultivariate fitted successfully' )
94+
95+ def _validate_input (self , X ):
96+ """Validate the input data."""
10397 if not isinstance (X , pd .DataFrame ):
10498 X = pd .DataFrame (X )
10599
100+ return X
101+
102+ def _fit_columns (self , X ):
103+ """Fit each column to its distribution."""
106104 columns = []
107105 univariates = []
108106 for column_name , column in X .items ():
109- if isinstance (self .distribution , dict ):
110- distribution = self .distribution .get (column_name , DEFAULT_DISTRIBUTION )
111- else :
112- distribution = self .distribution
113-
107+ distribution = self ._get_distribution_for_column (column_name )
114108 LOGGER .debug ('Fitting column %s to %s' , column_name , distribution )
115109
116- univariate = get_instance (distribution )
117- try :
118- univariate .fit (column )
119- except BaseException :
120- log_message = (
121- f'Unable to fit to a { distribution } distribution for column { column_name } . '
122- 'Using a Gaussian distribution instead.'
123- )
124- LOGGER .info (log_message )
125- univariate = GaussianUnivariate ()
126- univariate .fit (column )
127-
110+ univariate = self ._fit_column (column , distribution , column_name )
128111 columns .append (column_name )
129112 univariates .append (univariate )
130113
131- self .columns = columns
132- self .univariates = univariates
114+ return columns , univariates
115+
116+ def _get_distribution_for_column (self , column_name ):
117+ """Retrieve the distribution for a given column name."""
118+ if isinstance (self .distribution , dict ):
119+ return self .distribution .get (column_name , DEFAULT_DISTRIBUTION )
120+
121+ return self .distribution
122+
123+ def _fit_column (self , column , distribution , column_name ):
124+ """Fit a single column to its distribution with exception handling."""
125+ univariate = get_instance (distribution )
126+ try :
127+ univariate .fit (column )
128+ except Exception as error :
129+ univariate = self ._fit_with_fallback_distribution (
130+ column , distribution , column_name , error
131+ )
132+
133+ return univariate
134+
135+ def _fit_with_fallback_distribution (self , column , distribution , column_name , error ):
136+ """Fall back to fitting a Gaussian distribution and log the error."""
137+ log_message = (
138+ f'Unable to fit to a { distribution } distribution for column { column_name } . '
139+ 'Using a Gaussian distribution instead.'
140+ )
141+ LOGGER .info (log_message )
142+ univariate = GaussianUnivariate ()
143+ univariate .fit (column )
144+ return univariate
133145
134- LOGGER .debug ('Computing correlation' )
135- self .correlation = self ._get_correlation (X )
136- self .fitted = True
146+ def _get_correlation (self , X ):
147+ """Compute correlation matrix with transformed data.
137148
138- LOGGER .debug ('GaussianMultivariate fitted successfully' )
149+ Args:
150+ X (numpy.ndarray):
151+ Data for which the correlation needs to be computed.
152+
153+ Returns:
154+ numpy.ndarray:
155+ computed correlation matrix.
156+ """
157+ result = self ._transform_to_normal (X )
158+ correlation = pd .DataFrame (data = result ).corr ().to_numpy ()
159+ correlation = np .nan_to_num (correlation , nan = 0.0 )
160+ # If singular, add some noise to the diagonal
161+ if np .linalg .cond (correlation ) > 1.0 / sys .float_info .epsilon :
162+ correlation = correlation + np .identity (correlation .shape [0 ]) * EPSILON
163+
164+ return pd .DataFrame (correlation , index = self .columns , columns = self .columns )
139165
140166 def probability_density (self , X ):
141167 """Compute the probability density for each point in X.
0 commit comments