@@ -215,8 +215,8 @@ def decode(self,
215
215
def _validate_and_expand_encode_input (self , x ):
216
216
"""Validates the input to encode and modifies it if necessary."""
217
217
if x .shape .ndims not in [1 , 2 ]:
218
- raise ValueError (
219
- 'Number of dimensions must be 1 or 2. Shape of x: %s' % x .shape )
218
+ raise ValueError ('Number of dimensions must be 1 or 2. Shape of x: %s' %
219
+ x .shape )
220
220
if x .shape .ndims == 1 :
221
221
# The input to the fast_walsh_hadamard_transform must have 2 dimensions.
222
222
x = tf .expand_dims (x , 0 )
@@ -262,7 +262,7 @@ class UniformQuantizationEncodingStage(encoding_stage.EncodingStageInterface):
262
262
# otherwise be numerically unstable for float32 values.
263
263
_ALLOWED_BITS_ARG = list (range (1 , 17 ))
264
264
265
- def __init__ (self , bits = 8 , min_max = None , stochastic = True ):
265
+ def __init__ (self , bits = 8 , min_max = None , stochastic = True , ** kwargs ):
266
266
"""Initializer for the UniformQuantizationEncodingStage.
267
267
268
268
Args:
@@ -275,6 +275,7 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
275
275
stochastic: A Python bool, whether to use stochastic or deterministic
276
276
rounding. If `True`, the encoding is randomized and on expectation
277
277
unbiased. If `False`, the encoding is deterministic.
278
+ **kwargs: Keyword arguments.
278
279
279
280
Raises:
280
281
ValueError: The inputs do not satisfy the above constraints.
@@ -300,6 +301,8 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
300
301
if not isinstance (stochastic , bool ):
301
302
raise TypeError ('The stochastic argument must be a bool.' )
302
303
self ._stochastic = stochastic
304
+ self ._force_random_op_after_clipping = kwargs .get (
305
+ 'reduce_memory_use_by_forcing_random_op_after_clipping' , False )
303
306
304
307
@property
305
308
def name (self ):
@@ -350,8 +353,18 @@ def encode(self, x, encode_params):
350
353
x = tf .compat .v1 .div_no_nan (x - min_x , max_x - min_x ) * max_value
351
354
if self ._stochastic : # Randomized rounding.
352
355
floored_x = tf .floor (x )
353
- bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
354
- bernoulli = bernoulli < (x - floored_x )
356
+ residuals_x = x - floored_x
357
+ # Add graph dependencies to tensor `x` to ensure that the randomized
358
+ # rounding variables are not created before `x` is scaled above. This
359
+ # prevents TF from preallocating the tensor before it will actually be
360
+ # used, reducing memory pressure (especially important for mobile
361
+ # deployments).
362
+ if self ._force_random_op_after_clipping :
363
+ with tf .control_dependencies ([x ]):
364
+ bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
365
+ else :
366
+ bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
367
+ bernoulli = bernoulli < residuals_x
355
368
quantized_x = floored_x + tf .cast (bernoulli , x .dtype )
356
369
else : # Deterministic rounding.
357
370
quantized_x = tf .round (x )
0 commit comments