@@ -104,6 +104,18 @@ def __init__(self, task, model, **kwargs):
104
104
)
105
105
self ._pypinyin = pypinyin
106
106
self ._max_seq_length = 128
107
+ self ._batchify_fn = lambda samples , fn = Tuple (
108
+ Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_id ), # input
109
+ Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_type_id ), # segment
110
+ Pad (axis = 0 , pad_val = self ._pinyin_vocab .token_to_idx [self ._pinyin_vocab .pad_token ]), # pinyin
111
+ Stack (axis = 0 , dtype = 'int64' ), # length
112
+ ): [data for data in fn (samples )]
113
+ self ._num_workers = self .kwargs [
114
+ 'num_workers' ] if 'num_workers' in self .kwargs else 0
115
+ self ._batch_size = self .kwargs [
116
+ 'batch_size' ] if 'batch_size' in self .kwargs else 1
117
+ self ._lazy_load = self .kwargs [
118
+ 'lazy_load' ] if 'lazy_load' in self .kwargs else False
107
119
108
120
def _construct_input_spec (self ):
109
121
"""
@@ -141,61 +153,83 @@ def _construct_tokenizer(self, model):
141
153
142
154
def _preprocess (self , inputs , padding = True , add_special_tokens = True ):
143
155
inputs = self ._check_input_text (inputs )
144
- batch_size = self .kwargs [
145
- 'batch_size' ] if 'batch_size' in self .kwargs else 1
146
- trans_func = self ._convert_example
147
-
148
- batchify_fn = lambda samples , fn = Tuple (
149
- Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_id ), # input
150
- Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_type_id ), # segment
151
- Pad (axis = 0 , pad_val = self ._pinyin_vocab .token_to_idx [self ._pinyin_vocab .pad_token ]), # pinyin
152
- Stack (axis = 0 , dtype = 'int64' ), # length
153
- ): [data for data in fn (samples )]
154
-
155
156
examples = []
156
157
texts = []
157
158
for text in inputs :
158
159
if not (isinstance (text , str ) and len (text ) > 0 ):
159
160
continue
160
161
example = {"source" : text .strip ()}
161
- input_ids , token_type_ids , pinyin_ids , length = trans_func (example )
162
+ input_ids , token_type_ids , pinyin_ids , length = self ._convert_example (
163
+ example )
162
164
examples .append ((input_ids , token_type_ids , pinyin_ids , length ))
163
165
texts .append (example ["source" ])
164
166
165
167
batch_examples = [
166
- examples [idx :idx + batch_size ]
167
- for idx in range (0 , len (examples ), batch_size )
168
+ examples [idx :idx + self . _batch_size ]
169
+ for idx in range (0 , len (examples ), self . _batch_size )
168
170
]
169
171
batch_texts = [
170
- texts [idx :idx + batch_size ]
171
- for idx in range (0 , len (examples ), batch_size )
172
+ texts [idx :idx + self . _batch_size ]
173
+ for idx in range (0 , len (examples ), self . _batch_size )
172
174
]
173
175
outputs = {}
174
176
outputs ['batch_examples' ] = batch_examples
175
177
outputs ['batch_texts' ] = batch_texts
176
- self .batchify_fn = batchify_fn
178
+ if not self ._static_mode :
179
+
180
+ def read (inputs ):
181
+ for text in inputs :
182
+ example = {"source" : text .strip ()}
183
+ input_ids , token_type_ids , pinyin_ids , length = self ._convert_example (
184
+ example )
185
+ yield input_ids , token_type_ids , pinyin_ids , length
186
+
187
+ infer_ds = load_dataset (read , inputs = inputs , lazy = self ._lazy_load )
188
+ outputs ['data_loader' ] = paddle .io .DataLoader (
189
+ infer_ds ,
190
+ collate_fn = self ._batchify_fn ,
191
+ num_workers = self ._num_workers ,
192
+ batch_size = self ._batch_size ,
193
+ shuffle = False ,
194
+ return_list = True )
195
+
177
196
return outputs
178
197
179
198
def _run_model (self , inputs ):
180
199
"""
181
200
Run the task model from the outputs of the `_tokenize` function.
182
201
"""
183
202
results = []
184
- with static_mode_guard ():
185
- for examples in inputs ['batch_examples' ]:
186
- token_ids , token_type_ids , pinyin_ids , lengths = self .batchify_fn (
187
- examples )
188
- self .input_handles [0 ].copy_from_cpu (token_ids )
189
- self .input_handles [1 ].copy_from_cpu (pinyin_ids )
190
- self .predictor .run ()
191
- det_preds = self .output_handle [0 ].copy_to_cpu ()
192
- char_preds = self .output_handle [1 ].copy_to_cpu ()
193
-
194
- batch_result = []
195
- for i in range (len (lengths )):
196
- batch_result .append (
197
- (det_preds [i ], char_preds [i ], lengths [i ]))
198
- results .append (batch_result )
203
+ if not self ._static_mode :
204
+ with dygraph_mode_guard ():
205
+ for examples in inputs ['data_loader' ]:
206
+ token_ids , token_type_ids , pinyin_ids , lengths = examples
207
+ det_preds , char_preds = self ._model (token_ids , pinyin_ids )
208
+ det_preds = det_preds .numpy ()
209
+ char_preds = char_preds .numpy ()
210
+ lengths = lengths .numpy ()
211
+
212
+ batch_result = []
213
+ for i in range (len (lengths )):
214
+ batch_result .append (
215
+ (det_preds [i ], char_preds [i ], lengths [i ]))
216
+ results .append (batch_result )
217
+ else :
218
+ with static_mode_guard ():
219
+ for examples in inputs ['batch_examples' ]:
220
+ token_ids , token_type_ids , pinyin_ids , lengths = self ._batchify_fn (
221
+ examples )
222
+ self .input_handles [0 ].copy_from_cpu (token_ids )
223
+ self .input_handles [1 ].copy_from_cpu (pinyin_ids )
224
+ self .predictor .run ()
225
+ det_preds = self .output_handle [0 ].copy_to_cpu ()
226
+ char_preds = self .output_handle [1 ].copy_to_cpu ()
227
+
228
+ batch_result = []
229
+ for i in range (len (lengths )):
230
+ batch_result .append (
231
+ (det_preds [i ], char_preds [i ], lengths [i ]))
232
+ results .append (batch_result )
199
233
inputs ['batch_results' ] = results
200
234
return inputs
201
235
@@ -232,7 +266,7 @@ def _postprocess(self, inputs):
232
266
233
267
def _convert_example (self , example ):
234
268
source = example ["source" ]
235
- words = self . _tokenizer . tokenize ( text = source )
269
+ words = list ( source )
236
270
if len (words ) > self ._max_seq_length - 2 :
237
271
words = words [:self ._max_seq_length - 2 ]
238
272
length = len (words )
@@ -269,64 +303,22 @@ def _convert_example(self, example):
269
303
def _parse_decode (self , words , corr_preds , det_preds , lengths ):
270
304
UNK = self ._tokenizer .unk_token
271
305
UNK_id = self ._tokenizer .convert_tokens_to_ids (UNK )
272
- tokens = self ._tokenizer .tokenize (words )
273
- if len (tokens ) > self ._max_seq_length - 2 :
274
- tokens = tokens [:self ._max_seq_length - 2 ]
306
+
275
307
corr_pred = corr_preds [1 :1 + lengths ].tolist ()
276
308
det_pred = det_preds [1 :1 + lengths ].tolist ()
277
309
words = list (words )
310
+ rest_words = []
278
311
if len (words ) > self ._max_seq_length - 2 :
312
+ rest_words = words [max_seq_length - 2 :]
279
313
words = words [:self ._max_seq_length - 2 ]
280
314
281
- assert len (tokens ) == len (
282
- corr_pred
283
- ), "The number of tokens should be equal to the number of labels {}: {}: {}" .format (
284
- len (tokens ), len (corr_pred ), tokens )
285
315
pred_result = ""
286
-
287
- align_offset = 0
288
- # Need to be aligned
289
- if len (words ) != len (tokens ):
290
- first_unk_flag = True
291
- for j , word in enumerate (words ):
292
- if word .isspace ():
293
- tokens .insert (j + 1 , word )
294
- corr_pred .insert (j + 1 , UNK_id )
295
- det_pred .insert (j + 1 , 0 ) # No error
296
- elif tokens [j ] != word :
297
- if self ._tokenizer .convert_tokens_to_ids (word ) == UNK_id :
298
- if first_unk_flag :
299
- first_unk_flag = False
300
- corr_pred [j ] = UNK_id
301
- det_pred [j ] = 0
302
- else :
303
- tokens .insert (j , UNK )
304
- corr_pred .insert (j , UNK_id )
305
- det_pred .insert (j , 0 ) # No error
306
- continue
307
- elif tokens [j ] == UNK :
308
- # Remove rest unk
309
- k = 0
310
- while k + j < len (tokens ) and tokens [k + j ] == UNK :
311
- k += 1
312
- tokens = tokens [:j ] + tokens [j + k :]
313
- corr_pred = corr_pred [:j ] + corr_pred [j + k :]
314
- det_pred = det_pred [:j ] + det_pred [j + k :]
315
- else :
316
- # Maybe English, number, or suffix
317
- token = tokens [j ].lstrip ("##" )
318
- corr_pred = corr_pred [:j ] + [UNK_id ] * len (
319
- token ) + corr_pred [j + 1 :]
320
- det_pred = det_pred [:j ] + [0 ] * len (token ) + det_pred [
321
- j + 1 :]
322
- tokens = tokens [:j ] + list (token ) + tokens [j + 1 :]
323
- first_unk_flag = True
324
-
325
316
for j , word in enumerate (words ):
326
317
candidates = self ._tokenizer .convert_ids_to_tokens (corr_pred [j ])
327
- if det_pred [j ] == 0 or candidates == UNK or candidates == '[PAD]' :
318
+ if not is_chinese_char (ord (word )) or det_pred [
319
+ j ] == 0 or candidates == UNK or candidates == '[PAD]' :
328
320
pred_result += word
329
321
else :
330
322
pred_result += candidates .lstrip ("##" )
331
-
323
+ pred_result += '' . join ( rest_words )
332
324
return pred_result
0 commit comments