Skip to content

Commit c1fc418

Browse files
csferngtensorflow-copybara
authored andcommitted
Add epsilon to maximize_within_unit_norm to avoid division by 0.
PiperOrigin-RevId: 294535768
1 parent 6329455 commit c1fc418

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

neural_structured_learning/lib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ py_test(
137137
srcs_version = "PY2AND3",
138138
deps = [
139139
":utils",
140+
# package absl/testing:parameterized
140141
"//neural_structured_learning/configs",
141142
# package numpy
142143
# package tensorflow

neural_structured_learning/lib/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,20 @@ def normalize(tensor, norm_type, epsilon=1e-6):
5959
norm = tf.maximum(norm, epsilon)
6060
normalized_tensor = tensor / norm
6161
elif norm_type == configs.NormType.L2:
62-
normalized_tensor = tf.nn.l2_normalize(tensor, axis=target_axes)
62+
normalized_tensor = tf.nn.l2_normalize(
63+
tensor, axis=target_axes, epsilon=epsilon**2)
6364
else:
6465
raise NotImplementedError('Unrecognized or unimplemented "norm_type": %s' %
6566
norm_type)
6667
return normalized_tensor
6768

6869

6970
def _expand_to_rank(vector, rank):
71+
"""Expands a batched scalar to a tensor of certain rank."""
7072
return tf.reshape(vector, shape=[-1] + [1] * (rank - 1))
7173

7274

73-
def maximize_within_unit_norm(weights, norm_type):
75+
def maximize_within_unit_norm(weights, norm_type, epsilon=1e-6):
7476
"""Solves the maximization problem weights^T*x with the constraint norm(x)=1.
7577
7678
This op solves a batch of maximization problems at one time. The first axis of
@@ -91,6 +93,7 @@ def maximize_within_unit_norm(weights, norm_type):
9193
size).
9294
norm_type: One of `nsl.configs.NormType`, the type of the norm in the
9395
constraint.
96+
epsilon: A lower bound value for the norm to avoid division by 0.
9497
9598
Returns:
9699
A `Tensor` or a collection of `Tensor` objects (with the same structure and
@@ -122,7 +125,7 @@ def reduce_across_tensors(reduce_fn, input_tensors):
122125
if norm_type == configs.NormType.L2:
123126
squared_norm = reduce_across_tensors(tf.reduce_sum,
124127
[tf.square(t) for t in tensors])
125-
inv_global_norm = tf.math.rsqrt(squared_norm)
128+
inv_global_norm = tf.math.rsqrt(tf.maximum(squared_norm, epsilon**2))
126129
normalized_tensors = [
127130
tensor * _expand_to_rank(inv_global_norm, rank)
128131
for tensor, rank in zip(tensors, tensor_ranks)
@@ -141,8 +144,9 @@ def reduce_across_tensors(reduce_fn, input_tensors):
141144
for t, rank in zip(abs_tensors, tensor_ranks)
142145
]
143146
num_nonzero = reduce_across_tensors(tf.reduce_sum, is_max_elem)
147+
denominator = tf.maximum(num_nonzero, epsilon)
144148
mask = [
145-
is_max * tf.sign(t) / _expand_to_rank(num_nonzero, rank)
149+
is_max * tf.sign(t) / _expand_to_rank(denominator, rank)
146150
for t, rank, is_max in zip(tensors, tensor_ranks, is_max_elem)
147151
]
148152
return tf.nest.pack_sequence_as(weights, mask)

neural_structured_learning/lib/utils_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020
import math
2121

22+
from absl.testing import parameterized
2223
import neural_structured_learning.configs as configs
2324
from neural_structured_learning.lib import utils
2425
import numpy as np
2526
import tensorflow as tf
2627

2728

28-
class UtilsTest(tf.test.TestCase):
29+
class UtilsTest(tf.test.TestCase, parameterized.TestCase):
2930

3031
def testNormalizeInf(self):
3132
target_tensor = tf.constant([[1.0, 2.0, -4.0], [-1.0, 5.0, -3.0]])
@@ -109,6 +110,12 @@ def testMaximizeWithinUnitNormWithMultipleInputs(self):
109110
}
110111
self.assertAllClose(actual, expected)
111112

113+
@parameterized.parameters('l2', 'l1', 'infinity')
114+
def testMaximizeWithinUnitNormL2WithZeroInputShouldReturnZero(self, norm):
115+
weights = tf.constant([[0.0, 0.0]])
116+
actual = self.evaluate(utils.maximize_within_unit_norm(weights, norm))
117+
self.assertAllEqual(actual, weights)
118+
112119
def testReplicateEmbeddingsWithConstant(self):
113120
"""Test the replicate_embeddings function with constant replicate_times."""
114121
input_embeddings = tf.constant([

0 commit comments

Comments
 (0)