Skip to content

Commit a9277dc

Browse files
Move utiltiy functions to utils.py and tests accordingly. (#430)
1 parent b3b9147 commit a9277dc

File tree

23 files changed

+698
-687
lines changed

23 files changed

+698
-687
lines changed

copulas/__init__.py

Lines changed: 0 additions & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -1,268 +1,16 @@
1-
# -*- coding: utf-8 -*-
2-
31
"""Top-level package for Copulas."""
42

53
__author__ = 'DataCebo, Inc.'
64
__email__ = '[email protected]'
75
__version__ = '0.11.2.dev0'
86

9-
import contextlib
10-
import importlib
117
import sys
128
import warnings
139
from copy import deepcopy
1410
from importlib.metadata import entry_points
1511
from operator import attrgetter
1612
from types import ModuleType
1713

18-
import numpy as np
19-
import pandas as pd
20-
21-
EPSILON = np.finfo(np.float32).eps
22-
23-
24-
class NotFittedError(Exception):
25-
"""NotFittedError class."""
26-
27-
28-
@contextlib.contextmanager
29-
def set_random_state(random_state, set_model_random_state):
30-
"""Context manager for managing the random state.
31-
32-
Args:
33-
random_state (int or np.random.RandomState):
34-
The random seed or RandomState.
35-
set_model_random_state (function):
36-
Function to set the random state on the model.
37-
"""
38-
original_state = np.random.get_state()
39-
40-
np.random.set_state(random_state.get_state())
41-
42-
try:
43-
yield
44-
finally:
45-
current_random_state = np.random.RandomState()
46-
current_random_state.set_state(np.random.get_state())
47-
set_model_random_state(current_random_state)
48-
np.random.set_state(original_state)
49-
50-
51-
def random_state(function):
52-
"""Set the random state before calling the function.
53-
54-
Args:
55-
function (Callable):
56-
The function to wrap around.
57-
"""
58-
59-
def wrapper(self, *args, **kwargs):
60-
if self.random_state is None:
61-
return function(self, *args, **kwargs)
62-
63-
else:
64-
with set_random_state(self.random_state, self.set_random_state):
65-
return function(self, *args, **kwargs)
66-
67-
return wrapper
68-
69-
70-
def validate_random_state(random_state):
71-
"""Validate random state argument.
72-
73-
Args:
74-
random_state (int, numpy.random.RandomState, tuple, or None):
75-
Seed or RandomState for the random generator.
76-
77-
Output:
78-
numpy.random.RandomState
79-
"""
80-
if random_state is None:
81-
return None
82-
83-
if isinstance(random_state, int):
84-
return np.random.RandomState(seed=random_state)
85-
elif isinstance(random_state, np.random.RandomState):
86-
return random_state
87-
else:
88-
raise TypeError(
89-
f'`random_state` {random_state} expected to be an int '
90-
'or `np.random.RandomState` object.'
91-
)
92-
93-
94-
def get_instance(obj, **kwargs):
95-
"""Create new instance of the ``obj`` argument.
96-
97-
Args:
98-
obj (str, type, instance):
99-
"""
100-
instance = None
101-
if isinstance(obj, str):
102-
package, name = obj.rsplit('.', 1)
103-
instance = getattr(importlib.import_module(package), name)(**kwargs)
104-
elif isinstance(obj, type):
105-
instance = obj(**kwargs)
106-
else:
107-
if kwargs:
108-
instance = obj.__class__(**kwargs)
109-
else:
110-
args = getattr(obj, '__args__', ())
111-
kwargs = getattr(obj, '__kwargs__', {})
112-
instance = obj.__class__(*args, **kwargs)
113-
114-
return instance
115-
116-
117-
def store_args(__init__):
118-
"""Save ``*args`` and ``**kwargs`` used in the ``__init__`` of a copula.
119-
120-
Args:
121-
__init__(callable): ``__init__`` function to store their arguments.
122-
123-
Returns:
124-
callable: Decorated ``__init__`` function.
125-
"""
126-
127-
def new__init__(self, *args, **kwargs):
128-
args_copy = deepcopy(args)
129-
kwargs_copy = deepcopy(kwargs)
130-
__init__(self, *args, **kwargs)
131-
self.__args__ = args_copy
132-
self.__kwargs__ = kwargs_copy
133-
134-
return new__init__
135-
136-
137-
def get_qualified_name(_object):
138-
"""Return the Fully Qualified Name from an instance or class."""
139-
module = _object.__module__
140-
if hasattr(_object, '__name__'):
141-
_class = _object.__name__
142-
143-
else:
144-
_class = _object.__class__.__name__
145-
146-
return module + '.' + _class
147-
148-
149-
def vectorize(function):
150-
"""Allow a method that only accepts scalars to accept vectors too.
151-
152-
This decorator has two different behaviors depending on the dimensionality of the
153-
array passed as an argument:
154-
155-
**1-d array**
156-
157-
It will work under the assumption that the `function` argument is a callable
158-
with signature::
159-
160-
function(self, X, *args, **kwargs)
161-
162-
where X is an scalar magnitude.
163-
164-
In this case the arguments of the input array will be given one at a time, and
165-
both the input and output of the decorated function will have shape (n,).
166-
167-
**2-d array**
168-
169-
It will work under the assumption that the `function` argument is a callable with signature::
170-
171-
function(self, X0, ..., Xj, *args, **kwargs)
172-
173-
where `Xi` are scalar magnitudes.
174-
175-
It will pass the contents of each row unpacked on each call. The input is espected to have
176-
shape (n, j), the output a shape of (n,)
177-
178-
It will return a function that is guaranteed to return a `numpy.array`.
179-
180-
Args:
181-
function(callable): Function that only accept and return scalars.
182-
183-
Returns:
184-
callable: Decorated function that can accept and return :attr:`numpy.array`.
185-
186-
"""
187-
188-
def decorated(self, X, *args, **kwargs):
189-
if not isinstance(X, np.ndarray):
190-
return function(self, X, *args, **kwargs)
191-
192-
if len(X.shape) == 1:
193-
X = X.reshape([-1, 1])
194-
195-
if len(X.shape) == 2:
196-
return np.fromiter(
197-
(function(self, *x, *args, **kwargs) for x in X), np.dtype('float64')
198-
)
199-
200-
else:
201-
raise ValueError('Arrays of dimensionality higher than 2 are not supported.')
202-
203-
decorated.__doc__ = function.__doc__
204-
return decorated
205-
206-
207-
def scalarize(function):
208-
"""Allow methods that only accepts 1-d vectors to work with scalars.
209-
210-
Args:
211-
function(callable): Function that accepts and returns vectors.
212-
213-
Returns:
214-
callable: Decorated function that accepts and returns scalars.
215-
"""
216-
217-
def decorated(self, X, *args, **kwargs):
218-
scalar = not isinstance(X, np.ndarray)
219-
220-
if scalar:
221-
X = np.array([X])
222-
223-
result = function(self, X, *args, **kwargs)
224-
if scalar:
225-
result = result[0]
226-
227-
return result
228-
229-
decorated.__doc__ = function.__doc__
230-
return decorated
231-
232-
233-
def check_valid_values(function):
234-
"""Raise an exception if the given values are not supported.
235-
236-
Args:
237-
function(callable): Method whose unique argument is a numpy.array-like object.
238-
239-
Returns:
240-
callable: Decorated function
241-
242-
Raises:
243-
ValueError: If there are missing or invalid values or if the dataset is empty.
244-
"""
245-
246-
def decorated(self, X, *args, **kwargs):
247-
if isinstance(X, pd.DataFrame):
248-
W = X.to_numpy()
249-
250-
else:
251-
W = X
252-
253-
if not len(W):
254-
raise ValueError('Your dataset is empty.')
255-
256-
if not (np.issubdtype(W.dtype, np.floating) or np.issubdtype(W.dtype, np.integer)):
257-
raise ValueError('There are non-numerical values in your data.')
258-
259-
if np.isnan(W).any().any():
260-
raise ValueError('There are nan values in your data.')
261-
262-
return function(self, X, *args, **kwargs)
263-
264-
return decorated
265-
26614

26715
def _get_addon_target(addon_path_name):
26816
"""Find the target object for the add-on.

copulas/bivariate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas as pd
55

6-
from copulas import EPSILON
6+
from copulas.utils import EPSILON
77
from copulas.bivariate.base import Bivariate, CopulaTypes
88
from copulas.bivariate.clayton import Clayton
99
from copulas.bivariate.frank import Frank

copulas/bivariate/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from scipy import stats
99
from scipy.optimize import brentq
1010

11-
from copulas import EPSILON, NotFittedError, random_state, validate_random_state
1211
from copulas.bivariate.utils import split_matrix
12+
from copulas.errors import NotFittedError
13+
from copulas.utils import EPSILON, random_state, validate_random_state
1314

1415

1516
class CopulaTypes(Enum):

copulas/bivariate/frank.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import scipy.integrate as integrate
77
from scipy.optimize import least_squares
88

9-
from copulas import EPSILON
109
from copulas.bivariate.base import Bivariate, CopulaTypes
1110
from copulas.bivariate.utils import split_matrix
11+
from copulas.utils import EPSILON
1212

1313
MIN_FLOAT_LOG = np.log(sys.float_info.min)
1414
MAX_FLOAT_LOG = np.log(sys.float_info.max)

copulas/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pandas as pd
55
from scipy import stats
66

7-
from copulas import set_random_state, validate_random_state
7+
from copulas.utils import set_random_state, validate_random_state
88

99

1010
def _dummy_fn(state):

copulas/errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Copulas Exceptions."""
2+
3+
4+
class NotFittedError(Exception):
5+
"""NotFittedError class."""

copulas/multivariate/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import numpy as np
66

7-
from copulas import NotFittedError, get_instance, validate_random_state
7+
from copulas.errors import NotFittedError
8+
from copulas.utils import get_instance, validate_random_state
89

910

1011
class Multivariate(object):

copulas/multivariate/gaussian.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import pandas as pd
88
from scipy import stats
99

10-
from copulas import (
10+
from copulas.multivariate.base import Multivariate
11+
from copulas.univariate import GaussianUnivariate, Univariate
12+
from copulas.utils import (
1113
EPSILON,
1214
check_valid_values,
1315
get_instance,
@@ -16,8 +18,6 @@
1618
store_args,
1719
validate_random_state,
1820
)
19-
from copulas.multivariate.base import Multivariate
20-
from copulas.univariate import GaussianUnivariate, Univariate
2121

2222
LOGGER = logging.getLogger(__name__)
2323
DEFAULT_DISTRIBUTION = Univariate

copulas/multivariate/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import numpy as np
77
import scipy
88

9-
from copulas import EPSILON, get_qualified_name
109
from copulas.bivariate.base import Bivariate
1110
from copulas.multivariate.base import Multivariate
11+
from copulas.utils import EPSILON, get_qualified_name
1212

1313
LOGGER = logging.getLogger(__name__)
1414

copulas/multivariate/vine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
import numpy as np
88
import pandas as pd
99

10-
from copulas import (
10+
from copulas.bivariate.base import Bivariate, CopulaTypes
11+
from copulas.multivariate.base import Multivariate
12+
from copulas.multivariate.tree import Tree, get_tree
13+
from copulas.univariate.gaussian_kde import GaussianKDE
14+
from copulas.utils import (
1115
EPSILON,
1216
check_valid_values,
1317
get_qualified_name,
1418
random_state,
1519
store_args,
1620
validate_random_state,
1721
)
18-
from copulas.bivariate.base import Bivariate, CopulaTypes
19-
from copulas.multivariate.base import Multivariate
20-
from copulas.multivariate.tree import Tree, get_tree
21-
from copulas.univariate.gaussian_kde import GaussianKDE
2222

2323
LOGGER = logging.getLogger(__name__)
2424

0 commit comments

Comments
 (0)