@@ -37,14 +37,16 @@ import com.twitter.algebird.Operators._
3737import com .twitter .algebird .Tuple2Semigroup
3838import com .salesforce .op .utils .spark .RichEvaluator ._
3939import org .apache .spark .ml .evaluation .MulticlassClassificationEvaluator
40- import org .apache .spark .ml .linalg .Vector
41- import org .apache .spark .ml .param .{DoubleArrayParam , IntArrayParam }
40+ import org .apache .spark .ml .linalg .{ Vector , DenseVector }
41+ import org .apache .spark .ml .param .{DoubleArrayParam , IntArrayParam , IntParam , ParamValidators }
4242import org .apache .spark .mllib .evaluation .MulticlassMetrics
4343import org .apache .spark .rdd .RDD
4444import org .apache .spark .sql .functions .col
4545import org .apache .spark .sql .types .DoubleType
4646import org .apache .spark .sql .{Dataset , Row }
4747import org .slf4j .LoggerFactory
48+ import scala .collection .Searching ._
49+
4850
4951/**
5052 * Instance to evaluate Multi Classification metrics
@@ -94,6 +96,35 @@ private[op] class OpMultiClassificationEvaluator
9496
9597 def setThresholds (v : Array [Double ]): this .type = set(thresholds, v)
9698
99+ final val confMatrixNumClasses = new IntParam (
100+ parent = this ,
101+ name = " confMatrixNumClasses" ,
102+ doc = " # of the top most frequent classes used for confusion matrix metrics" ,
103+ isValid = ParamValidators .inRange(1 , 30 , lowerInclusive = true , upperInclusive = true )
104+ )
105+ setDefault(confMatrixNumClasses, 15 )
106+
107+ def setConfMatrixNumClasses (v : Int ): this .type = set(confMatrixNumClasses, v)
108+
109+ final val confMatrixMinSupport = new IntParam (
110+ parent = this ,
111+ name = " confMatrixMinSupport" ,
112+ doc = " # of the top most frequent misclassified classes in each label/prediction category" ,
113+ isValid = ParamValidators .inRange(1 , 10 , lowerInclusive = false , upperInclusive = true )
114+ )
115+ setDefault(confMatrixMinSupport, 5 )
116+
117+ def setConfMatrixMinSupport (v : Int ): this .type = set(confMatrixMinSupport, v)
118+
119+ final val confMatrixThresholds = new DoubleArrayParam (
120+ parent = this ,
121+ name = " confMatrixThresholds" ,
122+ doc = " sequence of threshold values used for confusion matrix metrics" ,
123+ isValid = _.forall(x => x >= 0.0 && x < 1.0 )
124+ )
125+ setDefault(confMatrixThresholds, Array (0.0 , 0.2 , 0.4 , 0.6 , 0.8 ))
126+ def setConfMatrixThresholds (v : Array [Double ]): this .type = set(confMatrixThresholds, v)
127+
97128 override def evaluateAll (data : Dataset [_]): MultiClassificationMetrics = {
98129 val labelColName = getLabelCol
99130 val dataUse = makeDataToUse(data, labelColName)
@@ -112,7 +143,9 @@ private[op] class OpMultiClassificationEvaluator
112143 log.warn(" The dataset is empty. Returning empty metrics." )
113144 MultiClassificationMetrics (0.0 , 0.0 , 0.0 , 0.0 ,
114145 MulticlassThresholdMetrics (Seq .empty, Seq .empty, Map .empty, Map .empty, Map .empty),
115- MultiClassificationMetricsTopK (Seq .empty, Seq .empty, Seq .empty, Seq .empty, Seq .empty))
146+ MultiClassificationMetricsTopK (Seq .empty, Seq .empty, Seq .empty, Seq .empty, Seq .empty),
147+ MulticlassConfMatrixMetricsByThreshold ($(confMatrixNumClasses), Seq .empty, $(confMatrixThresholds), Seq .empty),
148+ MisClassificationMetrics ($(confMatrixMinSupport), Seq .empty, Seq .empty))
116149 } else {
117150 val multiclassMetrics = new MulticlassMetrics (rdd)
118151 val error = 1.0 - multiclassMetrics.accuracy
@@ -133,20 +166,183 @@ private[op] class OpMultiClassificationEvaluator
133166 topKs = $(topKs)
134167 )
135168
169+ val rddCm = dataUse.select(col(labelColName), col(predictionColName), col(probabilityColName)).rdd.map{
170+ case Row (label : Double , pred : Double , prob : DenseVector ) => (label, pred, prob.toArray)
171+ }
172+ val confusionMatrixByThreshold = calculateConfMatrixMetricsByThreshold(rddCm)
173+ val misClassifications = calculateMisClassificationMetrics( rddCm.map{ case (label, pred, _) => (label, pred)} )
174+
136175 val metrics = MultiClassificationMetrics (
137176 Precision = precision,
138177 Recall = recall,
139178 F1 = f1,
140179 Error = error,
141180 ThresholdMetrics = thresholdMetrics,
142- TopKMetrics = topKMetrics
181+ TopKMetrics = topKMetrics,
182+ ConfusionMatrixMetrics = confusionMatrixByThreshold,
183+ MisClassificationMetrics = misClassifications
143184 )
144185
145186 log.info(" Evaluated metrics: {}" , metrics.toString)
146187 metrics
147188 }
148189 }
149190
191+ /**
192+ * function to construct the confusion matrix for the top n most occurring labels
193+ * @param labelPredictionCtRDD RDD of ((label, prediction, confidence), count)
194+ * @param cmClasses the top n most occurring classes, sorted by counts in descending order
195+ * @return an array of counts
196+ */
197+ def constructConfusionMatrix (
198+ labelPredictionCtRDD : RDD [((Double , Double , Double ), Long )],
199+ cmClasses : Seq [Double ]): Seq [Long ] = {
200+
201+ val confusionMatrixMap = labelPredictionCtRDD.map {
202+ case ((label, prediction, _), count) => ((label, prediction), count)
203+ }.reduceByKey(_ + _).collectAsMap()
204+
205+ for {
206+ label <- cmClasses
207+ prediction <- cmClasses
208+ } yield {
209+ confusionMatrixMap.getOrElse((label, prediction), 0L )
210+ }
211+ }
212+
213+ private [evaluators] object SearchHelper extends Serializable {
214+
215+ /**
216+ * function to search the confidence threshold corresponding to a probability score
217+ *
218+ * @param arr a sorted array of confidence thresholds
219+ * @param element the probability score to be searched
220+ * @return the confidence threshold corresponding of the element. It equals to the element if there is an exact
221+ * match. Otherwise it's the element right before the insertion point.
222+ */
223+ def findThreshold (arr : IndexedSeq [Double ], element : Double ): Double = {
224+ require(! arr.isEmpty, " Array of confidence thresholds can't be empty!" )
225+ if (element > arr.last) arr.last
226+ else if (element < arr.head) 0.0
227+ else {
228+ val insertionPoint = new SearchImpl (arr).search(element).insertionPoint
229+ val insertionPointValue = arr(insertionPoint)
230+ if (element == insertionPointValue) insertionPointValue
231+ else arr(insertionPoint- 1 )
232+ }
233+ }
234+ }
235+
236+ /**
237+ * function to calculate confusion matrix for TopK most occurring labels by confidence threshold
238+ *
239+ * @param data RDD of (label, prediction, prediction probability vector)
240+ * @return a MulticlassConfMatrixMetricsByThreshold instance
241+ */
242+ def calculateConfMatrixMetricsByThreshold (
243+ data : RDD [(Double , Double , Array [Double ])]): MulticlassConfMatrixMetricsByThreshold = {
244+
245+ val labelCountsRDD = data.map { case (label, _, _) => (label, 1L ) }.reduceByKey(_ + _)
246+ val cmClasses = labelCountsRDD.sortBy(- _._2).map(_._1).take($(confMatrixNumClasses)).toSeq
247+ val cmClassesSet = cmClasses.toSet
248+
249+ val dataTopNLabels = data.filter { case (label, prediction, _) =>
250+ cmClassesSet.contains(label) && cmClassesSet.contains(prediction)
251+ }
252+
253+ val sortedThresholds = $(confMatrixThresholds).sorted.toIndexedSeq
254+
255+ // reduce data to a coarser RDD (with size N * N * thresholds at most) for further aggregation
256+ val labelPredictionConfidenceCountRDD = dataTopNLabels.map{
257+ case (label, prediction, proba) => {
258+ ( (label, prediction, SearchHelper .findThreshold(sortedThresholds, proba.max)), 1L )
259+ }
260+ }.reduceByKey(_ + _)
261+
262+ labelPredictionConfidenceCountRDD.persist()
263+
264+ val cmByThreshold = sortedThresholds.map( threshold => {
265+ val filteredRDD = labelPredictionConfidenceCountRDD.filter {
266+ case ((_, _, confidence), _) => confidence >= threshold
267+ }
268+ constructConfusionMatrix(filteredRDD, cmClasses)
269+ })
270+
271+ labelPredictionConfidenceCountRDD.unpersist()
272+
273+ MulticlassConfMatrixMetricsByThreshold (
274+ ConfMatrixNumClasses = $(confMatrixNumClasses),
275+ ConfMatrixClassIndices = cmClasses,
276+ ConfMatrixThresholds = $(confMatrixThresholds),
277+ ConfMatrices = cmByThreshold
278+ )
279+ }
280+
281+ /**
282+ * function to calculate the mostly frequently mis-classified classes for each label/prediction category
283+ *
284+ * @param data RDD of (label, prediction)
285+ * @return a MisClassificationMetrics instance
286+ */
287+ def calculateMisClassificationMetrics (data : RDD [(Double , Double )]): MisClassificationMetrics = {
288+
289+ val labelPredictionCountRDD = data.map {
290+ case (label, prediction) => ((label, prediction), 1L ) }
291+ .reduceByKey(_ + _)
292+
293+ val misClassificationsByLabel = labelPredictionCountRDD.map {
294+ case ((label, prediction), count) => (label, Seq ((prediction, count)))
295+ }.reduceByKey(_ ++ _)
296+ .map { case (label, predictionCountsIter) => {
297+ val misClassificationCtMap = predictionCountsIter
298+ .filter { case (pred, _) => pred != label }
299+ .sortBy(- _._2)
300+ .take($(confMatrixMinSupport)).toMap
301+
302+ val labelCount = predictionCountsIter.map(_._2).reduce(_ + _)
303+ val correctCount = predictionCountsIter
304+ .collect { case (pred, count) if pred == label => count }
305+ .reduceOption(_ + _).getOrElse(0L )
306+
307+ MisClassificationsPerCategory (
308+ Category = label,
309+ TotalCount = labelCount,
310+ CorrectCount = correctCount,
311+ MisClassifications = misClassificationCtMap
312+ )
313+ }
314+ }.sortBy(- _.TotalCount ).collect()
315+
316+ val misClassificationsByPrediction = labelPredictionCountRDD.map {
317+ case ((label, prediction), count) => (prediction, Seq ((label, count)))
318+ }.reduceByKey(_ ++ _)
319+ .map { case (prediction, labelCountsIter) => {
320+ val sortedMisclassificationCt = labelCountsIter
321+ .filter { case (label, _) => label != prediction }
322+ .sortBy(- _._2)
323+ .take($(confMatrixMinSupport)).toMap
324+
325+ val predictionCount = labelCountsIter.map(_._2).reduce(_ + _)
326+ val correctCount = labelCountsIter
327+ .collect { case (label, count) if label == prediction => count }
328+ .reduceOption(_ + _).getOrElse(0L )
329+
330+ MisClassificationsPerCategory (
331+ Category = prediction,
332+ TotalCount = predictionCount,
333+ CorrectCount = correctCount,
334+ MisClassifications = sortedMisclassificationCt
335+ )
336+ }
337+ }.sortBy(- _.TotalCount ).collect()
338+
339+ MisClassificationMetrics (
340+ ConfMatrixMinSupport = $(confMatrixMinSupport),
341+ MisClassificationsByLabel = misClassificationsByLabel,
342+ MisClassificationsByPrediction = misClassificationsByPrediction
343+ )
344+ }
345+
150346/**
151347 * Function that calculates Multi Classification Metrics for different topK most occuring labels given an RDD
152348 * of scores & labels, and a list of topK values to consider.
@@ -187,7 +383,6 @@ private[op] class OpMultiClassificationEvaluator
187383 )
188384 }
189385
190-
191386 /**
192387 * Function that calculates a set of threshold metrics for different topN values given an RDD of scores & labels,
193388 * a list of topN values to consider, and a list of thresholds to use.
@@ -308,7 +503,6 @@ private[op] class OpMultiClassificationEvaluator
308503 .setMetricName(metricName.sparkEntryName)
309504 .evaluateOrDefault(dataUse, default = default)
310505 }
311-
312506}
313507
314508
@@ -328,7 +522,9 @@ case class MultiClassificationMetrics
328522 F1 : Double ,
329523 Error : Double ,
330524 ThresholdMetrics : MulticlassThresholdMetrics ,
331- TopKMetrics : MultiClassificationMetricsTopK
525+ TopKMetrics : MultiClassificationMetricsTopK ,
526+ ConfusionMatrixMetrics : MulticlassConfMatrixMetricsByThreshold ,
527+ MisClassificationMetrics : MisClassificationMetrics
332528) extends EvaluationMetrics
333529
334530/**
@@ -352,6 +548,56 @@ case class MultiClassificationMetricsTopK
352548 Error : Seq [Double ]
353549) extends EvaluationMetrics
354550
551+ /**
552+ * Metrics for multi-class confusion matrix. It captures confusion matrix of records of which
553+ * 1) the labels belong to the top n most occurring classes (n = confMatrixNumClasses)
554+ * 2) the top predicted probability exceeds a certain threshold in confMatrixThresholds
555+ *
556+ * @param confMatrixNumClasses value of the top n most occurring classes in the dataset
557+ * @param confMatrixClassIndices label index of the top n most occuring classes
558+ * @param confMatrixThresholds a sequence of thresholds
559+ * @param confMatrices a sequence of counts that stores the confusion matrix for each threshold in confMatrixThresholds
560+ */
561+ case class MulticlassConfMatrixMetricsByThreshold
562+ (
563+ ConfMatrixNumClasses : Int ,
564+ @ JsonDeserialize (contentAs = classOf [java.lang.Double ])
565+ ConfMatrixClassIndices : Seq [Double ],
566+ @ JsonDeserialize (contentAs = classOf [java.lang.Double ])
567+ ConfMatrixThresholds : Seq [Double ],
568+ ConfMatrices : Seq [Seq [Long ]]
569+ ) extends EvaluationMetrics
570+
571+ /**
572+ * Multiclass mis-classification metrics, including the top n (n = confMatrixMinSupport) most frequently
573+ * mis-classified classes for each label or prediction category.
574+ *
575+ */
576+ case class MisClassificationMetrics
577+ (
578+ ConfMatrixMinSupport : Int ,
579+ MisClassificationsByLabel : Seq [MisClassificationsPerCategory ],
580+ MisClassificationsByPrediction : Seq [MisClassificationsPerCategory ]
581+ )
582+
583+ /**
584+ * container to store the most frequently mis-classified classes for each label/prediction category
585+ *
586+ * @param category a category which a record's label or prediction equals to
587+ * @param totalCount total # of records in that category
588+ * @param correctCount # of correctly predicted records in that category
589+ * @param misClassifications the top n most frequently misclassified classes (n = confMatrixMinSupport) and
590+ * their respective counts in that category. Ordered by counts in descending order.
591+ */
592+ case class MisClassificationsPerCategory
593+ (
594+ Category : Double ,
595+ TotalCount : Long ,
596+ CorrectCount : Long ,
597+ @ JsonDeserialize (keyAs = classOf [java.lang.Double ])
598+ MisClassifications : Map [Double , Long ]
599+ )
600+
355601/**
356602 * Threshold-based metrics for multiclass classification
357603 *
0 commit comments