Skip to content

Commit d7eabef

Browse files
[translation] Add text2text export module.
PiperOrigin-RevId: 418559537
1 parent 439d515 commit d7eabef

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

official/nlp/serving/export_savedmodel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414

1515
"""A binary/library to export TF-NLP serving `SavedModel`."""
16+
import dataclasses
1617
import os
1718
from typing import Any, Dict, Text
19+
1820
from absl import app
1921
from absl import flags
20-
import dataclasses
2122
import yaml
23+
2224
from official.core import base_task
2325
from official.core import task_factory
2426
from official.modeling import hyperparams
@@ -29,6 +31,7 @@
2931
from official.nlp.tasks import question_answering
3032
from official.nlp.tasks import sentence_prediction
3133
from official.nlp.tasks import tagging
34+
from official.nlp.tasks import translation
3235

3336
FLAGS = flags.FLAGS
3437

@@ -40,7 +43,9 @@
4043
question_answering.QuestionAnsweringTask:
4144
serving_modules.QuestionAnswering,
4245
tagging.TaggingTask:
43-
serving_modules.Tagging
46+
serving_modules.Tagging,
47+
translation.TranslationTask:
48+
serving_modules.Translation
4449
}
4550

4651

official/nlp/serving/serving_modules.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
"""Serving export modules for TF Model Garden NLP models."""
1616
# pylint:disable=missing-class-docstring
17+
import dataclasses
1718
from typing import Dict, List, Optional, Text
1819

19-
import dataclasses
2020
import tensorflow as tf
21+
import tensorflow_text as tf_text
22+
2123
from official.core import export_base
2224
from official.modeling.hyperparams import base_config
2325
from official.nlp.data import sentence_prediction_dataloader
@@ -407,3 +409,48 @@ def get_inference_signatures(self, function_keys: Dict[Text, Text]):
407409
signatures[signature_key] = self.serve_examples.get_concrete_function(
408410
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
409411
return signatures
412+
413+
414+
class Translation(export_base.ExportModule):
415+
"""The export module for the translation task."""
416+
417+
@dataclasses.dataclass
418+
class Params(base_config.Config):
419+
sentencepiece_model_path: str = ""
420+
421+
def __init__(self, params, model: tf.keras.Model, inference_step=None):
422+
super().__init__(params, model, inference_step)
423+
self._sp_tokenizer = tf_text.SentencepieceTokenizer(
424+
model=tf.io.gfile.GFile(params.sentencepiece_model_path, "rb").read(),
425+
add_eos=True)
426+
try:
427+
empty_str_tokenized = self._sp_tokenizer.tokenize("").numpy()
428+
except tf.errors.InternalError:
429+
raise ValueError(
430+
"EOS token not in tokenizer vocab."
431+
"Please make sure the tokenizer generates a single token for an "
432+
"empty string.")
433+
self._eos_id = empty_str_tokenized.item()
434+
435+
@tf.function
436+
def serve(self, inputs) -> Dict[str, tf.Tensor]:
437+
return self.inference_step(inputs)
438+
439+
@tf.function
440+
def serve_text(self, text: tf.Tensor) -> Dict[str, tf.Tensor]:
441+
tokenized = self._sp_tokenizer.tokenize(text).to_tensor(0)
442+
return self._sp_tokenizer.detokenize(
443+
self.serve({"inputs": tokenized})["outputs"])
444+
445+
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
446+
signatures = {}
447+
valid_keys = ("serve_text")
448+
for func_key, signature_key in function_keys.items():
449+
if func_key not in valid_keys:
450+
raise ValueError("Invalid function key for the module: %s with key %s. "
451+
"Valid keys are: %s" %
452+
(self.__class__, func_key, valid_keys))
453+
if func_key == "serve_text":
454+
signatures[signature_key] = self.serve_text.get_concrete_function(
455+
tf.TensorSpec(shape=[None], dtype=tf.string, name="text"))
456+
return signatures

official/nlp/serving/serving_modules_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
"""Tests for nlp.serving.serving_modules."""
1616

1717
import os
18+
1819
from absl.testing import parameterized
1920
import tensorflow as tf
21+
22+
from sentencepiece import SentencePieceTrainer
2023
from official.nlp.configs import bert
2124
from official.nlp.configs import encoders
2225
from official.nlp.serving import serving_modules
2326
from official.nlp.tasks import masked_lm
2427
from official.nlp.tasks import question_answering
2528
from official.nlp.tasks import sentence_prediction
2629
from official.nlp.tasks import tagging
30+
from official.nlp.tasks import translation
2731

2832

2933
def _create_fake_serialized_examples(features_dict):
@@ -59,6 +63,33 @@ def _create_fake_vocab_file(vocab_file_path):
5963
outfile.write("\n".join(tokens))
6064

6165

66+
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
67+
argstr = " ".join([
68+
f"--input={input_path}", f"--vocab_size={vocab_size}",
69+
"--character_coverage=0.995",
70+
f"--model_prefix={model_path}", "--model_type=bpe",
71+
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
72+
])
73+
SentencePieceTrainer.Train(argstr)
74+
75+
76+
def _generate_line_file(filepath, lines):
77+
with tf.io.gfile.GFile(filepath, "w") as f:
78+
for l in lines:
79+
f.write("{}\n".format(l))
80+
81+
82+
def _make_sentencepeice(output_dir):
83+
src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"]
84+
tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"]
85+
sentencepeice_input_path = os.path.join(output_dir, "inputs.txt")
86+
_generate_line_file(sentencepeice_input_path, src_lines + tgt_lines)
87+
sentencepeice_model_prefix = os.path.join(output_dir, "sp")
88+
_train_sentencepiece(sentencepeice_input_path, 11, sentencepeice_model_prefix)
89+
sentencepeice_model_path = "{}.model".format(sentencepeice_model_prefix)
90+
return sentencepeice_model_path
91+
92+
6293
class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
6394

6495
@parameterized.parameters(
@@ -312,6 +343,31 @@ def test_tagging(self, use_v2_feature_names, output_encoder_outputs):
312343
with self.assertRaises(ValueError):
313344
_ = export_module.get_inference_signatures({"foo": None})
314345

346+
def test_translation(self):
347+
sp_path = _make_sentencepeice(self.get_temp_dir())
348+
encdecoder = translation.EncDecoder(
349+
num_attention_heads=4, intermediate_size=256)
350+
config = translation.TranslationConfig(
351+
model=translation.ModelConfig(
352+
encoder=encdecoder,
353+
decoder=encdecoder,
354+
embedding_width=256,
355+
padded_decode=False,
356+
decode_max_length=100),
357+
sentencepiece_model_path=sp_path,
358+
)
359+
task = translation.TranslationTask(config)
360+
model = task.build_model()
361+
362+
params = serving_modules.Translation.Params(
363+
sentencepiece_model_path=sp_path)
364+
export_module = serving_modules.Translation(params=params, model=model)
365+
functions = export_module.get_inference_signatures({
366+
"serve_text": "serving_default"
367+
})
368+
outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"]))
369+
self.assertEqual(outputs.shape, (2,))
370+
self.assertEqual(outputs.dtype, tf.string)
315371

316372
if __name__ == "__main__":
317373
tf.test.main()

0 commit comments

Comments
 (0)