Skip to content

Commit 96e0025

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Generate names with number suffix SepConv transform
SepConv1D transform generates lambda layers. This change ensures the names of these layers are numbered in suffix to avoid duplicates. PiperOrigin-RevId: 324901335
1 parent b1666cf commit 96e0025

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import numpy as np
2121
import tensorflow as tf
2222

23+
from tensorflow.python.keras import backend
24+
2325
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
2426
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
2527
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
@@ -258,6 +260,11 @@ class SeparableConv1DQuantize(transforms.Transform):
258260
def pattern(self):
259261
return LayerPattern('SeparableConv1D')
260262

263+
def _get_name(self, prefix):
264+
# TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't
265+
# exposed as externally usable.
266+
return backend.unique_object_name(prefix)
267+
261268
def replacement(self, match_layer):
262269
if _has_custom_quantize_config(match_layer):
263270
return match_layer
@@ -321,16 +328,20 @@ def replacement(self, match_layer):
321328
# Needed to ensure these new layers are considered for quantization.
322329
sepconv2d_metadata = {'quantize_config': None}
323330

331+
# TODO(pulkitb): Consider moving from Lambda to custom ExpandDims/Squeeze.
332+
324333
# Layer before SeparableConv2D which expands input tensors to match 2D.
325334
expand_layer = tf.keras.layers.Lambda(
326-
lambda x: tf.expand_dims(x, spatial_dim), name='sepconv1d_expand')
335+
lambda x: tf.expand_dims(x, spatial_dim),
336+
name=self._get_name('sepconv1d_expand'))
327337
expand_layer_config = keras.layers.serialize(expand_layer)
328338
expand_layer_config['name'] = expand_layer.name
329339
expand_layer_metadata = {
330340
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
331341

332342
squeeze_layer = tf.keras.layers.Lambda(
333-
lambda x: tf.squeeze(x, [spatial_dim]), name='sepconv1d_squeeze')
343+
lambda x: tf.squeeze(x, [spatial_dim]),
344+
name=self._get_name('sepconv1d_squeeze'))
334345
squeeze_layer_config = keras.layers.serialize(squeeze_layer)
335346
squeeze_layer_config['name'] = squeeze_layer.name
336347
squeeze_layer_metadata = {

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ def testSeparableConv1DQuantize_(self, kwargs):
258258
).transform()
259259

260260
self.assertContainsSubset(
261-
{'sepconv1d_expand', 'separable_conv1d_QAT_SepConv2D',
262-
'sepconv1d_squeeze'},
261+
{'sepconv1d_expand_1', 'separable_conv1d_QAT_SepConv2D',
262+
'sepconv1d_squeeze_1'},
263263
updated_metadata.keys())
264264
self.assertEqual(sepconv_model.output_shape, transformed_model.output_shape)
265265

0 commit comments

Comments
 (0)