22title : English-to-Spanish translation with a sequence-to-sequence Transformer
33author : ' [fchollet](https://twitter.com/fchollet), t-kalinowski'
44date-created : 2021/05/26
5- last-modified : 2023/02/25
5+ last-modified : 2024/11/18
66description : Implementing a sequence-to-sequence Transformer and training it on a
77 machine translation task.
88accelerator : GPU
@@ -199,11 +199,11 @@ using the source sentence and the target words from 1 to N.
199199
200200As such, the training dataset will yield a tuple ` (inputs, targets) ` , where:
201201
202- - ` inputs ` is a dictionary (named list) with the keys (names) ` encoder_inputs ` and ` decoder_inputs ` .
203- ` encoder_inputs ` is the vectorized source sentence and ` encoder_inputs ` is the target sentence "so far",
204- that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
205- - ` target ` is the target sentence offset by one step:
206- it provides the next words in the target sentence -- what the model will try to predict.
202+ * ` inputs ` is a dictionary (named list) with the keys (names) ` encoder_inputs ` and ` decoder_inputs ` .
203+ ` encoder_inputs ` is the vectorized source sentence and ` decoder_inputs ` is the target sentence "so far",
204+ that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
205+ * ` target ` is the target sentence offset by one step:
206+ it provides the next words in the target sentence -- what the model will try to predict.
207207
208208``` {r}
209209format_pair <- function(pair) {
@@ -347,29 +347,37 @@ layer_transformer_decoder <- Layer(
347347 repeats <- op_stack(c(batch_size, 1L, 1L))
348348 op_tile(mask[NULL, , ], repeats)
349349 },
350- call = function(inputs, encoder_outputs, mask = NULL) {
351- causal_mask <- self$get_causal_attention_mask(inputs)
352350
353- if (is.null(mask))
354- mask <- causal_mask
355- else
356- mask %<>% { op_minimum(op_cast(.[, NULL, ], "int32"),
357- causal_mask) }
358-
359- inputs %>%
360- { self$attention_1(query = ., value = ., key = .,
361- attention_mask = causal_mask) + . } %>%
362- self$layernorm_1() %>%
363-
364- { self$attention_2(query = .,
365- value = encoder_outputs,
366- key = encoder_outputs,
367- attention_mask = mask) + . } %>%
368- self$layernorm_2() %>%
369-
370- { self$dense_proj(.) + . } %>%
371- self$layernorm_3()
351+ call = function(inputs, mask = NULL) {
352+ c(inputs_seq, encoder_outputs) %<-% inputs
353+ causal_mask <- self$get_causal_attention_mask(inputs_seq)
354+
355+ if (is.null(mask)) {
356+ inputs_padding_mask <- NULL
357+ encoder_outputs_padding_mask <- NULL
358+ } else {
359+ c(inputs_padding_mask, encoder_outputs_padding_mask) %<-% mask
360+ }
361+
362+ attention_output_1 <- self$attention_1(
363+ query = inputs_seq,
364+ value = inputs_seq,
365+ key = inputs_seq,
366+ attention_mask = causal_mask,
367+ query_mask = inputs_padding_mask
368+ )
369+ out_1 <- self$layernorm_1(inputs_seq + attention_output_1)
370+
371+ attention_output_2 <- self$attention_2(
372+ query = out_1,
373+ value = encoder_outputs,
374+ key = encoder_outputs,
375+ query_mask = inputs_padding_mask,
376+ key_mask = encoder_outputs_padding_mask
377+ )
378+ out_2 <- self$layernorm_2(out_1 + attention_output_2)
372379
380+ self$layernorm_3(out_2 + self$dense_proj(out_2))
373381 }
374382)
375383
@@ -398,7 +406,6 @@ layer_positional_embedding <- Layer(
398406 },
399407
400408 compute_mask = function(inputs, mask = NULL) {
401- if (is.null(mask)) return (NULL)
402409 inputs != 0L
403410 },
404411
@@ -437,21 +444,22 @@ transformer_decoder <- layer_transformer_decoder(NULL,
437444
438445decoder_outputs <- decoder_inputs %>%
439446 layer_positional_embedding(sequence_length, vocab_size, embed_dim) %>%
440- transformer_decoder(., encoded_seq_inputs) %>%
447+ { transformer_decoder(list( ., encoded_seq_inputs)) } %>%
441448 layer_dropout(0.5) %>%
442449 layer_dense(vocab_size, activation="softmax")
443450
444451decoder <- keras_model(inputs = list(decoder_inputs, encoded_seq_inputs),
445452 outputs = decoder_outputs)
446453
447- decoder_outputs = decoder(list(decoder_inputs, encoder_outputs))
454+ decoder_outputs <- decoder(list(decoder_inputs, encoder_outputs))
448455
449- transformer <- keras_model(list(encoder_inputs, decoder_inputs),
450- decoder_outputs,
451- name = "transformer")
456+ transformer <- keras_model(
457+ inputs = list(encoder_inputs = encoder_inputs, decoder_inputs = decoder_inputs),
458+ outputs = decoder_outputs,
459+ name = "transformer"
460+ )
452461```
453462
454-
455463## Training our model
456464
457465We'll use accuracy as a quick way to monitor training progress on the validation data.
@@ -466,15 +474,14 @@ epochs <- 1 # This should be at least 30 for convergence
466474transformer
467475transformer |> compile(
468476 "rmsprop",
469- loss = "sparse_categorical_crossentropy" ,
477+ loss = loss_sparse_categorical_crossentropy(ignore_class = 0) ,
470478 metrics = "accuracy"
471479)
472480
473481transformer |> fit(train_ds, epochs = epochs,
474482 validation_data = val_ds)
475483```
476484
477-
478485## Decoding test sentences
479486
480487Finally, let's demonstrate how to translate brand new English sentences.
@@ -500,8 +507,10 @@ tf_decode_sequence <- tf_function(function(input_sentence) {
500507 spa_vectorization(decoded_sentence)[,NA:-1]
501508
502509 next_token_predictions <-
503- transformer(list(tokenized_input_sentence,
504- tokenized_target_sentence))
510+ transformer(list(
511+ encoder_inputs = tokenized_input_sentence,
512+ decoder_inputs = tokenized_target_sentence
513+ ))
505514
506515 sampled_token_index <- tf$argmax(next_token_predictions[0, i, ])
507516 sampled_token <- spa_vocab[sampled_token_index]
@@ -527,10 +536,8 @@ for (i in seq(20)) {
527536 cat(" Model Translation:", input_sentence %>% as_tensor() %>%
528537 tf_decode_sequence() %>% as.character(), "\n")
529538}
530-
531539```
532540
533-
534541After 30 epochs, we get results such as:
535542
536543```
0 commit comments