Skip to content

Commit ec50a92

Browse files
authored
Add multiclassification topk and confmatrix metrics to model insights serialization format (#537)
1 parent 37df638 commit ec50a92

File tree

6 files changed

+137
-66
lines changed

6 files changed

+137
-66
lines changed

core/src/main/scala/com/salesforce/op/ModelInsights.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ case object ModelInsights {
401401
classOf[DataBalancerSummary], classOf[DataCutterSummary], classOf[DataSplitterSummary],
402402
classOf[SingleMetric], classOf[MultiMetrics], classOf[BinaryClassificationMetrics],
403403
classOf[BinaryClassificationBinMetrics], classOf[MulticlassThresholdMetrics],
404-
classOf[BinaryThresholdMetrics], classOf[MultiClassificationMetrics], classOf[RegressionMetrics]
404+
classOf[BinaryThresholdMetrics], classOf[MultiClassificationMetrics], classOf[RegressionMetrics],
405+
classOf[MultiClassificationMetricsTopK],
406+
classOf[MulticlassConfMatrixMetricsByThreshold], classOf[MisClassificationMetrics]
405407
))
406408
val evalMetricsSerializer = new CustomSerializer[EvalMetric](_ =>
407409
( { case JString(s) => EvalMetric.withNameInsensitive(s) },

core/src/main/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluator.scala

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,34 @@ private[op] class OpMultiClassificationEvaluator
278278
)
279279
}
280280

281+
/**
282+
* function to convert a sequence of ClassCount to a MisClassificationsPerCategory instance
283+
*
284+
* @param allClassCtSeq sequence of ClassCount containing labels or predictions and their counts for each category
285+
* @param category index of a labelled or predicted class
286+
* @return a MisClassifcationPerCategory instance
287+
*/
288+
private def getMisclassificationsPerCategory(
289+
category: Double, allClassCtSeq: Seq[ClassCount]): MisClassificationsPerCategory = {
290+
291+
val misClassificationCtMap = allClassCtSeq
292+
.filter(_.ClassIndex != category)
293+
.sortBy(-_.Count)
294+
.take($(confMatrixMinSupport))
295+
296+
val labelCount = allClassCtSeq.map(_.Count).reduce(_ + _)
297+
val correctCount = allClassCtSeq.filter(_.ClassIndex == category)
298+
.map(_.Count)
299+
.reduceOption(_ + _).getOrElse(0L)
300+
301+
MisClassificationsPerCategory(
302+
Category = category,
303+
TotalCount = labelCount,
304+
CorrectCount = correctCount,
305+
MisClassifications = misClassificationCtMap
306+
)
307+
}
308+
281309
/**
282310
* function to calculate the mostly frequently mis-classified classes for each label/prediction category
283311
*
@@ -291,49 +319,15 @@ private[op] class OpMultiClassificationEvaluator
291319
.reduceByKey(_ + _)
292320

293321
val misClassificationsByLabel = labelPredictionCountRDD.map {
294-
case ((label, prediction), count) => (label, Seq((prediction, count)))
322+
case ((label, prediction), count) => (label, Seq(ClassCount(prediction, count)))
295323
}.reduceByKey(_ ++ _)
296-
.map { case (label, predictionCountsIter) => {
297-
val misClassificationCtMap = predictionCountsIter
298-
.filter { case (pred, _) => pred != label }
299-
.sortBy(-_._2)
300-
.take($(confMatrixMinSupport)).toMap
301-
302-
val labelCount = predictionCountsIter.map(_._2).reduce(_ + _)
303-
val correctCount = predictionCountsIter
304-
.collect { case (pred, count) if pred == label => count }
305-
.reduceOption(_ + _).getOrElse(0L)
306-
307-
MisClassificationsPerCategory(
308-
Category = label,
309-
TotalCount = labelCount,
310-
CorrectCount = correctCount,
311-
MisClassifications = misClassificationCtMap
312-
)
313-
}
314-
}.sortBy(-_.TotalCount).collect()
324+
.map { case (label, predictionCountsSeq) => getMisclassificationsPerCategory(label, predictionCountsSeq)}
325+
.sortBy(-_.TotalCount).collect()
315326

316327
val misClassificationsByPrediction = labelPredictionCountRDD.map {
317-
case ((label, prediction), count) => (prediction, Seq((label, count)))
328+
case ((label, prediction), count) => (prediction, Seq(ClassCount(label, count)))
318329
}.reduceByKey(_ ++ _)
319-
.map { case (prediction, labelCountsIter) => {
320-
val sortedMisclassificationCt = labelCountsIter
321-
.filter { case (label, _) => label != prediction }
322-
.sortBy(-_._2)
323-
.take($(confMatrixMinSupport)).toMap
324-
325-
val predictionCount = labelCountsIter.map(_._2).reduce(_ + _)
326-
val correctCount = labelCountsIter
327-
.collect { case (label, count) if label == prediction => count }
328-
.reduceOption(_ + _).getOrElse(0L)
329-
330-
MisClassificationsPerCategory(
331-
Category = prediction,
332-
TotalCount = predictionCount,
333-
CorrectCount = correctCount,
334-
MisClassifications = sortedMisclassificationCt
335-
)
336-
}
330+
.map { case (prediction, labelCountsSeq) => getMisclassificationsPerCategory(prediction, labelCountsSeq)
337331
}.sortBy(-_.TotalCount).collect()
338332

339333
MisClassificationMetrics(
@@ -541,10 +535,15 @@ case class MultiClassificationMetrics
541535
*/
542536
case class MultiClassificationMetricsTopK
543537
(
538+
@JsonDeserialize(contentAs = classOf[java.lang.Integer])
544539
topKs: Seq[Int],
540+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
545541
Precision: Seq[Double],
542+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
546543
Recall: Seq[Double],
544+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
547545
F1: Seq[Double],
546+
@JsonDeserialize(contentAs = classOf[java.lang.Double])
548547
Error: Seq[Double]
549548
) extends EvaluationMetrics
550549

@@ -594,8 +593,19 @@ case class MisClassificationsPerCategory
594593
Category: Double,
595594
TotalCount: Long,
596595
CorrectCount: Long,
597-
@JsonDeserialize(keyAs = classOf[java.lang.Double])
598-
MisClassifications: Map[Double, Long]
596+
MisClassifications: Seq[ClassCount]
597+
)
598+
599+
/**
600+
* container to store the count of a class
601+
*
602+
* @param ClassIndex
603+
* @param Count
604+
*/
605+
case class ClassCount
606+
(
607+
ClassIndex: Double,
608+
Count: Long
599609
)
600610

601611
/**

core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import com.salesforce.op.stages.impl.preparators._
4040
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, OpXGBoostRegressor, RegressionModelSelector}
4141
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
4242
import com.salesforce.op.stages.impl.selector.ValidationType._
43-
import com.salesforce.op.stages.impl.selector.{SelectedCombinerModel, SelectedModel, SelectedModelCombiner}
43+
import com.salesforce.op.stages.impl.selector.{ModelSelectorSummary, ProblemType, SelectedCombinerModel, SelectedModel, SelectedModelCombiner, ValidationType}
4444
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter}
4545
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestFeatureBuilder}
4646
import com.salesforce.op.testkit.RandomReal
@@ -406,6 +406,61 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
406406
pretty should include("Top Contributions")
407407
}
408408

409+
it should "correctly serialize and deserialize from json with MulticlassificationMetrics" in {
410+
val trainMetrics = MultiClassificationMetrics(
411+
Precision = 0.1,
412+
Recall = 0.2,
413+
F1 = 0.3,
414+
Error = 0.4,
415+
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
416+
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
417+
noPredictionCounts = Map(3 -> Seq(300L))),
418+
TopKMetrics = MultiClassificationMetricsTopK(Seq(1), Seq(0.1), Seq(0.1), Seq(0.1), Seq(0.1)),
419+
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1), Seq(Seq(1L))),
420+
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
421+
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))))
422+
)
423+
424+
val holdoutMetrics = MultiClassificationMetrics(
425+
Precision = 0.1,
426+
Recall = 0.2,
427+
F1 = 0.3,
428+
Error = 0.4,
429+
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
430+
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
431+
noPredictionCounts = Map(3 -> Seq(300L))),
432+
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
433+
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(2, Seq(0.1), Seq(0.1), Seq.empty),
434+
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
435+
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq.empty)))
436+
)
437+
438+
val summary = ModelSelectorSummary(
439+
validationType = ValidationType.TrainValidationSplit,
440+
validationParameters = Map.empty,
441+
dataPrepParameters = Map.empty,
442+
dataPrepResults = None,
443+
evaluationMetric = MultiClassEvalMetrics.Error,
444+
problemType = ProblemType.MultiClassification,
445+
bestModelUID = "test1",
446+
bestModelName = "test2",
447+
bestModelType = "test3",
448+
validationResults = Seq.empty,
449+
trainEvaluation = trainMetrics,
450+
holdoutEvaluation = Some(holdoutMetrics)
451+
)
452+
453+
val insights = workflowModel.modelInsights(pred).copy(selectedModelInfo = Some(summary))
454+
ModelInsights.fromJson(insights.toJson()) match {
455+
case Failure(e) => fail(e)
456+
case Success(deser) =>
457+
insights.selectedModelInfo.toSeq.zip(deser.selectedModelInfo.toSeq).foreach {
458+
case (o, i) =>
459+
o.trainEvaluation shouldEqual i.trainEvaluation
460+
o.holdoutEvaluation shouldEqual i.holdoutEvaluation
461+
}
462+
}
463+
}
409464

410465
it should "correctly serialize and deserialize from json when raw feature filter is not used" in {
411466
val insights = workflowModel.modelInsights(pred)

core/src/test/scala/com/salesforce/op/evaluators/OpMultiClassificationEvaluatorTest.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
398398

399399
// create a test 2D array where 1st dimension is the label and 2nd dimension is the prediction,
400400
// and the # of (label, prediction) equals to the value of the label
401-
// _| 1 2 3
402-
// 1| 1L 1L 1L
403-
// 2| 2L 2L 2L
404-
// 3| 3L 3L 3L
401+
// ___| 1.0 2.0 3.0
402+
// 1.0| 1L 1L 1L
403+
// 2.0| 2L 2L 2L
404+
// 3.0| 3L 3L 3L
405405
val testLabels = Array(1.0, 2.0, 3.0)
406406
val labelAndPrediction = testLabels.flatMap(label => {
407407
testLabels.flatMap(pred => Seq.fill(label.toInt)((label, pred)))
@@ -437,10 +437,10 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
437437

438438
// create a test 2D array with the count of each label & prediction combination as:
439439
// row is label and column is prediction
440-
// _| 1 2 3
441-
// 1| 2L 3L 4L
442-
// 2| 3L 4L 5L
443-
// 3| 4L 5L 6L
440+
// ___| 1.0 2.0 3.0
441+
// 1.0| 2L 3L 4L
442+
// 2.0| 3L 4L 5L
443+
// 3.0| 4L 5L 6L
444444
val testLabels = List(1.0, 2.0, 3.0)
445445
val labelAndPrediction = testLabels.flatMap(label => {
446446
testLabels.flatMap(pred => Seq.fill(label.toInt + pred.toInt)((label, pred)))
@@ -452,21 +452,21 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext
452452
outputMetrics.MisClassificationsByLabel shouldEqual
453453
Seq(
454454
MisClassificationsPerCategory(Category = 3.0, TotalCount = 15L, CorrectCount = 6L,
455-
MisClassifications = Map(2.0 -> 5L, 1.0 -> 4L)),
455+
MisClassifications = Seq(ClassCount(2.0, 5L), ClassCount(1.0, 4L))),
456456
MisClassificationsPerCategory(Category = 2.0, TotalCount = 12L, CorrectCount = 4L,
457-
MisClassifications = Map(3.0 -> 5L, 1.0 -> 3L)),
457+
MisClassifications = Seq(ClassCount(3.0, 5L), ClassCount(1.0, 3L))),
458458
MisClassificationsPerCategory(Category = 1.0, TotalCount = 9L, CorrectCount = 2L,
459-
MisClassifications = Map(3.0 -> 4L, 2.0 -> 3L))
459+
MisClassifications = Seq(ClassCount(3.0, 4L), ClassCount(2.0, 3L)))
460460
)
461461

462462
outputMetrics.MisClassificationsByPrediction shouldEqual
463463
Seq(
464464
MisClassificationsPerCategory(Category = 3.0, TotalCount = 15L, CorrectCount = 6L,
465-
MisClassifications = Map(2.0 -> 5L, 1.0 -> 4L)),
465+
MisClassifications = Seq(ClassCount(2.0, 5L), ClassCount(1.0, 4L))),
466466
MisClassificationsPerCategory(Category = 2.0, TotalCount = 12L, CorrectCount = 4L,
467-
MisClassifications = Map(3.0 -> 5L, 1.0 -> 3L)),
467+
MisClassifications = Seq(ClassCount(3.0, 5L), ClassCount(1.0, 3L))),
468468
MisClassificationsPerCategory(Category = 1.0, TotalCount = 9L, CorrectCount = 2L,
469-
MisClassifications = Map(3.0 -> 4L, 2.0 -> 3L))
469+
MisClassifications = Seq(ClassCount(3.0, 4L), ClassCount(2.0, 3L)))
470470
)
471471
}
472472
}

core/src/test/scala/com/salesforce/op/stages/impl/insights/RecordInsightsLOCOTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ class RecordInsightsLOCOTest extends FunSpec with TestSparkContext with RecordIn
232232
info("Each feature vector should only have either three or four non-zero entries. One each from country and " +
233233
"picklist, while currency can have either two (if it's null the currency column will be filled with the mean)" +
234234
" or just one if it's not null.")
235-
it("should pick between 1 and 4 of the features") {
236-
all(parsed.map(_.size)) should (be >= 1 and be <= 4)
235+
it("should pick between 0 and 4 of the features") {
236+
all(parsed.map(_.size)) should (be >= 0 and be <= 4)
237237
}
238238

239239
// Grab the feature vector metadata for comparison against the LOCO record insights

core/src/test/scala/com/salesforce/op/stages/impl/selector/ModelSelectorSummaryTest.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
9797
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
9898
noPredictionCounts = Map(3 -> Seq(300L))),
9999
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
100-
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(1, Seq(1.0), Seq(0.0, 0.5), Seq.empty),
100+
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold(1, Seq(1.0), Seq(0.0, 0.5), Seq(Seq(1L))),
101101
MisClassificationMetrics = MisClassificationMetrics(1, Seq.empty,
102102
Seq(MisClassificationsPerCategory(TotalCount = 5L, CorrectCount = 3L, Category = 1.0,
103-
MisClassifications = Map(1.0 -> 2L))))),
103+
MisClassifications = Seq(ClassCount(1.0, 2L)))))),
104104
holdoutEvaluation = None
105105
)
106106

@@ -121,19 +121,23 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
121121
}
122122

123123
it should "not hide the root cause of JSON parsing errors" in {
124-
val evalMetrics = MultiClassificationMetrics(Precision = 0.1, Recall = 0.2, F1 = 0.3, Error = 0.4,
124+
val evalMetrics = MultiClassificationMetrics(
125+
Precision = 0.1,
126+
Recall = 0.2,
127+
F1 = 0.3,
128+
Error = 0.4,
125129
ThresholdMetrics = MulticlassThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
126130
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
127131
noPredictionCounts = Map(3 -> Seq(300L))),
128132
TopKMetrics = MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty),
129-
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1),
130-
Seq.empty),
133+
ConfusionMatrixMetrics = MulticlassConfMatrixMetricsByThreshold( 2, Seq(0.1), Seq(0.1), Seq(Seq(1L))),
131134
MisClassificationMetrics = MisClassificationMetrics(1,
132-
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Map(1.0 -> 3L))),
133-
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Map(1.0 -> 3L))))
135+
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))),
136+
Seq(MisClassificationsPerCategory(0.0, 5L, 5L, Seq(ClassCount(1.0, 3L)))))
134137
)
135138

136139
val evalMetricsJson = evalMetrics.toJson()
140+
println(1)
137141
val roundTripEvalMetrics = ModelSelectorSummary.evalMetFromJson(
138142
classOf[MultiClassificationMetrics].getName, evalMetricsJson).get
139143
roundTripEvalMetrics shouldBe evalMetrics

0 commit comments

Comments
 (0)