22title : English-to-Spanish translation with a sequence-to-sequence Transformer
33author : ' [fchollet](https://twitter.com/fchollet)'
44date-created : 2021/05/26
5- last-modified : 2023/02/25
5+ last-modified : 2024/11/18
66description : Implementing a sequence-to-sequence Transformer and training it on a
77 machine translation task.
88accelerator : GPU
@@ -174,7 +174,7 @@ using the source sentence and the target words 0 to N.
174174As such, the training dataset will yield a tuple ` (inputs, targets) ` , where:
175175
176176- ` inputs ` is a dictionary with the keys ` encoder_inputs ` and ` decoder_inputs ` .
177- ` encoder_inputs ` is the vectorized source sentence and ` encoder_inputs ` is the target sentence "so far",
177+ ` encoder_inputs ` is the vectorized source sentence and ` decoder_inputs ` is the target sentence "so far",
178178that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
179179- ` target ` is the target sentence offset by one step:
180180it provides the next words in the target sentence -- what the model will try to predict.
@@ -304,10 +304,7 @@ class PositionalEmbedding(layers.Layer):
304304 return embedded_tokens + embedded_positions
305305
306306 def compute_mask (self , inputs , mask = None ):
307- if mask is None :
308- return None
309- else :
310- return ops.not_equal(inputs, 0 )
307+ return ops.not_equal(inputs, 0 )
311308
312309 def get_config (self ):
313310 config = super ().get_config()
@@ -344,24 +341,30 @@ class TransformerDecoder(layers.Layer):
344341 self .layernorm_3 = layers.LayerNormalization()
345342 self .supports_masking = True
346343
347- def call (self , inputs , encoder_outputs , mask = None ):
344+ def call (self , inputs , mask = None ):
345+ inputs, encoder_outputs = inputs
348346 causal_mask = self .get_causal_attention_mask(inputs)
349- if mask is not None :
350- padding_mask = ops.cast( mask[:, None , :], dtype = " int32 " )
351- padding_mask = ops.minimum(padding_mask, causal_mask)
347+
348+ if mask is None :
349+ inputs_padding_mask, encoder_outputs_padding_mask = None , None
352350 else :
353- padding_mask = None
351+ inputs_padding_mask, encoder_outputs_padding_mask = mask
354352
355353 attention_output_1 = self .attention_1(
356- query = inputs, value = inputs, key = inputs, attention_mask = causal_mask
354+ query = inputs,
355+ value = inputs,
356+ key = inputs,
357+ attention_mask = causal_mask,
358+ query_mask = inputs_padding_mask,
357359 )
358360 out_1 = self .layernorm_1(inputs + attention_output_1)
359361
360362 attention_output_2 = self .attention_2(
361363 query = out_1,
362364 value = encoder_outputs,
363365 key = encoder_outputs,
364- attention_mask = padding_mask,
366+ query_mask = inputs_padding_mask,
367+ key_mask = encoder_outputs_padding_mask,
365368 )
366369 out_2 = self .layernorm_2(out_1 + attention_output_2)
367370
@@ -408,14 +411,15 @@ encoder = keras.Model(encoder_inputs, encoder_outputs)
408411decoder_inputs = keras.Input(shape = (None ,), dtype = " int64" , name = " decoder_inputs" )
409412encoded_seq_inputs = keras.Input(shape = (None , embed_dim), name = " decoder_state_inputs" )
410413x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
411- x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs )
414+ x = TransformerDecoder(embed_dim, latent_dim, num_heads)([ x, encoder_outputs] )
412415x = layers.Dropout(0.5 )(x)
413416decoder_outputs = layers.Dense(vocab_size, activation = " softmax" )(x)
414417decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
415418
416- decoder_outputs = decoder([decoder_inputs, encoder_outputs])
417419transformer = keras.Model(
418- [encoder_inputs, decoder_inputs], decoder_outputs, name = " transformer"
420+ {" encoder_inputs" : encoder_inputs, " decoder_inputs" : decoder_inputs},
421+ decoder_outputs,
422+ name = " transformer" ,
419423)
420424```
421425
@@ -432,7 +436,9 @@ epochs = 1 # This should be at least 30 for convergence
432436
433437transformer.summary()
434438transformer.compile(
435- " rmsprop" , loss = " sparse_categorical_crossentropy" , metrics = [" accuracy" ]
439+ " rmsprop" ,
440+ loss = keras.losses.SparseCategoricalCrossentropy(ignore_class = 0 ),
441+ metrics = [" accuracy" ],
436442)
437443transformer.fit(train_ds, epochs = epochs, validation_data = val_ds)
438444```
@@ -455,7 +461,12 @@ def decode_sequence(input_sentence):
455461 decoded_sentence = " [start]"
456462 for i in range (max_decoded_sentence_length):
457463 tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :- 1 ]
458- predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
464+ predictions = transformer(
465+ {
466+ " encoder_inputs" : tokenized_input_sentence,
467+ " decoder_inputs" : tokenized_target_sentence,
468+ }
469+ )
459470
460471 # ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
461472 sampled_token_index = ops.convert_to_numpy(
0 commit comments