Skip to content

Commit bb413b0

Browse files
authored
Merge pull request PaddlePaddle#1126 from joey12300/fix_crf_decode_len_equal_1_bug
fix viterbi len=1 bug
2 parents 04de795 + 74a7f10 commit bb413b0

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

paddlenlp/layers/crf.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ def forward(self, inputs, lengths):
424424

425425
# last_ids: batch_size
426426
scores, last_ids = alpha.max(1), alpha.argmax(1)
427+
if max_seq_len == 1:
428+
return scores, last_ids.unsqueeze(1)
427429
# Trace back the best path
428430
# historys: seq_len, batch_size, n_labels
429431
historys = paddle.stack(historys)
@@ -438,10 +440,14 @@ def forward(self, inputs, lengths):
438440
# hist: batch_size, n_labels
439441
left_length = left_length + 1
440442
gather_idx = batch_offset + last_ids
441-
tag_mask = paddle.cast((left_length >= 0), 'int64')
443+
tag_mask = paddle.cast((left_length > 0), 'int64')
442444
last_ids_update = paddle.gather(hist.flatten(),
443445
gather_idx) * tag_mask
446+
zero_len_mask = paddle.cast((left_length == 0), 'int64')
447+
last_ids_update = last_ids_update * (1 - zero_len_mask
448+
) + last_ids * zero_len_mask
444449
batch_path.append(last_ids_update)
450+
tag_mask = paddle.cast((left_length >= 0), 'int64')
445451
last_ids = last_ids_update + last_ids * (1 - tag_mask)
446452
batch_path = paddle.reverse(paddle.stack(batch_path, 1), [1])
447453
return scores, batch_path

0 commit comments

Comments
 (0)