@@ -289,6 +289,61 @@ def replacement(self, match_layer):
289289 # Should match since bias is initialized with zeros.
290290 self ._assert_model_results_equal (model , transformed_model )
291291
292+ def testReplaceSingleLayerWithSingleLayer_CustomLayerWithNontensorInput (self ):
293+
294+ class CustomDense (keras .layers .Dense ):
295+
296+ def call (self , inputs , identity = False ):
297+ if identity :
298+ return inputs
299+ return super ().call (inputs )
300+
301+ class QuantizedCustomDense (CustomDense ):
302+ pass
303+
304+ class ReplaceCustomDenseLayer (transforms .Transform ):
305+ """Replaces `CustomDense` layers with `QuantizedCustomDense`."""
306+
307+ def pattern (self ):
308+ return LayerPattern ('CustomDense' )
309+
310+ def replacement (self , match_layer ):
311+ match_layer_config = match_layer .layer ['config' ]
312+ my_dense_layer = QuantizedCustomDense (** match_layer_config )
313+
314+ replace_layer = keras .layers .serialize (my_dense_layer )
315+ replace_layer ['name' ] = replace_layer ['config' ]['name' ]
316+
317+ return LayerNode (replace_layer , match_layer .weights , [])
318+
319+ def custom_objects (self ):
320+ return {
321+ 'CustomDense' : CustomDense ,
322+ 'QuantizedCustomDense' : QuantizedCustomDense }
323+
324+ inp = keras .layers .Input ((3 ,))
325+ x1 = CustomDense (2 )(inp , identity = False )
326+ x2 = CustomDense (2 )(x1 , identity = True )
327+ out = keras .layers .ReLU (6.0 )(x2 )
328+ model = keras .Model (inp , [out ])
329+
330+ transformed_model , _ = ModelTransformer (
331+ model , [ReplaceCustomDenseLayer ()]).transform ()
332+
333+ # build_input_shape is a TensorShape object and the two objects are not
334+ # considered the same even though the shapes are the same.
335+ self ._assert_config (model .get_config (), transformed_model .get_config (),
336+ ['class_name' , 'build_input_shape' ])
337+
338+ self .assertEqual (
339+ 'QuantizedCustomDense' ,
340+ self ._get_layer (transformed_model , 0 , 'functional' ).__class__ .__name__ )
341+ self .assertEqual (
342+ 'QuantizedCustomDense' ,
343+ self ._get_layer (transformed_model , 1 , 'functional' ).__class__ .__name__ )
344+
345+ self ._assert_model_results_equal (model , transformed_model )
346+
292347 @parameterized .parameters (['sequential' , 'functional' ])
293348 def testReplaceSingleLayer_WithMultipleLayers (self , model_type ):
294349
0 commit comments