@@ -103,7 +103,12 @@ def _get_producers(self, layer):
103103 return producers
104104
105105 def _get_consumers (self , layer ):
106- return [node .outbound_layer for node in layer ._outbound_nodes ]
106+
107+ def unpack (layer ):
108+ return (unpack (layer .layers [0 ])
109+ if isinstance (layer , tf .keras .Sequential ) else layer )
110+
111+ return [unpack (node .outbound_layer ) for node in layer ._outbound_nodes ]
107112
108113 def _lookup_layers (self , source_layers , stop_fn , next_fn ):
109114 """Traverses the model and returns layers satisfying `stop_fn` criteria."""
@@ -127,9 +132,15 @@ def _lookup_layers(self, source_layers, stop_fn, next_fn):
127132
128133 def _start_layer_stop_fn (self , layer ):
129134 """Determines whether the layer starts a subgraph of sparse inference."""
130- return (isinstance (layer , layers .Conv2D ) and hasattr (layer , 'kernel' ) and
131- layer .kernel .shape [:3 ] == (3 , 3 , 3 ) and layer .strides == (2 , 2 ) and
132- layer .padding .lower () == 'valid' )
135+ if isinstance (layer , layers .Conv2D ):
136+ producers = self ._get_producers (layer )
137+ return (hasattr (layer , 'kernel' ) and
138+ layer .kernel .shape [:3 ] == (3 , 3 , 3 ) and
139+ layer .strides == (2 , 2 ) and layer .padding .lower () == 'valid' and
140+ len (producers ) == 1 and
141+ isinstance (producers [0 ], layers .ZeroPadding2D ) and
142+ producers [0 ].padding == ((1 , 1 ), (1 , 1 )))
143+ return False
133144
134145 def _end_layer_stop_fn (self , layer ):
135146 """Determines whether the layer ends a subgraph of sparse inference."""
@@ -157,10 +168,36 @@ def _check_layer_support(self, layer):
157168 # 3x3 stride-2 convolution (no dilation, padding 1 on each side).
158169 # 5x5 stride-1 convolution (no dilation, padding 2 on each side).
159170 # 5x5 stride-2 convolution (no dilation, padding 2 on each side).
160- return (layer .depth_multiplier == 1 and layer .dilation_rate == (1 , 1 ) and
161- (layer .kernel_size == (3 , 3 ) or layer .kernel_size == (5 , 5 )) and
162- ((layer .padding .lower () == 'same' and layer .strides == (1 , 1 )) or
163- (layer .padding .lower () == 'valid' and layer .strides == (2 , 2 ))))
171+ padding = layer .padding .lower ()
172+ producers = self ._get_producers (layer )
173+ zero_padding = (
174+ producers [0 ] if len (producers ) == 1 and
175+ isinstance (producers [0 ], layers .ZeroPadding2D ) else None )
176+
177+ supported_case_1 = (
178+ layer .kernel_size == (3 , 3 ) and layer .strides == (1 , 1 ) and
179+ padding == 'same' )
180+
181+ supported_case_2 = (
182+ layer .kernel_size == (3 , 3 ) and layer .strides == (2 , 2 ) and
183+ padding == 'valid' and zero_padding and
184+ zero_padding .padding == ((1 , 1 ), (1 , 1 )))
185+
186+ supported_case_3 = (
187+ layer .kernel_size == (5 , 5 ) and layer .strides == (1 , 1 ) and
188+ padding == 'same' )
189+
190+ supported_case_4 = (
191+ layer .kernel_size == (5 , 5 ) and layer .strides == (2 , 2 ) and
192+ padding == 'valid' and zero_padding and
193+ zero_padding .padding == ((2 , 2 ), (2 , 2 )))
194+
195+ supported = (
196+ layer .depth_multiplier == 1 and layer .dilation_rate == (1 , 1 ) and
197+ (supported_case_1 or supported_case_2 or supported_case_3 or
198+ supported_case_4 ))
199+
200+ return supported
164201 elif isinstance (layer , layers .Conv2D ):
165202 # 1x1 convolution (no stride, no dilation, no padding, no groups).
166203 return (layer .groups == 1 and layer .dilation_rate == (1 , 1 ) and
@@ -200,8 +237,9 @@ def ensure_model_supports_pruning(self, model):
200237 )
201238 if not start_layers :
202239 raise ValueError (('Could not find `Conv2D 3x3` layer with stride 2x2, '
203- '`input filters == 3` and `VALID` padding in all input '
204- 'branches of the model' ))
240+ '`input filters == 3` and `VALID` padding and '
241+ 'preceding `ZeroPadding2D` with `padding == 1` in all '
242+ 'input branches of the model' ))
205243
206244 # Search for the end layer (GlobalAveragePooling with `keepdims = True`)
207245 # for every output branch (backward).
0 commit comments