Skip to content

Commit b1666cf

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
For LayerPattern, match config items with regex
PiperOrigin-RevId: 324901092
1 parent e3404c6 commit b1666cf

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,19 @@ def _match_pattern(self, target, pattern):
104104
def _match_layer(self, layer, pattern):
105105
"""Check if specific layer matches the pattern."""
106106

107-
if self.candidate_layers and layer['config'][
108-
'name'] not in self.candidate_layers:
107+
if self.candidate_layers and \
108+
layer['config']['name'] not in self.candidate_layers:
109109
return False
110110

111111
if not self._match_pattern(layer['class_name'], pattern.class_name):
112112
return False
113113

114114
layer_config = layer['config']
115115
for key, value in pattern.config.items():
116-
# This comparison should probably use the serialized value.
117-
# Consider adding regex support to key/values as well. This will allow
118-
# negative matches as well.
119-
if layer_config.get(key) != value:
116+
# Either the provided value should equal the config value, or
117+
# be a regex match to str(value).
118+
if not (self._match_pattern(str(layer_config.get(key)), str(value)) or \
119+
layer_config.get(key) == value):
120120
return False
121121

122122
return True

tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,16 @@ def matched(self):
527527
def reset(self):
528528
self._matched = False
529529

530+
@parameterized.parameters(['sequential', 'functional'])
531+
def testPatternMatches_ConfigParamsRegex(self, model_type):
532+
pattern = LayerPattern('Dense', config={'name': 'dense.*'})
533+
transform = self.VerifyMatch(pattern)
534+
535+
model = self._simple_dense_model(model_type)
536+
537+
ModelTransformer(model, [transform]).transform()
538+
self.assertTrue(transform.matched())
539+
530540
@parameterized.parameters(['sequential', 'functional'])
531541
def testPatternShouldOnlyMatch_CandidateLayers(self, model_type):
532542
pattern = LayerPattern('ReLU', inputs=[LayerPattern('Dense')])
@@ -680,5 +690,6 @@ class MyModel(keras.Model):
680690
with self.assertRaises(ValueError):
681691
ModelTransformer(MyModel(), [self.ReplaceDenseLayer()]).transform()
682692

693+
683694
if __name__ == '__main__':
684695
tf.test.main()

0 commit comments

Comments
 (0)