Skip to content

Commit f13be76

Browse files
yuanliangzhetensorflower-gardener
authored andcommitted
internal change.
PiperOrigin-RevId: 422665603
1 parent 27fb855 commit f13be76

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

official/projects/movinet/modeling/movinet_model.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,13 @@ def __init__(
8888
# Move backbone after super() call so Keras is happy
8989
self._backbone = backbone
9090

91-
def _build_network(
91+
def _build_backbone(
9292
self,
9393
backbone: tf.keras.Model,
9494
input_specs: Mapping[str, tf.keras.layers.InputSpec],
9595
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
96-
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
97-
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
98-
"""Builds the model network.
96+
) -> Tuple[Mapping[str, Any], Any, Any]:
97+
"""Builds the backbone network and gets states and endpoints.
9998
10099
Args:
101100
backbone: the model backbone.
@@ -104,9 +103,9 @@ def _build_network(
104103
layer, will overwrite the contents of the buffer(s).
105104
106105
Returns:
107-
Inputs and outputs as a tuple. Inputs are expected to be a dict with
108-
base input and states. Outputs are expected to be a dict of endpoints
109-
and (optionally) output states.
106+
inputs: a dict of input specs.
107+
endpoints: a dict of model endpoints.
108+
states: a dict of model states.
110109
"""
111110
state_specs = state_specs if state_specs is not None else {}
112111

@@ -145,7 +144,30 @@ def _build_network(
145144
mismatched_shapes))
146145
else:
147146
endpoints, states = backbone(inputs)
147+
return inputs, endpoints, states
148148

149+
def _build_network(
150+
self,
151+
backbone: tf.keras.Model,
152+
input_specs: Mapping[str, tf.keras.layers.InputSpec],
153+
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
154+
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
155+
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
156+
"""Builds the model network.
157+
158+
Args:
159+
backbone: the model backbone.
160+
input_specs: the model input spec to use.
161+
state_specs: a dict of states such that, if any of the keys match for a
162+
layer, will overwrite the contents of the buffer(s).
163+
164+
Returns:
165+
Inputs and outputs as a tuple. Inputs are expected to be a dict with
166+
base input and states. Outputs are expected to be a dict of endpoints
167+
and (optionally) output states.
168+
"""
169+
inputs, endpoints, states = self._build_backbone(
170+
backbone=backbone, input_specs=input_specs, state_specs=state_specs)
149171
x = endpoints['head']
150172

151173
x = movinet_layers.ClassifierHead(

official/projects/movinet/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
# Import movinet libraries to register the backbone and model into tf.vision
4747
# model garden factory.
4848
# pylint: disable=unused-import
49-
# the followings are the necessary imports.
49+
from official.projects.movinet.google.configs import movinet_google
50+
from official.projects.movinet.google.modeling import movinet_model_google
5051
from official.projects.movinet.modeling import movinet
5152
from official.projects.movinet.modeling import movinet_model
5253
# pylint: enable=unused-import

0 commit comments

Comments
 (0)