Skip to content

Commit 98e2f5e

Browse files
Add the option to enable adding tensorflow dependency control to the randomized variables generated in rounding.
PiperOrigin-RevId: 458009468
1 parent c8fc87f commit 98e2f5e

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/encoders/common_encoders.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def identity():
7979
stages_impl.IdentityEncodingStage()).make()
8080

8181

82-
def uniform_quantization(bits):
82+
def uniform_quantization(bits, **kwargs):
8383
"""Returns uniform quanitzation `Encoder`.
8484
8585
The `Encoder` first reshapes the input to a rank-1 `Tensor`, then applies
@@ -94,16 +94,17 @@ def uniform_quantization(bits):
9494
9595
Args:
9696
bits: Number of bits to quantize into.
97+
**kwargs: Keyword arguments.
9798
9899
Returns:
99100
The quantization `Encoder`.
100101
"""
101102
return core_encoder.EncoderComposer(
102103
stages_impl.BitpackingEncodingStage(bits)).add_parent(
103-
stages_impl.UniformQuantizationEncodingStage(bits), stages_impl
104-
.UniformQuantizationEncodingStage.ENCODED_VALUES_KEY).add_parent(
105-
stages_impl.FlattenEncodingStage(),
106-
stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()
104+
stages_impl.UniformQuantizationEncodingStage(bits, **kwargs),
105+
stages_impl.UniformQuantizationEncodingStage.ENCODED_VALUES_KEY
106+
).add_parent(stages_impl.FlattenEncodingStage(),
107+
stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()
107108

108109

109110
def hadamard_quantization(bits):

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/stages_impl.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def decode(self,
215215
def _validate_and_expand_encode_input(self, x):
216216
"""Validates the input to encode and modifies it if necessary."""
217217
if x.shape.ndims not in [1, 2]:
218-
raise ValueError(
219-
'Number of dimensions must be 1 or 2. Shape of x: %s' % x.shape)
218+
raise ValueError('Number of dimensions must be 1 or 2. Shape of x: %s' %
219+
x.shape)
220220
if x.shape.ndims == 1:
221221
# The input to the fast_walsh_hadamard_transform must have 2 dimensions.
222222
x = tf.expand_dims(x, 0)
@@ -262,7 +262,7 @@ class UniformQuantizationEncodingStage(encoding_stage.EncodingStageInterface):
262262
# otherwise be numerically unstable for float32 values.
263263
_ALLOWED_BITS_ARG = list(range(1, 17))
264264

265-
def __init__(self, bits=8, min_max=None, stochastic=True):
265+
def __init__(self, bits=8, min_max=None, stochastic=True, **kwargs):
266266
"""Initializer for the UniformQuantizationEncodingStage.
267267
268268
Args:
@@ -275,6 +275,7 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
275275
stochastic: A Python bool, whether to use stochastic or deterministic
276276
rounding. If `True`, the encoding is randomized and on expectation
277277
unbiased. If `False`, the encoding is deterministic.
278+
**kwargs: Keyword arguments.
278279
279280
Raises:
280281
ValueError: The inputs do not satisfy the above constraints.
@@ -300,6 +301,8 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
300301
if not isinstance(stochastic, bool):
301302
raise TypeError('The stochastic argument must be a bool.')
302303
self._stochastic = stochastic
304+
self._force_random_op_after_clipping = kwargs.get(
305+
'reduce_memory_use_by_forcing_random_op_after_clipping', False)
303306

304307
@property
305308
def name(self):
@@ -350,8 +353,18 @@ def encode(self, x, encode_params):
350353
x = tf.compat.v1.div_no_nan(x - min_x, max_x - min_x) * max_value
351354
if self._stochastic: # Randomized rounding.
352355
floored_x = tf.floor(x)
353-
bernoulli = tf.random.uniform(tf.shape(x), dtype=x.dtype)
354-
bernoulli = bernoulli < (x - floored_x)
356+
residuals_x = x - floored_x
357+
# Add graph dependencies to tensor `x` to ensure that the randomized
358+
# rounding variables are not created before `x` is scaled above. This
359+
# prevents TF from preallocating the tensor before it will actually be
360+
# used, reducing memory pressure (especially important for mobile
361+
# deployments).
362+
if self._force_random_op_after_clipping:
363+
with tf.control_dependencies([x]):
364+
bernoulli = tf.random.uniform(tf.shape(x), dtype=x.dtype)
365+
else:
366+
bernoulli = tf.random.uniform(tf.shape(x), dtype=x.dtype)
367+
bernoulli = bernoulli < residuals_x
355368
quantized_x = floored_x + tf.cast(bernoulli, x.dtype)
356369
else: # Deterministic rounding.
357370
quantized_x = tf.round(x)

0 commit comments

Comments
 (0)