Skip to content

Commit 1a6729e

Browse files
authored
Add custom vocab file for ce (PaddlePaddle#963)
1 parent 6ce057a commit 1a6729e

File tree

14 files changed

+261
-38
lines changed

14 files changed

+261
-38
lines changed

examples/machine_translation/transformer/deploy/python/inference.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,28 @@ def parse_args():
6464
default="./output/",
6565
type=str,
6666
help="The path to save logs when profile is enabled. ")
67+
parser.add_argument(
68+
"--vocab_file",
69+
default=None,
70+
type=str,
71+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
72+
)
73+
parser.add_argument(
74+
"--unk_token",
75+
default=None,
76+
type=str,
77+
help="The unknown token. It should be provided when use custom vocab_file. "
78+
)
79+
parser.add_argument(
80+
"--bos_token",
81+
default=None,
82+
type=str,
83+
help="The bos token. It should be provided when use custom vocab_file. ")
84+
parser.add_argument(
85+
"--eos_token",
86+
default=None,
87+
type=str,
88+
help="The eos token. It should be provided when use custom vocab_file. ")
6789
args = parser.parse_args()
6890
return args
6991

@@ -222,6 +244,10 @@ def do_inference(args):
222244
args.inference_model_dir = ARGS.model_dir
223245
args.test_file = ARGS.test_file
224246
args.save_log_path = ARGS.save_log_path
247+
args.vocab_file = ARGS.vocab_file
248+
args.unk_token = ARGS.unk_token
249+
args.bos_token = ARGS.bos_token
250+
args.eos_token = ARGS.eos_token
225251
pprint(args)
226252

227253
if args.profile:

examples/machine_translation/transformer/export_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@ def parse_args():
2424
action="store_true",
2525
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
2626
)
27+
parser.add_argument(
28+
"--vocab_file",
29+
default=None,
30+
type=str,
31+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
32+
)
33+
parser.add_argument(
34+
"--unk_token",
35+
default=None,
36+
type=str,
37+
help="The unknown token. It should be provided when use custom vocab_file. "
38+
)
39+
parser.add_argument(
40+
"--bos_token",
41+
default=None,
42+
type=str,
43+
help="The bos token. It should be provided when use custom vocab_file. ")
44+
parser.add_argument(
45+
"--eos_token",
46+
default=None,
47+
type=str,
48+
help="The eos token. It should be provided when use custom vocab_file. ")
2749
args = parser.parse_args()
2850
return args
2951

@@ -87,6 +109,10 @@ def do_export(args):
87109
with open(yaml_file, 'rt') as f:
88110
args = AttrDict(yaml.safe_load(f))
89111
args.benchmark = ARGS.benchmark
112+
args.vocab_file = ARGS.vocab_file
113+
args.unk_token = ARGS.unk_token
114+
args.bos_token = ARGS.bos_token
115+
args.eos_token = ARGS.eos_token
90116
pprint(args)
91117

92118
do_export(args)

examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ def parse_args():
7474
action="store_true",
7575
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
7676
)
77+
parser.add_argument(
78+
"--vocab_file",
79+
default=None,
80+
type=str,
81+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
82+
)
83+
parser.add_argument(
84+
"--unk_token",
85+
default=None,
86+
type=str,
87+
help="The unknown token. It should be provided when use custom vocab_file. "
88+
)
89+
parser.add_argument(
90+
"--bos_token",
91+
default=None,
92+
type=str,
93+
help="The bos token. It should be provided when use custom vocab_file. ")
94+
parser.add_argument(
95+
"--eos_token",
96+
default=None,
97+
type=str,
98+
help="The eos token. It should be provided when use custom vocab_file. ")
7799
args = parser.parse_args()
78100
return args
79101

@@ -191,6 +213,10 @@ def do_predict(args):
191213
if ARGS.batch_size:
192214
args.infer_batch_size = ARGS.batch_size
193215
args.test_file = ARGS.test_file
216+
args.vocab_file = ARGS.vocab_file
217+
args.unk_token = ARGS.unk_token
218+
args.bos_token = ARGS.bos_token
219+
args.eos_token = ARGS.eos_token
194220
pprint(args)
195221

196222
do_predict(args)

examples/machine_translation/transformer/faster_transformer/export_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,28 @@ def parse_args():
6262
action="store_true",
6363
help="Whether to print logs on each cards and use benchmark vocab. Normally, not necessary to set --benchmark. "
6464
)
65+
parser.add_argument(
66+
"--vocab_file",
67+
default=None,
68+
type=str,
69+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
70+
)
71+
parser.add_argument(
72+
"--unk_token",
73+
default=None,
74+
type=str,
75+
help="The unknown token. It should be provided when use custom vocab_file. "
76+
)
77+
parser.add_argument(
78+
"--bos_token",
79+
default=None,
80+
type=str,
81+
help="The bos token. It should be provided when use custom vocab_file. ")
82+
parser.add_argument(
83+
"--eos_token",
84+
default=None,
85+
type=str,
86+
help="The eos token. It should be provided when use custom vocab_file. ")
6587
args = parser.parse_args()
6688
return args
6789

@@ -133,6 +155,10 @@ def do_predict(args):
133155
args.topk = ARGS.topk
134156
args.topp = ARGS.topp
135157
args.benchmark = ARGS.benchmark
158+
args.vocab_file = ARGS.vocab_file
159+
args.unk_token = ARGS.unk_token
160+
args.bos_token = ARGS.bos_token
161+
args.eos_token = ARGS.eos_token
136162
pprint(args)
137163

138164
do_predict(args)

examples/machine_translation/transformer/predict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ def parse_args():
3535
"--without_ft",
3636
action="store_true",
3737
help="Whether to use Faster Transformer to do predict. ")
38+
parser.add_argument(
39+
"--vocab_file",
40+
default=None,
41+
type=str,
42+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
43+
)
44+
parser.add_argument(
45+
"--unk_token",
46+
default=None,
47+
type=str,
48+
help="The unknown token. It should be provided when use custom vocab_file. "
49+
)
50+
parser.add_argument(
51+
"--bos_token",
52+
default=None,
53+
type=str,
54+
help="The bos token. It should be provided when use custom vocab_file. ")
55+
parser.add_argument(
56+
"--eos_token",
57+
default=None,
58+
type=str,
59+
help="The eos token. It should be provided when use custom vocab_file. ")
3860
args = parser.parse_args()
3961
return args
4062

@@ -127,6 +149,10 @@ def do_predict(args):
127149
args.benchmark = ARGS.benchmark
128150
args.test_file = ARGS.test_file
129151
args.without_ft = ARGS.without_ft
152+
args.vocab_file = ARGS.vocab_file
153+
args.unk_token = ARGS.unk_token
154+
args.bos_token = ARGS.bos_token
155+
args.eos_token = ARGS.eos_token
130156
pprint(args)
131157

132158
do_predict(args)

examples/machine_translation/transformer/reader.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def create_data_loader(args, places=None):
4545
raise ValueError(
4646
"--train_file and --dev_file must be both or neither set. ")
4747

48-
if not args.benchmark:
48+
if args.vocab_file is not None:
49+
src_vocab = Vocab.load_vocabulary(
50+
filepath=args.vocab_file,
51+
unk_token=args.unk_token,
52+
bos_token=args.bos_token,
53+
eos_token=args.eos_token)
54+
elif not args.benchmark:
4955
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["bpe"])
5056
else:
5157
src_vocab = Vocab.load_vocabulary(**datasets[0].vocab_info["benchmark"])
@@ -109,7 +115,13 @@ def create_infer_loader(args):
109115
else:
110116
dataset = load_dataset('wmt14ende', splits=('test'))
111117

112-
if not args.benchmark:
118+
if args.vocab_file is not None:
119+
src_vocab = Vocab.load_vocabulary(
120+
filepath=args.vocab_file,
121+
unk_token=args.unk_token,
122+
bos_token=args.bos_token,
123+
eos_token=args.eos_token)
124+
elif not args.benchmark:
113125
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
114126
else:
115127
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
@@ -151,11 +163,18 @@ def convert_samples(sample):
151163

152164

153165
def adapt_vocab_size(args):
154-
dataset = load_dataset('wmt14ende', splits=('test'))
155-
if not args.benchmark:
156-
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
166+
if args.vocab_file is not None:
167+
src_vocab = Vocab.load_vocabulary(
168+
filepath=args.vocab_file,
169+
unk_token=args.unk_token,
170+
bos_token=args.bos_token,
171+
eos_token=args.eos_token)
157172
else:
158-
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
173+
dataset = load_dataset('wmt14ende', splits=('test'))
174+
if not args.benchmark:
175+
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["bpe"])
176+
else:
177+
src_vocab = Vocab.load_vocabulary(**dataset.vocab_info["benchmark"])
159178
trg_vocab = src_vocab
160179

161180
padding_vocab = (

examples/machine_translation/transformer/static/predict.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,28 @@ def parse_args():
5151
type=str,
5252
help="The file for testing. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to process testing."
5353
)
54+
parser.add_argument(
55+
"--vocab_file",
56+
default=None,
57+
type=str,
58+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
59+
)
60+
parser.add_argument(
61+
"--unk_token",
62+
default=None,
63+
type=str,
64+
help="The unknown token. It should be provided when use custom vocab_file. "
65+
)
66+
parser.add_argument(
67+
"--bos_token",
68+
default=None,
69+
type=str,
70+
help="The bos token. It should be provided when use custom vocab_file. ")
71+
parser.add_argument(
72+
"--eos_token",
73+
default=None,
74+
type=str,
75+
help="The eos token. It should be provided when use custom vocab_file. ")
5476
args = parser.parse_args()
5577
return args
5678

@@ -146,6 +168,10 @@ def do_predict(args):
146168
args = AttrDict(yaml.safe_load(f))
147169
args.benchmark = ARGS.benchmark
148170
args.test_file = ARGS.test_file
171+
args.vocab_file = ARGS.vocab_file
172+
args.unk_token = ARGS.unk_token
173+
args.bos_token = ARGS.bos_token
174+
args.eos_token = ARGS.eos_token
149175
pprint(args)
150176

151177
do_predict(args)

examples/machine_translation/transformer/static/train.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,28 @@ def parse_args():
6060
type=str,
6161
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
6262
)
63+
parser.add_argument(
64+
"--vocab_file",
65+
default=None,
66+
type=str,
67+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
68+
)
69+
parser.add_argument(
70+
"--unk_token",
71+
default=None,
72+
type=str,
73+
help="The unknown token. It should be provided when use custom vocab_file. "
74+
)
75+
parser.add_argument(
76+
"--bos_token",
77+
default=None,
78+
type=str,
79+
help="The bos token. It should be provided when use custom vocab_file. ")
80+
parser.add_argument(
81+
"--eos_token",
82+
default=None,
83+
type=str,
84+
help="The eos token. It should be provided when use custom vocab_file. ")
6385
args = parser.parse_args()
6486
return args
6587

@@ -299,6 +321,10 @@ def do_train(args):
299321
args.max_iter = ARGS.max_iter
300322
args.train_file = ARGS.train_file
301323
args.dev_file = ARGS.dev_file
324+
args.vocab_file = ARGS.vocab_file
325+
args.unk_token = ARGS.unk_token
326+
args.bos_token = ARGS.bos_token
327+
args.eos_token = ARGS.eos_token
302328
pprint(args)
303329

304330
do_train(args)

examples/machine_translation/transformer/train.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,28 @@ def parse_args():
4848
type=str,
4949
help="The files for validation, including [source language file, target language file]. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used to do validation. "
5050
)
51+
parser.add_argument(
52+
"--vocab_file",
53+
default=None,
54+
type=str,
55+
help="The vocab file. Normally, it shouldn't be set and in this case, the default WMT14 dataset will be used."
56+
)
57+
parser.add_argument(
58+
"--unk_token",
59+
default=None,
60+
type=str,
61+
help="The unknown token. It should be provided when use custom vocab_file. "
62+
)
63+
parser.add_argument(
64+
"--bos_token",
65+
default=None,
66+
type=str,
67+
help="The bos token. It should be provided when use custom vocab_file. ")
68+
parser.add_argument(
69+
"--eos_token",
70+
default=None,
71+
type=str,
72+
help="The eos token. It should be provided when use custom vocab_file. ")
5173
args = parser.parse_args()
5274
return args
5375

@@ -270,6 +292,10 @@ def do_train(args):
270292
args.max_iter = ARGS.max_iter
271293
args.train_file = ARGS.train_file
272294
args.dev_file = ARGS.dev_file
295+
args.vocab_file = ARGS.vocab_file
296+
args.unk_token = ARGS.unk_token
297+
args.bos_token = ARGS.bos_token
298+
args.eos_token = ARGS.eos_token
273299
pprint(args)
274300

275301
do_train(args)

0 commit comments

Comments
 (0)