@@ -26,6 +26,7 @@ import io.github.mandar2812.dynaml.tensorflow.utils._
2626import io .github .mandar2812 .dynaml .tensorflow .layers ._
2727import org .platanios .tensorflow .api ._
2828import org .platanios .tensorflow .api .core .Shape
29+ import org .platanios .tensorflow .api .learn .StopCriteria
2930import org .platanios .tensorflow .api .learn .layers .{Activation , Input , Layer }
3031import org .platanios .tensorflow .api .ops .NN .SamePadding
3132import org .platanios .tensorflow .api .ops .io .data .Dataset
@@ -303,15 +304,27 @@ package object tensorflow {
303304
304305 type TFDATA = Dataset [(Tensor , Tensor ), (Output , Output ), (DataType , DataType ), (Shape , Shape )]
305306
306- val Phi : layers.Phi .type = layers.Phi
307+ val Phi : layers.Phi .type = layers.Phi
308+ val Tanh : layers.Tanh .type = layers.Tanh
307309
308- val Tanh : layers.Tanh .type = layers.Tanh
310+ val ctrnn : layers.FiniteHorizonCTRNN .type = layers.FiniteHorizonCTRNN
311+ val dctrnn : layers.DynamicTimeStepCTRNN .type = layers.DynamicTimeStepCTRNN
312+ val ts_linear : layers.FiniteHorizonLinear .type = layers.FiniteHorizonLinear
309313
310- val ctrnn : layers.FiniteHorizonCTRNN .type = layers.FiniteHorizonCTRNN
314+ /**
315+ * Stop after a specified maximum number of iterations has been reached.
316+ * */
317+ val max_iter_stop : Long => StopCriteria = (n : Long ) => tf.learn.StopCriteria (maxSteps = Some (n))
311318
312- val dctrnn : layers.DynamicTimeStepCTRNN .type = layers.DynamicTimeStepCTRNN
319+ /**
320+ * Stop after the change in the loss function falls below a specified threshold.
321+ * */
322+ val abs_loss_change_stop : Double => StopCriteria = (d : Double ) => tf.learn.StopCriteria (absLossChangeTol = Some (d))
313323
314- val ts_linear : layers.FiniteHorizonLinear .type = layers.FiniteHorizonLinear
324+ /**
325+ * Stop after the relative change in the loss function falls below a specified threshold.
326+ * */
327+ val rel_loss_change_stop : Double => StopCriteria = (d : Double ) => tf.learn.StopCriteria (relLossChangeTol = Some (d))
315328
316329 /**
317330 * Constructs a feed-forward layer.
@@ -493,7 +506,12 @@ package object tensorflow {
493506 * @param summariesDir A filesystem path of type [[java.nio.file.Path ]], which
494507 * determines where the intermediate model parameters/checkpoints
495508 * will be written.
496- * @param iterations The maximum number of iterations.
509+ * @param stopCriteria The stopping criteria for training, for examples see
510+ * [[max_iter_stop ]], [[abs_loss_change_stop ]] and [[rel_loss_change_stop ]]
511+ *
512+ * @param stepRateFreq The frequency at which to log the step rate (expressed as number of iterations/sec).
513+ * @param summarySaveFreq The frequency at which to log the loss summary.
514+ * @param checkPointFreq The frequency at which to log the model parameters.
497515 * @param training_data A training data set, as an instance of [[Dataset ]].
498516 *
499517 * @return A [[Tuple2 ]] containing the model and estimator.
@@ -506,7 +524,10 @@ package object tensorflow {
506524 loss : Layer [(Output , Output ), Output ],
507525 optimizer : Optimizer ,
508526 summariesDir : java.nio.file.Path ,
509- iterations : Int )(
527+ stopCriteria : StopCriteria ,
528+ stepRateFreq : Int = 5000 ,
529+ summarySaveFreq : Int = 5000 ,
530+ checkPointFreq : Int = 5000 )(
510531 training_data : TFDATA ) = {
511532
512533 val (model, estimator) = tf.createWith(graph = Graph ()) {
@@ -520,16 +541,16 @@ package object tensorflow {
520541 val estimator = tf.learn.FileBasedEstimator (
521542 model,
522543 tf.learn.Configuration (Some (summariesDir)),
523- tf.learn. StopCriteria (maxSteps = Some (iterations)) ,
544+ stopCriteria ,
524545 Set (
525546 tf.learn.StepRateLogger (
526547 log = false , summaryDir = summariesDir,
527- trigger = tf.learn.StepHookTrigger (5000 )),
528- tf.learn.SummarySaver (summariesDir, tf.learn.StepHookTrigger (5000 )),
529- tf.learn.CheckpointSaver (summariesDir, tf.learn.StepHookTrigger (5000 ))),
530- tensorBoardConfig = tf.learn.TensorBoardConfig (summariesDir, reloadInterval = 5000 ))
548+ trigger = tf.learn.StepHookTrigger (stepRateFreq )),
549+ tf.learn.SummarySaver (summariesDir, tf.learn.StepHookTrigger (summarySaveFreq )),
550+ tf.learn.CheckpointSaver (summariesDir, tf.learn.StepHookTrigger (checkPointFreq ))),
551+ tensorBoardConfig = tf.learn.TensorBoardConfig (summariesDir, reloadInterval = checkPointFreq ))
531552
532- estimator.train(() => training_data, tf.learn. StopCriteria (maxSteps = Some (iterations)) )
553+ estimator.train(() => training_data)
533554
534555 (model, estimator)
535556 }
0 commit comments