|
15 | 15 | """Tests for nlp.serving.serving_modules."""
|
16 | 16 |
|
17 | 17 | import os
|
| 18 | + |
18 | 19 | from absl.testing import parameterized
|
19 | 20 | import tensorflow as tf
|
| 21 | + |
| 22 | +from sentencepiece import SentencePieceTrainer |
20 | 23 | from official.nlp.configs import bert
|
21 | 24 | from official.nlp.configs import encoders
|
22 | 25 | from official.nlp.serving import serving_modules
|
23 | 26 | from official.nlp.tasks import masked_lm
|
24 | 27 | from official.nlp.tasks import question_answering
|
25 | 28 | from official.nlp.tasks import sentence_prediction
|
26 | 29 | from official.nlp.tasks import tagging
|
| 30 | +from official.nlp.tasks import translation |
27 | 31 |
|
28 | 32 |
|
29 | 33 | def _create_fake_serialized_examples(features_dict):
|
@@ -59,6 +63,33 @@ def _create_fake_vocab_file(vocab_file_path):
|
59 | 63 | outfile.write("\n".join(tokens))
|
60 | 64 |
|
61 | 65 |
|
| 66 | +def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1): |
| 67 | + argstr = " ".join([ |
| 68 | + f"--input={input_path}", f"--vocab_size={vocab_size}", |
| 69 | + "--character_coverage=0.995", |
| 70 | + f"--model_prefix={model_path}", "--model_type=bpe", |
| 71 | + "--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2" |
| 72 | + ]) |
| 73 | + SentencePieceTrainer.Train(argstr) |
| 74 | + |
| 75 | + |
| 76 | +def _generate_line_file(filepath, lines): |
| 77 | + with tf.io.gfile.GFile(filepath, "w") as f: |
| 78 | + for l in lines: |
| 79 | + f.write("{}\n".format(l)) |
| 80 | + |
| 81 | + |
| 82 | +def _make_sentencepeice(output_dir): |
| 83 | + src_lines = ["abc ede fg", "bbcd ef a g", "de f a a g"] |
| 84 | + tgt_lines = ["dd cc a ef g", "bcd ef a g", "gef cd ba"] |
| 85 | + sentencepeice_input_path = os.path.join(output_dir, "inputs.txt") |
| 86 | + _generate_line_file(sentencepeice_input_path, src_lines + tgt_lines) |
| 87 | + sentencepeice_model_prefix = os.path.join(output_dir, "sp") |
| 88 | + _train_sentencepiece(sentencepeice_input_path, 11, sentencepeice_model_prefix) |
| 89 | + sentencepeice_model_path = "{}.model".format(sentencepeice_model_prefix) |
| 90 | + return sentencepeice_model_path |
| 91 | + |
| 92 | + |
62 | 93 | class ServingModulesTest(tf.test.TestCase, parameterized.TestCase):
|
63 | 94 |
|
64 | 95 | @parameterized.parameters(
|
@@ -312,6 +343,31 @@ def test_tagging(self, use_v2_feature_names, output_encoder_outputs):
|
312 | 343 | with self.assertRaises(ValueError):
|
313 | 344 | _ = export_module.get_inference_signatures({"foo": None})
|
314 | 345 |
|
| 346 | + def test_translation(self): |
| 347 | + sp_path = _make_sentencepeice(self.get_temp_dir()) |
| 348 | + encdecoder = translation.EncDecoder( |
| 349 | + num_attention_heads=4, intermediate_size=256) |
| 350 | + config = translation.TranslationConfig( |
| 351 | + model=translation.ModelConfig( |
| 352 | + encoder=encdecoder, |
| 353 | + decoder=encdecoder, |
| 354 | + embedding_width=256, |
| 355 | + padded_decode=False, |
| 356 | + decode_max_length=100), |
| 357 | + sentencepiece_model_path=sp_path, |
| 358 | + ) |
| 359 | + task = translation.TranslationTask(config) |
| 360 | + model = task.build_model() |
| 361 | + |
| 362 | + params = serving_modules.Translation.Params( |
| 363 | + sentencepiece_model_path=sp_path) |
| 364 | + export_module = serving_modules.Translation(params=params, model=model) |
| 365 | + functions = export_module.get_inference_signatures({ |
| 366 | + "serve_text": "serving_default" |
| 367 | + }) |
| 368 | + outputs = functions["serving_default"](tf.constant(["abcd", "ef gh"])) |
| 369 | + self.assertEqual(outputs.shape, (2,)) |
| 370 | + self.assertEqual(outputs.dtype, tf.string) |
315 | 371 |
|
316 | 372 | if __name__ == "__main__":
|
317 | 373 | tf.test.main()
|
0 commit comments