Skip to content

Commit 045472b

Browse files
author
Sigrid Keydana
committed
don't set KERAS_IMPLEMENTATION if default
1 parent b4afc0c commit 045472b

File tree

4 files changed

+17
-19
lines changed

4 files changed

+17
-19
lines changed

R/package.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ keras <- NULL
8888

8989
# resolve the implementaiton module (might be keras proper or might be tensorflow)
9090
implementation_module <- resolve_implementation_module()
91-
if (implementation_module == "tensorflow.python.keras")
92-
Sys.setenv(KERAS_IMPLEMENTATION = "tensorflow")
93-
91+
9492
# if KERAS_PYTHON is defined then forward it to RETICULATE_PYTHON
9593
keras_python <- get_keras_python()
9694
if (!is.null(keras_python))

inst/python/kerastools/layer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11

22
import os
33

4-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
4+
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'keras'):
5+
from keras.engine.topology import Layer
6+
def shape_filter(shape):
7+
return shape
8+
else:
59
from tensorflow.python.keras.engine import Layer
610
def shape_filter(shape):
711
if not isinstance(shape, list):
812
return shape.as_list()
913
else:
1014
return shape
11-
else:
12-
from keras.engine.topology import Layer
13-
def shape_filter(shape):
14-
return shape
15+
1516

1617
class RLayer(Layer):
1718

inst/python/kerastools/model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
import os
44

5-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
5+
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'keras'):
6+
from keras.engine import Model
7+
else:
68
try:
79
from tensorflow.python.keras.engine import Model
810
except:
911
from tensorflow.python.keras.engine.training import Model
10-
else:
11-
from keras.engine import Model
12-
13-
12+
1413
class RModel(Model):
1514

1615
def __init__(self, name = None):

inst/python/kerastools/wrapper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import os
22

3-
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'tensorflow'):
3+
if (os.getenv('KERAS_IMPLEMENTATION', 'keras') == 'keras'):
4+
from keras.layers import Wrapper
5+
def shape_filter(shape):
6+
return shape
7+
else:
48
from tensorflow.python.keras.layers import Wrapper
59
def shape_filter(shape):
610
if not isinstance(shape, list):
711
return shape.as_list()
812
else:
913
return shape
10-
else:
11-
from keras.layers import Wrapper
12-
def shape_filter(shape):
13-
return shape
14-
14+
1515
class RWrapper(Wrapper):
1616

1717
def __init__(self, r_build, r_call, r_compute_output_shape, **kwargs):

0 commit comments

Comments
 (0)