Skip to content

Commit 8a08880

Browse files
tf-transform-teamtfx-copybara
authored andcommitted
Improve error message by printing shapes
PiperOrigin-RevId: 491319525
1 parent bc27920 commit 8a08880

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tensorflow_transform/beam/tft_unit.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,20 @@ def preprocessing_fn(inputs):
225225
result = {}
226226
for key, output_tensor in analyzer_outputs.items():
227227
# Get the expected shape, and set it.
228-
output_shape = list(expected_outputs[key].shape)
228+
expected_output_shape = list(expected_outputs[key].shape)
229229
try:
230-
output_tensor.set_shape(output_shape)
230+
output_tensor.set_shape(expected_output_shape)
231231
except ValueError as e:
232-
raise ValueError(f'Error for key {key}') from e
232+
raise ValueError(
233+
f'Error for key {key}, shapes are incompatible. Got '
234+
f'{output_tensor.shape}, expected {expected_output_shape}.'
235+
) from e
233236
# Add a batch dimension
234237
output_tensor = tf.expand_dims(output_tensor, 0)
235238
# Broadcast along the batch dimension
236239
result[key] = tf.tile(
237-
output_tensor, multiples=[batch_size] + [1] * len(output_shape))
240+
output_tensor,
241+
multiples=[batch_size] + [1] * len(expected_output_shape))
238242

239243
return result
240244

0 commit comments

Comments
 (0)