@@ -512,7 +512,8 @@ def quantiles_initializer(shape, dtype=None, partition_info=None):
512
512
513
513
quantiles = self .add_weight (
514
514
"quantiles" , shape = (channels , 1 , 3 ), dtype = self .dtype ,
515
- initializer = quantiles_initializer )
515
+ initializer = quantiles_initializer ,
516
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA )
516
517
logits = self ._logits_cumulative (quantiles , stop_gradient = True )
517
518
loss = tf .math .reduce_sum (abs (logits - target ))
518
519
self .add_loss (loss , inputs = None )
@@ -577,10 +578,12 @@ def cdf_initializer(shape, dtype=None, partition_info=None):
577
578
shape = (channels , None ),
578
579
dtype = tf .int32 ,
579
580
trainable = False ,
580
- initializer = cdf_initializer )
581
+ initializer = cdf_initializer ,
582
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA )
581
583
cdf_length = self .add_weight (
582
584
"cdf_length" , shape = (channels ,), dtype = tf .int32 , trainable = False ,
583
- initializer = tf .initializers .constant (3 ))
585
+ initializer = tf .initializers .constant (3 ),
586
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA )
584
587
# Works around a weird TF issue with reading variables inside a loop.
585
588
self ._quantized_cdf = tf .identity (quantized_cdf )
586
589
self ._cdf_length = tf .identity (cdf_length )
@@ -855,11 +858,13 @@ def cdf_initializer(shape, dtype=None, partition_info=None):
855
858
856
859
quantized_cdf = self .add_weight (
857
860
"quantized_cdf" , shape = (len (pmf_length ), max_length + 2 ),
858
- initializer = cdf_initializer , dtype = tf .int32 , trainable = False )
861
+ initializer = cdf_initializer , dtype = tf .int32 , trainable = False ,
862
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA )
859
863
cdf_length = self .add_weight (
860
864
"cdf_length" , shape = (len (pmf_length ),),
861
865
initializer = tf .initializers .constant (pmf_length + 2 ),
862
- dtype = tf .int32 , trainable = False )
866
+ dtype = tf .int32 , trainable = False ,
867
+ aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA )
863
868
# Works around a weird TF issue with reading variables inside a loop.
864
869
self ._quantized_cdf = tf .identity (quantized_cdf )
865
870
self ._cdf_length = tf .identity (cdf_length )
0 commit comments