Skip to content

Commit b76a0e2

Browse files
authored
Merge pull request #888 from dfalbel/test/serialize-dense-features
Can serialize a model with dense features
2 parents 799819e + 4b70056 commit b76a0e2

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/testthat/test-layers.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,4 +524,42 @@ test_call_succeeds("layer_dense_features", required_version = "2.1.3", {
524524
}
525525
})
526526

527+
test_succeeds("Can serialize a model that contains dense_features", {
528+
529+
if (tensorflow::tf_version() <= "2.0")
530+
skip("TensorFlow 2.0 is required.")
531+
532+
fc <- list(tensorflow::tf$feature_column$numeric_column("mpg"))
533+
534+
535+
input <- list(mpg = layer_input(1))
536+
537+
out <- input %>%
538+
layer_dense_features(feature_columns = fc)
539+
540+
# sequential: needs to pass a list in the begining.
541+
feature_layer <- layer_dense_features(feature_columns = fc)
542+
543+
model <- keras_model_sequential(list(
544+
feature_layer,
545+
layer_dense(units = 1)
546+
))
547+
548+
model %>% compile(loss = "mae", optimizer = "adam")
549+
550+
model %>% fit(x = list(mpg = 1:10), y = 1:10, verbose = 0)
551+
552+
pred <- predict(model, list(mpg = 1:10))
553+
554+
fname <- tempfile()
555+
save_model_tf(model, fname)
556+
557+
loaded <- load_model_tf(fname)
558+
pred2 <- predict(loaded, list(mpg = 1:10))
559+
560+
expect_equal(pred, pred2)
561+
})
562+
563+
564+
527565

0 commit comments

Comments
 (0)