Skip to content

Commit 13f647d

Browse files
author
Guillaume Lemaitre
committed
Forgot to add the base class
1 parent 0fcd502 commit 13f647d

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed

unbalanced_dataset/base.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Base class for sampling"""
2+
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import warnings
7+
8+
import numpy as np
9+
10+
from abc import ABCMeta, abstractmethod
11+
12+
from collections import Counter
13+
14+
from sklearn.base import BaseEstimator
15+
from sklearn.utils import check_X_y
16+
from sklearn.externals import six
17+
18+
from six import string_types
19+
20+
21+
class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
22+
"""Mixin class for samplers with abstact method.
23+
24+
Warning: This class should not be used directly. Use the derive classes
25+
instead.
26+
"""
27+
28+
@abstractmethod
29+
def __init__(self, ratio='auto', random_state=None, verbose=True):
30+
"""Initialize this object and its instance variables.
31+
32+
Parameters
33+
----------
34+
ratio : str or float, optional (default='auto')
35+
If 'auto', the ratio will be defined automatically to balanced
36+
the dataset. Otherwise, the ratio will corresponds to the number
37+
of samples in the minority class over the the number of samples
38+
in the majority class.
39+
40+
random_state : int or None, optional (default=None)
41+
Seed for random number generation.
42+
43+
verbose : bool, optional (default=True)
44+
Boolean to either or not print information about the processing
45+
46+
Returns
47+
-------
48+
None
49+
50+
"""
51+
# The ratio correspond to the number of samples in the minority class
52+
# over the number of samples in the majority class. Thus, the ratio
53+
# cannot be greater than 1.0
54+
if isinstance(ratio, float):
55+
if ratio > 1:
56+
raise ValueError('Ration cannot be greater than one.')
57+
elif ratio <= 0:
58+
raise ValueError('Ratio cannot be negative.')
59+
else:
60+
self.ratio = ratio
61+
elif isinstance(ratio, string_types):
62+
if ratio == 'auto':
63+
self.ratio = ratio
64+
else:
65+
raise ValueError('Unknown string for the parameter ratio.')
66+
else:
67+
raise ValueError('Unknown parameter type for ratio.')
68+
69+
self.random_state = random_state
70+
self.verbose = verbose
71+
72+
# Create the member variables regarding the classes statistics
73+
self.min_c_ = None
74+
self.maj_c_ = None
75+
self.stats_c_ = {}
76+
77+
@abstractmethod
78+
def fit(self, X, y):
79+
"""Find the classes statistics before to perform sampling.
80+
81+
Parameters
82+
----------
83+
X : ndarray, shape (n_samples, n_features)
84+
Matrix containing the data which have to be sampled.
85+
86+
y : ndarray, shape (n_samples, )
87+
Corresponding label for each sample in X.
88+
89+
Returns
90+
-------
91+
self : object,
92+
Return self.
93+
94+
"""
95+
96+
# Check the consistency of X and y
97+
X, y = check_X_y(X, y)
98+
99+
if self.verbose:
100+
print("Determining classes statistics... ", end="")
101+
102+
# Get all the unique elements in the target array
103+
uniques = np.unique(y)
104+
105+
# # Raise an error if there is only one class
106+
# if uniques.size == 1:
107+
# raise RuntimeError("Only one class detected, aborting...")
108+
# Raise a warning for the moment to be compatible with BaseEstimator
109+
if uniques.size == 1:
110+
warnings.warn('Only one class detected, something will get wrong',
111+
RuntimeWarning)
112+
113+
# Create a dictionary containing the class statistics
114+
self.stats_c_ = Counter(y)
115+
116+
# Find the minority and majority classes
117+
self.min_c_ = min(self.stats_c_, key=self.stats_c_.get)
118+
self.maj_c_ = max(self.stats_c_, key=self.stats_c_.get)
119+
120+
if self.verbose:
121+
print('{} classes detected: {}'.format(uniques.size,
122+
self.stats_c_))
123+
124+
# Check if the ratio provided at initialisation make sense
125+
if isinstance(self.ratio, float):
126+
if self.ratio < (self.stats_c_[self.min_c_] /
127+
self.stats_c_[self.maj_c_]):
128+
raise RuntimeError('The ratio requested at initialisation'
129+
' should be greater or equal than the'
130+
' balancing ratio of the current data.')
131+
132+
return self
133+
134+
@abstractmethod
135+
def sample(self, X, y):
136+
"""Resample the dataset.
137+
138+
Parameters
139+
----------
140+
X : ndarray, shape (n_samples, n_features)
141+
Matrix containing the data which have to be sampled.
142+
143+
y : ndarray, shape (n_samples, )
144+
Corresponding label for each sample in X.
145+
146+
Returns
147+
-------
148+
X_resampled : ndarray, shape (n_samples_new, n_features)
149+
The array containing the resampled data.
150+
151+
y_resampled : ndarray, shape (n_samples_new)
152+
The corresponding label of `X_resampled`
153+
154+
"""
155+
156+
# Check that the data have been fitted
157+
if not self.stats_c_:
158+
raise RuntimeError('You need to fit the data, first!!!')
159+
160+
return self
161+
162+
def fit_sample(self, X, y):
163+
"""Fit the statistics and resample the data directly.
164+
165+
Parameters
166+
----------
167+
X : ndarray, shape (n_samples, n_features)
168+
Matrix containing the data which have to be sampled.
169+
170+
y : ndarray, shape (n_samples, )
171+
Corresponding label for each sample in X.
172+
173+
Returns
174+
-------
175+
X_resampled : ndarray, shape (n_samples_new, n_features)
176+
The array containing the resampled data.
177+
178+
y_resampled : ndarray, shape (n_samples_new)
179+
The corresponding label of `X_resampled`
180+
181+
"""
182+
183+
return self.fit(X, y).sample(X, y)

0 commit comments

Comments
 (0)