Skip to content

Commit 91724f1

Browse files
authored
Add confusion matrix (#533)
1 parent 13ad9cd commit 91724f1

File tree

5 files changed

+420
-36
lines changed

5 files changed

+420
-36
lines changed

core/src/main/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluator.scala

Lines changed: 253 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ import com.twitter.algebird.Operators._
3737
import com.twitter.algebird.Tuple2Semigroup
3838
import com.salesforce.op.utils.spark.RichEvaluator._
3939
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
40-
import org.apache.spark.ml.linalg.Vector
41-
import org.apache.spark.ml.param.{DoubleArrayParam, IntArrayParam}
40+
import org.apache.spark.ml.linalg.{Vector, DenseVector}
41+
import org.apache.spark.ml.param.{DoubleArrayParam, IntArrayParam, IntParam, ParamValidators}
4242
import org.apache.spark.mllib.evaluation.MulticlassMetrics
4343
import org.apache.spark.rdd.RDD
4444
import org.apache.spark.sql.functions.col
4545
import org.apache.spark.sql.types.DoubleType
4646
import org.apache.spark.sql.{Dataset, Row}
4747
import org.slf4j.LoggerFactory
48+
import scala.collection.Searching._
49+
4850

4951
/**
5052
* Instance to evaluate Multi Classification metrics
@@ -94,6 +96,35 @@ private[op] class OpMultiClassificationEvaluator
9496

9597
def setThresholds(v: Array[Double]): this.type = set(thresholds, v)
9698

99+
final val confMatrixNumClasses = new IntParam(
100+
parent = this,
101+
name = "confMatrixNumClasses",
102+
doc = "# of the top most frequent classes used for confusion matrix metrics",
103+
isValid = ParamValidators.inRange(1, 30, lowerInclusive = true, upperInclusive = true)
104+
)
105+
setDefault(confMatrixNumClasses, 15)
106+
107+
def setConfMatrixNumClasses(v: Int): this.type = set(confMatrixNumClasses, v)
108+
109+
final val confMatrixMinSupport = new IntParam(
110+
parent = this,
111+
name = "confMatrixMinSupport",
112+
doc = "# of the top most frequent misclassified classes in each label/prediction category",
113+
isValid = ParamValidators.inRange(1, 10, lowerInclusive = false, upperInclusive = true)
114+
)
115+
setDefault(confMatrixMinSupport, 5)
116+
117+
def setConfMatrixMinSupport(v: Int): this.type = set(confMatrixMinSupport, v)
118+
119+
final val confMatrixThresholds = new DoubleArrayParam(
120+
parent = this,
121+
name = "confMatrixThresholds",
122+
doc = "sequence of threshold values used for confusion matrix metrics",
123+
isValid = _.forall(x => x >= 0.0 && x < 1.0)
124+
)
125+
setDefault(confMatrixThresholds, Array(0.0, 0.2, 0.4, 0.6, 0.8))
126+
def setConfMatrixThresholds(v: Array[Double]): this.type = set(confMatrixThresholds, v)
127+
97128
override def evaluateAll(data: Dataset[_]): MultiClassificationMetrics = {
98129
val labelColName = getLabelCol
99130
val dataUse = makeDataToUse(data, labelColName)
@@ -112,7 +143,9 @@ private[op] class OpMultiClassificationEvaluator
112143
log.warn("The dataset is empty. Returning empty metrics.")
113144
MultiClassificationMetrics(0.0, 0.0, 0.0, 0.0,
114145
MulticlassThresholdMetrics(Seq.empty, Seq.empty, Map.empty, Map.empty, Map.empty),
115-
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty))
146+
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
147+
MulticlassConfMatrixMetricsByThreshold($(confMatrixNumClasses), Seq.empty, $(confMatrixThresholds), Seq.empty),
148+
MisClassificationMetrics($(confMatrixMinSupport), Seq.empty, Seq.empty))
116149
} else {
117150
val multiclassMetrics = new MulticlassMetrics(rdd)
118151
val error = 1.0 - multiclassMetrics.accuracy
@@ -133,20 +166,183 @@ private[op] class OpMultiClassificationEvaluator
133166
topKs = $(topKs)
134167
)
135168

169+
val rddCm = dataUse.select(col(labelColName), col(predictionColName), col(probabilityColName)).rdd.map{
170+
case Row(label: Double, pred: Double, prob: DenseVector) => (label, pred, prob.toArray)
171+
}
172+
val confusionMatrixByThreshold = calculateConfMatrixMetricsByThreshold(rddCm)
173+
val misClassifications = calculateMisClassificationMetrics( rddCm.map{ case (label, pred, _) => (label, pred)} )
174+
136175
val metrics = MultiClassificationMetrics(
137176
Precision = precision,
138177
Recall = recall,
139178
F1 = f1,
140179
Error = error,
141180
ThresholdMetrics = thresholdMetrics,
142-
TopKMetrics = topKMetrics
181+
TopKMetrics = topKMetrics,
182+
ConfusionMatrixMetrics = confusionMatrixByThreshold,
183+
MisClassificationMetrics = misClassifications
143184
)
144185

145186
log.info("Evaluated metrics: {}", metrics.toString)
146187
metrics
147188
}
148189
}
149190

191+
/**
192+
* function to construct the confusion matrix for the top n most occurring labels
193+
* @param labelPredictionCtRDD RDD of ((label, prediction, confidence), count)
194+
* @param cmClasses the top n most occurring classes, sorted by counts in descending order
195+
* @return an array of counts
196+
*/
197+
def constructConfusionMatrix(
198+
labelPredictionCtRDD: RDD[((Double, Double, Double), Long)],
199+
cmClasses: Seq[Double]): Seq[Long] = {
200+
201+
val confusionMatrixMap = labelPredictionCtRDD.map {
202+
case ((label, prediction, _), count) => ((label, prediction), count)
203+
}.reduceByKey(_ + _).collectAsMap()
204+
205+
for {
206+
label <- cmClasses
207+
prediction <- cmClasses
208+
} yield {
209+
confusionMatrixMap.getOrElse((label, prediction), 0L)
210+
}
211+
}
212+
213+
private[evaluators] object SearchHelper extends Serializable{
214+
215+
/**
216+
* function to search the confidence threshold corresponding to a probability score
217+
*
218+
* @param arr a sorted array of confidence thresholds
219+
* @param element the probability score to be searched
220+
* @return the confidence threshold corresponding of the element. It equals to the element if there is an exact
221+
* match. Otherwise it's the element right before the insertion point.
222+
*/
223+
def findThreshold(arr: IndexedSeq[Double], element: Double): Double = {
224+
require(!arr.isEmpty, "Array of confidence thresholds can't be empty!")
225+
if (element > arr.last) arr.last
226+
else if (element < arr.head) 0.0
227+
else {
228+
val insertionPoint = new SearchImpl(arr).search(element).insertionPoint
229+
val insertionPointValue = arr(insertionPoint)
230+
if (element == insertionPointValue) insertionPointValue
231+
else arr(insertionPoint-1)
232+
}
233+
}
234+
}
235+
236+
/**
237+
* function to calculate confusion matrix for TopK most occurring labels by confidence threshold
238+
*
239+
* @param data RDD of (label, prediction, prediction probability vector)
240+
* @return a MulticlassConfMatrixMetricsByThreshold instance
241+
*/
242+
def calculateConfMatrixMetricsByThreshold(
243+
data: RDD[(Double, Double, Array[Double])]): MulticlassConfMatrixMetricsByThreshold = {
244+
245+
val labelCountsRDD = data.map { case (label, _, _) => (label, 1L) }.reduceByKey(_ + _)
246+
val cmClasses = labelCountsRDD.sortBy(-_._2).map(_._1).take($(confMatrixNumClasses)).toSeq
247+
val cmClassesSet = cmClasses.toSet
248+
249+
val dataTopNLabels = data.filter { case (label, prediction, _) =>
250+
cmClassesSet.contains(label) && cmClassesSet.contains(prediction)
251+
}
252+
253+
val sortedThresholds = $(confMatrixThresholds).sorted.toIndexedSeq
254+
255+
// reduce data to a coarser RDD (with size N * N * thresholds at most) for further aggregation
256+
val labelPredictionConfidenceCountRDD = dataTopNLabels.map{
257+
case (label, prediction, proba) => {
258+
( (label, prediction, SearchHelper.findThreshold(sortedThresholds, proba.max)), 1L )
259+
}
260+
}.reduceByKey(_ + _)
261+
262+
labelPredictionConfidenceCountRDD.persist()
263+
264+
val cmByThreshold = sortedThresholds.map( threshold => {
265+
val filteredRDD = labelPredictionConfidenceCountRDD.filter {
266+
case ((_, _, confidence), _) => confidence >= threshold
267+
}
268+
constructConfusionMatrix(filteredRDD, cmClasses)
269+
})
270+
271+
labelPredictionConfidenceCountRDD.unpersist()
272+
273+
MulticlassConfMatrixMetricsByThreshold(
274+
ConfMatrixNumClasses = $(confMatrixNumClasses),
275+
ConfMatrixClassIndices = cmClasses,
276+
ConfMatrixThresholds = $(confMatrixThresholds),
277+
ConfMatrices = cmByThreshold
278+
)
279+
}
280+
281+
/**
282+
* function to calculate the mostly frequently mis-classified classes for each label/prediction category
283+
*
284+
* @param data RDD of (label, prediction)
285+
* @return a MisClassificationMetrics instance
286+
*/
287+
def calculateMisClassificationMetrics(data: RDD[(Double, Double)]): MisClassificationMetrics = {
288+
289+
val labelPredictionCountRDD = data.map {
290+
case (label, prediction) => ((label, prediction), 1L) }
291+
.reduceByKey(_ + _)
292+
293+
val misClassificationsByLabel = labelPredictionCountRDD.map {
294+
case ((label, prediction), count) => (label, Seq((prediction, count)))
295+
}.reduceByKey(_ ++ _)
296+
.map { case (label, predictionCountsIter) => {
297+
val misClassificationCtMap = predictionCountsIter
298+
.filter { case (pred, _) => pred != label }
299+
.sortBy(-_._2)
300+
.take($(confMatrixMinSupport)).toMap
301+
302+
val labelCount = predictionCountsIter.map(_._2).reduce(_ + _)
303+
val correctCount = predictionCountsIter
304+
.collect { case (pred, count) if pred == label => count }
305+
.reduceOption(_ + _).getOrElse(0L)
306+
307+
MisClassificationsPerCategory(
308+
Category = label,
309+
TotalCount = labelCount,
310+
CorrectCount = correctCount,
311+
MisClassifications = misClassificationCtMap
312+
)
313+
}
314+
}.sortBy(-_.TotalCount).collect()
315+
316+
val misClassificationsByPrediction = labelPredictionCountRDD.map {
317+
case ((label, prediction), count) => (prediction, Seq((label, count)))
318+
}.reduceByKey(_ ++ _)
319+
.map { case (prediction, labelCountsIter) => {
320+
val sortedMisclassificationCt = labelCountsIter
321+
.filter { case (label, _) => label != prediction }
322+
.sortBy(-_._2)
323+
.take($(confMatrixMinSupport)).toMap
324+
325+
val predictionCount = labelCountsIter.map(_._2).reduce(_ + _)
326+
val correctCount = labelCountsIter
327+
.collect { case (label, count) if label == prediction => count }
328+
.reduceOption(_ + _).getOrElse(0L)
329+
330+
MisClassificationsPerCategory(
331+
Category = prediction,
332+
TotalCount = predictionCount,
333+
CorrectCount = correctCount,
334+
MisClassifications = sortedMisclassificationCt
335+
)
336+
}
337+
}.sortBy(-_.TotalCount).collect()
338+
339+
MisClassificationMetrics(
340+
ConfMatrixMinSupport = $(confMatrixMinSupport),
341+
MisClassificationsByLabel = misClassificationsByLabel,
342+
MisClassificationsByPrediction = misClassificationsByPrediction
343+
)
344+
}
345+
150346
/**
151347
* Function that calculates Multi Classification Metrics for different topK most occuring labels given an RDD
152348
* of scores & labels, and a list of topK values to consider.
@@ -187,7 +383,6 @@ private[op] class OpMultiClassificationEvaluator
187383
)
188384
}
189385

190-
191386
/**
192387
* Function that calculates a set of threshold metrics for different topN values given an RDD of scores & labels,
193388
* a list of topN values to consider, and a list of thresholds to use.
@@ -308,7 +503,6 @@ private[op] class OpMultiClassificationEvaluator
308503
.setMetricName(metricName.sparkEntryName)
309504
.evaluateOrDefault(dataUse, default = default)
310505
}
311-
312506
}
313507

314508

@@ -328,7 +522,9 @@ case class MultiClassificationMetrics
328522
F1: Double,
329523
Error: Double,
330524
ThresholdMetrics: MulticlassThresholdMetrics,
331-
TopKMetrics: MultiClassificationMetricsTopK
525+
TopKMetrics: MultiClassificationMetricsTopK,
526+
ConfusionMatrixMetrics: MulticlassConfMatrixMetricsByThreshold,
527+
MisClassificationMetrics: MisClassificationMetrics
332528
) extends EvaluationMetrics
333529

334530
/**
@@ -352,6 +548,56 @@ case class MultiClassificationMetricsTopK
352548
Error: Seq[Double]
353549
) extends EvaluationMetrics
354550

551+
/**
552+
* Metrics for multi-class confusion matrix. It captures confusion matrix of records of which
553+
* 1) the labels belong to the top n most occurring classes (n = confMatrixNumClasses)
554+
* 2) the top predicted probability exceeds a certain threshold in confMatrixThresholds
555+
*
556+
* @param confMatrixNumClasses value of the top n most occurring classes in the dataset
557+
* @param confMatrixClassIndices label index of the top n most occuring classes
558+
* @param confMatrixThresholds a sequence of thresholds
559+
* @param confMatrices a sequence of counts that stores the confusion matrix for each threshold in confMatrixThresholds
560+
*/
561+
case class MulticlassConfMatrixMetricsByThreshold
562+
(
563+
ConfMatrixNumClasses: Int,
564+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
565+
ConfMatrixClassIndices: Seq[Double],
566+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
567+
ConfMatrixThresholds: Seq[Double],
568+
ConfMatrices: Seq[Seq[Long]]
569+
) extends EvaluationMetrics
570+
571+
/**
572+
* Multiclass mis-classification metrics, including the top n (n = confMatrixMinSupport) most frequently
573+
* mis-classified classes for each label or prediction category.
574+
*
575+
*/
576+
case class MisClassificationMetrics
577+
(
578+
ConfMatrixMinSupport: Int,
579+
MisClassificationsByLabel: Seq[MisClassificationsPerCategory],
580+
MisClassificationsByPrediction: Seq[MisClassificationsPerCategory]
581+
)
582+
583+
/**
584+
* container to store the most frequently mis-classified classes for each label/prediction category
585+
*
586+
* @param category a category which a record's label or prediction equals to
587+
* @param totalCount total # of records in that category
588+
* @param correctCount # of correctly predicted records in that category
589+
* @param misClassifications the top n most frequently misclassified classes (n = confMatrixMinSupport) and
590+
* their respective counts in that category. Ordered by counts in descending order.
591+
*/
592+
case class MisClassificationsPerCategory
593+
(
594+
Category: Double,
595+
TotalCount: Long,
596+
CorrectCount: Long,
597+
@JsonDeserialize(keyAs = classOf[java.lang.Double])
598+
MisClassifications: Map[Double, Long]
599+
)
600+
355601
/**
356602
* Threshold-based metrics for multiclass classification
357603
*

0 commit comments

Comments
 (0)