Skip to content

Commit 779ce04

Browse files
committed
Update binary_logistic_regression_manual.ipynb
ROC curve and other curves based on shifted threshold
1 parent 0a1857f commit 779ce04

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

binary_logistic_regression_manual.ipynb

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"\n",
5555
"from sklearn.metrics import confusion_matrix, precision_recall_fscore_support\n",
5656
"from sklearn.metrics import balanced_accuracy_score, accuracy_score\n",
57+
"# from sklearn.metrics import RocCurveDisplay\n",
5758
"\n",
5859
"from util_binary_logistic_regression import toy_data, init_weights\n",
5960
"from util_binary_logistic_regression import my_sigmoid, predict_class\n",
@@ -637,6 +638,112 @@
637638
"print('accuray', accuracy_test)"
638639
]
639640
},
641+
{
642+
"cell_type": "markdown",
643+
"id": "5c265037",
644+
"metadata": {},
645+
"source": [
646+
"### Curves Based on True/False Positive/Negative Rates\n",
647+
"\n",
648+
"- model robustness is evaluated by checking different prediction thresholds\n",
649+
"- for each prediction threshold four entries in the confusion matrix can be calculated on a unseen test data set\n",
650+
"- we can set up (at least) four characteristic curves\n",
651+
"- a model that is fair w.r.t. binary classes should have type-I $\\approx$ type-II error, which is the same as requiring TPR $\\approx$ TNR\n",
652+
"- this information is conveniently and directly deduced by inspecting not only one but at least two of those curves (for example the first column or the second column of the plot below)\n",
653+
"- very often the **receiver operating characteristic** (ROC) curve is discussed (left, top subplot)\n",
654+
"- the ROC tells us about type-I error vs. test power"
655+
]
656+
},
657+
{
658+
"cell_type": "code",
659+
"execution_count": null,
660+
"id": "caedb2f6",
661+
"metadata": {},
662+
"outputs": [],
663+
"source": [
664+
"# RocCurveDisplay.from_predictions(\n",
665+
"# Y_test_man[0, :],\n",
666+
"# my_sigmoid(np.dot(w.T, X_test_man) + b)[0, :])\n",
667+
"\n",
668+
"plt.figure(figsize=(8, 8))\n",
669+
"N = 1000\n",
670+
"predict_threshold = np.linspace(0, 1, N, endpoint=False)\n",
671+
"TPR = np.zeros_like(predict_threshold)\n",
672+
"FPR = np.zeros_like(predict_threshold)\n",
673+
"TNR = np.zeros_like(predict_threshold)\n",
674+
"FNR = np.zeros_like(predict_threshold)\n",
675+
"for idx, val in enumerate(predict_threshold):\n",
676+
" Y_pred = (my_sigmoid(np.dot(w.T, X_test_man) + b) >= val) * 1\n",
677+
" cm = confusion_matrix(Y_test_man[0, :], Y_pred[0, :])\n",
678+
" TN, FP = cm[0, 0], cm[0, 1]\n",
679+
" FN, TP = cm[1, 0], cm[1, 1]\n",
680+
" FPR[idx] = FP / (TN+FP) # type I error\n",
681+
" TPR[idx] = TP / (FN+TP) # recall, sensitivity, test power\n",
682+
" FNR[idx] = FN / (FN+TP) # type II error\n",
683+
" TNR[idx] = TN / (TN+FP) # specificity, selectivity\n",
684+
" if idx == N//2: # indicate 0.5 probability decision point\n",
685+
" plt.subplot(2, 2, 1)\n",
686+
" plt.text(FPR[idx], TPR[idx], '. %0.2f' % val)\n",
687+
" plt.subplot(2, 2, 2)\n",
688+
" plt.text(FPR[idx], FNR[idx], '. %0.2f' % val)\n",
689+
" plt.subplot(2, 2, 3)\n",
690+
" plt.text(FNR[idx], TNR[idx], '. %0.2f' % val)\n",
691+
" plt.subplot(2, 2, 4)\n",
692+
" plt.text(TPR[idx], TNR[idx], '. %0.2f' % val)\n",
693+
"\n",
694+
"# receiver operating characteristic (ROC) curve:\n",
695+
"plt.subplot(2, 2, 1)\n",
696+
"plt.plot(FPR, TPR, lw=2) # TN ok, FP ok, FN ok, TP ok\n",
697+
"plt.plot(0.05, 0.95, 'C3x'),\n",
698+
"plt.plot([0, 1], [0, 1])\n",
699+
"plt.text(0.05, 0.7, 'ROC curve')\n",
700+
"plt.xlabel('FPR = type I error')\n",
701+
"plt.ylabel('TPR = recall = sensitivity = power')\n",
702+
"plt.grid(True)\n",
703+
"plt.axis([0, 1, 0, 1])\n",
704+
"\n",
705+
"plt.subplot(2, 2, 2)\n",
706+
"plt.plot(FPR, FNR, lw=2) # TN ok, FP ok, FN ok, TP ok\n",
707+
"plt.plot(0.05, 0.05, 'C3x')\n",
708+
"plt.plot([0, 1], [1, 0])\n",
709+
"plt.xlabel('FPR = type I error')\n",
710+
"plt.ylabel('FNR = type II error')\n",
711+
"plt.grid(True)\n",
712+
"plt.axis([0, 1, 0, 1])\n",
713+
"\n",
714+
"plt.subplot(2, 2, 3)\n",
715+
"plt.plot(FNR, TNR, lw=2) # TN ok, FP ok, FN ok, TP ok\n",
716+
"plt.plot(0.05, 0.95, 'C3x')\n",
717+
"plt.plot([0, 1], [0, 1])\n",
718+
"plt.xlabel('FNR = type II error')\n",
719+
"plt.ylabel('TNR = specificity = selectivity')\n",
720+
"plt.grid(True)\n",
721+
"plt.axis([0, 1, 0, 1])\n",
722+
"\n",
723+
"plt.subplot(2, 2, 4)\n",
724+
"plt.plot(TPR, TNR, lw=2) # TN ok, FP ok, FN ok, TP ok\n",
725+
"plt.plot(0.95, 0.95, 'C3x')\n",
726+
"plt.plot([0, 1], [1, 0])\n",
727+
"plt.xlabel('TPR = recall = sensitivity = power')\n",
728+
"plt.ylabel('TNR = specificity = selectivity')\n",
729+
"plt.grid(True)\n",
730+
"plt.axis([0, 1, 0, 1])"
731+
]
732+
},
733+
{
734+
"cell_type": "markdown",
735+
"id": "c386ec3c",
736+
"metadata": {},
737+
"source": [
738+
"We see that the model is highly balanced w.r.t. the {0,1}-class predictions.\n",
739+
"\n",
740+
"The model performs with a type-I error less than 5% and with a type-II error less than 5%.\n",
741+
"\n",
742+
"The model performs with a true positive rate (TPR) larger than 95% and with a true negative rate (TNR) larger than 95%.\n",
743+
"\n",
744+
"In medical applications (cf. COVID-19 testing) type-I and type-II percentages are typically even smaller; and thus TPR and TNR percentages are typically even larger."
745+
]
746+
},
640747
{
641748
"cell_type": "markdown",
642749
"id": "651b1eff",

0 commit comments

Comments
 (0)