Skip to content

Commit 3b4998c

Browse files
Johannes Ballécopybara-github
authored andcommitted
Adds an aggregation strategy to CDF variable.
This enables training entropy models in a distributed scenario. PiperOrigin-RevId: 305893627 Change-Id: I8293e83558f00a29765df3583b5d3d0500d0b8bb
1 parent 49c8b0a commit 3b4998c

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

tensorflow_compression/python/layers/entropy_models.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ def quantiles_initializer(shape, dtype=None, partition_info=None):
512512

513513
quantiles = self.add_weight(
514514
"quantiles", shape=(channels, 1, 3), dtype=self.dtype,
515-
initializer=quantiles_initializer)
515+
initializer=quantiles_initializer,
516+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
516517
logits = self._logits_cumulative(quantiles, stop_gradient=True)
517518
loss = tf.math.reduce_sum(abs(logits - target))
518519
self.add_loss(loss, inputs=None)
@@ -577,10 +578,12 @@ def cdf_initializer(shape, dtype=None, partition_info=None):
577578
shape=(channels, None),
578579
dtype=tf.int32,
579580
trainable=False,
580-
initializer=cdf_initializer)
581+
initializer=cdf_initializer,
582+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
581583
cdf_length = self.add_weight(
582584
"cdf_length", shape=(channels,), dtype=tf.int32, trainable=False,
583-
initializer=tf.initializers.constant(3))
585+
initializer=tf.initializers.constant(3),
586+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
584587
# Works around a weird TF issue with reading variables inside a loop.
585588
self._quantized_cdf = tf.identity(quantized_cdf)
586589
self._cdf_length = tf.identity(cdf_length)
@@ -855,11 +858,13 @@ def cdf_initializer(shape, dtype=None, partition_info=None):
855858

856859
quantized_cdf = self.add_weight(
857860
"quantized_cdf", shape=(len(pmf_length), max_length + 2),
858-
initializer=cdf_initializer, dtype=tf.int32, trainable=False)
861+
initializer=cdf_initializer, dtype=tf.int32, trainable=False,
862+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
859863
cdf_length = self.add_weight(
860864
"cdf_length", shape=(len(pmf_length),),
861865
initializer=tf.initializers.constant(pmf_length + 2),
862-
dtype=tf.int32, trainable=False)
866+
dtype=tf.int32, trainable=False,
867+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
863868
# Works around a weird TF issue with reading variables inside a loop.
864869
self._quantized_cdf = tf.identity(quantized_cdf)
865870
self._cdf_length = tf.identity(cdf_length)

0 commit comments

Comments
 (0)