@@ -59,18 +59,20 @@ def normalize(tensor, norm_type, epsilon=1e-6):
5959 norm = tf .maximum (norm , epsilon )
6060 normalized_tensor = tensor / norm
6161 elif norm_type == configs .NormType .L2 :
62- normalized_tensor = tf .nn .l2_normalize (tensor , axis = target_axes )
62+ normalized_tensor = tf .nn .l2_normalize (
63+ tensor , axis = target_axes , epsilon = epsilon ** 2 )
6364 else :
6465 raise NotImplementedError ('Unrecognized or unimplemented "norm_type": %s' %
6566 norm_type )
6667 return normalized_tensor
6768
6869
6970def _expand_to_rank (vector , rank ):
71+ """Expands a batched scalar to a tensor of certain rank."""
7072 return tf .reshape (vector , shape = [- 1 ] + [1 ] * (rank - 1 ))
7173
7274
73- def maximize_within_unit_norm (weights , norm_type ):
75+ def maximize_within_unit_norm (weights , norm_type , epsilon = 1e-6 ):
7476 """Solves the maximization problem weights^T*x with the constraint norm(x)=1.
7577
7678 This op solves a batch of maximization problems at one time. The first axis of
@@ -91,6 +93,7 @@ def maximize_within_unit_norm(weights, norm_type):
9193 size).
9294 norm_type: One of `nsl.configs.NormType`, the type of the norm in the
9395 constraint.
96+ epsilon: A lower bound value for the norm to avoid division by 0.
9497
9598 Returns:
9699 A `Tensor` or a collection of `Tensor` objects (with the same structure and
@@ -122,7 +125,7 @@ def reduce_across_tensors(reduce_fn, input_tensors):
122125 if norm_type == configs .NormType .L2 :
123126 squared_norm = reduce_across_tensors (tf .reduce_sum ,
124127 [tf .square (t ) for t in tensors ])
125- inv_global_norm = tf .math .rsqrt (squared_norm )
128+ inv_global_norm = tf .math .rsqrt (tf . maximum ( squared_norm , epsilon ** 2 ) )
126129 normalized_tensors = [
127130 tensor * _expand_to_rank (inv_global_norm , rank )
128131 for tensor , rank in zip (tensors , tensor_ranks )
@@ -141,8 +144,9 @@ def reduce_across_tensors(reduce_fn, input_tensors):
141144 for t , rank in zip (abs_tensors , tensor_ranks )
142145 ]
143146 num_nonzero = reduce_across_tensors (tf .reduce_sum , is_max_elem )
147+ denominator = tf .maximum (num_nonzero , epsilon )
144148 mask = [
145- is_max * tf .sign (t ) / _expand_to_rank (num_nonzero , rank )
149+ is_max * tf .sign (t ) / _expand_to_rank (denominator , rank )
146150 for t , rank , is_max in zip (tensors , tensor_ranks , is_max_elem )
147151 ]
148152 return tf .nest .pack_sequence_as (weights , mask )
0 commit comments