@@ -88,14 +88,13 @@ def __init__(
88
88
# Move backbone after super() call so Keras is happy
89
89
self ._backbone = backbone
90
90
91
- def _build_network (
91
+ def _build_backbone (
92
92
self ,
93
93
backbone : tf .keras .Model ,
94
94
input_specs : Mapping [str , tf .keras .layers .InputSpec ],
95
95
state_specs : Optional [Mapping [str , tf .keras .layers .InputSpec ]] = None ,
96
- ) -> Tuple [Mapping [str , tf .keras .Input ], Union [Tuple [Mapping [ # pytype: disable=invalid-annotation # typed-keras
97
- str , tf .Tensor ], Mapping [str , tf .Tensor ]], Mapping [str , tf .Tensor ]]]:
98
- """Builds the model network.
96
+ ) -> Tuple [Mapping [str , Any ], Any , Any ]:
97
+ """Builds the backbone network and gets states and endpoints.
99
98
100
99
Args:
101
100
backbone: the model backbone.
@@ -104,9 +103,9 @@ def _build_network(
104
103
layer, will overwrite the contents of the buffer(s).
105
104
106
105
Returns:
107
- Inputs and outputs as a tuple. Inputs are expected to be a dict with
108
- base input and states. Outputs are expected to be a dict of endpoints
109
- and (optionally) output states.
106
+ inputs: a dict of input specs.
107
+ endpoints: a dict of model endpoints.
108
+ states: a dict of model states.
110
109
"""
111
110
state_specs = state_specs if state_specs is not None else {}
112
111
@@ -145,7 +144,30 @@ def _build_network(
145
144
mismatched_shapes ))
146
145
else :
147
146
endpoints , states = backbone (inputs )
147
+ return inputs , endpoints , states
148
148
149
+ def _build_network (
150
+ self ,
151
+ backbone : tf .keras .Model ,
152
+ input_specs : Mapping [str , tf .keras .layers .InputSpec ],
153
+ state_specs : Optional [Mapping [str , tf .keras .layers .InputSpec ]] = None ,
154
+ ) -> Tuple [Mapping [str , tf .keras .Input ], Union [Tuple [Mapping [ # pytype: disable=invalid-annotation # typed-keras
155
+ str , tf .Tensor ], Mapping [str , tf .Tensor ]], Mapping [str , tf .Tensor ]]]:
156
+ """Builds the model network.
157
+
158
+ Args:
159
+ backbone: the model backbone.
160
+ input_specs: the model input spec to use.
161
+ state_specs: a dict of states such that, if any of the keys match for a
162
+ layer, will overwrite the contents of the buffer(s).
163
+
164
+ Returns:
165
+ Inputs and outputs as a tuple. Inputs are expected to be a dict with
166
+ base input and states. Outputs are expected to be a dict of endpoints
167
+ and (optionally) output states.
168
+ """
169
+ inputs , endpoints , states = self ._build_backbone (
170
+ backbone = backbone , input_specs = input_specs , state_specs = state_specs )
149
171
x = endpoints ['head' ]
150
172
151
173
x = movinet_layers .ClassifierHead (
0 commit comments