File tree Expand file tree Collapse file tree 2 files changed +17
-0
lines changed
tensorflow_transform/beam Expand file tree Collapse file tree 2 files changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -1067,6 +1067,11 @@ def expand(self, dataset):
10671067 # output. This is because if we allowed the output of preprocessing_fn to
10681068 # be empty, we wouldn't be able to determine how many instances to
10691069 # "unbatch" the output into.
1070+ if not isinstance (structured_outputs , dict ):
1071+ raise ValueError (
1072+ 'A `preprocessing_fn` must return a '
1073+ 'Dict[str, Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]]. '
1074+ f'Got: { structured_outputs } ' )
10701075 if not structured_outputs :
10711076 raise ValueError ('The preprocessing function returned an empty dict' )
10721077
Original file line number Diff line number Diff line change @@ -4529,6 +4529,18 @@ def preprocessing_fn(inputs):
45294529 self .assertAnalyzeAndTransformResults (input_data , input_metadata ,
45304530 preprocessing_fn , expected_outputs )
45314531
4532+ def test_preprocessing_fn_returns_wrong_type (self ):
4533+ with self .assertRaisesRegexp ( # pylint: disable=g-error-prone-assert-raises
4534+ ValueError , r'A `preprocessing_fn` must return a '
4535+ r'Dict\[str, Union\[tf.Tensor, tf.SparseTensor, tf.RaggedTensor\]\]. '
4536+ 'Got: Tensor.*' ):
4537+ self .assertAnalyzeAndTransformResults (
4538+ input_data = [{'f1' : 0 }],
4539+ input_metadata = tft .DatasetMetadata .from_feature_spec (
4540+ {'f1' : tf .io .FixedLenFeature ([], tf .float32 )}),
4541+ preprocessing_fn = lambda inputs : inputs ['f1' ],
4542+ expected_data = None )
4543+
45324544
45334545if __name__ == '__main__' :
45344546 tft_unit .main ()
You can’t perform that action at this time.
0 commit comments