Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit b6ea77f

Browse files
author
Mesh TensorFlow Team
committed
Enable multi-file inference
PiperOrigin-RevId: 475338276
1 parent d229a44 commit b6ea77f

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,35 @@ def input_fn(params):
15001500
write_lines_to_file(decodes, output_filename)
15011501

15021502

1503+
@gin.configurable
1504+
def decode_from_files(
1505+
estimator,
1506+
vocabulary,
1507+
model_type,
1508+
batch_size,
1509+
sequence_length,
1510+
checkpoint_path=None,
1511+
input_filenames=gin.REQUIRED,
1512+
output_filenames=gin.REQUIRED,
1513+
eos_id=1,
1514+
repeats=1,
1515+
):
1516+
"""Decodes from multiple files and writes to output_filenames."""
1517+
if len(input_filenames) != len(output_filenames):
1518+
raise ValueError("Input and output filename lists must have equal length.")
1519+
for input_filename, output_filename in zip(input_filenames, output_filenames):
1520+
decode_from_file(estimator=estimator,
1521+
vocabulary=vocabulary,
1522+
model_type=model_type,
1523+
batch_size=batch_size,
1524+
sequence_length=sequence_length,
1525+
checkpoint_path=checkpoint_path,
1526+
input_filename=input_filename,
1527+
output_filename=output_filename,
1528+
eos_id=eos_id,
1529+
repeats=repeats)
1530+
1531+
15031532
@gin.configurable
15041533
def decode_from_dataset(estimator,
15051534
vocabulary,

0 commit comments

Comments
 (0)