Skip to content

Commit 6feb6d0

Browse files
author
gongenlei
authored
fix: fix mem_seq_lens dtype (PaddlePaddle#1050)
1 parent 2cf80b9 commit 6feb6d0

File tree

1 file changed

+2
-1
lines changed
  • paddlenlp/ops/faster_transformer/transformer

1 file changed

+2
-1
lines changed

paddlenlp/ops/faster_transformer/transformer/decoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ def forward(self, src_word):
374374
mem_seq_lens = paddle.sum(paddle.cast(
375375
src_word != self.bos_id, dtype="int32"),
376376
axis=-1,
377-
keepdim=True)
377+
keepdim=True,
378+
dtype="int32")
378379

379380
src_slf_attn_bias = paddle.cast(
380381
src_word == self.bos_id,

0 commit comments

Comments
 (0)