Skip to content

Commit 7f67f93

Browse files
Rudraksh TuwaniRudraksh Tuwani
authored andcommitted
covshift conf draft
1 parent e85c5a5 commit 7f67f93

File tree

4 files changed

+973
-22
lines changed

4 files changed

+973
-22
lines changed

examples/regression/4-covariate-shift/paper_replication.ipynb

Lines changed: 545 additions & 0 deletions
Large diffs are not rendered by default.

mapie/dre.py

100755100644
Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
class DensityRatioEstimator():
1313
""" Template class for density ratio estimation. """
1414

15-
def __init__(self):
16-
pass
15+
def __init__(self) -> None:
16+
raise NotImplementedError
1717

18-
def fit(self):
19-
pass
18+
def fit(self) -> None:
19+
raise NotImplementedError
2020

21-
def predict(self):
22-
pass
21+
def predict(self) -> None:
22+
raise NotImplementedError
2323

24-
def check_is_fitted(self):
25-
pass
24+
def check_is_fitted(self) -> None:
25+
raise NotImplementedError
2626

2727

2828
class ProbClassificationDRE(DensityRatioEstimator):
@@ -37,8 +37,8 @@ class ProbClassificationDRE(DensityRatioEstimator):
3737
Parameters
3838
----------
3939
estimator: Optional[ClassifierMixin]
40-
Any classifier with scikit-learn API
41-
(i.e. with fit, predict, and predict_proba methods), by default ``None``.
40+
Any classifier with scikit-learn API (i.e. with fit, predict, and
41+
predict_proba methods), by default ``None``.
4242
If ``None``, estimator defaults to a ``LogisticRegression`` instance.
4343
4444
clip_min: Optional[float]
@@ -56,11 +56,11 @@ class ProbClassificationDRE(DensityRatioEstimator):
5656
Attributes
5757
----------
5858
source_prob: float
59-
The marginal probability of getting a datapoint from the source
59+
The marginal probability of getting a datapoint from the source
6060
distribution.
6161
6262
target_prob: float
63-
The marginal probability of getting a datapoint from the target
63+
The marginal probability of getting a datapoint from the target
6464
distribution.
6565
6666
References
@@ -80,14 +80,14 @@ def __init__(
8080

8181
self.estimator = self._check_estimator(estimator)
8282

83-
if self.clip_max is None:
83+
if clip_max is None:
8484
self.clip_max = 1
8585
elif all((clip_max >= 0, clip_max <= 1)):
8686
self.clip_max = clip_max
8787
else:
8888
raise ValueError("Expected `clip_max` to be between 0 and 1.")
8989

90-
if self.clip_min is None:
90+
if clip_min is None:
9191
self.clip_min = 0
9292
elif all((clip_min >= 0, clip_min <= clip_max)):
9393
self.clip_min = clip_min
@@ -160,19 +160,19 @@ def fit(
160160
161161
source_prob: Optional[float]
162162
The marginal probability of getting a datapoint from the source
163-
distribution. If ``None``, the proportion of source examples in
163+
distribution. If ``None``, the proportion of source examples in
164164
the training dataset is used.
165165
166166
By default ``None``.
167167
168168
target_prob: Optional[float]
169169
The marginal probability of getting a datapoint from the target
170-
distribution. If ``None``, the proportion of target examples in
170+
distribution. If ``None``, the proportion of target examples in
171171
the training dataset is used.
172172
173173
By default ``None``.
174174
175-
sample_weight : Optional[ArrayLike] of shape (n_source_samples + n_target_samples,)
175+
sample_weight : Optional[ArrayLike] of shape (n_source + n_target,)
176176
Sample weights for fitting the out-of-fold models.
177177
If ``None``, then samples are equally weighted.
178178
If some weights are null,
@@ -192,15 +192,18 @@ def fit(
192192
n_target = X_target.shape[0]
193193

194194
if source_prob is None:
195-
source_prob = self.n_source/(self.n_source + self.n_target)
195+
source_prob = n_source/(n_source + n_target)
196196

197197
if target_prob is None:
198-
target_prob = self.n_target/(self.n_source + self.n_target)
198+
target_prob = n_target/(n_source + n_target)
199199

200200
if source_prob + target_prob != 1:
201201
raise ValueError(
202202
"``source_prob`` and ``target_prob`` do not add up to 1.")
203203

204+
self.source_prob = source_prob
205+
self.target_prob = target_prob
206+
204207
# Estimate the conditional probability of source/target given X.
205208
X = np.concatenate((X_source, X_target), axis=0)
206209
y = np.concatenate((np.zeros(n_source), np.ones(n_target)), axis=0)
@@ -243,9 +246,10 @@ def predict(
243246
log_probs = np.clip(log_probs, a_min=np.log(
244247
self.clip_min), a_max=np.log(self.clip_max))
245248

246-
return np.exp(log_probs[:, 1] - log_probs[:, 0] + np.log(self.source_prob) - np.log(self.target_prob))
249+
return np.exp(log_probs[:, 1] - log_probs[:, 0] +
250+
np.log(self.source_prob) - np.log(self.target_prob))
247251

248-
def check_is_fitted(self):
252+
def check_is_fitted(self) -> None:
249253
if isinstance(self.estimator, Pipeline):
250254
check_is_fitted(self.estimator[-1])
251255
else:
@@ -254,7 +258,7 @@ def check_is_fitted(self):
254258

255259
def calculate_ess(weights: ArrayLike) -> float:
256260
"""
257-
Calculates the effective sample size given importance weights for the
261+
Calculates the effective sample size given importance weights for the
258262
source distribution.
259263
260264
Parameters

0 commit comments

Comments
 (0)