From 02bb1f2cb3eb861038a7b4a5be28d51082ae6392 Mon Sep 17 00:00:00 2001 From: FrostML <380185688@qq.com> Date: Fri, 20 Aug 2021 06:32:01 +0000 Subject: [PATCH] fix paddle.sum --- .../ops/faster_transformer/transformer/faster_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py index 8bee55af23d2..c04127b2329b 100644 --- a/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py +++ b/paddlenlp/ops/faster_transformer/transformer/faster_transformer.py @@ -125,6 +125,7 @@ def forward(self, src_word): mem_seq_lens = paddle.sum(paddle.cast( src_word != self.bos_id, dtype="int32"), + dtype="int32", axis=1) ids = self.decoding(enc_output, mem_seq_lens)