@@ -103,7 +103,12 @@ def _get_producers(self, layer):
103
103
return producers
104
104
105
105
def _get_consumers (self , layer ):
106
- return [node .outbound_layer for node in layer ._outbound_nodes ]
106
+
107
+ def unpack (layer ):
108
+ return (unpack (layer .layers [0 ])
109
+ if isinstance (layer , tf .keras .Sequential ) else layer )
110
+
111
+ return [unpack (node .outbound_layer ) for node in layer ._outbound_nodes ]
107
112
108
113
def _lookup_layers (self , source_layers , stop_fn , next_fn ):
109
114
"""Traverses the model and returns layers satisfying `stop_fn` criteria."""
@@ -127,9 +132,15 @@ def _lookup_layers(self, source_layers, stop_fn, next_fn):
127
132
128
133
def _start_layer_stop_fn (self , layer ):
129
134
"""Determines whether the layer starts a subgraph of sparse inference."""
130
- return (isinstance (layer , layers .Conv2D ) and hasattr (layer , 'kernel' ) and
131
- layer .kernel .shape [:3 ] == (3 , 3 , 3 ) and layer .strides == (2 , 2 ) and
132
- layer .padding .lower () == 'valid' )
135
+ if isinstance (layer , layers .Conv2D ):
136
+ producers = self ._get_producers (layer )
137
+ return (hasattr (layer , 'kernel' ) and
138
+ layer .kernel .shape [:3 ] == (3 , 3 , 3 ) and
139
+ layer .strides == (2 , 2 ) and layer .padding .lower () == 'valid' and
140
+ len (producers ) == 1 and
141
+ isinstance (producers [0 ], layers .ZeroPadding2D ) and
142
+ producers [0 ].padding == ((1 , 1 ), (1 , 1 )))
143
+ return False
133
144
134
145
def _end_layer_stop_fn (self , layer ):
135
146
"""Determines whether the layer ends a subgraph of sparse inference."""
@@ -157,10 +168,36 @@ def _check_layer_support(self, layer):
157
168
# 3x3 stride-2 convolution (no dilation, padding 1 on each side).
158
169
# 5x5 stride-1 convolution (no dilation, padding 2 on each side).
159
170
# 5x5 stride-2 convolution (no dilation, padding 2 on each side).
160
- return (layer .depth_multiplier == 1 and layer .dilation_rate == (1 , 1 ) and
161
- (layer .kernel_size == (3 , 3 ) or layer .kernel_size == (5 , 5 )) and
162
- ((layer .padding .lower () == 'same' and layer .strides == (1 , 1 )) or
163
- (layer .padding .lower () == 'valid' and layer .strides == (2 , 2 ))))
171
+ padding = layer .padding .lower ()
172
+ producers = self ._get_producers (layer )
173
+ zero_padding = (
174
+ producers [0 ] if len (producers ) == 1 and
175
+ isinstance (producers [0 ], layers .ZeroPadding2D ) else None )
176
+
177
+ supported_case_1 = (
178
+ layer .kernel_size == (3 , 3 ) and layer .strides == (1 , 1 ) and
179
+ padding == 'same' )
180
+
181
+ supported_case_2 = (
182
+ layer .kernel_size == (3 , 3 ) and layer .strides == (2 , 2 ) and
183
+ padding == 'valid' and zero_padding and
184
+ zero_padding .padding == ((1 , 1 ), (1 , 1 )))
185
+
186
+ supported_case_3 = (
187
+ layer .kernel_size == (5 , 5 ) and layer .strides == (1 , 1 ) and
188
+ padding == 'same' )
189
+
190
+ supported_case_4 = (
191
+ layer .kernel_size == (5 , 5 ) and layer .strides == (2 , 2 ) and
192
+ padding == 'valid' and zero_padding and
193
+ zero_padding .padding == ((2 , 2 ), (2 , 2 )))
194
+
195
+ supported = (
196
+ layer .depth_multiplier == 1 and layer .dilation_rate == (1 , 1 ) and
197
+ (supported_case_1 or supported_case_2 or supported_case_3 or
198
+ supported_case_4 ))
199
+
200
+ return supported
164
201
elif isinstance (layer , layers .Conv2D ):
165
202
# 1x1 convolution (no stride, no dilation, no padding, no groups).
166
203
return (layer .groups == 1 and layer .dilation_rate == (1 , 1 ) and
@@ -200,8 +237,9 @@ def ensure_model_supports_pruning(self, model):
200
237
)
201
238
if not start_layers :
202
239
raise ValueError (('Could not find `Conv2D 3x3` layer with stride 2x2, '
203
- '`input filters == 3` and `VALID` padding in all input '
204
- 'branches of the model' ))
240
+ '`input filters == 3` and `VALID` padding and '
241
+ 'preceding `ZeroPadding2D` with `padding == 1` in all '
242
+ 'input branches of the model' ))
205
243
206
244
# Search for the end layer (GlobalAveragePooling with `keepdims = True`)
207
245
# for every output branch (backward).
0 commit comments