|
| 1 | +{ |
| 2 | + import org.platanios.tensorflow.api._ |
| 3 | + |
| 4 | + import scala.collection.mutable.ArrayBuffer |
| 5 | + import scala.util.Random |
| 6 | + |
| 7 | + val random = new Random() |
| 8 | + |
| 9 | + val weight = random.nextFloat() |
| 10 | + |
| 11 | + def batch(batchSize: Int): (Tensor, Tensor) = { |
| 12 | + val inputs = ArrayBuffer.empty[Float] |
| 13 | + val outputs = ArrayBuffer.empty[Float] |
| 14 | + var i = 0 |
| 15 | + while (i < batchSize) { |
| 16 | + val input = random.nextFloat() |
| 17 | + inputs += input |
| 18 | + outputs += weight * input |
| 19 | + i += 1 |
| 20 | + } |
| 21 | + (Tensor(inputs).reshape(Shape(-1, 1)), Tensor(outputs).reshape(Shape(-1, 1))) |
| 22 | + } |
| 23 | + |
| 24 | + print("Building linear regression model.") |
| 25 | + val inputs = tf.placeholder(FLOAT32, Shape(-1, 1)) |
| 26 | + val outputs = tf.placeholder(FLOAT32, Shape(-1, 1)) |
| 27 | + val weights = tf.variable("weights", FLOAT32, Shape(1, 1), tf.ZerosInitializer) |
| 28 | + val predictions = tf.matmul(inputs, weights) |
| 29 | + val loss = tf.sum(tf.square(predictions - outputs)) |
| 30 | + val trainOp = tf.train.AdaGrad(1.0).minimize(loss) |
| 31 | + |
| 32 | + println("Training the linear regression model.") |
| 33 | + val session = Session() |
| 34 | + session.run(targets = tf.globalVariablesInitializer()) |
| 35 | + for (i <- 0 to 50) { |
| 36 | + val trainBatch = batch(10000) |
| 37 | + val feeds = Map(inputs -> trainBatch._1, outputs -> trainBatch._2) |
| 38 | + val trainLoss = session.run(feeds = feeds, fetches = loss, targets = trainOp) |
| 39 | + if (i % 1 == 0) |
| 40 | + println(s"Train loss at iteration $i = ${trainLoss.scalar} " + |
| 41 | + s"(weight = ${session.run(fetches = weights.value).scalar})") |
| 42 | + } |
| 43 | + |
| 44 | + println(s"Trained weight value: ${session.run(fetches = weights.value).scalar}") |
| 45 | + println(s"True weight value: $weight") |
| 46 | +} |
0 commit comments