Skip to content

Commit 9b76fa1

Browse files
committed
Modified mnist example for updated TF-Scala API
1 parent 7d54e04 commit 9b76fa1

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

scripts/mnist.sc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@
3232

3333
loss = loss >> tf.learn.ScalarSummary("Loss") // Collect loss summaries for plotting
3434
val summariesDir = java.nio.file.Paths.get((tempdir/"summaries").toString()) // Directory in which to save summaries and checkpoints
35-
val estimator = Estimator(model, Configuration(Some(summariesDir)))
36-
estimator.train(
37-
trainData, StopCriteria(maxSteps = Some(17000)),
38-
Seq(
35+
36+
val estimator = FileBasedEstimator(
37+
supervisedTrainableModelToModelFunction(model),
38+
Configuration(Some(summariesDir)),
39+
StopCriteria(maxSteps = Some(17000)),
40+
Set(
3941
SummarySaverHook(summariesDir, StepHookTrigger(100)), // Save summaries every 1000 steps
4042
CheckpointSaverHook(summariesDir, StepHookTrigger(1000))), // Save checkpoint every 1000 steps
4143
tensorBoardConfig = TensorBoardConfig(summariesDir)) // Launch TensorBoard server in the background
44+
45+
estimator.train(trainData)
46+
4247
}

0 commit comments

Comments
 (0)