Skip to content

Commit 0d03e77

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Extended _min_and_max_per_key to support element-wise reduction (reduce_instance_dims=False).
PiperOrigin-RevId: 454894071
1 parent 74789ff commit 0d03e77

File tree

2 files changed

+166
-7
lines changed

2 files changed

+166
-7
lines changed

tensorflow_transform/analyzers.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ def _get_output_shape_from_input(x):
348348
return (None,)
349349

350350

351+
def _get_elementwise_per_key_output_shape(
352+
x: tf.Tensor, key: Optional[tf.Tensor]) -> Optional[Tuple[int]]:
353+
shape = x.get_shape() if key is None else x.get_shape()[1:]
354+
return tuple(shape) if shape.is_fully_defined() else None
355+
356+
351357
# TODO(b/112414577): Go back to accepting only a single input.
352358
# Currently we accept multiple inputs so that we can implement min and max
353359
# with a single combiner. Once this is done, add a return pytype as well.
@@ -401,8 +407,7 @@ def _numeric_combine(inputs: List[tf.Tensor],
401407
else:
402408
# Reducing over batch dimensions.
403409
output_shapes = [
404-
(tuple(x.get_shape()) if x.get_shape().is_fully_defined() else None)
405-
for x in inputs
410+
_get_elementwise_per_key_output_shape(x, key) for x in inputs
406411
]
407412
combiner = NumPyCombiner(fn, default_accumulator_value,
408413
[dtype.as_numpy_dtype for dtype in output_dtypes],
@@ -414,8 +419,8 @@ def _numeric_combine(inputs: List[tf.Tensor],
414419
return _apply_cacheable_combiner_per_key(combiner, key, *inputs)
415420

416421
return _apply_cacheable_combiner_per_key_large(
417-
combiner, _maybe_get_per_key_vocab_filename(key_vocabulary_filename),
418-
key, *inputs)
422+
combiner, _maybe_get_per_key_vocab_filename(key_vocabulary_filename), key,
423+
*inputs)
419424

420425

421426
@common.log_api_use(common.ANALYZER_COLLECTION)
@@ -565,8 +570,10 @@ def _min_and_max_per_key(
565570
if key is None:
566571
raise ValueError('A key is required for _min_and_max_per_key')
567572

568-
if not reduce_instance_dims:
569-
raise NotImplementedError('Per-key elementwise reduction not supported')
573+
if not reduce_instance_dims and isinstance(
574+
x, (tf.SparseTensor, tf.RaggedTensor)):
575+
raise NotImplementedError(
576+
'Per-key elementwise reduction of Composite Tensors not supported ')
570577

571578
with tf.compat.v1.name_scope(name, 'min_and_max_per_key'):
572579
output_dtype = x.dtype
@@ -582,7 +589,8 @@ def _min_and_max_per_key(
582589
-output_dtype.max)
583590

584591
key_vocab, x_batch_minus_min, x_batch_max = (
585-
tf_utils.reduce_batch_minus_min_and_max_per_key(x, key))
592+
tf_utils.reduce_batch_minus_min_and_max_per_key(x, key,
593+
reduce_instance_dims))
586594

587595
key_values = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking
588596
inputs=[x_batch_minus_min, x_batch_max],

tensorflow_transform/beam/impl_test.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,157 @@ def analyzer_fn(inputs):
19191919
expected_outputs,
19201920
desired_batch_size=10)
19211921

1922+
@tft_unit.named_parameters(
1923+
dict(
1924+
testcase_name='_dense_2d',
1925+
input_data=[{
1926+
'x': [4, 8],
1927+
'key': 'a'
1928+
}, {
1929+
'x': [1, 5],
1930+
'key': 'a'
1931+
}, {
1932+
'x': [5, 9],
1933+
'key': 'a'
1934+
}, {
1935+
'x': [2, 6],
1936+
'key': 'a'
1937+
}, {
1938+
'x': [-2, 0],
1939+
'key': 'b'
1940+
}, {
1941+
'x': [0, 2],
1942+
'key': 'b'
1943+
}, {
1944+
'x': [2, 4],
1945+
'key': 'b'
1946+
}],
1947+
input_metadata=tft.DatasetMetadata.from_feature_spec({
1948+
'x': tf.io.FixedLenFeature([2], tf.float32),
1949+
'key': tf.io.FixedLenFeature([], tf.string),
1950+
}),
1951+
reduce_instance_dims=True,
1952+
expected_outputs={
1953+
'key_vocab': np.array([b'a', b'b'], np.object),
1954+
'min_x_value': np.array([1, -2], np.float32),
1955+
'max_x_value': np.array([9, 4], np.float32),
1956+
}),
1957+
dict(
1958+
testcase_name='_dense_2d_elementwise',
1959+
input_data=[{
1960+
'x': [4, 8],
1961+
'key': 'a'
1962+
}, {
1963+
'x': [1, 5],
1964+
'key': 'a'
1965+
}, {
1966+
'x': [5, 9],
1967+
'key': 'a'
1968+
}, {
1969+
'x': [2, 6],
1970+
'key': 'a'
1971+
}, {
1972+
'x': [-2, 0],
1973+
'key': 'b'
1974+
}, {
1975+
'x': [0, 2],
1976+
'key': 'b'
1977+
}, {
1978+
'x': [2, 4],
1979+
'key': 'b'
1980+
}],
1981+
input_metadata=tft.DatasetMetadata.from_feature_spec({
1982+
'x': tf.io.FixedLenFeature([2], tf.float32),
1983+
'key': tf.io.FixedLenFeature([], tf.string),
1984+
}),
1985+
reduce_instance_dims=False,
1986+
expected_outputs={
1987+
'key_vocab': np.array([b'a', b'b'], np.object),
1988+
'min_x_value': np.array([[1, 5], [-2, 0]], np.float32),
1989+
'max_x_value': np.array([[5, 9], [2, 4]], np.float32),
1990+
}),
1991+
dict(
1992+
testcase_name='_dense_3d',
1993+
input_data=[
1994+
{
1995+
'x': [[1, 5], [1, 1]],
1996+
'key': 'a'
1997+
},
1998+
{
1999+
'x': [[5, 1], [5, 5]],
2000+
'key': 'a'
2001+
},
2002+
{
2003+
'x': [[2, 2], [2, 5]],
2004+
'key': 'a'
2005+
},
2006+
{
2007+
'x': [[3, -3], [3, 3]],
2008+
'key': 'b'
2009+
},
2010+
],
2011+
input_metadata=tft.DatasetMetadata.from_feature_spec({
2012+
'x': tf.io.FixedLenFeature([2, 2], tf.float32),
2013+
'key': tf.io.FixedLenFeature([], tf.string),
2014+
}),
2015+
reduce_instance_dims=True,
2016+
expected_outputs={
2017+
'key_vocab': np.array([b'a', b'b'], np.object),
2018+
'min_x_value': np.array([1, -3], np.float32),
2019+
'max_x_value': np.array([5, 3], np.float32),
2020+
}),
2021+
dict(
2022+
testcase_name='_dense_3d_elementwise',
2023+
input_data=[
2024+
{
2025+
'x': [[1, 5], [1, 1]],
2026+
'key': 'a'
2027+
},
2028+
{
2029+
'x': [[5, 1], [5, 5]],
2030+
'key': 'a'
2031+
},
2032+
{
2033+
'x': [[2, 2], [2, 5]],
2034+
'key': 'a'
2035+
},
2036+
{
2037+
'x': [[3, -3], [3, 3]],
2038+
'key': 'b'
2039+
},
2040+
],
2041+
input_metadata=tft.DatasetMetadata.from_feature_spec({
2042+
'x': tf.io.FixedLenFeature([2, 2], tf.float32),
2043+
'key': tf.io.FixedLenFeature([], tf.string),
2044+
}),
2045+
reduce_instance_dims=False,
2046+
expected_outputs={
2047+
'key_vocab':
2048+
np.array([b'a', b'b'], np.object),
2049+
'min_x_value':
2050+
np.array([[[1, 1], [1, 1]], [[3, -3], [3, 3]]], np.float32),
2051+
'max_x_value':
2052+
np.array([[[5, 5], [5, 5]], [[3, -3], [3, 3]]], np.float32),
2053+
}),
2054+
)
2055+
def testMinAndMaxPerKey(self, input_data, input_metadata,
2056+
reduce_instance_dims, expected_outputs):
2057+
self._SkipIfOutputRecordBatches()
2058+
2059+
def analyzer_fn(inputs):
2060+
key_vocab, min_x_value, max_x_value = analyzers._min_and_max_per_key(
2061+
x=inputs['x'],
2062+
key=inputs['key'],
2063+
reduce_instance_dims=reduce_instance_dims)
2064+
return {
2065+
'key_vocab': key_vocab,
2066+
'min_x_value': min_x_value,
2067+
'max_x_value': max_x_value,
2068+
}
2069+
2070+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
2071+
expected_outputs)
2072+
19222073
@tft_unit.parameters((True,), (False,))
19232074
def testPerKeyWithOOVKeys(self, use_vocabulary):
19242075
def preprocessing_fn(inputs):

0 commit comments

Comments
 (0)