Skip to content

Commit fb7688c

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Optimize analyzers and mappers for faster inference in TF2
PiperOrigin-RevId: 604944502
1 parent c4e6066 commit fb7688c

File tree

5 files changed

+7
-22
lines changed

5 files changed

+7
-22
lines changed

RELEASE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
## Breaking Changes
1818

19-
* Existing `tft.vocabulary` cache is automatically invalidated.
19+
* Existing analyzer cache is automatically invalidated.
2020

2121
## Deprecations
2222

tensorflow_transform/analyzer_nodes.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,7 @@ def _bind_future_as_tensor_v2(
281281
replaced_result)
282282
return replaced_result
283283
else:
284-
# Without the identity wrapper some V2 tests fail with AttributeError:
285-
# Tensor.name is meaningless when eager execution is enabled.
286-
# TODO(b/149997088): Remove the identity wrapper once we no longer rely on
287-
# tensor names.
288-
return tf.identity(replaced_result)
284+
return replaced_result
289285
else:
290286
graph.add_to_collection(TENSOR_REPLACEMENTS, tensor_sink)
291287
eager_asset_path = temporary_analyzer_info.eager_asset_path

tensorflow_transform/analyzers.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1973,15 +1973,10 @@ def _get_vocabulary_analyzer_inputs(
19731973
elif vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY:
19741974
reduced_batch = tf_utils.reduce_batch_weighted_counts(
19751975
x, weights, filter_regex=filter_regex)
1976-
assert reduced_batch.summed_positive_per_x_and_y is None
1977-
assert reduced_batch.counts_per_x is None
19781976
return [reduced_batch.unique_x, reduced_batch.summed_weights_per_x]
19791977
else:
19801978
reduced_batch = tf_utils.reduce_batch_weighted_counts(
19811979
x, filter_regex=filter_regex)
1982-
assert reduced_batch.summed_weights_per_x is None
1983-
assert reduced_batch.summed_positive_per_x_and_y is None
1984-
assert reduced_batch.counts_per_x is None
19851980
return [reduced_batch.unique_x]
19861981

19871982

tensorflow_transform/mappers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,11 +1449,12 @@ def _deduplicate_row(dedup_row_loop_vars):
14491449

14501450
# Keep track of the maximum number of unique elements in a row, as this
14511451
# will determine the resulting dense shape.
1452+
num_unique_values = tf.shape(row_values)[0]
14521453
max_unique = tf.cast(
1453-
tf.maximum(tf.cast(tf.shape(row_values)[0], tf.int64), max_unique),
1454+
tf.maximum(tf.cast(num_unique_values, tf.int64), max_unique),
14541455
tf.int64)
14551456
column_indices = tf.cast(
1456-
tf.expand_dims(tf.range(tf.shape(row_values)[0]), axis=1), tf.int64)
1457+
tf.expand_dims(tf.range(num_unique_values), axis=1), tf.int64)
14571458
row_indices = tf.fill(tf.shape(column_indices), tf.cast(index, tf.int64))
14581459
values = values.write(index, row_values)
14591460
indices = indices.write(index, tf.concat([row_indices, column_indices], 1))

tensorflow_transform/tf_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,6 @@ def reduce_batch_weighted_counts(
187187
else:
188188
# TODO(b/112916494): Always do batch wise reduction once possible.
189189
return ReducedBatchWeightedCounts(flat_x, None, None, None)
190-
# TODO(b/134075780): Revisit expected weights shape when input is composite.
191-
x, weights = assert_same_shape(x, weights)
192190
weights = filter_fn(tf.reshape(weights, [-1]))
193191
unique_x_values, unique_idx, _ = tf.unique_with_counts(
194192
flat_x, out_idx=tf.int64)
@@ -410,7 +408,6 @@ def _preprocess_tensors_for_cooccurences(
410408
x, weights_input = assert_same_shape(x, weights_input)
411409
weights = weights_input
412410
y = _broadcast_to_x_shape(x, y)
413-
x, y = assert_same_shape(x, y)
414411
x = tf.reshape(x, [-1])
415412
filter_fn = _make_regex_filter_fn(x, filter_regex)
416413
x = filter_fn(x)
@@ -593,8 +590,7 @@ def _broadcast_to_x_shape(x, y):
593590
y_shape = tf.shape(input=y)
594591
assert_eq = tf.compat.v1.assert_equal(x_shape[0], y_shape[0])
595592
with tf.control_dependencies([assert_eq]):
596-
y = tf.identity(y)
597-
rank_delta = tf.rank(x) - tf.rank(y)
593+
rank_delta = tf.rank(x) - tf.rank(y)
598594
target_shape = tf.concat(
599595
[tf.shape(y), tf.ones(rank_delta, dtype=tf.int32)], axis=0)
600596
matched_rank = tf.reshape(y, target_shape)
@@ -1756,7 +1752,7 @@ def reduce_batch_minus_min_and_max(
17561752

17571753
x_batch_max = tf.reduce_max(input_tensor=x)
17581754
x_batch_minus_min = tf.reduce_max(input_tensor=tf.zeros_like(x) - x)
1759-
return assert_same_shape(x_batch_minus_min, x_batch_max)
1755+
return x_batch_minus_min, x_batch_max
17601756

17611757
elif isinstance(x, tf.SparseTensor):
17621758
return _sparse_minus_reduce_min_and_reduce_max(x)
@@ -1820,9 +1816,6 @@ def get_batch_max_per_key(tensor, key_uniques): # pylint: disable=missing-docst
18201816
x_batch_maxes = get_batch_max_per_key(x, unique)
18211817
x_batch_minus_mins = get_batch_max_per_key(-x, unique)
18221818

1823-
x_batch_minus_mins, x_batch_maxes = assert_same_shape(x_batch_minus_mins,
1824-
x_batch_maxes)
1825-
18261819
return (unique.y, x_batch_minus_mins, x_batch_maxes)
18271820

18281821

0 commit comments

Comments
 (0)