diff --git a/mapie/__init__.py b/mapie/__init__.py index 5ec9939e1..35fd5b090 100644 --- a/mapie/__init__.py +++ b/mapie/__init__.py @@ -4,6 +4,7 @@ regression, utils, risk_control, + risk_control_draft, calibration, subsample, ) @@ -13,6 +14,7 @@ "regression", "classification", "risk_control", + "risk_control_draft", "calibration", "metrics", "utils", diff --git a/mapie/control_risk/ltt.py b/mapie/control_risk/ltt.py index 9b7d7e124..e19d3b849 100644 --- a/mapie/control_risk/ltt.py +++ b/mapie/control_risk/ltt.py @@ -9,11 +9,12 @@ def ltt_procedure( - r_hat: NDArray, - alpha_np: NDArray, + r_hat: NDArray[np.float32], + alpha_np: NDArray[np.float32], delta: Optional[float], - n_obs: int -) -> Tuple[List[List[Any]], NDArray]: + n_obs: int, + binary: bool = False, # TODO: maybe should pass p_values fonction instead +) -> Tuple[List[List[Any]], NDArray[np.float32]]: """ Apply the Learn-Then-Test procedure for risk control. Note that we will do a multiple test for ``r_hat`` that are @@ -63,13 +64,14 @@ def ltt_procedure( "Invalid delta: delta cannot be None while" + " controlling precision with LTT. " ) - p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np) + p_values = compute_hoeffdding_bentkus_p_value(r_hat, n_obs, alpha_np, binary) N = len(p_values) valid_index = [] for i in range(len(alpha_np)): l_index = np.where(p_values[:, i] <= delta/N)[0].tolist() valid_index.append(l_index) - return valid_index, p_values + return valid_index, p_values # TODO : p_values is not used, we could remove it + # Or return corrected p_values def find_lambda_control_star( diff --git a/mapie/control_risk/p_values.py b/mapie/control_risk/p_values.py index 2305c126f..d1a420a4c 100644 --- a/mapie/control_risk/p_values.py +++ b/mapie/control_risk/p_values.py @@ -8,10 +8,11 @@ def compute_hoeffdding_bentkus_p_value( - r_hat: NDArray, + r_hat: NDArray[np.float32], n_obs: int, - alpha: Union[float, NDArray] -) -> NDArray: + alpha: Union[float, NDArray[np.float32]], + binary: bool = False, +) -> NDArray[np.float32]: """ The method computes the p_values according to the Hoeffding_Bentkus inequality for each @@ -63,7 +64,7 @@ def compute_hoeffdding_bentkus_p_value( ) hoeffding_p_value = np.exp( -n_obs * _h1( - np.where( + np.where( # TODO : shouldn't we use np.minimum ? r_hat_repeat > alpha_repeat, alpha_repeat, r_hat_repeat @@ -71,10 +72,11 @@ def compute_hoeffdding_bentkus_p_value( alpha_repeat ) ) - bentkus_p_value = np.e * binom.cdf( + factor = 1 if binary else np.e + bentkus_p_value = factor * binom.cdf( np.ceil(n_obs * r_hat_repeat), n_obs, alpha_repeat ) - hb_p_value = np.where( + hb_p_value = np.where( # TODO : shouldn't we use np.minimum ? bentkus_p_value > hoeffding_p_value, hoeffding_p_value, bentkus_p_value @@ -83,9 +85,8 @@ def compute_hoeffdding_bentkus_p_value( def _h1( - r_hats: NDArray, - alphas: NDArray -) -> NDArray: + r_hats: NDArray[np.float32], alphas: NDArray[np.float32] +) -> NDArray[np.float32]: """ This function allow us to compute the tighter version of hoeffding inequality. @@ -114,6 +115,11 @@ def _h1( ------- NDArray of shape a(n_lambdas, n_alpha). """ - elt1 = r_hats * np.log(r_hats/alphas) - elt2 = (1-r_hats) * np.log((1-r_hats)/(1-alphas)) + elt1 = np.zeros_like(r_hats, dtype=float) + + # Compute only where r_hats != 0 to avoid log(0) + # TODO: check Angelopoulos implementation + mask = r_hats != 0 + elt1[mask] = r_hats[mask] * np.log(r_hats[mask] / alphas[mask]) + elt2 = (1 - r_hats) * np.log((1 - r_hats) / (1 - alphas)) return elt1 + elt2 diff --git a/mapie/control_risk/risk_control_theoretical_tests_proto.ipynb b/mapie/control_risk/risk_control_theoretical_tests_proto.ipynb new file mode 100644 index 000000000..3bf9e9d47 --- /dev/null +++ b/mapie/control_risk/risk_control_theoretical_tests_proto.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2ae91ff6-9706-41f1-bfdb-c39f5f2bfb9d", + "metadata": { + "id": "2ae91ff6-9706-41f1-bfdb-c39f5f2bfb9d" + }, + "source": [ + "# Binary classification risk control - Theoretical tests prototype" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1c564c4f-1e63-4c2f-bdd5-d84029c1473a", + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f1c2e64a", + "metadata": { + "id": "f1c2e64a" + }, + "outputs": [], + "source": [ + "from sklearn.datasets import make_classification\n", + "import numpy as np\n", + "import itertools\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.dummy import DummyClassifier\n", + "\n", + "from mapie.risk_control_draft import BinaryClassificationController" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6c0b5e81-81f1-4688-a4d7-57c6adba44b4", + "metadata": {}, + "outputs": [], + "source": [ + "class RandomClassifier:\n", + " def __init__(self, seed=42, threshold=0.5):\n", + " self.seed = seed\n", + " self.threshold = threshold\n", + "\n", + " def _get_prob(self, x):\n", + " local_seed = hash((x, self.seed)) % (2**32)\n", + " rng = np.random.RandomState(local_seed)\n", + " return np.round(rng.rand(), 2)\n", + "\n", + " def predict_proba(self, X):\n", + " probs = np.array([self._get_prob(x) for x in X])\n", + " return np.vstack([1 - probs, probs]).T\n", + "\n", + " def predict(self, X):\n", + " probs = self.predict_proba(X)[:, 1]\n", + " return (probs >= self.threshold).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "8da2839a-3f14-4054-acf1-c60b1d02b7d0", + "metadata": { + "id": "8da2839a-3f14-4054-acf1-c60b1d02b7d0" + }, + "outputs": [], + "source": [ + "N = 100 # size of the calibration set\n", + "p = 0.5 # proportion of positives in the data generator\n", + "metric = \"precision\"\n", + "target_level = 0.6\n", + "predict_params = np.linspace(0, 0.99, 100)\n", + "confidence_level = 0.7\n", + "\n", + "n_repeats = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "03383363-b86d-4593-adf4-80215b6f1dcf", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 376 + }, + "id": "03383363-b86d-4593-adf4-80215b6f1dcf", + "outputId": "b15146cf-518e-4a93-8128-6c1865a08b01" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean number of valid thresholds found per iteration: 1.37\n", + "Proportion of times LTT finds no valid threshold: 0.63\n", + "Proportion of times the risk is not controlled: 0.37\n", + "Risk level: 0.30000000000000004\n", + "Risk not controlled\n" + ] + } + ], + "source": [ + "clf = RandomClassifier()\n", + "nb_errors = 0 # number of iterations where the risk is not controlled (i.e., not all the valid thresholds found by LTT are actually valid)\n", + "no_valid_params = 0 # number of iterations where LTT finds no valid threshold\n", + "nb_valid_params = 0 # total number of valid thresholds LTT finds over all iterations\n", + "\n", + "for _ in range(n_repeats):\n", + "\n", + " X_calibrate, y_calibrate = make_classification(\n", + " n_samples=N,\n", + " n_features=1,\n", + " n_informative=1,\n", + " n_redundant=0,\n", + " n_repeated=0,\n", + " n_classes=2,\n", + " n_clusters_per_class=1,\n", + " weights=[1 - p, p],\n", + " flip_y=0,\n", + " random_state=None\n", + " )\n", + " X_calibrate = X_calibrate.squeeze()\n", + "\n", + " controller = BinaryClassificationController(\n", + " fitted_binary_classifier=clf,\n", + " metric=metric,\n", + " target_level=target_level,\n", + " confidence_level=confidence_level,\n", + " )\n", + " controller._thresholds = predict_params\n", + " controller.calibrate(X_calibrate, y_calibrate)\n", + " valid_parameters = controller.valid_thresholds\n", + "\n", + " nb_valid_params += len(valid_parameters)\n", + "\n", + " if len(valid_parameters) == 0:\n", + " no_valid_params += 1\n", + "\n", + " if metric == \"precision\" or metric == \"accuracy\": # vérifier que l'accuracy ne dépend pas de p\n", + " if target_level > p and len(valid_parameters) >= 1:\n", + " nb_errors += 1\n", + " elif metric == \"recall\":\n", + " if any(x < 0 or x > np.round(1-target_level, 2) for x in valid_parameters) and len(valid_parameters) >= 1:\n", + " nb_errors += 1\n", + "\n", + "print(f\"Mean number of valid thresholds found per iteration: {nb_valid_params/n_repeats}\")\n", + "print(f\"Proportion of times LTT finds no valid threshold: {no_valid_params/n_repeats}\")\n", + "print(f\"Proportion of times the risk is not controlled: {nb_errors/n_repeats}\")\n", + "print(f\"Risk level: {1-confidence_level}\")\n", + "\n", + "if nb_errors/n_repeats <= 1 - confidence_level:\n", + " print(\"Risk controlled\")\n", + "else:\n", + " print(\"Risk not controlled\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "104c7232-c8b1-432e-94dd-3f65e730483f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mapie/control_risk/risk_control_theoretical_tests_target_api.ipynb b/mapie/control_risk/risk_control_theoretical_tests_target_api.ipynb new file mode 100644 index 000000000..89836e3cb --- /dev/null +++ b/mapie/control_risk/risk_control_theoretical_tests_target_api.ipynb @@ -0,0 +1,255 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2ae91ff6-9706-41f1-bfdb-c39f5f2bfb9d", + "metadata": {}, + "source": [ + "# Binary classification risk control - Theoretical tests - Target API" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "f1c2e64a", + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'mapie'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[41]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mitertools\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmapie\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbinary_risk_control_target_api\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m BinaryClassificationRisk, BinaryClassificationController\n", + "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'mapie'" + ] + } + ], + "source": [ + "import numpy as np\n", + "import itertools\n", + "\n", + "from mapie.binary_risk_control_target_api import BinaryClassificationRisk, BinaryClassificationController" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "1fef2dc6-b5b1-43bc-ad05-e5e1fe7844bd", + "metadata": {}, + "outputs": [], + "source": [ + "class RandomClassifier:\n", + " def __init__(self, seed=42, threshold=0.5):\n", + " self.random_state = np.random.RandomState(seed)\n", + " self.threshold = threshold\n", + "\n", + " def predict_proba(self, X):\n", + " probs = np.round(self.random_state.rand(len(X)), 2)\n", + " return np.vstack([1 - probs, probs]).T\n", + "\n", + " def predict(self, X):\n", + " probs = self.predict_proba(X)[:, 1]\n", + " return (probs >= self.threshold).astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "837bbfd2-0e30-4b27-99e6-71a6ee5c21d7", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'BinaryClassificationRisk' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[43]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m precision = \u001b[43mBinaryClassificationRisk\u001b[49m(\n\u001b[32m 2\u001b[39m occurrence=\u001b[38;5;28;01mlambda\u001b[39;00m y_true, y_pred: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m y_pred == \u001b[32m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mint\u001b[39m(y_pred == y_true),\n\u001b[32m 3\u001b[39m higher_is_better=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 4\u001b[39m binary=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 5\u001b[39m )\n\u001b[32m 7\u001b[39m false_discovery_rate = precision.transform_to_opposite()\n\u001b[32m 9\u001b[39m recall = BinaryClassificationRisk(\n\u001b[32m 10\u001b[39m occurrence=\u001b[38;5;28;01mlambda\u001b[39;00m y_true, y_pred: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m y_true == \u001b[32m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mint\u001b[39m(y_pred == y_true),\n\u001b[32m 11\u001b[39m higher_is_better=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 12\u001b[39m binary=\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 13\u001b[39m )\n", + "\u001b[31mNameError\u001b[39m: name 'BinaryClassificationRisk' is not defined" + ] + } + ], + "source": [ + "precision = BinaryClassificationRisk(\n", + " occurrence=lambda y_true, y_pred: None if y_pred == 0 else int(y_pred == y_true),\n", + " higher_is_better=True,\n", + " binary=True,\n", + ")\n", + "\n", + "false_discovery_rate = precision.transform_to_opposite()\n", + "\n", + "recall = BinaryClassificationRisk(\n", + " occurrence=lambda y_true, y_pred: None if y_true == 0 else int(y_pred == y_true),\n", + " higher_is_better=True,\n", + " binary=True,\n", + ")\n", + "\n", + "false_negative_rate = recall.transform_to_opposite()\n", + "\n", + "accuracy = BinaryClassificationRisk(\n", + " occurrence=lambda y_true, y_pred: int(y_pred == y_true),\n", + " higher_is_better=True,\n", + " binary=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "8da2839a-3f14-4054-acf1-c60b1d02b7d0", + "metadata": {}, + "outputs": [], + "source": [ + "N_values = [1, 100] # size of the calibration set\n", + "p = 0.5 # proportion of positives in the calibration set\n", + "metrics = ['recall', 'precision']\n", + "target_levels = [0.2, 0.8]\n", + "predict_params_sets = [np.linspace(0, 0.99, 100), [0.5]]\n", + "confidence_levels = [0.1, 0.9]\n", + "\n", + "n_repeats = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "b415f516-7782-4d76-b304-3d630af365fc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Combination 1:\n", + "N = 1\n", + "Metric = recall\n", + "Target level = 0.2\n", + "Predict params = [0. 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.11 0.12 0.13\n", + " 0.14 0.15 0.16 0.17 0.18 0.19 0.2 0.21 0.22 0.23 0.24 0.25 0.26 0.27\n", + " 0.28 0.29 0.3 0.31 0.32 0.33 0.34 0.35 0.36 0.37 0.38 0.39 0.4 0.41\n", + " 0.42 0.43 0.44 0.45 0.46 0.47 0.48 0.49 0.5 0.51 0.52 0.53 0.54 0.55\n", + " 0.56 0.57 0.58 0.59 0.6 0.61 0.62 0.63 0.64 0.65 0.66 0.67 0.68 0.69\n", + " 0.7 0.71 0.72 0.73 0.74 0.75 0.76 0.77 0.78 0.79 0.8 0.81 0.82 0.83\n", + " 0.84 0.85 0.86 0.87 0.88 0.89 0.9 0.91 0.92 0.93 0.94 0.95 0.96 0.97\n", + " 0.98 0.99]\n", + "Confidence Level = 0.1\n" + ] + } + ], + "source": [ + "combinations = list(itertools.product(N_values, metrics, target_levels, predict_params_sets, confidence_levels))\n", + "\n", + "# for i, combination in enumerate(combinations[0], 1):\n", + "i, combination = 1, combinations[0]\n", + "\n", + "N, metric, target_level, predict_params, confidence_level = combination\n", + "print(f\"Combination {i}:\")\n", + "print(f\"N = {N}\")\n", + "print(f\"Metric = {metric}\")\n", + "print(f\"Target level = {target_level}\")\n", + "print(f\"Predict params = {predict_params}\")\n", + "print(f\"Confidence Level = {confidence_level}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "649f5ef0-3c5e-410a-949e-e7aa3142d5fc", + "metadata": {}, + "outputs": [], + "source": [ + "X_calibrate = list(range(1, N+1))\n", + "y_calibrate = [1] * int(p*N) + [0] * (N - int(p*N))\n", + "np.random.seed(42)\n", + "np.random.shuffle(y_calibrate)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "03383363-b86d-4593-adf4-80215b6f1dcf", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'recall' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[47]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m 5\u001b[39m theoretical_value = p\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m metric == \u001b[33m'\u001b[39m\u001b[33mrecall\u001b[39m\u001b[33m'\u001b[39m:\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m risk = \u001b[43mrecall\u001b[49m\n\u001b[32m 8\u001b[39m theoretical_value = \u001b[32m1\u001b[39m - clf.threshold\n\u001b[32m 10\u001b[39m all_valid_parameters = []\n", + "\u001b[31mNameError\u001b[39m: name 'recall' is not defined" + ] + } + ], + "source": [ + "clf = RandomClassifier()\n", + "\n", + "if metric == 'precision':\n", + " risk = precision\n", + " theoretical_value = p\n", + "elif metric == 'recall':\n", + " risk = recall\n", + " theoretical_value = 1 - clf.threshold\n", + "\n", + "all_valid_parameters = []\n", + "\n", + "for _ in range(n_repeats):\n", + " \n", + " controller = BinaryClassificationController(\n", + " predict_function=clf.predict_proba,\n", + " risk=risk,\n", + " target_level=target_level,\n", + " confidence_level=confidence_level,\n", + " best_predict_param_choice=\"auto\",\n", + " )\n", + " controller.calibrate(X_calibrate, y_calibrate)\n", + " \n", + " valid_parameters = controller.valid_thresholds\n", + " all_valid_parameters.append(valid_parameters)\n", + "\n", + "if metric == 'precision':\n", + " nb_actual_valid = sum(1 for x in all_valid_parameters if p >= theoretical_value)\n", + "elif metric == 'recall':\n", + " nb_actual_valid = sum(1 for x in all_valid_parameters if x <= (1 - theoretical_value))\n", + "\n", + "if nb_actual_valid/len(all_valid_parameters) >= confidence_level:\n", + " print(\"Risk controlled\")\n", + "else:\n", + " print(\"Risk not controlled\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42f85519-17ae-46c7-bfdc-e26c1d87017a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mapie/risk_control_draft.py b/mapie/risk_control_draft.py new file mode 100644 index 000000000..334f157df --- /dev/null +++ b/mapie/risk_control_draft.py @@ -0,0 +1,243 @@ +import warnings +from typing import Any, Optional, Union + +import numpy as np +from numpy._typing import ArrayLike, NDArray +from sklearn.utils import check_random_state + +from mapie.control_risk.ltt import ltt_procedure +from mapie.utils import _check_n_jobs, _check_verbose + +# General TODOs: +# TODO: maybe use type float instead of float32? +# TODO : in calibration and prediction, +# use _transform_pred_proba or a function adapted to binary +# to get the probabilities depending on the classifier + + +class BinaryClassificationController: # pragma: no cover + # TODO : test that this is working with a sklearn pipeline + # TODO : test that this is working with a pandas dataframes + """ + Controller for the calibration of our binary classifier. + + Parameters + ---------- + fitted_binary_classifier: Any + Any object that provides a `predict_proba` method. + + metric: str + The performance metric we want to control (ex: "precision") + + target_level: float + The target performance level we want to achieve (ex: 0.8) + + confidence_level: float + The maximum acceptable probability of the precision falling below the + target precision level (ex: 0.8) + + Attributes + ---------- + precision_per_threshold: NDArray + Precision of the binary classifier on the calibration set for each + threshold from self._thresholds. + + valid_threshold: NDArray + Thresholds that meet the target precision with the desired confidence. + + best_threshold: float + Valid threshold that maximizes the recall, i.e. the smallest valid + threshold. + """ + + def __init__( + self, + fitted_binary_classifier: Any, + metric: str, + target_level: float, + confidence_level: float = 0.9, + n_jobs: Optional[int] = None, + random_state: Optional[Union[int, np.random.RandomState]] = None, + verbose: int = 0 + ): + _check_n_jobs(n_jobs) + _check_verbose(verbose) + check_random_state(random_state) + + self._classifier = fitted_binary_classifier + self._alpha = 1 - target_level + self._delta = 1 - confidence_level + self._n_jobs = n_jobs # TODO : use this in the class or delete + self._random_state = random_state # TODO : use this in the class or delete + self._verbose = verbose # TODO : use this in the class or delete + self._metric = metric + + self._thresholds: NDArray[np.float32] = np.arange(0, 1, 0.01) + # TODO: add a _is_calibrated attribute to check at prediction time + + self.valid_thresholds: Optional[NDArray[np.float32]] = None + self.best_threshold: Optional[float] = None + + def calibrate(self, X_calibrate: ArrayLike, y_calibrate: ArrayLike) -> None: + """ + Find the threshold that statistically guarantees the desired precision + level while maximizing the recall. + + Parameters + ---------- + X_calibrate: ArrayLike + Features of the calibration set. + + y_calibrate: ArrayLike + True labels of the calibration set. + + Raises + ------ + ValueError + If no thresholds that meet the target precision with the desired + confidence level are found. + """ + # TODO: Make sure this works with sklearn train_test_split/Series + y_calibrate_ = np.asarray(y_calibrate) + + predictions_proba = self._classifier.predict_proba(X_calibrate)[:, 1] + + if self._metric == "precision": + risk_per_threshold = 1 - self._compute_precision( + predictions_proba, y_calibrate_ + ) + valid_thresholds_index, _ = ltt_procedure( + risk_per_threshold, + np.array([self._alpha]), + self._delta, + int(len(y_calibrate_)/2), + True, + ) + + elif self._metric == "recall": + risk_per_threshold = 1 - self._compute_recall( + predictions_proba, y_calibrate_ + ) + valid_thresholds_index, _ = ltt_procedure( + risk_per_threshold, + np.array([self._alpha]), + self._delta, + int(len(y_calibrate_)/2), + True, + ) + + elif self._metric == "accuracy": + risk_per_threshold = 1 - self._compute_accuracy( + predictions_proba, y_calibrate_ + ) + valid_thresholds_index, _ = ltt_procedure( + risk_per_threshold, + np.array([self._alpha]), + self._delta, + len(y_calibrate_), + True, + ) + + self.valid_thresholds = self._thresholds[valid_thresholds_index[0]] + if len(self.valid_thresholds) == 0: + warnings.warn("No valid thresholds found", UserWarning) + + else: + # Minimum in case of precision control only + self.best_threshold = min(self.valid_thresholds) + + def predict(self, X_test: ArrayLike) -> NDArray: + """ + Predict binary labels on the test set, using the best threshold found + during calibration. + + Parameters + ---------- + X_test: ArrayLike + Features of the test set. + + Returns + ------- + ArrayLike + Predicted labels (0 or 1) for each sample in the test set. + """ + predictions_proba = self._classifier.predict_proba(X_test)[:, 1] + return (predictions_proba >= self.best_threshold).astype(int) + + def _compute_precision( # TODO: use sklearn or MAPIE ? + self, predictions_proba: NDArray[np.float32], y_cal: NDArray[np.float32] + ) -> NDArray[np.float32]: + """ + Compute the precision for each threshold. + """ + predictions_per_threshold = ( + predictions_proba[:, np.newaxis] >= self._thresholds + ).astype(int) + + true_positives = np.sum( + (predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 1), + axis=0, + ) + false_positives = np.sum( + (predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 0), + axis=0, + ) + + positive_predictions = true_positives + false_positives + + # Avoid division by zero + precision_per_threshold = np.zeros_like(self._thresholds, dtype=float) + nonzero_mask = positive_predictions > 0 + precision_per_threshold[nonzero_mask] = ( + true_positives[nonzero_mask] / positive_predictions[nonzero_mask] + ) + + return precision_per_threshold + + def _compute_recall( + self, predictions_proba: NDArray[np.float32], y_cal: NDArray[np.float32] + ) -> NDArray[np.float32]: + """ + Compute the recall for each threshold. + """ + predictions_per_threshold = ( + predictions_proba[:, np.newaxis] >= self._thresholds + ).astype(int) + + true_positives = np.sum( + (predictions_per_threshold == 1) & (y_cal[:, np.newaxis] == 1), + axis=0, + ) + false_negatives = np.sum( + (predictions_per_threshold == 0) & (y_cal[:, np.newaxis] == 1), + axis=0, + ) + + actual_positives = true_positives + false_negatives + + # Avoid division by zero + recall_per_threshold = np.ones_like(self._thresholds, dtype=float) + nonzero_mask = actual_positives > 0 + recall_per_threshold[nonzero_mask] = ( + true_positives[nonzero_mask] / actual_positives[nonzero_mask] + ) + + return recall_per_threshold + + def _compute_accuracy( + self, predictions_proba: NDArray[np.float32], y_cal: NDArray[np.float32] + ) -> NDArray[np.float32]: + """ + Compute the accuracy for each threshold. + """ + predictions_per_threshold = ( + predictions_proba[:, np.newaxis] >= self._thresholds + ).astype(int) + + correct_predictions = ( + predictions_per_threshold == y_cal[:, np.newaxis] + ).astype(int) + + accuracy_per_threshold = np.mean(correct_predictions, axis=0) + + return accuracy_per_threshold