Skip to content

Conversation

@floriankozikowski
Copy link
Contributor

@floriankozikowski floriankozikowski commented Jul 23, 2025

Follows up on #321 : Add unit tests verifying sklearn prediction compatibility

This PR addresses the request from @mathurinm to add unit tests ensuring that skglm's Poisson and Gamma estimators produce the same predictions as sklearn on simple data.

These tests validate that the prediction fix from #321 (applying exp() transform for log-link datafits) correctly matches sklearn's behavior.

indices = scores.argmax(axis=1)
return self.classes_[indices]
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
elif isinstance(self.datafit, (Poisson, PoissonGroup, Gamma)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if __hasattr__(self.datafit, "inverse_link_function"):
     return self.datafit.inverse_link_function(self._decision_function(X))

assert isinstance(res, str)


def test_poisson_predictions_match_sklearn():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge in a single parametrized test test_inverse_link_prediction

@floriankozikowski
Copy link
Contributor Author

floriankozikowski commented Jul 24, 2025

@Badr-MOUFAD @mathurinm Ready for review!

@mathurinm
Copy link
Collaborator

@Badr-MOUFAD this is the last one we need to release 0.5, WDYT ?

Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work guys 🙏

I have two comment

1 . What do you think about make inverse_link the identity by default to have a consistent API

 @staticmethod
 def inverse_link(x):
         return x
  1. The inverse_link for Logistic is missing, is there are reason for not implementing it ?

@floriankozikowski
Copy link
Contributor Author

@Badr-MOUFAD
Good idea with making inverse_link the identity by default. I just added this.

Regarding Question 2:
Now, logistic inherits the default identity inverse_link from BaseDatafit (see Q1).
The classification logic remains in predict() rather than inverse_link because I think inverse_link should handle pure mathematical transformations (like log-odds to probabilities). Classification involves class mapping with self.classes_[indices], which I would say is an estimator-level responsibility.
We could implement a proper sigmoid inverse link for Logistic to convert log-odds to probabilities, but since the classification branch never calls it the default identity is sufficient.
Let me know what you think and I can change it!

@mathurinm
Copy link
Collaborator

Agree with @floriankozikowski, the prediction logic for classification is different

"""Base class for datafits."""

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just call the argument Xw for clarity

self.grp_ptr, self.grp_indices = grp_ptr, grp_indices

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

pass

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

pass

@staticmethod
def inverse_link(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

)
def test_inverse_link_prediction(sklearn_reg, skglm_datafit, y_gen):
np.random.seed(42)
X = np.random.randn(20, 5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last thing : IMO it makes sense to run the test on completely random values of y. They don't have to fit the model well, thay could be random integers between 0 and 5. We're notchecking statistical validity, we're checking that the optimizer works well and we return the same thing as sklearn. This would make the test simpler.

@mathurinm mathurinm merged commit ba5d9d9 into scikit-learn-contrib:main Jul 28, 2025
4 checks passed
@mathurinm
Copy link
Collaborator

Thanks @floriankozikowski !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants