Skip to content

Commit 975a3a8

Browse files
committed
evaluation: Fix to TF Regression Metrics' plot method
1 parent 8206bf3 commit 975a3a8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/evaluation/RegressionMetricsTF.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class RegressionMetricsTF(preds: Tensor, targets: Tensor)
4242

4343
if(num_outputs == 1) {
4444
val (pr, tar) = (
45-
scoresAndLabels._1.entriesIterator.map(_.asInstanceOf[Double]),
46-
scoresAndLabels._2.entriesIterator.map(_.asInstanceOf[Double]))
45+
scoresAndLabels._1.entriesIterator.map(_.asInstanceOf[Float]),
46+
scoresAndLabels._2.entriesIterator.map(_.asInstanceOf[Float]))
4747

4848
regression(pr.zip(tar).toSeq)
4949

@@ -54,8 +54,8 @@ class RegressionMetricsTF(preds: Tensor, targets: Tensor)
5454
} else {
5555
(0 until num_outputs).foreach(output => {
5656
val (pr, tar) = (
57-
scoresAndLabels._1(::, output).entriesIterator.map(_.asInstanceOf[Double]),
58-
scoresAndLabels._2(::, output).entriesIterator.map(_.asInstanceOf[Double]))
57+
scoresAndLabels._1(::, output).entriesIterator.map(_.asInstanceOf[Float]),
58+
scoresAndLabels._2(::, output).entriesIterator.map(_.asInstanceOf[Float]))
5959

6060
regression(pr.zip(tar).toSeq)
6161
})

0 commit comments

Comments
 (0)