Skip to content

Commit c7ccfdd

Browse files
Merge pull request #659 from rstudio/feature/tf-keras-default
switch to tf.keras as default implementation
2 parents aae09de + 86a4e94 commit c7ccfdd

File tree

6 files changed

+25
-19
lines changed

6 files changed

+25
-19
lines changed

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11

2+
## Development version
3+
4+
- Use `tf.keras` as default implementation module.
5+
6+
27
## Keras 2.2.4 (CRAN)
38

49
- Improve handling of `timeseries_generator()` in calls to `fit_generator()`

R/package.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ use_backend <- function(backend = c("tensorflow", "cntk", "theano", "plaidml"))
7777
} else {
7878
Sys.setenv(KERAS_BACKEND = match.arg(backend))
7979
}
80+
if (backend != "tensorflow") use_implementation("keras")
8081
}
8182

8283

@@ -87,7 +88,7 @@ keras <- NULL
8788

8889
# resolve the implementaiton module (might be keras proper or might be tensorflow)
8990
implementation_module <- resolve_implementation_module()
90-
91+
9192
# if KERAS_PYTHON is defined then forward it to RETICULATE_PYTHON
9293
keras_python <- get_keras_python()
9394
if (!is.null(keras_python))
@@ -160,7 +161,7 @@ resolve_implementation_module <- function() {
160161
implementation_module
161162
}
162163

163-
get_keras_implementation <- function(default = "keras") {
164+
get_keras_implementation <- function(default = "tensorflow") {
164165
get_keras_option("KERAS_IMPLEMENTATION", default = default)
165166
}
166167

inst/python/kerastools/layer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11

22
import os
33

4-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
4+
if (os.getenv('KERAS_IMPLEMENTATION', 'tensorflow') == 'keras'):
5+
from keras.engine.topology import Layer
6+
def shape_filter(shape):
7+
return shape
8+
else:
59
from tensorflow.python.keras.engine import Layer
610
def shape_filter(shape):
711
if not isinstance(shape, list):
812
return shape.as_list()
913
else:
1014
return shape
11-
else:
12-
from keras.engine.topology import Layer
13-
def shape_filter(shape):
14-
return shape
15+
1516

1617
class RLayer(Layer):
1718

inst/python/kerastools/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
import os
44

5-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
5+
if (os.getenv('KERAS_IMPLEMENTATION', 'tensorflow') == 'keras'):
6+
from keras.engine import Model
7+
else:
68
try:
79
from tensorflow.python.keras.engine import Model
810
except:
911
from tensorflow.python.keras.engine.training import Model
10-
else:
11-
from keras.engine import Model
12-
13-
12+
1413
class RModel(Model):
1514

1615
def __init__(self, name = None):

inst/python/kerastools/wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import os
22

3-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
3+
if (os.getenv('KERAS_IMPLEMENTATION', 'tensorflow') == 'keras'):
4+
from keras.layers import Wrapper
5+
def shape_filter(shape):
6+
return shape
7+
else:
48
from tensorflow.python.keras.layers import Wrapper
59
def shape_filter(shape):
610
if not isinstance(shape, list):
711
return shape.as_list()
812
else:
913
return shape
10-
else:
11-
from keras.layers import Wrapper
12-
def shape_filter(shape):
13-
return shape
14-
14+
1515
class RWrapper(Wrapper):
1616

1717
def __init__(self, r_build, r_call, r_compute_output_shape, **kwargs):

tests/testthat/test-model-persistence.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ test_succeeds("model can be exported to TensorFlow", {
115115

116116
export <- function() tensorflow::export_savedmodel(model, model_dir)
117117

118-
if (grepl("^tensorflow", Sys.getenv("KERAS_IMPLEMENTATION"))) {
118+
if (!grepl("^keras", Sys.getenv("KERAS_IMPLEMENTATION"))) {
119119
expect_error(export())
120120
}
121121
else {

0 commit comments

Comments
 (0)