Skip to content

Commit 8c387e5

Browse files
author
gongenlei
authored
[BeamSearchV2] Support converting dygraph to static graph (PaddlePaddle#959)
* feat: beamsearch v2 support dygraph to static graph * docs: add some annotations
1 parent fe69df5 commit 8c387e5

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

paddlenlp/transformers/transformer/modeling.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,14 +1010,15 @@ def beam_search_v2(self, src_word, beam_size=4, max_len=None, alpha=0.6):
10101010
"""
10111011

10121012
def expand_to_beam_size(tensor, beam_size):
1013-
tensor = paddle.reshape(tensor,
1014-
[tensor.shape[0], 1] + tensor.shape[1:])
1013+
tensor = paddle.unsqueeze(tensor, axis=1)
10151014
tile_dims = [1] * len(tensor.shape)
10161015
tile_dims[1] = beam_size
10171016
return paddle.tile(tensor, tile_dims)
10181017

10191018
def merge_beam_dim(tensor):
1020-
return paddle.reshape(tensor, [-1] + tensor.shape[2:])
1019+
shape = tensor.shape
1020+
return paddle.reshape(tensor,
1021+
[shape[0] * shape[1]] + list(shape[2:]))
10211022

10221023
# run encoder
10231024
src_max_len = paddle.shape(src_word)[-1]
@@ -1045,23 +1046,26 @@ def merge_beam_dim(tensor):
10451046

10461047
### initialize states of beam search ###
10471048
## init for the alive ##
1048-
initial_log_probs = paddle.to_tensor(
1049+
initial_log_probs = paddle.assign(
10491050
np.array(
10501051
[[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
10511052
alive_log_probs = paddle.tile(initial_log_probs, [batch_size, 1])
1052-
# (batch_size, beam_size, 1)
1053-
alive_seq = paddle.to_tensor(
1054-
np.tile(np.array([[[self.bos_id]]]), (batch_size, beam_size, 1)),
1055-
dtype=src_word.dtype)
1053+
1054+
alive_seq = paddle.tile(
1055+
paddle.cast(
1056+
paddle.assign(np.array([[[self.bos_id]]])), src_word.dtype),
1057+
[batch_size, beam_size, 1])
10561058

10571059
## init for the finished ##
1058-
finished_scores = paddle.to_tensor(
1060+
finished_scores = paddle.assign(
10591061
np.array(
10601062
[[-inf] * beam_size], dtype="float32"))
10611063
finished_scores = paddle.tile(finished_scores, [batch_size, 1])
1062-
finished_seq = paddle.to_tensor(
1063-
np.tile(np.array([[[self.bos_id]]]), (batch_size, beam_size, 1)),
1064-
dtype=src_word.dtype)
1064+
1065+
finished_seq = paddle.tile(
1066+
paddle.cast(
1067+
paddle.assign(np.array([[[self.bos_id]]])), src_word.dtype),
1068+
[batch_size, beam_size, 1])
10651069
finished_flags = paddle.zeros_like(finished_scores)
10661070

10671071
### initialize inputs and states of transformer decoder ###
@@ -1076,7 +1080,7 @@ def merge_beam_dim(tensor):
10761080
## init states (caches) for transformer, need to be updated according to selected beam
10771081
caches = self.transformer.decoder.gen_cache(enc_output, do_zip=False)
10781082

1079-
def update_states(caches, topk_coordinates, beam_size):
1083+
def update_states(caches, topk_coordinates, beam_size, batch_size):
10801084
new_caches = []
10811085
for cache in caches:
10821086
k = gather_2d(
@@ -1107,9 +1111,11 @@ def gather_2d(tensor_nd,
11071111
beam_size,
11081112
batch_size,
11091113
need_unmerge=False):
1114+
11101115
new_tensor_nd = paddle.reshape(
1111-
tensor_nd, shape=[batch_size, beam_size] +
1112-
tensor_nd.shape[1:]) if need_unmerge else tensor_nd
1116+
tensor_nd,
1117+
shape=[batch_size, beam_size] +
1118+
list(tensor_nd.shape[1:])) if need_unmerge else tensor_nd
11131119
topk_seq = paddle.gather_nd(new_tensor_nd, topk_coordinates)
11141120
return merge_beam_dim(topk_seq) if need_unmerge else topk_seq
11151121

@@ -1162,11 +1168,15 @@ def grow_topk(i, logits, alive_seq, alive_log_probs, states):
11621168
topk_seq = gather_2d(alive_seq, topk_coordinates, beam_size,
11631169
batch_size)
11641170
topk_seq = paddle.concat(
1165-
[topk_seq, paddle.reshape(topk_ids, topk_ids.shape + [1])],
1171+
[
1172+
topk_seq, paddle.reshape(topk_ids,
1173+
list(topk_ids.shape[:]) + [1])
1174+
],
11661175
axis=2)
1167-
states = update_states(states, topk_coordinates, beam_size)
1176+
states = update_states(states, topk_coordinates, beam_size,
1177+
batch_size)
11681178
eos = paddle.full(
1169-
shape=topk_ids.shape,
1179+
shape=paddle.shape(topk_ids),
11701180
dtype=alive_seq.dtype,
11711181
fill_value=self.eos_id)
11721182
topk_finished = paddle.cast(paddle.equal(topk_ids, eos), "float32")
@@ -1192,7 +1202,8 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
11921202

11931203
alive_log_probs = gather_2d(curr_log_probs, topk_coordinates,
11941204
beam_size, batch_size)
1195-
states = update_states(states, topk_coordinates, beam_size * 2)
1205+
states = update_states(states, topk_coordinates, beam_size * 2,
1206+
batch_size)
11961207

11971208
return alive_seq, alive_log_probs, states
11981209

@@ -1234,7 +1245,9 @@ def grow_finished(finished_seq, finished_scores, finished_flags,
12341245
def inner_loop(i, trg_word, alive_seq, alive_log_probs, finished_seq,
12351246
finished_scores, finished_flags, caches):
12361247
trg_pos = paddle.full(
1237-
shape=trg_word.shape, dtype=alive_seq.dtype, fill_value=i)
1248+
shape=paddle.shape(trg_word),
1249+
dtype=alive_seq.dtype,
1250+
fill_value=i)
12381251
trg_emb = self.trg_word_embedding(trg_word)
12391252
trg_pos_emb = self.trg_pos_embedding(trg_pos)
12401253
trg_emb = trg_emb + trg_pos_emb
@@ -1271,13 +1284,19 @@ def is_not_finish(i, trg_word, alive_seq, alive_log_probs, finished_seq,
12711284
finished_seq, finished_scores, finished_flags, caches
12721285
])
12731286

1274-
finished_flags = paddle.any(paddle.cast(
1275-
finished_flags, dtype='bool'),
1276-
axis=1,
1277-
keepdim=True).tile([1, beam_size])
1278-
finished_seq = paddle.where(
1279-
finished_flags.unsqueeze(-1).tile([1, 1, alive_seq.shape[-1]]),
1280-
finished_seq, alive_seq)
1281-
finished_scores = paddle.where(finished_flags, finished_scores,
1282-
alive_log_probs)
1287+
# (gongenlei) `paddle.where` doesn't support broadcast, so we need to use `paddle.unsqueeze`
1288+
# and `paddle.tile` to make condition.shape same as X.shape. But when converting dygraph
1289+
# to static graph, `paddle.tile` will raise error.
1290+
finished_flags = paddle.cast(finished_flags, dtype=finished_seq.dtype)
1291+
neg_finished_flags = 1 - finished_flags
1292+
finished_seq = paddle.multiply(
1293+
finished_seq, finished_flags.unsqueeze(-1)) + paddle.multiply(
1294+
alive_seq, neg_finished_flags.unsqueeze(-1))
1295+
finished_scores = paddle.multiply(
1296+
finished_scores,
1297+
paddle.cast(
1298+
finished_flags, dtype=finished_scores.dtype)) + paddle.multiply(
1299+
alive_log_probs,
1300+
paddle.cast(
1301+
neg_finished_flags, dtype=alive_log_probs.dtype))
12831302
return finished_seq, finished_scores

0 commit comments

Comments
 (0)