@@ -229,11 +229,15 @@ def make(self, encoder: LayerRef):
229229 blank_idx = self .ctx .blank_idx
230230
231231 rec_decoder = {
232- "am0" : {"class" : "gather_nd" , "from" : _base (encoder ), "position" : "prev:t" }, # [B,D]
232+ "index" : {"class" : "eval" , "from" : ["prev:t" , "enc_seq_len" ], "eval" : 'tf.minimum(source(0), source(1)-1)' },
233+ "am0" : {"class" : "gather_nd" , "from" : _base (encoder ), "position" : "index" }, # [B,D]
233234 "am" : {"class" : "copy" , "from" : "am0" if search else "data:source" },
234235
236+ "prev_output_wo_b" : {
237+ "class" : "masked_computation" , "unit" : {"class" : "copy" , "initial_output" : 0 },
238+ "from" : "prev:output_" , "mask" : "prev:output_emit" , "initial_output" : 0 },
235239 "prev_out_non_blank" : {
236- "class" : "reinterpret_data" , "from" : "prev:output_ " , "set_sparse_dim" : target .get_num_classes ()},
240+ "class" : "reinterpret_data" , "from" : "prev_output_wo_b " , "set_sparse_dim" : target .get_num_classes ()},
237241
238242 "slow_rnn" : self .slow_rnn .make (
239243 prev_sparse_label_nb = "prev_out_non_blank" ,
@@ -252,7 +256,7 @@ def make(self, encoder: LayerRef):
252256
253257 "output" : {
254258 "class" : 'choice' ,
255- 'target' : target .key , # note: wrong! but this is ignored both in full-sum training and in search
259+ 'target' : target .key if train else None , # note: wrong! but this is ignored both in full-sum training and in search
256260 'beam_size' : beam_size ,
257261 'from' : "output_log_prob_wb" , "input_type" : "log_prob" ,
258262 "initial_output" : 0 ,
0 commit comments