55from sklearn .metrics import roc_curve , precision_recall_curve , auc
66from sklearn .preprocessing import LabelEncoder
77from typing import Any , Iterable
8- from .typing import Number , OneDimArray
8+ from .typing import Number , OneDimArray , MetricGraphResult , SingleCurveResult , SingleMethodResult
99from ._private import convert , plot_or_not
1010
1111__all__ = ["random_forest_feature_importance" , "metric_graph" , "ks_abc" ]
@@ -53,29 +53,38 @@ def _draw_estimated_optimal_threshold_mark(
5353 ms : int ,
5454 fmt : str ,
5555 ax : Axes ,
56- ) -> tuple [Number , Number , Number ]:
56+ ) -> list [ tuple [Number , Number , Number ] ]:
5757 annotation_offset = (- 0.027 , 0.03 )
5858 a = np .zeros ((len (x_axis ), 2 ))
5959 a [:, 0 ] = x_axis
6060 a [:, 1 ] = y_axis
61+ a = a [a [:, 0 ] != a [:, 1 ]]
6162 if metric == "roc" :
62- dist = lambda row : row [0 ] ** 2 + (1 - row [1 ]) ** 2 # optimal: (0,1)
63+ dists = [ # optimal: (0,1)
64+ lambda row : row [0 ] ** 2 + (1 - row [1 ]) ** 2 , # geo
65+ lambda row : row [0 ] - row [1 ] # Inverse Youden's J (X-Y instead of Y-X) as later on we're finding the min value, and Youden's J needs to be maximized
66+ ]
6367 else : # metric == 'pr'
64- dist = (
65- lambda row : (1 - row [0 ]) ** 2 + (1 - row [1 ]) ** 2
66- ) # optimal: (1,1)
67- amin = np .apply_along_axis (dist , 1 , a ).argmin ()
68- ax .plot (x_axis [amin ], y_axis [amin ], color = color , marker = "o" , ms = ms ) # pyright: ignore[reportCallIssue, reportArgumentType]
69- ax .annotate (
70- "{th:{fmt}}" .format (th = thresholds [amin ], fmt = fmt ), # pyright: ignore[reportCallIssue, reportArgumentType]
71- xy = (x_axis [amin ], y_axis [amin ]), # pyright: ignore[reportCallIssue, reportArgumentType]
72- color = color ,
73- xytext = (
74- x_axis [amin ] + annotation_offset [0 ], # pyright: ignore[reportCallIssue, reportArgumentType, reportOperatorIssue]
75- y_axis [amin ] + annotation_offset [1 ], # pyright: ignore[reportCallIssue, reportArgumentType, reportOperatorIssue]
76- ),
77- )
78- return thresholds [amin ], x_axis [amin ], y_axis [amin ] # pyright: ignore[reportCallIssue, reportArgumentType, reportReturnType]
68+ dists = [ # optimal: (1,1)
69+ lambda row : (1 - row [0 ]) ** 2 + (1 - row [1 ]) ** 2 # geo
70+ ]
71+ output_tuples = []
72+ for dist , marker in zip (dists , ['o' ,'x' ]):
73+ amin = np .apply_along_axis (dist , 1 , a ).argmin ()
74+ ax .plot (x_axis [amin ], y_axis [amin ], color = color , marker = marker , ms = ms ) # pyright: ignore[reportCallIssue, reportArgumentType]
75+ ax .annotate (
76+ "{th:{fmt}}" .format (th = thresholds [amin ], fmt = fmt ), # pyright: ignore[reportCallIssue, reportArgumentType]
77+ xy = (x_axis [amin ], y_axis [amin ]), # pyright: ignore[reportCallIssue, reportArgumentType]
78+ color = color ,
79+ xytext = (
80+ x_axis [amin ] + annotation_offset [0 ], # pyright: ignore[reportCallIssue, reportArgumentType, reportOperatorIssue]
81+ y_axis [amin ] + annotation_offset [1 ], # pyright: ignore[reportCallIssue, reportArgumentType, reportOperatorIssue]
82+ ),
83+ )
84+ output_tuples .append (
85+ (thresholds [amin ], x_axis [amin ], y_axis [amin ]) # pyright: ignore[reportArgumentType, reportCallIssue]
86+ )
87+ return output_tuples
7988
8089
8190def _plot_macro_metric (
@@ -141,39 +150,58 @@ def _binary_metric_graph(
141150 metric = metric .upper (), class_label = class_label , auc = auc_score , fmt = fmt
142151 )
143152 if metric == "pr" :
144- label += ", naive = {ytr:{fmt}}" .format (ytr = y_t_ratio , fmt = fmt )
153+ label += ", naive = {ytr:{fmt}}) " .format (ytr = y_t_ratio , fmt = fmt )
145154 if eoptimal :
146- eopt , eopt_x , eopt_y = _draw_estimated_optimal_threshold_mark (
155+ eopts = _draw_estimated_optimal_threshold_mark (
147156 metric , x_axis , y_axis , th , color , ms , fmt , ax
148157 )
149- label += ", eOpT = {th:{fmt}})" .format (th = eopt , fmt = fmt )
158+ if len (eopts ) == 1 :
159+ eopts .append ((None , None , None )) # pyright: ignore[reportArgumentType]
150160 else :
151- eopt = None
152- eopt_x = None
153- eopt_y = None
154- label += ")"
161+ eopts = [
162+ ( None , None , None ),
163+ ( None , None , None )
164+ ]
155165 ax .plot (x_axis , y_axis , color = color , lw = lw , ls = ls , label = label )
156166 return {
157167 "x" : x_axis ,
158168 "y" : y_axis ,
159169 "thresholds" : th ,
160170 "auc" : auc_score ,
161- "eopt" : eopt ,
162- "eopt_x" : eopt_x ,
163- "eopt_y" : eopt_y ,
171+ "eopts" : [
172+ {
173+ "eopt" : eopts [0 ][0 ],
174+ "eopt_x" : eopts [0 ][1 ],
175+ "eopt_y" : eopts [0 ][2 ],
176+ "name" : "geo"
177+ },
178+ {
179+ "eopt" : eopts [1 ][0 ],
180+ "eopt_x" : eopts [1 ][1 ],
181+ "eopt_y" : eopts [1 ][2 ],
182+ "name" : "youden_j"
183+ },
184+ ],
164185 "y_t_ratio" : y_t_ratio ,
165186 }
166187
167188
168189def _build_metric_graph_output_dict (
169190 metric : str ,
170191 d : dict [str , Any ]
171- ) -> dict [ str , dict [ str , Any ]] :
192+ ) -> SingleCurveResult :
172193 naive = d ["y_t_ratio" ] if metric == "pr" else 0.5
173- return {
174- "auc" : {"val" : d ["auc" ], "naive" : naive },
175- "eopt" : {"val" : d ["eopt" ], "x" : d ["eopt_x" ], "y" : d ["eopt_y" ]},
176- }
194+ output : dict = {'auc' : {"val" : d ["auc" ], "naive" : naive }}
195+ for eopt in d ['eopts' ]:
196+ if eopt ['eopt' ] is None :
197+ continue
198+ method_result = SingleMethodResult (
199+ x = eopt ['eopt_x' ],
200+ y = eopt ['eopt_y' ],
201+ val = eopt ['eopt' ]
202+ )
203+ output [eopt ['name' ]] = method_result
204+ return output # pyright: ignore[reportReturnType]
177205
178206
179207def metric_graph (
@@ -199,15 +227,25 @@ def metric_graph(
199227 title : str | None = None ,
200228 filename : str | None = None ,
201229 force_multiclass : bool = False ,
202- ) -> dict [ str , Any ] :
230+ ) -> MetricGraphResult :
203231 """
204- Plot a ROC graph of predictor's results (including AUC scores), where each
232+ Plot a metric graph of predictor's results (including AUC scores), where each
205233 row of y_true and y_pred represent a single example.
206- If there are 1 or two columns only, the data is treated as a binary
207- classification (see input example below).
208- If there are more then 2 columns, each column is considered a
209- unique class, and a ROC graph and AUC score will be computed for each.
210- A Macro-ROC and Micro-ROC are computed and plotted too by default.
234+
235+ **ROC:**
236+ Plots true-positive rate as a function of the false-positive rate of the positive label in a binary classification,
237+ where $TPR = TP / (TP + FN)$ and $FPR = FP / (FP + TN)$. A naive algorithm will display a linear line going from
238+ (0,0) to (1,1), therefore having an area under-curve (AUC) of 0.5.
239+
240+ Computes the estimated optimal threshold using two methods:
241+ * Geometric distance: Finding the closest point to the optimum at (0,1) using Euclidean distance
242+ * Youden's J: Maximizing $TPR - FPR$ (corresponding to $Y - X$)
243+
244+ **Precision-Recall:**
245+ Plots precision as a function of recall of the positive label in a binary classification, where
246+ $Precision = TP / (TP + FP)$ and $Recall = TP / (TP + FN)$. A naive algorithm will display a horizontal linear
247+ line with precision of the ratio of positive examples in the dataset.
248+ Estimated optimal threshold is computed using Euclidean (geometric) distance.
211249
212250 Based on sklearn examples (as was seen on April 2018):
213251 http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
@@ -270,8 +308,20 @@ def metric_graph(
270308
271309 Returns:
272310 --------
273- A dictionary, one key for each class. Each value is another dictionary,
274- holding AUC and eOpT values.
311+ A dictionary with these keys:
312+ - `ax`: the Matplotlib plot axis
313+ - `metrics`: each key is a class name from the list of provided classes.,
314+ Per each class, another dict exists with AUC results
315+ and measurement methods results.
316+ AUC key holds both the measured area-under-curve (under `val`)
317+ and the AUC of a random-guess classifier (under `naive`) for
318+ comparison.
319+ Each measurement method key contains three values: `x`, `y`, `val`,
320+ corresponding to the (x,y) coordinates on the metric graph of the
321+ threshold, and its value.
322+ If only one class exists, then the measurements method keys and AUC
323+ will be directly under `metrics`.
324+
275325
276326 Binary Classification Input Example:
277327 ------------------------------------
@@ -325,7 +375,7 @@ def metric_graph(
325375 else :
326376 colors_list : list [str ] = colors or _ROC_PLOT_COLORS
327377
328- output_dict = dict ()
378+ output_dict : dict [ str , SingleCurveResult ] = {}
329379 pr_naives = list ()
330380 if (
331381 len (y_pred_array .shape ) == 1
@@ -422,8 +472,11 @@ def metric_graph(
422472 filename = filename ,
423473 plot = plot ,
424474 )
425- output_dict ["ax" ] = axis
426- return output_dict
475+ metric_graph_result = MetricGraphResult (
476+ ax = axis ,
477+ metrics = output_dict if len (output_dict ) > 1 else output_dict [list (output_dict .keys ())[0 ]]
478+ )
479+ return metric_graph_result
427480
428481
429482def random_forest_feature_importance (
0 commit comments