@@ -1010,14 +1010,15 @@ def beam_search_v2(self, src_word, beam_size=4, max_len=None, alpha=0.6):
1010
1010
"""
1011
1011
1012
1012
def expand_to_beam_size (tensor , beam_size ):
1013
- tensor = paddle .reshape (tensor ,
1014
- [tensor .shape [0 ], 1 ] + tensor .shape [1 :])
1013
+ tensor = paddle .unsqueeze (tensor , axis = 1 )
1015
1014
tile_dims = [1 ] * len (tensor .shape )
1016
1015
tile_dims [1 ] = beam_size
1017
1016
return paddle .tile (tensor , tile_dims )
1018
1017
1019
1018
def merge_beam_dim (tensor ):
1020
- return paddle .reshape (tensor , [- 1 ] + tensor .shape [2 :])
1019
+ shape = tensor .shape
1020
+ return paddle .reshape (tensor ,
1021
+ [shape [0 ] * shape [1 ]] + list (shape [2 :]))
1021
1022
1022
1023
# run encoder
1023
1024
src_max_len = paddle .shape (src_word )[- 1 ]
@@ -1045,23 +1046,26 @@ def merge_beam_dim(tensor):
1045
1046
1046
1047
### initialize states of beam search ###
1047
1048
## init for the alive ##
1048
- initial_log_probs = paddle .to_tensor (
1049
+ initial_log_probs = paddle .assign (
1049
1050
np .array (
1050
1051
[[0. ] + [- inf ] * (beam_size - 1 )], dtype = "float32" ))
1051
1052
alive_log_probs = paddle .tile (initial_log_probs , [batch_size , 1 ])
1052
- # (batch_size, beam_size, 1)
1053
- alive_seq = paddle .to_tensor (
1054
- np .tile (np .array ([[[self .bos_id ]]]), (batch_size , beam_size , 1 )),
1055
- dtype = src_word .dtype )
1053
+
1054
+ alive_seq = paddle .tile (
1055
+ paddle .cast (
1056
+ paddle .assign (np .array ([[[self .bos_id ]]])), src_word .dtype ),
1057
+ [batch_size , beam_size , 1 ])
1056
1058
1057
1059
## init for the finished ##
1058
- finished_scores = paddle .to_tensor (
1060
+ finished_scores = paddle .assign (
1059
1061
np .array (
1060
1062
[[- inf ] * beam_size ], dtype = "float32" ))
1061
1063
finished_scores = paddle .tile (finished_scores , [batch_size , 1 ])
1062
- finished_seq = paddle .to_tensor (
1063
- np .tile (np .array ([[[self .bos_id ]]]), (batch_size , beam_size , 1 )),
1064
- dtype = src_word .dtype )
1064
+
1065
+ finished_seq = paddle .tile (
1066
+ paddle .cast (
1067
+ paddle .assign (np .array ([[[self .bos_id ]]])), src_word .dtype ),
1068
+ [batch_size , beam_size , 1 ])
1065
1069
finished_flags = paddle .zeros_like (finished_scores )
1066
1070
1067
1071
### initialize inputs and states of transformer decoder ###
@@ -1076,7 +1080,7 @@ def merge_beam_dim(tensor):
1076
1080
## init states (caches) for transformer, need to be updated according to selected beam
1077
1081
caches = self .transformer .decoder .gen_cache (enc_output , do_zip = False )
1078
1082
1079
- def update_states (caches , topk_coordinates , beam_size ):
1083
+ def update_states (caches , topk_coordinates , beam_size , batch_size ):
1080
1084
new_caches = []
1081
1085
for cache in caches :
1082
1086
k = gather_2d (
@@ -1107,9 +1111,11 @@ def gather_2d(tensor_nd,
1107
1111
beam_size ,
1108
1112
batch_size ,
1109
1113
need_unmerge = False ):
1114
+
1110
1115
new_tensor_nd = paddle .reshape (
1111
- tensor_nd , shape = [batch_size , beam_size ] +
1112
- tensor_nd .shape [1 :]) if need_unmerge else tensor_nd
1116
+ tensor_nd ,
1117
+ shape = [batch_size , beam_size ] +
1118
+ list (tensor_nd .shape [1 :])) if need_unmerge else tensor_nd
1113
1119
topk_seq = paddle .gather_nd (new_tensor_nd , topk_coordinates )
1114
1120
return merge_beam_dim (topk_seq ) if need_unmerge else topk_seq
1115
1121
@@ -1162,11 +1168,15 @@ def grow_topk(i, logits, alive_seq, alive_log_probs, states):
1162
1168
topk_seq = gather_2d (alive_seq , topk_coordinates , beam_size ,
1163
1169
batch_size )
1164
1170
topk_seq = paddle .concat (
1165
- [topk_seq , paddle .reshape (topk_ids , topk_ids .shape + [1 ])],
1171
+ [
1172
+ topk_seq , paddle .reshape (topk_ids ,
1173
+ list (topk_ids .shape [:]) + [1 ])
1174
+ ],
1166
1175
axis = 2 )
1167
- states = update_states (states , topk_coordinates , beam_size )
1176
+ states = update_states (states , topk_coordinates , beam_size ,
1177
+ batch_size )
1168
1178
eos = paddle .full (
1169
- shape = topk_ids .shape ,
1179
+ shape = paddle .shape ( topk_ids ) ,
1170
1180
dtype = alive_seq .dtype ,
1171
1181
fill_value = self .eos_id )
1172
1182
topk_finished = paddle .cast (paddle .equal (topk_ids , eos ), "float32" )
@@ -1192,7 +1202,8 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
1192
1202
1193
1203
alive_log_probs = gather_2d (curr_log_probs , topk_coordinates ,
1194
1204
beam_size , batch_size )
1195
- states = update_states (states , topk_coordinates , beam_size * 2 )
1205
+ states = update_states (states , topk_coordinates , beam_size * 2 ,
1206
+ batch_size )
1196
1207
1197
1208
return alive_seq , alive_log_probs , states
1198
1209
@@ -1234,7 +1245,9 @@ def grow_finished(finished_seq, finished_scores, finished_flags,
1234
1245
def inner_loop (i , trg_word , alive_seq , alive_log_probs , finished_seq ,
1235
1246
finished_scores , finished_flags , caches ):
1236
1247
trg_pos = paddle .full (
1237
- shape = trg_word .shape , dtype = alive_seq .dtype , fill_value = i )
1248
+ shape = paddle .shape (trg_word ),
1249
+ dtype = alive_seq .dtype ,
1250
+ fill_value = i )
1238
1251
trg_emb = self .trg_word_embedding (trg_word )
1239
1252
trg_pos_emb = self .trg_pos_embedding (trg_pos )
1240
1253
trg_emb = trg_emb + trg_pos_emb
@@ -1271,13 +1284,19 @@ def is_not_finish(i, trg_word, alive_seq, alive_log_probs, finished_seq,
1271
1284
finished_seq , finished_scores , finished_flags , caches
1272
1285
])
1273
1286
1274
- finished_flags = paddle .any (paddle .cast (
1275
- finished_flags , dtype = 'bool' ),
1276
- axis = 1 ,
1277
- keepdim = True ).tile ([1 , beam_size ])
1278
- finished_seq = paddle .where (
1279
- finished_flags .unsqueeze (- 1 ).tile ([1 , 1 , alive_seq .shape [- 1 ]]),
1280
- finished_seq , alive_seq )
1281
- finished_scores = paddle .where (finished_flags , finished_scores ,
1282
- alive_log_probs )
1287
+ # (gongenlei) `paddle.where` doesn't support broadcast, so we need to use `paddle.unsqueeze`
1288
+ # and `paddle.tile` to make condition.shape same as X.shape. But when converting dygraph
1289
+ # to static graph, `paddle.tile` will raise error.
1290
+ finished_flags = paddle .cast (finished_flags , dtype = finished_seq .dtype )
1291
+ neg_finished_flags = 1 - finished_flags
1292
+ finished_seq = paddle .multiply (
1293
+ finished_seq , finished_flags .unsqueeze (- 1 )) + paddle .multiply (
1294
+ alive_seq , neg_finished_flags .unsqueeze (- 1 ))
1295
+ finished_scores = paddle .multiply (
1296
+ finished_scores ,
1297
+ paddle .cast (
1298
+ finished_flags , dtype = finished_scores .dtype )) + paddle .multiply (
1299
+ alive_log_probs ,
1300
+ paddle .cast (
1301
+ neg_finished_flags , dtype = alive_log_probs .dtype ))
1283
1302
return finished_seq , finished_scores
0 commit comments