Skip to content

Commit 04064f3

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Extended tf_utils.map_per_key_reductions to map reductions per-key and element-wise, for n-dim Dense Tensors.
PiperOrigin-RevId: 455681807
1 parent 0d03e77 commit 04064f3

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

tensorflow_transform/mappers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _scale_by_min_max_internal(
379379
# Missing keys will translate to 0 for both min and max which will be
380380
# ignored below in the tf.where.
381381
min_x_value, max_x_value = tf_utils.map_per_key_reductions(
382-
(min_x_value, max_x_value), key, key_vocab, x)
382+
(min_x_value, max_x_value), key, key_vocab, x, not elementwise)
383383
else:
384384
minus_min_max_for_key = tf_utils.apply_per_key_vocabulary(
385385
key_values, key, target_ndims=x.get_shape().ndims)
@@ -626,8 +626,8 @@ def _scale_to_z_score_internal(
626626
# Missing keys will translate to 0 for both mean and var which will be
627627
# ignored below in the tf.where.
628628
key_vocab, key_means, key_vars = mean_and_var_per_key_result
629-
x_mean, x_var = tf_utils.map_per_key_reductions((key_means, key_vars),
630-
key, key_vocab, x)
629+
x_mean, x_var = tf_utils.map_per_key_reductions(
630+
(key_means, key_vars), key, key_vocab, x, not elementwise)
631631
else:
632632
mean_var_for_key = tf_utils.apply_per_key_vocabulary(
633633
mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims)

tensorflow_transform/tf_utils.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,18 +1234,20 @@ def _align_dims(tensor: tf.Tensor, target_ndims: int) -> tf.Tensor:
12341234
return tensor
12351235

12361236

1237-
def map_per_key_reductions(
1238-
tensors_to_map: Tuple[tf.Tensor, ...], key: common_types.TensorType,
1239-
key_vocab: tf.Tensor,
1240-
original_input: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
1237+
def map_per_key_reductions(tensors_to_map: Tuple[tf.Tensor, ...],
1238+
key: common_types.TensorType, key_vocab: tf.Tensor,
1239+
original_input: common_types.TensorType,
1240+
reduce_instance_dims: bool) -> Tuple[tf.Tensor, ...]:
12411241
"""Rearrange the reduced per-key result to correspond to the original keys.
12421242
12431243
Args:
12441244
tensors_to_map: A tuple of 1-D `Tensor`s that are same shape as key_vocab,
1245-
to be mapped to respective key.
1245+
to be mapped to respective key.
12461246
key: A `Tensor` or `CompositeTensor`.
12471247
key_vocab: A 1-D `Tensor`.
12481248
original_input: A `Tensor` or `CompositeTensor`.
1249+
reduce_instance_dims: A `bool`. True if tensors_to_map are reduced in
1250+
dimension, else False.
12491251
12501252
Returns:
12511253
A tuple same length as tensors_to_map, of `Tensor`s the same dimension as
@@ -1262,17 +1264,22 @@ def map_per_key_reductions(
12621264
(tf.SparseTensor, tf.RaggedTensor)) else
12631265
original_input.get_shape().ndims)
12641266

1265-
# Append a 0 to allow mapping OOVs to it.
1266-
tensors_to_map = [tf.concat([t, [0]], axis=0) for t in tensors_to_map]
1267+
# Append 0s to allow mapping OOVs to it.
1268+
tensors_to_map = [
1269+
tf.concat([t, tf.expand_dims(tf.zeros_like(t[0]), 0)], axis=0)
1270+
for t in tensors_to_map
1271+
]
12671272

12681273
# Replace `-1`s due to OOV with size of key_vocab.
12691274
adjusted_indices = tf.where(
12701275
key_indices >= 0, key_indices,
12711276
tf.cast(
12721277
tf.fill(tf.shape(key_indices), tf.size(key_vocab)), dtype=tf.int64))
1273-
1274-
mapped_result = [_align_dims(tf.gather(t, adjusted_indices, axis=-1), ndims)
1275-
for t in tensors_to_map]
1278+
axis = -1 if reduce_instance_dims else 0
1279+
mapped_result = [
1280+
_align_dims(tf.gather(t, adjusted_indices, axis=axis), ndims)
1281+
for t in tensors_to_map
1282+
]
12761283

12771284
return tuple(mapped_result)
12781285

tensorflow_transform/tf_utils_test.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1942,6 +1942,7 @@ def test_convert_ragged_indices(self):
19421942
key_vocab=['a', 'b'],
19431943
reductions=([1, 2], [3, 4]),
19441944
x=[5, 6, 7],
1945+
reduce_instance_dims=True,
19451946
expected_results=([2, 1, 2], [4, 3, 4])),
19461947
dict(
19471948
testcase_name='sparse_tensor_dense_key',
@@ -1952,6 +1953,7 @@ def test_convert_ragged_indices(self):
19521953
indices=[[0, 0], [1, 2], [2, 2], [2, 3]],
19531954
values=[3, 2, -1, 3],
19541955
dense_shape=[3, 5]),
1956+
reduce_instance_dims=True,
19551957
expected_results=([2, 1, 2, 2], [4, 3, 4, 4])),
19561958
dict(
19571959
testcase_name='sparse_tensor_sparse_key',
@@ -1965,6 +1967,7 @@ def test_convert_ragged_indices(self):
19651967
indices=[[0, 0], [1, 2], [2, 2], [2, 3]],
19661968
values=[3, 2, -1, 3],
19671969
dense_shape=[3, 5]),
1970+
reduce_instance_dims=True,
19681971
expected_results=([2, 1, 2, 2], [4, 3, 4, 4])),
19691972
dict(
19701973
testcase_name='ragged_tensor_dense_key',
@@ -1976,6 +1979,7 @@ def test_convert_ragged_indices(self):
19761979
values=np.array([1.2, 1., 1.2, 1.]),
19771980
row_splits=np.array([0, 2, 4])),
19781981
row_splits=np.array([0, 1, 2, 2])),
1982+
reduce_instance_dims=True,
19791983
expected_results=([1, 1, 2, 2], [3, 3, 4, 4])),
19801984
dict(
19811985
testcase_name='ragged_tensor_ragged_key',
@@ -1991,24 +1995,45 @@ def test_convert_ragged_indices(self):
19911995
values=np.array([1.2, 1., 1.2, 1.]),
19921996
row_splits=np.array([0, 2, 4])),
19931997
row_splits=np.array([0, 2])),
1998+
reduce_instance_dims=True,
19941999
expected_results=([1, 2, 2, 1], [3, 4, 4, 3])),
19952000
dict(
19962001
testcase_name='missing_key',
19972002
key=['b', 'a', 'c'],
19982003
key_vocab=['z', 'a', 'b'],
19992004
reductions=([-77, 1, 2], [-99, 3, 4]),
20002005
x=[5, 6, 7],
2006+
reduce_instance_dims=True,
20012007
expected_results=([2, 1, 0], [4, 3, 0])),
2008+
dict(
2009+
testcase_name='_dense_tensor_2d_elementwise',
2010+
key=['a'],
2011+
key_vocab=['a', 'b'],
2012+
reductions=([[1, 5], [-2, 0]], [[5, 9], [2, 4]]),
2013+
x=[[4, 8]],
2014+
reduce_instance_dims=False,
2015+
expected_results=([[1, 5]], [[5, 9]])),
2016+
dict(
2017+
testcase_name='_dense_tensor_3d_elementwise',
2018+
key=['a'],
2019+
key_vocab=['a', 'b'],
2020+
reductions=([[[1, 1], [1, 1]], [[3, -3], [3, 3]]], [[[5, 5], [5, 5]],
2021+
[[3, -3], [3,
2022+
3]]]),
2023+
x=[[[1, 5], [1, 1]]],
2024+
reduce_instance_dims=False,
2025+
expected_results=([[[1, 1], [1, 1]]], [[[5, 5], [5, 5]]])),
20022026
)
2003-
def test_map_per_key_reductions(
2004-
self, key, key_vocab, reductions, x, expected_results):
2027+
def test_map_per_key_reductions(self, key, key_vocab, reductions, x,
2028+
reduce_instance_dims, expected_results):
20052029
with tf.compat.v1.Graph().as_default():
20062030
key = _value_to_tensor(key)
20072031
key_vocab = tf.constant(key_vocab)
20082032
reductions = tuple([tf.constant(t) for t in reductions])
20092033
x = _value_to_tensor(x)
20102034
expected_results = tuple(tf.constant(t) for t in expected_results)
2011-
results = tf_utils.map_per_key_reductions(reductions, key, key_vocab, x)
2035+
results = tf_utils.map_per_key_reductions(reductions, key, key_vocab, x,
2036+
reduce_instance_dims)
20122037
with tf.compat.v1.Session() as sess:
20132038
sess.run(tf.compat.v1.tables_initializer())
20142039
output = sess.run(results)

0 commit comments

Comments
 (0)