Skip to content

Commit 42447b8

Browse files
authored
Fix ernie csc decode (PaddlePaddle#1005)
* use list() instead of tokenize * use list() instead of tokenize in taskflow * add max_seq_length in readme * add dynamic predict in text_correction task * fix windows predict bug
1 parent 1914c8a commit 42447b8

File tree

5 files changed

+87
-142
lines changed

5 files changed

+87
-142
lines changed

examples/text_correction/ernie-csc/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ python change_sgml_to_txt.py -i extra_train_ds/train.sgml -o extra_train_ds/trai
6969
### 单卡训练
7070

7171
```python
72-
python train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/
72+
python train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/ --max_seq_length 192
7373
```
7474

7575
### 多卡训练
7676

7777
```python
78-
python -m paddle.distributed.launch --gpus "0,1" train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/
78+
python -m paddle.distributed.launch --gpus "0,1" train.py --batch_size 32 --logging_steps 100 --epochs 10 --learning_rate 5e-5 --model_name_or_path ernie-1.0 --output_dir ./checkpoints/ --extra_train_ds_dir ./extra_train_ds/ --max_seq_length 192
7979
```
8080

8181
## 模型预测

examples/text_correction/ernie-csc/predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def predict(self, data, batch_size=1):
8989
is_test=True)
9090

9191
batchify_fn = lambda samples, fn=Tuple(
92-
Pad(axis=0, pad_val=self.tokenizer.pad_token_id), # input
93-
Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id), # segment
94-
Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token]), # pinyin
92+
Pad(axis=0, pad_val=self.tokenizer.pad_token_id, dtype='int64'), # input
93+
Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id, dtype='int64'), # segment
94+
Pad(axis=0, pad_val=self.pinyin_vocab.token_to_idx[self.pinyin_vocab.pad_token], dtype='int64'), # pinyin
9595
Stack(axis=0, dtype='int64'), # length
9696
): [data for data in fn(samples)]
9797

examples/text_correction/ernie-csc/predict_sighan.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,11 @@ def write_sighan_result_to_file(args, corr_preds, det_preds, lengths,
5757
lengths[i], tokenizer,
5858
args.max_seq_length)
5959
words = list(words)
60-
if len(words) > args.max_seq_length - 2:
61-
words = words[:args.max_seq_length - 2]
62-
words = ''.join(words)
63-
60+
pred_result = list(pred_result)
6461
result = ids
6562
if pred_result == words:
6663
result += ', 0'
6764
else:
68-
pred_result = list(pred_result)
6965
assert len(pred_result) == len(
7066
words), "pred_result: {}, words: {}".format(pred_result,
7167
words)

examples/text_correction/ernie-csc/utils.py

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def convert_example(example,
3939
ignore_label=-1,
4040
is_test=False):
4141
source = example["source"]
42-
words = tokenizer.tokenize(text=source)
42+
words = list(source)
4343
if len(words) > max_seq_length - 2:
4444
words = words[:max_seq_length - 2]
4545
length = len(words)
@@ -50,7 +50,6 @@ def convert_example(example,
5050
# Use pad token in pinyin emb to map word emb [CLS], [SEP]
5151
pinyins = lazy_pinyin(
5252
source, style=Style.TONE3, neutral_tone_with_five=True)
53-
5453
pinyin_ids = [0]
5554
# Align pinyin and chinese char
5655
pinyin_offset = 0
@@ -71,7 +70,7 @@ def convert_example(example,
7170

7271
if not is_test:
7372
target = example["target"]
74-
correction_labels = tokenizer.tokenize(text=target)
73+
correction_labels = list(target)
7574
if len(correction_labels) > max_seq_length - 2:
7675
correction_labels = correction_labels[:max_seq_length - 2]
7776
correction_labels = tokenizer.convert_tokens_to_ids(correction_labels)
@@ -114,64 +113,22 @@ def parse_decode(words, corr_preds, det_preds, lengths, tokenizer,
114113
max_seq_length):
115114
UNK = tokenizer.unk_token
116115
UNK_id = tokenizer.convert_tokens_to_ids(UNK)
117-
tokens = tokenizer.tokenize(words)
118-
if len(tokens) > max_seq_length - 2:
119-
tokens = tokens[:max_seq_length - 2]
116+
120117
corr_pred = corr_preds[1:1 + lengths].tolist()
121118
det_pred = det_preds[1:1 + lengths].tolist()
122119
words = list(words)
120+
rest_words = []
123121
if len(words) > max_seq_length - 2:
122+
rest_words = words[max_seq_length - 2:]
124123
words = words[:max_seq_length - 2]
125124

126-
assert len(tokens) == len(
127-
corr_pred
128-
), "The number of tokens should be equal to the number of labels {}: {}: {}".format(
129-
len(tokens), len(corr_pred), tokens)
130125
pred_result = ""
131-
132-
align_offset = 0
133-
# Need to be aligned
134-
if len(words) != len(tokens):
135-
first_unk_flag = True
136-
for j, word in enumerate(words):
137-
if word.isspace():
138-
tokens.insert(j + 1, word)
139-
corr_pred.insert(j + 1, UNK_id)
140-
det_pred.insert(j + 1, 0) # No error
141-
elif tokens[j] != word:
142-
if tokenizer.convert_tokens_to_ids(word) == UNK_id:
143-
if first_unk_flag:
144-
first_unk_flag = False
145-
corr_pred[j] = UNK_id
146-
det_pred[j] = 0
147-
else:
148-
tokens.insert(j, UNK)
149-
corr_pred.insert(j, UNK_id)
150-
det_pred.insert(j, 0) # No error
151-
continue
152-
elif tokens[j] == UNK:
153-
# Remove rest unk
154-
k = 0
155-
while k + j < len(tokens) and tokens[k + j] == UNK:
156-
k += 1
157-
tokens = tokens[:j] + tokens[j + k:]
158-
corr_pred = corr_pred[:j] + corr_pred[j + k:]
159-
det_pred = det_pred[:j] + det_pred[j + k:]
160-
else:
161-
# Maybe English, number, or suffix
162-
token = tokens[j].lstrip("##")
163-
corr_pred = corr_pred[:j] + [UNK_id] * len(
164-
token) + corr_pred[j + 1:]
165-
det_pred = det_pred[:j] + [0] * len(token) + det_pred[j +
166-
1:]
167-
tokens = tokens[:j] + list(token) + tokens[j + 1:]
168-
first_unk_flag = True
169-
170126
for j, word in enumerate(words):
171127
candidates = tokenizer.convert_ids_to_tokens(corr_pred[j])
172-
if det_pred[j] == 0 or candidates == UNK or candidates == '[PAD]':
128+
if not is_chinese_char(ord(word)) or det_pred[
129+
j] == 0 or candidates == UNK or candidates == '[PAD]':
173130
pred_result += word
174131
else:
175132
pred_result += candidates.lstrip("##")
176-
133+
pred_result += ''.join(rest_words)
177134
return pred_result

paddlenlp/taskflow/text_correction.py

Lines changed: 73 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ def __init__(self, task, model, **kwargs):
104104
)
105105
self._pypinyin = pypinyin
106106
self._max_seq_length = 128
107+
self._batchify_fn = lambda samples, fn=Tuple(
108+
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input
109+
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # segment
110+
Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token]), # pinyin
111+
Stack(axis=0, dtype='int64'), # length
112+
): [data for data in fn(samples)]
113+
self._num_workers = self.kwargs[
114+
'num_workers'] if 'num_workers' in self.kwargs else 0
115+
self._batch_size = self.kwargs[
116+
'batch_size'] if 'batch_size' in self.kwargs else 1
117+
self._lazy_load = self.kwargs[
118+
'lazy_load'] if 'lazy_load' in self.kwargs else False
107119

108120
def _construct_input_spec(self):
109121
"""
@@ -141,61 +153,83 @@ def _construct_tokenizer(self, model):
141153

142154
def _preprocess(self, inputs, padding=True, add_special_tokens=True):
143155
inputs = self._check_input_text(inputs)
144-
batch_size = self.kwargs[
145-
'batch_size'] if 'batch_size' in self.kwargs else 1
146-
trans_func = self._convert_example
147-
148-
batchify_fn = lambda samples, fn=Tuple(
149-
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # input
150-
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # segment
151-
Pad(axis=0, pad_val=self._pinyin_vocab.token_to_idx[self._pinyin_vocab.pad_token]), # pinyin
152-
Stack(axis=0, dtype='int64'), # length
153-
): [data for data in fn(samples)]
154-
155156
examples = []
156157
texts = []
157158
for text in inputs:
158159
if not (isinstance(text, str) and len(text) > 0):
159160
continue
160161
example = {"source": text.strip()}
161-
input_ids, token_type_ids, pinyin_ids, length = trans_func(example)
162+
input_ids, token_type_ids, pinyin_ids, length = self._convert_example(
163+
example)
162164
examples.append((input_ids, token_type_ids, pinyin_ids, length))
163165
texts.append(example["source"])
164166

165167
batch_examples = [
166-
examples[idx:idx + batch_size]
167-
for idx in range(0, len(examples), batch_size)
168+
examples[idx:idx + self._batch_size]
169+
for idx in range(0, len(examples), self._batch_size)
168170
]
169171
batch_texts = [
170-
texts[idx:idx + batch_size]
171-
for idx in range(0, len(examples), batch_size)
172+
texts[idx:idx + self._batch_size]
173+
for idx in range(0, len(examples), self._batch_size)
172174
]
173175
outputs = {}
174176
outputs['batch_examples'] = batch_examples
175177
outputs['batch_texts'] = batch_texts
176-
self.batchify_fn = batchify_fn
178+
if not self._static_mode:
179+
180+
def read(inputs):
181+
for text in inputs:
182+
example = {"source": text.strip()}
183+
input_ids, token_type_ids, pinyin_ids, length = self._convert_example(
184+
example)
185+
yield input_ids, token_type_ids, pinyin_ids, length
186+
187+
infer_ds = load_dataset(read, inputs=inputs, lazy=self._lazy_load)
188+
outputs['data_loader'] = paddle.io.DataLoader(
189+
infer_ds,
190+
collate_fn=self._batchify_fn,
191+
num_workers=self._num_workers,
192+
batch_size=self._batch_size,
193+
shuffle=False,
194+
return_list=True)
195+
177196
return outputs
178197

179198
def _run_model(self, inputs):
180199
"""
181200
Run the task model from the outputs of the `_tokenize` function.
182201
"""
183202
results = []
184-
with static_mode_guard():
185-
for examples in inputs['batch_examples']:
186-
token_ids, token_type_ids, pinyin_ids, lengths = self.batchify_fn(
187-
examples)
188-
self.input_handles[0].copy_from_cpu(token_ids)
189-
self.input_handles[1].copy_from_cpu(pinyin_ids)
190-
self.predictor.run()
191-
det_preds = self.output_handle[0].copy_to_cpu()
192-
char_preds = self.output_handle[1].copy_to_cpu()
193-
194-
batch_result = []
195-
for i in range(len(lengths)):
196-
batch_result.append(
197-
(det_preds[i], char_preds[i], lengths[i]))
198-
results.append(batch_result)
203+
if not self._static_mode:
204+
with dygraph_mode_guard():
205+
for examples in inputs['data_loader']:
206+
token_ids, token_type_ids, pinyin_ids, lengths = examples
207+
det_preds, char_preds = self._model(token_ids, pinyin_ids)
208+
det_preds = det_preds.numpy()
209+
char_preds = char_preds.numpy()
210+
lengths = lengths.numpy()
211+
212+
batch_result = []
213+
for i in range(len(lengths)):
214+
batch_result.append(
215+
(det_preds[i], char_preds[i], lengths[i]))
216+
results.append(batch_result)
217+
else:
218+
with static_mode_guard():
219+
for examples in inputs['batch_examples']:
220+
token_ids, token_type_ids, pinyin_ids, lengths = self._batchify_fn(
221+
examples)
222+
self.input_handles[0].copy_from_cpu(token_ids)
223+
self.input_handles[1].copy_from_cpu(pinyin_ids)
224+
self.predictor.run()
225+
det_preds = self.output_handle[0].copy_to_cpu()
226+
char_preds = self.output_handle[1].copy_to_cpu()
227+
228+
batch_result = []
229+
for i in range(len(lengths)):
230+
batch_result.append(
231+
(det_preds[i], char_preds[i], lengths[i]))
232+
results.append(batch_result)
199233
inputs['batch_results'] = results
200234
return inputs
201235

@@ -232,7 +266,7 @@ def _postprocess(self, inputs):
232266

233267
def _convert_example(self, example):
234268
source = example["source"]
235-
words = self._tokenizer.tokenize(text=source)
269+
words = list(source)
236270
if len(words) > self._max_seq_length - 2:
237271
words = words[:self._max_seq_length - 2]
238272
length = len(words)
@@ -269,64 +303,22 @@ def _convert_example(self, example):
269303
def _parse_decode(self, words, corr_preds, det_preds, lengths):
270304
UNK = self._tokenizer.unk_token
271305
UNK_id = self._tokenizer.convert_tokens_to_ids(UNK)
272-
tokens = self._tokenizer.tokenize(words)
273-
if len(tokens) > self._max_seq_length - 2:
274-
tokens = tokens[:self._max_seq_length - 2]
306+
275307
corr_pred = corr_preds[1:1 + lengths].tolist()
276308
det_pred = det_preds[1:1 + lengths].tolist()
277309
words = list(words)
310+
rest_words = []
278311
if len(words) > self._max_seq_length - 2:
312+
rest_words = words[max_seq_length - 2:]
279313
words = words[:self._max_seq_length - 2]
280314

281-
assert len(tokens) == len(
282-
corr_pred
283-
), "The number of tokens should be equal to the number of labels {}: {}: {}".format(
284-
len(tokens), len(corr_pred), tokens)
285315
pred_result = ""
286-
287-
align_offset = 0
288-
# Need to be aligned
289-
if len(words) != len(tokens):
290-
first_unk_flag = True
291-
for j, word in enumerate(words):
292-
if word.isspace():
293-
tokens.insert(j + 1, word)
294-
corr_pred.insert(j + 1, UNK_id)
295-
det_pred.insert(j + 1, 0) # No error
296-
elif tokens[j] != word:
297-
if self._tokenizer.convert_tokens_to_ids(word) == UNK_id:
298-
if first_unk_flag:
299-
first_unk_flag = False
300-
corr_pred[j] = UNK_id
301-
det_pred[j] = 0
302-
else:
303-
tokens.insert(j, UNK)
304-
corr_pred.insert(j, UNK_id)
305-
det_pred.insert(j, 0) # No error
306-
continue
307-
elif tokens[j] == UNK:
308-
# Remove rest unk
309-
k = 0
310-
while k + j < len(tokens) and tokens[k + j] == UNK:
311-
k += 1
312-
tokens = tokens[:j] + tokens[j + k:]
313-
corr_pred = corr_pred[:j] + corr_pred[j + k:]
314-
det_pred = det_pred[:j] + det_pred[j + k:]
315-
else:
316-
# Maybe English, number, or suffix
317-
token = tokens[j].lstrip("##")
318-
corr_pred = corr_pred[:j] + [UNK_id] * len(
319-
token) + corr_pred[j + 1:]
320-
det_pred = det_pred[:j] + [0] * len(token) + det_pred[
321-
j + 1:]
322-
tokens = tokens[:j] + list(token) + tokens[j + 1:]
323-
first_unk_flag = True
324-
325316
for j, word in enumerate(words):
326317
candidates = self._tokenizer.convert_ids_to_tokens(corr_pred[j])
327-
if det_pred[j] == 0 or candidates == UNK or candidates == '[PAD]':
318+
if not is_chinese_char(ord(word)) or det_pred[
319+
j] == 0 or candidates == UNK or candidates == '[PAD]':
328320
pred_result += word
329321
else:
330322
pred_result += candidates.lstrip("##")
331-
323+
pred_result += ''.join(rest_words)
332324
return pred_result

0 commit comments

Comments
 (0)