@@ -289,6 +289,61 @@ def replacement(self, match_layer):
289
289
# Should match since bias is initialized with zeros.
290
290
self ._assert_model_results_equal (model , transformed_model )
291
291
292
+ def testReplaceSingleLayerWithSingleLayer_CustomLayerWithNontensorInput (self ):
293
+
294
+ class CustomDense (keras .layers .Dense ):
295
+
296
+ def call (self , inputs , identity = False ):
297
+ if identity :
298
+ return inputs
299
+ return super ().call (inputs )
300
+
301
+ class QuantizedCustomDense (CustomDense ):
302
+ pass
303
+
304
+ class ReplaceCustomDenseLayer (transforms .Transform ):
305
+ """Replaces `CustomDense` layers with `QuantizedCustomDense`."""
306
+
307
+ def pattern (self ):
308
+ return LayerPattern ('CustomDense' )
309
+
310
+ def replacement (self , match_layer ):
311
+ match_layer_config = match_layer .layer ['config' ]
312
+ my_dense_layer = QuantizedCustomDense (** match_layer_config )
313
+
314
+ replace_layer = keras .layers .serialize (my_dense_layer )
315
+ replace_layer ['name' ] = replace_layer ['config' ]['name' ]
316
+
317
+ return LayerNode (replace_layer , match_layer .weights , [])
318
+
319
+ def custom_objects (self ):
320
+ return {
321
+ 'CustomDense' : CustomDense ,
322
+ 'QuantizedCustomDense' : QuantizedCustomDense }
323
+
324
+ inp = keras .layers .Input ((3 ,))
325
+ x1 = CustomDense (2 )(inp , identity = False )
326
+ x2 = CustomDense (2 )(x1 , identity = True )
327
+ out = keras .layers .ReLU (6.0 )(x2 )
328
+ model = keras .Model (inp , [out ])
329
+
330
+ transformed_model , _ = ModelTransformer (
331
+ model , [ReplaceCustomDenseLayer ()]).transform ()
332
+
333
+ # build_input_shape is a TensorShape object and the two objects are not
334
+ # considered the same even though the shapes are the same.
335
+ self ._assert_config (model .get_config (), transformed_model .get_config (),
336
+ ['class_name' , 'build_input_shape' ])
337
+
338
+ self .assertEqual (
339
+ 'QuantizedCustomDense' ,
340
+ self ._get_layer (transformed_model , 0 , 'functional' ).__class__ .__name__ )
341
+ self .assertEqual (
342
+ 'QuantizedCustomDense' ,
343
+ self ._get_layer (transformed_model , 1 , 'functional' ).__class__ .__name__ )
344
+
345
+ self ._assert_model_results_equal (model , transformed_model )
346
+
292
347
@parameterized .parameters (['sequential' , 'functional' ])
293
348
def testReplaceSingleLayer_WithMultipleLayers (self , model_type ):
294
349
0 commit comments