Skip to content

Commit 31bc22e

Browse files
committed
Added tensorflow linear regression toy example
1 parent 98e2d56 commit 31bc22e

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

scripts/tf_linear_regression.sc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
{
2+
import org.platanios.tensorflow.api._
3+
4+
import scala.collection.mutable.ArrayBuffer
5+
import scala.util.Random
6+
7+
val random = new Random()
8+
9+
val weight = random.nextFloat()
10+
11+
def batch(batchSize: Int): (Tensor, Tensor) = {
12+
val inputs = ArrayBuffer.empty[Float]
13+
val outputs = ArrayBuffer.empty[Float]
14+
var i = 0
15+
while (i < batchSize) {
16+
val input = random.nextFloat()
17+
inputs += input
18+
outputs += weight * input
19+
i += 1
20+
}
21+
(Tensor(inputs).reshape(Shape(-1, 1)), Tensor(outputs).reshape(Shape(-1, 1)))
22+
}
23+
24+
print("Building linear regression model.")
25+
val inputs = tf.placeholder(FLOAT32, Shape(-1, 1))
26+
val outputs = tf.placeholder(FLOAT32, Shape(-1, 1))
27+
val weights = tf.variable("weights", FLOAT32, Shape(1, 1), tf.ZerosInitializer)
28+
val predictions = tf.matmul(inputs, weights)
29+
val loss = tf.sum(tf.square(predictions - outputs))
30+
val trainOp = tf.train.AdaGrad(1.0).minimize(loss)
31+
32+
println("Training the linear regression model.")
33+
val session = Session()
34+
session.run(targets = tf.globalVariablesInitializer())
35+
for (i <- 0 to 50) {
36+
val trainBatch = batch(10000)
37+
val feeds = Map(inputs -> trainBatch._1, outputs -> trainBatch._2)
38+
val trainLoss = session.run(feeds = feeds, fetches = loss, targets = trainOp)
39+
if (i % 1 == 0)
40+
println(s"Train loss at iteration $i = ${trainLoss.scalar} " +
41+
s"(weight = ${session.run(fetches = weights.value).scalar})")
42+
}
43+
44+
println(s"Trained weight value: ${session.run(fetches = weights.value).scalar}")
45+
println(s"True weight value: $weight")
46+
}

0 commit comments

Comments
 (0)