Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion classification_models/keras.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import keras
# import keras
from tensorflow import keras
from .models_factory import ModelsFactory


Expand Down
21 changes: 16 additions & 5 deletions classification_models/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._common_blocks import ChannelSE
from .. import get_submodules_from_kwargs
from ..weights import load_model_weights
from tensorflow.keras.utils import get_source_inputs

backend = None
layers = None
Expand Down Expand Up @@ -209,10 +210,16 @@ def ResNet(model_params, input_shape=None, input_tensor=None, include_top=True,
if input_tensor is None:
img_input = layers.Input(shape=input_shape, name='data')
else:
if not backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
""" Commented to solve following error:
ValueError: Unexpectedly found an instance of type
`<class 'tensorflow.python.keras.engine.keras_tensor.KerasTensor'>`.
Expected a symbolic tensor instance.
"""
# if not backend.is_keras_tensor(input_tensor):
# img_input = layers.Input(tensor=input_tensor, shape=input_shape)
# else:
# img_input = input_tensor
img_input = input_tensor

# choose residual block type
ResidualBlock = model_params.residual_block
Expand Down Expand Up @@ -266,7 +273,11 @@ def ResNet(model_params, input_shape=None, input_tensor=None, include_top=True,

# Ensure that the model takes into account any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = keras_utils.get_source_inputs(input_tensor)
""" Modified to solve following error:
module 'keras.utils' has no attribute 'get_source_inputs'
"""
# inputs = keras_utils.get_source_inputs(input_tensor)
inputs = get_source_inputs(input_tensor)
else:
inputs = img_input

Expand Down
7 changes: 6 additions & 1 deletion classification_models/weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import get_submodules_from_kwargs
from tensorflow.keras.utils import get_file

__all__ = ['load_model_weights']

Expand All @@ -22,7 +23,11 @@ def load_model_weights(model, model_name, dataset, classes, include_top, **kwarg
raise ValueError('If using `weights` and `include_top`'
' as true, `classes` should be {}'.format(weights['classes']))

weights_path = keras_utils.get_file(
""" Modified to solve following error:
module 'keras.utils' has no attribute 'get_file'
"""
# weights_path = keras_utils.get_file(
weights_path = get_file(
weights['name'],
weights['url'],
cache_subdir='models',
Expand Down