Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 5ce9683

Browse files
author
Mesh TensorFlow Team
committed
Fix bug in T5 model export for eval_with_score mode -- handle case where vocabulary is tuple.
The fix is to encode inputs and targets each with their respective vocabularies. PiperOrigin-RevId: 350812757
1 parent 82ffe9d commit 5ce9683

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,9 +1991,18 @@ def str_placeholder(name):
19911991
targets = str_placeholder("targets")
19921992

19931993
predict_batch_size = tf.shape(inputs)[0]
1994-
dataset = tf.data.Dataset.from_tensor_slices(
1995-
{"inputs": inputs, "targets": targets})
1996-
dataset = transformer_dataset.encode_all_features(dataset, vocabulary)
1994+
1995+
inputs_dataset = transformer_dataset.encode_all_features(
1996+
tf.data.Dataset.from_tensor_slices({"inputs": inputs}),
1997+
inputs_vocabulary(vocabulary))
1998+
1999+
targets_dataset = transformer_dataset.encode_all_features(
2000+
tf.data.Dataset.from_tensor_slices({"targets": targets}),
2001+
targets_vocabulary(vocabulary))
2002+
2003+
dataset = tf.data.Dataset.zip((inputs_dataset, targets_dataset))
2004+
dataset = dataset.map(lambda x, y: {**x, **y},
2005+
num_parallel_calls=tf.data.experimental.AUTOTUNE)
19972006

19982007
receiver_tensors = {"inputs": inputs, "targets": targets}
19992008

0 commit comments

Comments
 (0)