Skip to content

Commit 13ae6cf

Browse files
committed
py_to_r method for keras SharedObjectConfig
1 parent ab138a4 commit 13ae6cf

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ S3method(print,keras_training_history)
1414
S3method(print,kerastools.model.RModel)
1515
S3method(print,py_R6ClassGenerator)
1616
S3method(py_str,keras.engine.training.Model)
17+
S3method(py_to_r,keras.utils.generic_utils.SharedObjectConfig)
1718
S3method(py_to_r_wrapper,keras.engine.base_layer.Layer)
1819
S3method(py_to_r_wrapper,keras.engine.training.Model)
1920
S3method(py_to_r_wrapper,kerastools.model.RModel)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
layer_add(block_1_output)
1313
```
1414

15+
- `model$get_config()` method now returns an R object that can be safely serialized
16+
to rds.
17+
1518
# keras 2.9.0
1619

1720
- New functions for constructing custom keras subclasses:

R/layer-methods.R

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,27 @@ get_config <- function(object) {
3838
config
3939
}
4040

41+
# model$get_config() returns a nested object of list/dicts, but the contract
42+
# is that the object is serializable to json, which means that once
43+
# reticulate conversion rules are redone, the config will be guaranteed to be safe
44+
# to convert to R and back.
45+
46+
# this py_to_r method is so that model$get_config() can return a pure R object.
47+
# A keras SharedObjectConfig is a keras dictionary
48+
#' @export
49+
py_to_r.keras.utils.generic_utils.SharedObjectConfig <- function(x) {
50+
import_builtins()$dict(x)
51+
}
52+
4153

4254
#' @rdname get_config
4355
#' @export
44-
from_config <- function(config) {
45-
class <- attr(config, "config_class")
46-
class$from_config(config)
56+
from_config <- function(config, custom_objects = NULL) {
57+
class <- attr(config, "config_class") %||% keras$Model
58+
args <- list(config)
59+
if(length(custom_objects))
60+
args[[2L]] <- objects_with_py_function_names(custom_objects)
61+
do.call(class$from_config, args)
4762
}
4863

4964
#' Layer/Model weights as R arrays

man/get_config.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)