Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mapie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
regression,
utils,
risk_control,
risk_control_draft,
calibration,
subsample,
)
Expand All @@ -13,6 +14,7 @@
"regression",
"classification",
"risk_control",
"risk_control_draft",
"calibration",
"metrics",
"utils",
Expand Down
14 changes: 8 additions & 6 deletions mapie/control_risk/ltt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@


def ltt_procedure(
r_hat: NDArray,
alpha_np: NDArray,
r_hat: NDArray[np.float32],
alpha_np: NDArray[np.float32],
delta: Optional[float],
n_obs: int
) -> Tuple[List[List[Any]], NDArray]:
n_obs: int,
binary: bool = False, # TODO: maybe should pass p_values fonction instead
) -> Tuple[List[List[Any]], NDArray[np.float32]]:
"""
Apply the Learn-Then-Test procedure for risk control.
Note that we will do a multiple test for ``r_hat`` that are
Expand Down Expand Up @@ -63,13 +64,14 @@ def ltt_procedure(
"Invalid delta: delta cannot be None while"
+ " controlling precision with LTT. "
)
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np)
p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np, binary)
N = len(p_values)
valid_index = []
for i in range(len(alpha_np)):
l_index = np.where(p_values[:, i] <= delta/N)[0].tolist()
valid_index.append(l_index)
return valid_index, p_values
return valid_index, p_values # TODO : p_values is not used, we could remove it
# Or return corrected p_values


def find_lambda_control_star(
Expand Down
28 changes: 17 additions & 11 deletions mapie/control_risk/p_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@


def compute_hoeffdding_bentkus_p_value(
r_hat: NDArray,
r_hat: NDArray[np.float32],
n_obs: int,
alpha: Union[float, NDArray]
) -> NDArray:
alpha: Union[float, NDArray[np.float32]],
binary: bool = False,
) -> NDArray[np.float32]:
"""
The method computes the p_values according to
the Hoeffding_Bentkus inequality for each
Expand Down Expand Up @@ -63,18 +64,19 @@ def compute_hoeffdding_bentkus_p_value(
)
hoeffding_p_value = np.exp(
-n_obs * _h1(
np.where(
np.where( # TODO : shouldn't we use np.minimum ?
r_hat_repeat > alpha_repeat,
alpha_repeat,
r_hat_repeat
),
alpha_repeat
)
)
bentkus_p_value = np.e * binom.cdf(
factor = 1 if binary else np.e
bentkus_p_value = factor * binom.cdf(
np.ceil(n_obs * r_hat_repeat), n_obs, alpha_repeat
)
hb_p_value = np.where(
hb_p_value = np.where( # TODO : shouldn't we use np.minimum ?
bentkus_p_value > hoeffding_p_value,
hoeffding_p_value,
bentkus_p_value
Expand All @@ -83,9 +85,8 @@ def compute_hoeffdding_bentkus_p_value(


def _h1(
r_hats: NDArray,
alphas: NDArray
) -> NDArray:
r_hats: NDArray[np.float32], alphas: NDArray[np.float32]
) -> NDArray[np.float32]:
"""
This function allow us to compute
the tighter version of hoeffding inequality.
Expand Down Expand Up @@ -114,6 +115,11 @@ def _h1(
-------
NDArray of shape a(n_lambdas, n_alpha).
"""
elt1 = r_hats * np.log(r_hats/alphas)
elt2 = (1-r_hats) * np.log((1-r_hats)/(1-alphas))
elt1 = np.zeros_like(r_hats, dtype=float)

# Compute only where r_hats != 0 to avoid log(0)
# TODO: check Angelopoulos implementation
mask = r_hats != 0
elt1[mask] = r_hats[mask] * np.log(r_hats[mask] / alphas[mask])
elt2 = (1 - r_hats) * np.log((1 - r_hats) / (1 - alphas))
return elt1 + elt2
200 changes: 200 additions & 0 deletions mapie/control_risk/risk_control_theoretical_tests_proto.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2ae91ff6-9706-41f1-bfdb-c39f5f2bfb9d",
"metadata": {
"id": "2ae91ff6-9706-41f1-bfdb-c39f5f2bfb9d"
},
"source": [
"# Binary classification risk control - Theoretical tests prototype"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1c564c4f-1e63-4c2f-bdd5-d84029c1473a",
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1c2e64a",
"metadata": {
"id": "f1c2e64a"
},
"outputs": [],
"source": [
"from sklearn.datasets import make_classification\n",
"import numpy as np\n",
"import itertools\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.dummy import DummyClassifier\n",
"\n",
"from mapie.risk_control_draft import BinaryClassificationController"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6c0b5e81-81f1-4688-a4d7-57c6adba44b4",
"metadata": {},
"outputs": [],
"source": [
"class RandomClassifier:\n",
" def __init__(self, seed=42, threshold=0.5):\n",
" self.seed = seed\n",
" self.threshold = threshold\n",
"\n",
" def _get_prob(self, x):\n",
" local_seed = hash((x, self.seed)) % (2**32)\n",
" rng = np.random.RandomState(local_seed)\n",
" return np.round(rng.rand(), 2)\n",
"\n",
" def predict_proba(self, X):\n",
" probs = np.array([self._get_prob(x) for x in X])\n",
" return np.vstack([1 - probs, probs]).T\n",
"\n",
" def predict(self, X):\n",
" probs = self.predict_proba(X)[:, 1]\n",
" return (probs >= self.threshold).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "8da2839a-3f14-4054-acf1-c60b1d02b7d0",
"metadata": {
"id": "8da2839a-3f14-4054-acf1-c60b1d02b7d0"
},
"outputs": [],
"source": [
"N = 100 # size of the calibration set\n",
"p = 0.5 # proportion of positives in the data generator\n",
"metric = \"precision\"\n",
"target_level = 0.6\n",
"predict_params = np.linspace(0, 0.99, 100)\n",
"confidence_level = 0.7\n",
"\n",
"n_repeats = 100"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "03383363-b86d-4593-adf4-80215b6f1dcf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 376
},
"id": "03383363-b86d-4593-adf4-80215b6f1dcf",
"outputId": "b15146cf-518e-4a93-8128-6c1865a08b01"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean number of valid thresholds found per iteration: 1.37\n",
"Proportion of times LTT finds no valid threshold: 0.63\n",
"Proportion of times the risk is not controlled: 0.37\n",
"Risk level: 0.30000000000000004\n",
"Risk not controlled\n"
]
}
],
"source": [
"clf = RandomClassifier()\n",
"nb_errors = 0 # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)\n",
"no_valid_params = 0 # number of iterations where LTT finds no valid threshold\n",
"nb_valid_params = 0 # total number of valid thresholds LTT finds over all iterations\n",
"\n",
"for _ in range(n_repeats):\n",
"\n",
" X_calibrate, y_calibrate = make_classification(\n",
" n_samples=N,\n",
" n_features=1,\n",
" n_informative=1,\n",
" n_redundant=0,\n",
" n_repeated=0,\n",
" n_classes=2,\n",
" n_clusters_per_class=1,\n",
" weights=[1 - p, p],\n",
" flip_y=0,\n",
" random_state=None\n",
" )\n",
" X_calibrate = X_calibrate.squeeze()\n",
"\n",
" controller = BinaryClassificationController(\n",
" fitted_binary_classifier=clf,\n",
" metric=metric,\n",
" target_level=target_level,\n",
" confidence_level=confidence_level,\n",
" )\n",
" controller._thresholds = predict_params\n",
" controller.calibrate(X_calibrate, y_calibrate)\n",
" valid_parameters = controller.valid_thresholds\n",
"\n",
" nb_valid_params += len(valid_parameters)\n",
"\n",
" if len(valid_parameters) == 0:\n",
" no_valid_params += 1\n",
"\n",
" if metric == \"precision\" or metric == \"accuracy\": # vérifier que l'accuracy ne dépend pas de p\n",
" if target_level > p and len(valid_parameters) >= 1:\n",
" nb_errors += 1\n",
" elif metric == \"recall\":\n",
" if any(x < 0 or x > np.round(1-target_level, 2) for x in valid_parameters) and len(valid_parameters) >= 1:\n",
" nb_errors += 1\n",
"\n",
"print(f\"Mean number of valid thresholds found per iteration: {nb_valid_params/n_repeats}\")\n",
"print(f\"Proportion of times LTT finds no valid threshold: {no_valid_params/n_repeats}\")\n",
"print(f\"Proportion of times the risk is not controlled: {nb_errors/n_repeats}\")\n",
"print(f\"Risk level: {1-confidence_level}\")\n",
"\n",
"if nb_errors/n_repeats <= 1 - confidence_level:\n",
" print(\"Risk controlled\")\n",
"else:\n",
" print(\"Risk not controlled\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "104c7232-c8b1-432e-94dd-3f65e730483f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading
Loading