Skip to content

Commit 5d814fb

Browse files
committed
dynaml-tensorflow: Improvements to tensorflow wrapper
- Added parameters to `dtflearn.build_tf_model()` method. - Added common stop criteria in `dtflearn` - Updated `mnist.sc` and `cifar.sc` Signed-off-by: mandar2812 <[email protected]>
1 parent 130fd53 commit 5d814fb

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/package.scala

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import io.github.mandar2812.dynaml.tensorflow.utils._
2626
import io.github.mandar2812.dynaml.tensorflow.layers._
2727
import org.platanios.tensorflow.api._
2828
import org.platanios.tensorflow.api.core.Shape
29+
import org.platanios.tensorflow.api.learn.StopCriteria
2930
import org.platanios.tensorflow.api.learn.layers.{Activation, Input, Layer}
3031
import org.platanios.tensorflow.api.ops.NN.SamePadding
3132
import org.platanios.tensorflow.api.ops.io.data.Dataset
@@ -303,15 +304,27 @@ package object tensorflow {
303304

304305
type TFDATA = Dataset[(Tensor, Tensor), (Output, Output), (DataType, DataType), (Shape, Shape)]
305306

306-
val Phi: layers.Phi.type = layers.Phi
307+
val Phi: layers.Phi.type = layers.Phi
308+
val Tanh: layers.Tanh.type = layers.Tanh
307309

308-
val Tanh: layers.Tanh.type = layers.Tanh
310+
val ctrnn: layers.FiniteHorizonCTRNN.type = layers.FiniteHorizonCTRNN
311+
val dctrnn: layers.DynamicTimeStepCTRNN.type = layers.DynamicTimeStepCTRNN
312+
val ts_linear: layers.FiniteHorizonLinear.type = layers.FiniteHorizonLinear
309313

310-
val ctrnn: layers.FiniteHorizonCTRNN.type = layers.FiniteHorizonCTRNN
314+
/**
315+
* Stop after a specified maximum number of iterations has been reached.
316+
* */
317+
val max_iter_stop: Long => StopCriteria = (n: Long) => tf.learn.StopCriteria(maxSteps = Some(n))
311318

312-
val dctrnn: layers.DynamicTimeStepCTRNN.type = layers.DynamicTimeStepCTRNN
319+
/**
320+
* Stop after the change in the loss function falls below a specified threshold.
321+
* */
322+
val abs_loss_change_stop: Double => StopCriteria = (d: Double) => tf.learn.StopCriteria(absLossChangeTol = Some(d))
313323

314-
val ts_linear: layers.FiniteHorizonLinear.type = layers.FiniteHorizonLinear
324+
/**
325+
* Stop after the relative change in the loss function falls below a specified threshold.
326+
* */
327+
val rel_loss_change_stop: Double => StopCriteria = (d: Double) => tf.learn.StopCriteria(relLossChangeTol = Some(d))
315328

316329
/**
317330
* Constructs a feed-forward layer.
@@ -493,7 +506,12 @@ package object tensorflow {
493506
* @param summariesDir A filesystem path of type [[java.nio.file.Path]], which
494507
* determines where the intermediate model parameters/checkpoints
495508
* will be written.
496-
* @param iterations The maximum number of iterations.
509+
* @param stopCriteria The stopping criteria for training, for examples see
510+
* [[max_iter_stop]], [[abs_loss_change_stop]] and [[rel_loss_change_stop]]
511+
*
512+
* @param stepRateFreq The frequency at which to log the step rate (expressed as number of iterations/sec).
513+
* @param summarySaveFreq The frequency at which to log the loss summary.
514+
* @param checkPointFreq The frequency at which to log the model parameters.
497515
* @param training_data A training data set, as an instance of [[Dataset]].
498516
*
499517
* @return A [[Tuple2]] containing the model and estimator.
@@ -506,7 +524,10 @@ package object tensorflow {
506524
loss: Layer[(Output, Output), Output],
507525
optimizer: Optimizer,
508526
summariesDir: java.nio.file.Path,
509-
iterations: Int)(
527+
stopCriteria: StopCriteria,
528+
stepRateFreq: Int = 5000,
529+
summarySaveFreq: Int = 5000,
530+
checkPointFreq: Int = 5000)(
510531
training_data: TFDATA) = {
511532

512533
val (model, estimator) = tf.createWith(graph = Graph()) {
@@ -520,16 +541,16 @@ package object tensorflow {
520541
val estimator = tf.learn.FileBasedEstimator(
521542
model,
522543
tf.learn.Configuration(Some(summariesDir)),
523-
tf.learn.StopCriteria(maxSteps = Some(iterations)),
544+
stopCriteria,
524545
Set(
525546
tf.learn.StepRateLogger(
526547
log = false, summaryDir = summariesDir,
527-
trigger = tf.learn.StepHookTrigger(5000)),
528-
tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(5000)),
529-
tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(5000))),
530-
tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = 5000))
548+
trigger = tf.learn.StepHookTrigger(stepRateFreq)),
549+
tf.learn.SummarySaver(summariesDir, tf.learn.StepHookTrigger(summarySaveFreq)),
550+
tf.learn.CheckpointSaver(summariesDir, tf.learn.StepHookTrigger(checkPointFreq))),
551+
tensorBoardConfig = tf.learn.TensorBoardConfig(summariesDir, reloadInterval = checkPointFreq))
531552

532-
estimator.train(() => training_data, tf.learn.StopCriteria(maxSteps = Some(iterations)))
553+
estimator.train(() => training_data)
533554

534555
(model, estimator)
535556
}

scripts/cifar.sc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848

4949
val (model, estimator) = dtflearn.build_tf_model(
5050
architecture, input, trainInput, trainingInputLayer,
51-
loss, optimizer, summariesDir, 1000)(trainData)
51+
loss, optimizer, summariesDir, dtflearn.max_iter_stop(1000),
52+
100, 100, 100)(trainData)
5253

5354
def accuracy(images: Tensor, labels: Tensor): Float = {
5455
val predictions = estimator.infer(() => images)

scripts/mnist.sc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747

4848
val (model, estimator) = dtflearn.build_tf_model(
4949
architecture, input, trainInput, trainingInputLayer,
50-
loss, optimizer, summariesDir, 1000)(trainData)
50+
loss, optimizer, summariesDir, dtflearn.max_iter_stop(1000),
51+
100, 100, 100)(trainData)
5152

5253
def accuracy(images: Tensor, labels: Tensor): Float = {
5354
val predictions = estimator.infer(() => images)

0 commit comments

Comments
 (0)