@@ -424,6 +424,8 @@ def forward(self, inputs, lengths):
424
424
425
425
# last_ids: batch_size
426
426
scores , last_ids = alpha .max (1 ), alpha .argmax (1 )
427
+ if max_seq_len == 1 :
428
+ return scores , last_ids .unsqueeze (1 )
427
429
# Trace back the best path
428
430
# historys: seq_len, batch_size, n_labels
429
431
historys = paddle .stack (historys )
@@ -438,10 +440,14 @@ def forward(self, inputs, lengths):
438
440
# hist: batch_size, n_labels
439
441
left_length = left_length + 1
440
442
gather_idx = batch_offset + last_ids
441
- tag_mask = paddle .cast ((left_length >= 0 ), 'int64' )
443
+ tag_mask = paddle .cast ((left_length > 0 ), 'int64' )
442
444
last_ids_update = paddle .gather (hist .flatten (),
443
445
gather_idx ) * tag_mask
446
+ zero_len_mask = paddle .cast ((left_length == 0 ), 'int64' )
447
+ last_ids_update = last_ids_update * (1 - zero_len_mask
448
+ ) + last_ids * zero_len_mask
444
449
batch_path .append (last_ids_update )
450
+ tag_mask = paddle .cast ((left_length >= 0 ), 'int64' )
445
451
last_ids = last_ids_update + last_ids * (1 - tag_mask )
446
452
batch_path = paddle .reverse (paddle .stack (batch_path , 1 ), [1 ])
447
453
return scores , batch_path
0 commit comments