Skip to content

Commit e3f474e

Browse files
committed
evaluation: Added ClassificationMetricsTF class
1 parent ca96dc7 commit e3f474e

File tree

3 files changed

+155
-77
lines changed

3 files changed

+155
-77
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.evaluation
20+
21+
import org.platanios.tensorflow.api.Tensor
22+
23+
/**
24+
* Evaluates classification models, by calculating confusion matrices.
25+
*
26+
* @param num_classes The number of classes in the classification task
27+
* @param preds Predictions expressed as class probabilities/one-hot vectors
28+
* @param targets Class labels expressed as one-hot vectors
29+
* */
30+
class ClassificationMetricsTF(num_classes: Int, preds: Tensor, targets: Tensor) extends
31+
MetricsTF(Seq("Class Fidelity Score"), preds, targets) {
32+
33+
val confusion_matrix: Tensor = targets.transpose().matmul(preds)
34+
35+
val class_score: Tensor = {
36+
val d = confusion_matrix.trace
37+
val s = confusion_matrix.sum()
38+
d.divide(s)
39+
}
40+
41+
override protected def run(): Tensor = class_score
42+
43+
override def print(): Unit = {
44+
println("\nClassification Model Performance: "+name)
45+
scala.Predef.print("Number of classes: ")
46+
pprint.pprintln(num_classes)
47+
println("============================")
48+
println()
49+
50+
println("Confusion Matrix: ")
51+
println(confusion_matrix.summarize(maxEntries = confusion_matrix.size.toInt))
52+
println()
53+
54+
scala.Predef.print("Class Prediction Fidelity Score: ")
55+
pprint.pprintln(class_score.scalar.asInstanceOf[Double])
56+
}
57+
}

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/evaluation/MetricsTF.scala

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ under the License.
1818
* */
1919
package io.github.mandar2812.dynaml.evaluation
2020

21-
import com.quantifind.charts.Highcharts.{regression, title, xAxis, yAxis}
22-
import io.github.mandar2812.dynaml.tensorflow.dtf
2321
import org.platanios.tensorflow.api.{---, ::, Tensor}
2422

2523

@@ -73,78 +71,3 @@ abstract class MetricsTF(val names: Seq[String], val preds: Tensor, val targets:
7371

7472

7573
}
76-
77-
/**
78-
* Implements a common use for Regression Task Evaluators.
79-
* */
80-
class RegressionMetricsTF(preds: Tensor, targets: Tensor)
81-
extends MetricsTF(Seq("RMSE", "MAE", "Coefficient of Corr.", "Yield"), preds, targets) {
82-
83-
private val num_outputs = if (preds.shape.toTensor().size == 1) 1 else preds.shape(1)
84-
85-
private lazy val (_ , rmse , mae, corr) = RegressionMetricsTF.calculate(preds, targets)
86-
87-
private lazy val modelyield =
88-
(preds.max(axes = 0) - preds.min(axes = 0)).divide(targets.max(axes = 0) - targets.min(axes = 0))
89-
90-
override protected def run(): Tensor = dtf.stack(Seq(rmse, mae, corr, modelyield))
91-
92-
override def generatePlots(): Unit = {
93-
println("Generating Plot of Fit for each target")
94-
95-
if(num_outputs == 1) {
96-
val (pr, tar) = (
97-
scoresAndLabels._1.entriesIterator.map(_.asInstanceOf[Double]),
98-
scoresAndLabels._2.entriesIterator.map(_.asInstanceOf[Double]))
99-
100-
regression(pr.zip(tar).toSeq)
101-
102-
title("Goodness of fit: "+name)
103-
xAxis("Predicted "+name)
104-
yAxis("Actual "+name)
105-
106-
} else {
107-
(0 until num_outputs).foreach(output => {
108-
val (pr, tar) = (
109-
scoresAndLabels._1(::, output).entriesIterator.map(_.asInstanceOf[Double]),
110-
scoresAndLabels._2(::, output).entriesIterator.map(_.asInstanceOf[Double]))
111-
112-
regression(pr.zip(tar).toSeq)
113-
})
114-
}
115-
}
116-
}
117-
118-
/**
119-
* Implements core logic of [[RegressionMetricsTF]]
120-
* */
121-
object RegressionMetricsTF {
122-
123-
protected def calculate(preds: Tensor, targets: Tensor): (Tensor, Tensor, Tensor, Tensor) = {
124-
val error = targets.subtract(preds)
125-
126-
println("Shape of error tensor: "+error.shape.toString()+"\n")
127-
128-
val num_instances = error.shape(0)
129-
val rmse = error.square.mean(axes = 0).sqrt
130-
131-
val mae = error.abs.mean(axes = 0)
132-
133-
val corr = {
134-
135-
val mean_preds = preds.mean(axes = 0)
136-
137-
val mean_targets = targets.mean(axes = 0)
138-
139-
val preds_c = preds.subtract(dtf.stack(Seq.fill(num_instances)(mean_preds)))
140-
141-
val targets_c = targets.subtract(dtf.stack(Seq.fill(num_instances)(mean_targets)))
142-
143-
val (sigma_t, sigma_p) = (targets_c.square.mean().sqrt, preds_c.square.mean().sqrt)
144-
145-
preds_c.multiply(targets_c).mean(axes = 0).divide(sigma_t.multiply(sigma_p))
146-
}
147-
148-
(error, rmse, mae, corr)
149-
}
150-
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.evaluation
20+
21+
import com.quantifind.charts.Highcharts.{regression, title, xAxis, yAxis}
22+
import io.github.mandar2812.dynaml.tensorflow.dtf
23+
import org.platanios.tensorflow.api.{::, Tensor}
24+
25+
/**
26+
* Implements a common use for Regression Task Evaluators.
27+
* */
28+
class RegressionMetricsTF(preds: Tensor, targets: Tensor)
29+
extends MetricsTF(Seq("RMSE", "MAE", "Coefficient of Corr.", "Yield"), preds, targets) {
30+
31+
private val num_outputs = if (preds.shape.toTensor().size == 1) 1 else preds.shape(1)
32+
33+
private lazy val (_ , rmse , mae, corr) = RegressionMetricsTF.calculate(preds, targets)
34+
35+
private lazy val modelyield =
36+
(preds.max(axes = 0) - preds.min(axes = 0)).divide(targets.max(axes = 0) - targets.min(axes = 0))
37+
38+
override protected def run(): Tensor = dtf.stack(Seq(rmse, mae, corr, modelyield))
39+
40+
override def generatePlots(): Unit = {
41+
println("Generating Plot of Fit for each target")
42+
43+
if(num_outputs == 1) {
44+
val (pr, tar) = (
45+
scoresAndLabels._1.entriesIterator.map(_.asInstanceOf[Double]),
46+
scoresAndLabels._2.entriesIterator.map(_.asInstanceOf[Double]))
47+
48+
regression(pr.zip(tar).toSeq)
49+
50+
title("Goodness of fit: "+name)
51+
xAxis("Predicted "+name)
52+
yAxis("Actual "+name)
53+
54+
} else {
55+
(0 until num_outputs).foreach(output => {
56+
val (pr, tar) = (
57+
scoresAndLabels._1(::, output).entriesIterator.map(_.asInstanceOf[Double]),
58+
scoresAndLabels._2(::, output).entriesIterator.map(_.asInstanceOf[Double]))
59+
60+
regression(pr.zip(tar).toSeq)
61+
})
62+
}
63+
}
64+
}
65+
66+
/**
67+
* Implements core logic of [[RegressionMetricsTF]]
68+
* */
69+
object RegressionMetricsTF {
70+
71+
protected def calculate(preds: Tensor, targets: Tensor): (Tensor, Tensor, Tensor, Tensor) = {
72+
val error = targets.subtract(preds)
73+
74+
println("Shape of error tensor: "+error.shape.toString()+"\n")
75+
76+
val num_instances = error.shape(0)
77+
val rmse = error.square.mean(axes = 0).sqrt
78+
79+
val mae = error.abs.mean(axes = 0)
80+
81+
val corr = {
82+
83+
val mean_preds = preds.mean(axes = 0)
84+
85+
val mean_targets = targets.mean(axes = 0)
86+
87+
val preds_c = preds.subtract(dtf.stack(Seq.fill(num_instances)(mean_preds)))
88+
89+
val targets_c = targets.subtract(dtf.stack(Seq.fill(num_instances)(mean_targets)))
90+
91+
val (sigma_t, sigma_p) = (targets_c.square.mean().sqrt, preds_c.square.mean().sqrt)
92+
93+
preds_c.multiply(targets_c).mean(axes = 0).divide(sigma_t.multiply(sigma_p))
94+
}
95+
96+
(error, rmse, mae, corr)
97+
}
98+
}

0 commit comments

Comments
 (0)