Skip to content

Commit 10f538b

Browse files
zoyahavtfx-copybara
authored andcommitted
Improve error when a user returns an unexpected type from their preprocessing_fn.
Previously for example: "Using a symbolic `tf.Tensor` as a Python `bool` is not allowed in Graph execution". PiperOrigin-RevId: 466037670
1 parent 5efb51d commit 10f538b

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

tensorflow_transform/beam/impl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,11 @@ def expand(self, dataset):
10671067
# output. This is because if we allowed the output of preprocessing_fn to
10681068
# be empty, we wouldn't be able to determine how many instances to
10691069
# "unbatch" the output into.
1070+
if not isinstance(structured_outputs, dict):
1071+
raise ValueError(
1072+
'A `preprocessing_fn` must return a '
1073+
'Dict[str, Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]]. '
1074+
f'Got: {structured_outputs}')
10701075
if not structured_outputs:
10711076
raise ValueError('The preprocessing function returned an empty dict')
10721077

tensorflow_transform/beam/impl_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4529,6 +4529,18 @@ def preprocessing_fn(inputs):
45294529
self.assertAnalyzeAndTransformResults(input_data, input_metadata,
45304530
preprocessing_fn, expected_outputs)
45314531

4532+
def test_preprocessing_fn_returns_wrong_type(self):
4533+
with self.assertRaisesRegexp( # pylint: disable=g-error-prone-assert-raises
4534+
ValueError, r'A `preprocessing_fn` must return a '
4535+
r'Dict\[str, Union\[tf.Tensor, tf.SparseTensor, tf.RaggedTensor\]\]. '
4536+
'Got: Tensor.*'):
4537+
self.assertAnalyzeAndTransformResults(
4538+
input_data=[{'f1': 0}],
4539+
input_metadata=tft.DatasetMetadata.from_feature_spec(
4540+
{'f1': tf.io.FixedLenFeature([], tf.float32)}),
4541+
preprocessing_fn=lambda inputs: inputs['f1'],
4542+
expected_data=None)
4543+
45324544

45334545
if __name__ == '__main__':
45344546
tft_unit.main()

0 commit comments

Comments
 (0)