Skip to content

Commit 296adf0

Browse files
authored
API match sklearn behavior for poisson (#321)
1 parent 7e4802b commit 296adf0

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

skglm/estimators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from sklearn.multiclass import OneVsRestClassifier, check_classification_targets
1919

2020
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
21-
from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC,
22-
QuadraticMultiTask, QuadraticGroup,)
21+
from skglm.datafits import (
22+
Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC,
23+
QuadraticMultiTask, QuadraticGroup,)
2324
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2425
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
2526
from skglm.utils.data import grp_converter
@@ -265,6 +266,8 @@ def predict(self, X):
265266
else:
266267
indices = scores.argmax(axis=1)
267268
return self.classes_[indices]
269+
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
270+
return np.exp(self._decision_function(X))
268271
else:
269272
return self._decision_function(X)
270273

0 commit comments

Comments
 (0)