@@ -60,7 +60,14 @@ def _get_layer(self, model, n_excluding_input, model_type):
6060 return model .layers [n_excluding_input ]
6161
6262 def _create_model_inputs (self , model ):
63- return np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 1 ))
63+ if isinstance (model .input , dict ):
64+ inputs = {}
65+ for key , input_layer in model .input .items ():
66+ inputs [key ] = np .random .randn (
67+ * self ._batch (input_layer .get_shape ().as_list (), 1 ))
68+ return inputs
69+ else :
70+ return np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 1 ))
6471
6572 def _simple_dense_model (self , model_type = 'functional' ):
6673 if model_type == 'functional' :
@@ -90,9 +97,7 @@ def _nested_model(self, model_type='functional', submodel_type='functional'):
9097 out = keras .layers .ReLU (6.0 )(x )
9198 return keras .Model (inp , out )
9299 elif model_type == 'sequential' :
93- return keras .Sequential (
94- [submodel ,
95- keras .layers .ReLU (6.0 )])
100+ return keras .Sequential ([submodel , keras .layers .ReLU (6.0 )])
96101
97102 def _assert_config (self , expected_config , actual_config , exclude_keys = None ):
98103 """Asserts that the two config dictionaries are equal.
@@ -216,6 +221,35 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences(
216221
217222 self ._assert_model_results_equal (model , transformed_model )
218223
224+ def testReplaceSingleLayerWithSingleLayer_DictInputOutput (self ):
225+ inp = {
226+ 'input1' : keras .layers .Input ((3 ,)),
227+ 'input2' : keras .layers .Input ((3 ,))
228+ }
229+ x1 = keras .layers .Dense (2 )(inp ['input1' ])
230+ x2 = keras .layers .Dense (2 )(inp ['input2' ])
231+ out1 = keras .layers .ReLU (6.0 )(x1 )
232+ out2 = keras .layers .ReLU (6.0 )(x2 )
233+ model = keras .Model (inp , {'output1' : out1 , 'output2' : out2 })
234+
235+ transformed_model , _ = ModelTransformer (
236+ model , [self .ReplaceDenseLayer ()]).transform ()
237+
238+ # build_input_shape is a TensorShape object and the two objects are not
239+ # considered the same even though the shapes are the same.
240+ self ._assert_config (model .get_config (), transformed_model .get_config (),
241+ ['class_name' , 'build_input_shape' ])
242+
243+ # There are two input layers in the input dict.
244+ self .assertEqual (
245+ 'MyDense' ,
246+ self ._get_layer (transformed_model , 1 , 'functional' ).__class__ .__name__ )
247+ self .assertEqual (
248+ 'MyDense' ,
249+ self ._get_layer (transformed_model , 2 , 'functional' ).__class__ .__name__ )
250+
251+ self ._assert_model_results_equal (model , transformed_model )
252+
219253 @parameterized .parameters (['sequential' , 'functional' ])
220254 def testReplaceSingleLayerWithSingleLayer_MatchParameters (self , model_type ):
221255
@@ -241,8 +275,8 @@ def replacement(self, match_layer):
241275
242276 model = self ._simple_dense_model (model_type )
243277
244- transformed_model , _ = ModelTransformer (
245- model , [RemoveBiasInDense ()]).transform ()
278+ transformed_model , _ = ModelTransformer (model ,
279+ [RemoveBiasInDense ()]).transform ()
246280
247281 # build_input_shape is a TensorShape object and the two objects are not
248282 # considered the same even though the shapes are the same.
@@ -312,8 +346,7 @@ def replacement(self, match_layer):
312346 layer_config ['name' ] = activation_layer .name
313347
314348 activation_layer_node = LayerNode (
315- layer_config ,
316- input_layers = [match_layer ])
349+ layer_config , input_layers = [match_layer ])
317350
318351 return activation_layer_node
319352
@@ -371,8 +404,8 @@ def replacement(self, match_layer):
371404 keras .layers .ReLU ()])
372405 model .set_weights (model_fused .get_weights ())
373406
374- transformed_model , _ = ModelTransformer (
375- model , [FuseReLUIntoDense ()]).transform ()
407+ transformed_model , _ = ModelTransformer (model ,
408+ [FuseReLUIntoDense ()]).transform ()
376409
377410 self ._assert_config (
378411 model_fused .get_config (),
@@ -430,6 +463,7 @@ def replacement(self, match_layer):
430463 ['build_input_shape' ])
431464
432465 def testReplaceListOfLayers_Sequential (self ):
466+
433467 class ReplaceConvBatchNorm (transforms .Transform ):
434468 """Replaces a ConvBatchNorm pattern with the same set of layers.
435469
@@ -438,8 +472,8 @@ class ReplaceConvBatchNorm(transforms.Transform):
438472 """
439473
440474 def pattern (self ):
441- return LayerPattern ('BatchNormalization' ,
442- inputs = [LayerPattern ('Conv2D' )])
475+ return LayerPattern (
476+ 'BatchNormalization' , inputs = [LayerPattern ('Conv2D' )])
443477
444478 def replacement (self , match_layer ):
445479 # Adds a modification so the transform happens. If the layers are
@@ -457,7 +491,8 @@ def replacement(self, match_layer):
457491 transformed_model , _ = ModelTransformer (
458492 model , [ReplaceConvBatchNorm ()]).transform ()
459493 transformed_model_layer_names = [
460- layer .name for layer in transformed_model .layers ]
494+ layer .name for layer in transformed_model .layers
495+ ]
461496
462497 self .assertEqual (model_layer_names , transformed_model_layer_names )
463498
@@ -495,10 +530,10 @@ def replacement(self, match_layer):
495530
496531 model = self ._simple_dense_model (model_type )
497532 transformed_model , _ = ModelTransformer (
498- model ,
499- [ ReplaceReLUWithSoftmax (), ReplaceSoftmaxWithELU ()],
500- candidate_layers = set ([layer .name for layer in model .layers ])
501- ).transform ()
533+ model , [ ReplaceReLUWithSoftmax (),
534+ ReplaceSoftmaxWithELU ()],
535+ candidate_layers = set ([layer .name for layer in model .layers
536+ ]) ).transform ()
502537
503538 self .assertEqual (transformed_model .layers [- 1 ].__class__ .__name__ , 'ELU' )
504539
@@ -515,8 +550,8 @@ def replacement(self, match_layer):
515550
516551 model = self ._simple_dense_model (model_type )
517552
518- transformed_model , _ = ModelTransformer (
519- model , [ReplaceWithSelf ()]).transform ()
553+ transformed_model , _ = ModelTransformer (model ,
554+ [ReplaceWithSelf ()]).transform ()
520555
521556 # build_input_shape is a TensorShape object and the two objects are not
522557 # considered the same even though the shapes are the same.
@@ -689,8 +724,8 @@ def replacement(self, match_layer):
689724 }
690725 }
691726
692- transformer = ModelTransformer (
693- model , [ ReplaceLayerMetadata ()], None , layer_metadata )
727+ transformer = ModelTransformer (model , [ ReplaceLayerMetadata ()], None ,
728+ layer_metadata )
694729 transformed_model , updated_metadata = transformer .transform ()
695730
696731 self .assertEqual (expected_metadata , updated_metadata )
@@ -704,12 +739,12 @@ def replacement(self, match_layer):
704739 ('sequential' , 'sequential' ),
705740 ('sequential' , 'functional' ),
706741 ('functional' , 'sequential' ),
707- ('functional' , 'functional' ),])
742+ ('functional' , 'functional' ),
743+ ])
708744 def testNestedModelNoChange (self , model_type , submodel_type ):
709745 model = self ._nested_model (model_type , submodel_type )
710746
711- transformed_model , _ = ModelTransformer (
712- model , []).transform ()
747+ transformed_model , _ = ModelTransformer (model , []).transform ()
713748
714749 # build_input_shape is a TensorShape object and the two objects are not
715750 # considered the same even though the shapes are the same.
@@ -721,6 +756,7 @@ def testNestedModelNoChange(self, model_type, submodel_type):
721756 # Validation Tests
722757
723758 def testRaisesErrorForSubclassModels (self ):
759+
724760 class MyModel (keras .Model ):
725761 pass
726762
0 commit comments