Skip to content

Commit 99b2aad

Browse files
committed
feat: add scikit-learn compatible estimators with pipeline and grid search support
1 parent c4d653f commit 99b2aad

File tree

11 files changed

+202
-48
lines changed

11 files changed

+202
-48
lines changed

README.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
- Paper: [NeurIPS | 2023](https://openreview.net/pdf?id=3pEBW2UPAD)
1010
<!-- - Open Source: [MIT license](https://opensource.org/licenses/MIT) -->
1111

12-
1312
The **ReHLine** solver has four appealing
1413
"linear properties":
1514

@@ -18,6 +17,16 @@ The **ReHLine** solver has four appealing
1817
- The optimization algorithm has a provable linear convergence rate.
1918
- The per-iteration computational complexity is linear in the sample size.
2019

20+
21+
## ✨ New Features: Scikit-Learn Compatible Estimators
22+
23+
We are excited to introduce full scikit-learn compatibility! `ReHLine` now provides `plq_Ridge_Classifier` and `plq_Ridge_Regressor` estimators that integrate seamlessly with the entire scikit-learn ecosystem.
24+
25+
This means you can:
26+
- Drop `ReHLine` estimators directly into your existing scikit-learn `Pipeline`.
27+
- Perform robust hyperparameter tuning using `GridSearchCV`.
28+
- Use standard scikit-learn evaluation metrics and cross-validation tools.
29+
2130
<!--
2231
## 📝 Formulation
2332
@@ -57,7 +66,3 @@ benchmark code and results at the
5766
|[RidgeHuber](https://github.com/softmin/ReHLine-benchmark/tree/main/benchmark_Huber) | [Result](https://rehline-python.readthedocs.io/en/latest/_static/benchmark/benchmark_Huber.html)|
5867
|[SVM](https://github.com/softmin/ReHLine-benchmark/tree/main/benchmark_SVM) | [Result](https://rehline-python.readthedocs.io/en/latest/_static/benchmark/benchmark_SVM.html)|
5968
|[Smoothed SVM](https://github.com/softmin/ReHLine-benchmark/tree/main/benchmark_sSVM) | [Result](https://rehline-python.readthedocs.io/en/latest/_static/benchmark/benchmark_sSVM.html)|
60-
61-
## 🧾 Overview of Results
62-
63-
![](./figs/res.png)

doc/source/autoapi/rehline/index.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,12 +501,12 @@ Classes
501501
and ridge penalty, compatible with the scikit-learn API.
502502

503503
This wrapper makes ``plqERM_Ridge`` behave as a classifier:
504-
- Accepts arbitrary binary labels in the original label space.
505-
- Computes class weights on original labels (if ``class_weight`` is set).
506-
- Encodes labels with ``LabelEncoder`` into {0,1}, then maps to {-1,+1} for training.
507-
- Supports optional intercept fitting (via an augmented constant feature).
508-
- Provides standard methods ``fit``, ``predict``, and ``decision_function``.
509-
- Integrates with scikit-learn ecosystem (e.g., GridSearchCV, Pipeline).
504+
- Accepts arbitrary binary labels in the original label space.
505+
- Computes class weights on original labels (if ``class_weight`` is set).
506+
- Encodes labels with ``LabelEncoder`` into {0,1}, then maps to {-1,+1} for training.
507+
- Supports optional intercept fitting (via an augmented constant feature).
508+
- Provides standard methods ``fit``, ``predict``, and ``decision_function``.
509+
- Integrates with scikit-learn ecosystem (e.g., GridSearchCV, Pipeline).
510510

511511
Parameters
512512
----------

doc/source/clean_notebooks.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
import sys
3+
from pathlib import Path
4+
5+
def clean_notebook(file_path):
6+
"""Removes the 'id' field from all cells in a Jupyter notebook."""
7+
try:
8+
with open(file_path, 'r', encoding='utf-8') as f:
9+
notebook = json.load(f)
10+
11+
changes_made = False
12+
if 'cells' in notebook and isinstance(notebook['cells'], list):
13+
for cell in notebook['cells']:
14+
if isinstance(cell, dict) and 'id' in cell:
15+
del cell['id']
16+
changes_made = True
17+
18+
if changes_made:
19+
with open(file_path, 'w', encoding='utf-8') as f:
20+
json.dump(notebook, f, indent=1, ensure_ascii=False)
21+
f.write('\n') # Add a newline at the end of the file
22+
print(f"Cleaned: {file_path}")
23+
else:
24+
print(f"No changes needed: {file_path}")
25+
26+
except Exception as e:
27+
print(f"Error processing {file_path}: {e}")
28+
29+
def main():
30+
if len(sys.argv) != 2:
31+
print("Usage: python clean_notebooks.py <directory>")
32+
sys.exit(1)
33+
34+
target_dir = Path(sys.argv[1])
35+
if not target_dir.is_dir():
36+
print(f"Error: {target_dir} is not a valid directory.")
37+
sys.exit(1)
38+
39+
print(f"Searching for notebooks in {target_dir}...")
40+
notebook_files = list(target_dir.rglob('*.ipynb'))
41+
42+
if not notebook_files:
43+
print("No notebook files found.")
44+
return
45+
46+
for notebook_file in notebook_files:
47+
clean_notebook(notebook_file)
48+
49+
if __name__ == "__main__":
50+
main()

doc/source/examples/.ipynb_checkpoints/FairSVM-checkpoint.ipynb

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "50711dda-105e-4714-937b-e8be06370605",
65
"metadata": {},
76
"source": [
87
"# **FairSVM**\n",
@@ -38,7 +37,6 @@
3837
{
3938
"cell_type": "code",
4039
"execution_count": 1,
41-
"id": "e66268fa-403d-402b-9ea1-fbfe7573af40",
4240
"metadata": {},
4341
"outputs": [],
4442
"source": [
@@ -62,7 +60,6 @@
6260
},
6361
{
6462
"cell_type": "markdown",
65-
"id": "6a576a09-b700-49cd-b500-219f3a6e40b0",
6663
"metadata": {},
6764
"source": [
6865
"## SVM as baseline"
@@ -71,7 +68,6 @@
7168
{
7269
"cell_type": "code",
7370
"execution_count": 2,
74-
"id": "15531796-3a45-42b3-8a99-da0343be9d4d",
7571
"metadata": {},
7672
"outputs": [],
7773
"source": [
@@ -84,7 +80,6 @@
8480
},
8581
{
8682
"cell_type": "markdown",
87-
"id": "79bb275b-2dfd-4608-83e3-b4b4eb0fdb72",
8883
"metadata": {},
8984
"source": [
9085
"## FairSVM"
@@ -93,7 +88,6 @@
9388
{
9489
"cell_type": "code",
9590
"execution_count": 3,
96-
"id": "c43509f7-031b-4620-bc5e-fb5aea2ef1c2",
9791
"metadata": {},
9892
"outputs": [],
9993
"source": [
@@ -111,7 +105,6 @@
111105
},
112106
{
113107
"cell_type": "markdown",
114-
"id": "794ede1f-13a4-4889-b6d9-f19a61faa510",
115108
"metadata": {},
116109
"source": [
117110
"## Results"
@@ -120,7 +113,6 @@
120113
{
121114
"cell_type": "code",
122115
"execution_count": 4,
123-
"id": "05dc3921-1837-474e-9a6d-4555a94ddc30",
124116
"metadata": {},
125117
"outputs": [
126118
{
@@ -159,7 +151,6 @@
159151
{
160152
"cell_type": "code",
161153
"execution_count": 5,
162-
"id": "ad5a863e-fbbb-4caf-876d-374f3ca9b891",
163154
"metadata": {},
164155
"outputs": [
165156
{

doc/source/examples/CQR.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"## **Ridge Composite Quantile Regression**\n",
7+
"## Ridge Composite Quantile Regression\n",
88
"\n",
99
"[![Slides](https://img.shields.io/badge/🦌-ReHLine-blueviolet)](https://rehline-python.readthedocs.io/en/latest/)\n",
1010
"\n",

doc/source/examples/FairSVM.ipynb

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "50711dda-105e-4714-937b-e8be06370605",
65
"metadata": {},
76
"source": [
8-
"# **FairSVM**\n",
7+
"# FairSVM\n",
98
"\n",
109
"[![Slides](https://img.shields.io/badge/🦌-ReHLine-blueviolet)](https://rehline-python.readthedocs.io/en/latest/)\n",
1110
"\n",
@@ -36,14 +35,12 @@
3635
},
3736
{
3837
"cell_type": "markdown",
39-
"id": "7bf9272115591bbf",
4038
"metadata": {},
4139
"source": []
4240
},
4341
{
4442
"cell_type": "code",
4543
"execution_count": 1,
46-
"id": "e66268fa-403d-402b-9ea1-fbfe7573af40",
4744
"metadata": {},
4845
"outputs": [],
4946
"source": [
@@ -67,7 +64,6 @@
6764
},
6865
{
6966
"cell_type": "markdown",
70-
"id": "6a576a09-b700-49cd-b500-219f3a6e40b0",
7167
"metadata": {},
7268
"source": [
7369
"## SVM as baseline"
@@ -76,7 +72,6 @@
7672
{
7773
"cell_type": "code",
7874
"execution_count": 2,
79-
"id": "15531796-3a45-42b3-8a99-da0343be9d4d",
8075
"metadata": {},
8176
"outputs": [],
8277
"source": [
@@ -89,7 +84,6 @@
8984
},
9085
{
9186
"cell_type": "markdown",
92-
"id": "79bb275b-2dfd-4608-83e3-b4b4eb0fdb72",
9387
"metadata": {},
9488
"source": [
9589
"## FairSVM"
@@ -98,7 +92,6 @@
9892
{
9993
"cell_type": "code",
10094
"execution_count": 3,
101-
"id": "c43509f7-031b-4620-bc5e-fb5aea2ef1c2",
10295
"metadata": {},
10396
"outputs": [],
10497
"source": [
@@ -116,7 +109,6 @@
116109
},
117110
{
118111
"cell_type": "markdown",
119-
"id": "794ede1f-13a4-4889-b6d9-f19a61faa510",
120112
"metadata": {},
121113
"source": [
122114
"## Results"
@@ -125,7 +117,6 @@
125117
{
126118
"cell_type": "code",
127119
"execution_count": 4,
128-
"id": "05dc3921-1837-474e-9a6d-4555a94ddc30",
129120
"metadata": {},
130121
"outputs": [
131122
{
@@ -164,7 +155,6 @@
164155
{
165156
"cell_type": "code",
166157
"execution_count": 5,
167-
"id": "ad5a863e-fbbb-4caf-876d-374f3ca9b891",
168158
"metadata": {},
169159
"outputs": [
170160
{

doc/source/examples/QR.ipynb

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "e3a11293-4739-476e-a513-48a256d425a2",
65
"metadata": {},
76
"source": [
8-
"## **Ridge Quantile Regression**\n",
7+
"## Ridge Quantile Regression\n",
98
"\n",
109
"[![Slides](https://img.shields.io/badge/🦌-ReHLine-blueviolet)](https://rehline-python.readthedocs.io/en/latest/)\n",
1110
"\n",
@@ -24,7 +23,6 @@
2423
{
2524
"cell_type": "code",
2625
"execution_count": 1,
27-
"id": "b2dd4ce5-bc27-41a4-89ab-7920d393f377",
2826
"metadata": {},
2927
"outputs": [],
3028
"source": [
@@ -46,7 +44,6 @@
4644
{
4745
"cell_type": "code",
4846
"execution_count": 2,
49-
"id": "80129ee6-f886-4e27-a764-630f15826bca",
5047
"metadata": {},
5148
"outputs": [],
5249
"source": [
@@ -63,7 +60,6 @@
6360
{
6461
"cell_type": "code",
6562
"execution_count": 3,
66-
"id": "1d8b90e9-6af9-4856-9751-6fe6fbc7665c",
6763
"metadata": {},
6864
"outputs": [
6965
{
@@ -98,7 +94,7 @@
9894
]
9995
}
10096
],
101-
"metadata": {
97+
"metadata": {
10298
"colab": {
10399
"provenance": []
104100
},
@@ -112,4 +108,4 @@
112108
},
113109
"nbformat": 4,
114110
"nbformat_minor": 0
115-
}
111+
}

doc/source/examples/SVM.ipynb

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "markdown",
5-
"id": "fbcb401d-6ca6-4933-abd5-f8f504282416",
65
"metadata": {},
76
"source": [
8-
"# **SVM**\n",
7+
"# SVM\n",
98
"\n",
109
"[![Slides](https://img.shields.io/badge/🦌-ReHLine-blueviolet)](https://rehline-python.readthedocs.io/en/latest/)\n",
1110
"\n",
@@ -21,7 +20,6 @@
2120
{
2221
"cell_type": "code",
2322
"execution_count": 1,
24-
"id": "2dd1c096-e0df-492f-be63-8ac272007237",
2523
"metadata": {},
2624
"outputs": [],
2725
"source": [
@@ -42,7 +40,6 @@
4240
{
4341
"cell_type": "code",
4442
"execution_count": 2,
45-
"id": "aece9fbe-f9be-40ae-8179-b44849fb0fd3",
4643
"metadata": {},
4744
"outputs": [],
4845
"source": [
@@ -56,7 +53,6 @@
5653
{
5754
"cell_type": "code",
5855
"execution_count": 3,
59-
"id": "93719987-c6b3-4a9b-9b40-c35e5bf90ef0",
6056
"metadata": {},
6157
"outputs": [
6258
{

doc/source/index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ The proposed **ReHLine** solver has appealing exhibits appealing properties:
4848
* - **Super-Efficient**
4949
- The optimization algorithm has a provable **LINEAR** convergence rate, and the per-iteration computational complexity is **LINEAR** in the sample size.
5050

51+
✨ New Features: Scikit-Learn Compatible Estimators
52+
---------------------------------------------------
53+
54+
We are excited to introduce full scikit-learn compatibility! `ReHLine` now provides `plq_Ridge_Classifier` and `plq_Ridge_Regressor` estimators that integrate seamlessly with the entire scikit-learn ecosystem.
55+
56+
This means you can:
57+
- Drop `ReHLine` estimators directly into your existing scikit-learn `Pipeline`.
58+
- Perform robust hyperparameter tuning using `GridSearchCV`.
59+
- Use standard scikit-learn evaluation metrics and cross-validation tools.
60+
5161
🔨 Installation
5262
---------------
5363

doc/source/tutorials.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,16 @@ List of Tutorials
4141
- | `plqERM_Ridge <./autoapi/rehline/index.html#rehline.plqERM_Ridge>`_
4242
- | Empirical Risk Minimization (ERM) with a piecewise linear-quadratic (PLQ) objective with a ridge penalty.
4343

44-
* - `ReHLine: Scikit-learn Compatible Estimators Powered by ReHLine <./examples/Sklearn_Mixin.ipynb>`_
45-
- | `plqERM_Ridge <./autoapi/rehline/index.html#rehline.plq_Ridge_Classifier>`_
46-
- | `plqERM_Ridge <./autoapi/rehline/index.html#rehline.plq_Ridge_Regressor>`_
44+
* - `ReHLine: Scikit-learn Compatible Estimators <./tutorials/ReHLine_sklearn.rst>`_
45+
- | `plq_Ridge_Classifier <./autoapi/rehline/index.html#rehline.plq_Ridge_Classifier>`_ `plq_Ridge_Regressor <./autoapi/rehline/index.html#rehline.plq_Ridge_Regressor>`_
4746
- | Scikit-learn compatible estimators framework for empirical risk minimization problem.
4847

4948
* - `ReHLine: Ridge Composite Quantile Regression <./examples/CQR.ipynb>`_
5049
- | `CQR_Ridge <./autoapi/rehline/index.html#rehline.CQR_Ridge>`_
5150
- | Composite Quantile Regression (CQR) with a ridge penalty.
5251

5352
* - `ReHLine: Matrix Factorization <./tutorials/ReHLine_MF.rst>`_
54-
- | `plqMF_Ridge <./autoapi/rehline/index.html#rehline.plqERM_Ridge>`_
53+
- | `plqMF_Ridge <./autoapi/rehline/index.html#rehline.plqMF_Ridge>`_
5554
- | Matrix Factorization (MF) with a piecewise linear-quadratic (PLQ) objective with a ridge penalty.
5655

5756
.. toctree::
@@ -62,5 +61,6 @@ List of Tutorials
6261
./tutorials/ReHLine_ERM
6362
./tutorials/loss
6463
./tutorials/constraint
64+
./tutorials/ReHLine_sklearn
6565
./tutorials/warmstart
6666

0 commit comments

Comments
 (0)