Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 135f62f

Browse files
katelee168Mesh TensorFlow Team
authored andcommitted
Log sequence lengths of targets during evaluation.
PiperOrigin-RevId: 353246188
1 parent b1b5364 commit 135f62f

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,18 @@ def save_scores(results, vocabulary,
12821282
if scores_filename is not None:
12831283
write_lines_to_file(targets, scores_filename+".targets")
12841284

1285+
# Write sequence lengths
1286+
def get_sequence_length(tokens, pad_id=0):
1287+
tokens = np.array(tokens)
1288+
if not np.isin(pad_id, tokens):
1289+
return len(tokens)
1290+
# Argmax returns the index of the first occurrence of pad_id.
1291+
return np.argmax(tokens == pad_id)
1292+
1293+
seq_lengths = [get_sequence_length(r["targets"]) for r in results]
1294+
if scores_filename is not None:
1295+
write_lines_to_file(seq_lengths, scores_filename+".lengths")
1296+
12851297
# Inputs may only exist for some tasks.
12861298
if "inputs" in results[0]:
12871299
inputs = [r.get("inputs_pretokenized", r["inputs"]) for r in results]

0 commit comments

Comments
 (0)