Skip to content

Commit aa9880c

Browse files
FIX huber estimator problem (#98)
* fix huber estimator problem * Add regression test and remove unused finite check * black reformat * fix typo * capture warning for regression test
1 parent 48504c8 commit aa9880c

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

sklearn_extra/robust/mean_estimators.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def psisx(x, c):
121121
res = np.zeros(len(x))
122122
mask = np.abs(x) <= c
123123
res[mask] = 1
124-
res[~mask] = (c / np.abs(x))[~mask]
125-
res[~np.isfinite(x)] = 0
124+
res[~mask] = c / np.abs(x[~mask])
126125
return res
127126

128127
# Run the iterative reweighting algorithm to compute M-estimator.

sklearn_extra/robust/tests/test_mean_estimators.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
from sklearn_extra.robust.mean_estimators import median_of_means, huber
45

@@ -21,3 +22,10 @@ def test_mom():
2122
sample_cor = sample
2223
sample_cor[:num_out] = np.inf
2324
assert np.abs(median_of_means(sample_cor, num_out, rng)) < 2
25+
26+
27+
def test_huber():
28+
X = np.hstack([np.zeros(90), np.ones(10)])
29+
with pytest.warns(None) as record:
30+
huber(X)
31+
assert len(record) == 0

0 commit comments

Comments
 (0)