@@ -60,7 +60,14 @@ def _get_layer(self, model, n_excluding_input, model_type):
60
60
return model .layers [n_excluding_input ]
61
61
62
62
def _create_model_inputs (self , model ):
63
- return np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 1 ))
63
+ if isinstance (model .input , dict ):
64
+ inputs = {}
65
+ for key , input_layer in model .input .items ():
66
+ inputs [key ] = np .random .randn (
67
+ * self ._batch (input_layer .get_shape ().as_list (), 1 ))
68
+ return inputs
69
+ else :
70
+ return np .random .randn (* self ._batch (model .input .get_shape ().as_list (), 1 ))
64
71
65
72
def _simple_dense_model (self , model_type = 'functional' ):
66
73
if model_type == 'functional' :
@@ -90,9 +97,7 @@ def _nested_model(self, model_type='functional', submodel_type='functional'):
90
97
out = keras .layers .ReLU (6.0 )(x )
91
98
return keras .Model (inp , out )
92
99
elif model_type == 'sequential' :
93
- return keras .Sequential (
94
- [submodel ,
95
- keras .layers .ReLU (6.0 )])
100
+ return keras .Sequential ([submodel , keras .layers .ReLU (6.0 )])
96
101
97
102
def _assert_config (self , expected_config , actual_config , exclude_keys = None ):
98
103
"""Asserts that the two config dictionaries are equal.
@@ -216,6 +221,35 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences(
216
221
217
222
self ._assert_model_results_equal (model , transformed_model )
218
223
224
+ def testReplaceSingleLayerWithSingleLayer_DictInputOutput (self ):
225
+ inp = {
226
+ 'input1' : keras .layers .Input ((3 ,)),
227
+ 'input2' : keras .layers .Input ((3 ,))
228
+ }
229
+ x1 = keras .layers .Dense (2 )(inp ['input1' ])
230
+ x2 = keras .layers .Dense (2 )(inp ['input2' ])
231
+ out1 = keras .layers .ReLU (6.0 )(x1 )
232
+ out2 = keras .layers .ReLU (6.0 )(x2 )
233
+ model = keras .Model (inp , {'output1' : out1 , 'output2' : out2 })
234
+
235
+ transformed_model , _ = ModelTransformer (
236
+ model , [self .ReplaceDenseLayer ()]).transform ()
237
+
238
+ # build_input_shape is a TensorShape object and the two objects are not
239
+ # considered the same even though the shapes are the same.
240
+ self ._assert_config (model .get_config (), transformed_model .get_config (),
241
+ ['class_name' , 'build_input_shape' ])
242
+
243
+ # There are two input layers in the input dict.
244
+ self .assertEqual (
245
+ 'MyDense' ,
246
+ self ._get_layer (transformed_model , 1 , 'functional' ).__class__ .__name__ )
247
+ self .assertEqual (
248
+ 'MyDense' ,
249
+ self ._get_layer (transformed_model , 2 , 'functional' ).__class__ .__name__ )
250
+
251
+ self ._assert_model_results_equal (model , transformed_model )
252
+
219
253
@parameterized .parameters (['sequential' , 'functional' ])
220
254
def testReplaceSingleLayerWithSingleLayer_MatchParameters (self , model_type ):
221
255
@@ -241,8 +275,8 @@ def replacement(self, match_layer):
241
275
242
276
model = self ._simple_dense_model (model_type )
243
277
244
- transformed_model , _ = ModelTransformer (
245
- model , [RemoveBiasInDense ()]).transform ()
278
+ transformed_model , _ = ModelTransformer (model ,
279
+ [RemoveBiasInDense ()]).transform ()
246
280
247
281
# build_input_shape is a TensorShape object and the two objects are not
248
282
# considered the same even though the shapes are the same.
@@ -312,8 +346,7 @@ def replacement(self, match_layer):
312
346
layer_config ['name' ] = activation_layer .name
313
347
314
348
activation_layer_node = LayerNode (
315
- layer_config ,
316
- input_layers = [match_layer ])
349
+ layer_config , input_layers = [match_layer ])
317
350
318
351
return activation_layer_node
319
352
@@ -371,8 +404,8 @@ def replacement(self, match_layer):
371
404
keras .layers .ReLU ()])
372
405
model .set_weights (model_fused .get_weights ())
373
406
374
- transformed_model , _ = ModelTransformer (
375
- model , [FuseReLUIntoDense ()]).transform ()
407
+ transformed_model , _ = ModelTransformer (model ,
408
+ [FuseReLUIntoDense ()]).transform ()
376
409
377
410
self ._assert_config (
378
411
model_fused .get_config (),
@@ -430,6 +463,7 @@ def replacement(self, match_layer):
430
463
['build_input_shape' ])
431
464
432
465
def testReplaceListOfLayers_Sequential (self ):
466
+
433
467
class ReplaceConvBatchNorm (transforms .Transform ):
434
468
"""Replaces a ConvBatchNorm pattern with the same set of layers.
435
469
@@ -438,8 +472,8 @@ class ReplaceConvBatchNorm(transforms.Transform):
438
472
"""
439
473
440
474
def pattern (self ):
441
- return LayerPattern ('BatchNormalization' ,
442
- inputs = [LayerPattern ('Conv2D' )])
475
+ return LayerPattern (
476
+ 'BatchNormalization' , inputs = [LayerPattern ('Conv2D' )])
443
477
444
478
def replacement (self , match_layer ):
445
479
# Adds a modification so the transform happens. If the layers are
@@ -457,7 +491,8 @@ def replacement(self, match_layer):
457
491
transformed_model , _ = ModelTransformer (
458
492
model , [ReplaceConvBatchNorm ()]).transform ()
459
493
transformed_model_layer_names = [
460
- layer .name for layer in transformed_model .layers ]
494
+ layer .name for layer in transformed_model .layers
495
+ ]
461
496
462
497
self .assertEqual (model_layer_names , transformed_model_layer_names )
463
498
@@ -495,10 +530,10 @@ def replacement(self, match_layer):
495
530
496
531
model = self ._simple_dense_model (model_type )
497
532
transformed_model , _ = ModelTransformer (
498
- model ,
499
- [ ReplaceReLUWithSoftmax (), ReplaceSoftmaxWithELU ()],
500
- candidate_layers = set ([layer .name for layer in model .layers ])
501
- ).transform ()
533
+ model , [ ReplaceReLUWithSoftmax (),
534
+ ReplaceSoftmaxWithELU ()],
535
+ candidate_layers = set ([layer .name for layer in model .layers
536
+ ]) ).transform ()
502
537
503
538
self .assertEqual (transformed_model .layers [- 1 ].__class__ .__name__ , 'ELU' )
504
539
@@ -515,8 +550,8 @@ def replacement(self, match_layer):
515
550
516
551
model = self ._simple_dense_model (model_type )
517
552
518
- transformed_model , _ = ModelTransformer (
519
- model , [ReplaceWithSelf ()]).transform ()
553
+ transformed_model , _ = ModelTransformer (model ,
554
+ [ReplaceWithSelf ()]).transform ()
520
555
521
556
# build_input_shape is a TensorShape object and the two objects are not
522
557
# considered the same even though the shapes are the same.
@@ -689,8 +724,8 @@ def replacement(self, match_layer):
689
724
}
690
725
}
691
726
692
- transformer = ModelTransformer (
693
- model , [ ReplaceLayerMetadata ()], None , layer_metadata )
727
+ transformer = ModelTransformer (model , [ ReplaceLayerMetadata ()], None ,
728
+ layer_metadata )
694
729
transformed_model , updated_metadata = transformer .transform ()
695
730
696
731
self .assertEqual (expected_metadata , updated_metadata )
@@ -704,12 +739,12 @@ def replacement(self, match_layer):
704
739
('sequential' , 'sequential' ),
705
740
('sequential' , 'functional' ),
706
741
('functional' , 'sequential' ),
707
- ('functional' , 'functional' ),])
742
+ ('functional' , 'functional' ),
743
+ ])
708
744
def testNestedModelNoChange (self , model_type , submodel_type ):
709
745
model = self ._nested_model (model_type , submodel_type )
710
746
711
- transformed_model , _ = ModelTransformer (
712
- model , []).transform ()
747
+ transformed_model , _ = ModelTransformer (model , []).transform ()
713
748
714
749
# build_input_shape is a TensorShape object and the two objects are not
715
750
# considered the same even though the shapes are the same.
@@ -721,6 +756,7 @@ def testNestedModelNoChange(self, model_type, submodel_type):
721
756
# Validation Tests
722
757
723
758
def testRaisesErrorForSubclassModels (self ):
759
+
724
760
class MyModel (keras .Model ):
725
761
pass
726
762
0 commit comments