|
18 | 18 | from sklearn.multiclass import OneVsRestClassifier, check_classification_targets |
19 | 19 |
|
20 | 20 | from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS |
21 | | -from skglm.datafits import (Cox, Quadratic, Logistic, QuadraticSVC, |
22 | | - QuadraticMultiTask, QuadraticGroup,) |
| 21 | +from skglm.datafits import ( |
| 22 | + Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC, |
| 23 | + QuadraticMultiTask, QuadraticGroup,) |
23 | 24 | from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2, |
24 | 25 | MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1) |
25 | 26 | from skglm.utils.data import grp_converter |
@@ -265,6 +266,8 @@ def predict(self, X): |
265 | 266 | else: |
266 | 267 | indices = scores.argmax(axis=1) |
267 | 268 | return self.classes_[indices] |
| 269 | + elif isinstance(self.datafit, (Poisson, PoissonGroup)): |
| 270 | + return np.exp(self._decision_function(X)) |
268 | 271 | else: |
269 | 272 | return self._decision_function(X) |
270 | 273 |
|
|
0 commit comments