Skip to content

Commit 5c3b2db

Browse files
committed
update neural_machine_translation_with_transformer guide
1 parent bdf93b0 commit 5c3b2db

File tree

3 files changed

+129
-99
lines changed

3 files changed

+129
-99
lines changed

.tether/vignettes-src/examples/nlp/neural_machine_translation_with_transformer.Rmd

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: English-to-Spanish translation with a sequence-to-sequence Transformer
33
author: '[fchollet](https://twitter.com/fchollet)'
44
date-created: 2021/05/26
5-
last-modified: 2023/02/25
5+
last-modified: 2024/11/18
66
description: Implementing a sequence-to-sequence Transformer and training it on a
77
machine translation task.
88
accelerator: GPU
@@ -174,7 +174,7 @@ using the source sentence and the target words 0 to N.
174174
As such, the training dataset will yield a tuple `(inputs, targets)`, where:
175175

176176
- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
177-
`encoder_inputs` is the vectorized source sentence and `encoder_inputs` is the target sentence "so far",
177+
`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
178178
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
179179
- `target` is the target sentence offset by one step:
180180
it provides the next words in the target sentence -- what the model will try to predict.
@@ -304,10 +304,7 @@ class PositionalEmbedding(layers.Layer):
304304
return embedded_tokens + embedded_positions
305305

306306
def compute_mask(self, inputs, mask=None):
307-
if mask is None:
308-
return None
309-
else:
310-
return ops.not_equal(inputs, 0)
307+
return ops.not_equal(inputs, 0)
311308

312309
def get_config(self):
313310
config = super().get_config()
@@ -344,24 +341,30 @@ class TransformerDecoder(layers.Layer):
344341
self.layernorm_3 = layers.LayerNormalization()
345342
self.supports_masking = True
346343

347-
def call(self, inputs, encoder_outputs, mask=None):
344+
def call(self, inputs, mask=None):
345+
inputs, encoder_outputs = inputs
348346
causal_mask = self.get_causal_attention_mask(inputs)
349-
if mask is not None:
350-
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
351-
padding_mask = ops.minimum(padding_mask, causal_mask)
347+
348+
if mask is None:
349+
inputs_padding_mask, encoder_outputs_padding_mask = None, None
352350
else:
353-
padding_mask = None
351+
inputs_padding_mask, encoder_outputs_padding_mask = mask
354352

355353
attention_output_1 = self.attention_1(
356-
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
354+
query=inputs,
355+
value=inputs,
356+
key=inputs,
357+
attention_mask=causal_mask,
358+
query_mask=inputs_padding_mask,
357359
)
358360
out_1 = self.layernorm_1(inputs + attention_output_1)
359361

360362
attention_output_2 = self.attention_2(
361363
query=out_1,
362364
value=encoder_outputs,
363365
key=encoder_outputs,
364-
attention_mask=padding_mask,
366+
query_mask=inputs_padding_mask,
367+
key_mask=encoder_outputs_padding_mask,
365368
)
366369
out_2 = self.layernorm_2(out_1 + attention_output_2)
367370

@@ -408,14 +411,15 @@ encoder = keras.Model(encoder_inputs, encoder_outputs)
408411
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
409412
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
410413
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
411-
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
414+
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
412415
x = layers.Dropout(0.5)(x)
413416
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
414417
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
415418

416-
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
417419
transformer = keras.Model(
418-
[encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
420+
{"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
421+
decoder_outputs,
422+
name="transformer",
419423
)
420424
```
421425

@@ -432,7 +436,9 @@ epochs = 1 # This should be at least 30 for convergence
432436

433437
transformer.summary()
434438
transformer.compile(
435-
"rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
439+
"rmsprop",
440+
loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
441+
metrics=["accuracy"],
436442
)
437443
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
438444
```
@@ -455,7 +461,12 @@ def decode_sequence(input_sentence):
455461
decoded_sentence = "[start]"
456462
for i in range(max_decoded_sentence_length):
457463
tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
458-
predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
464+
predictions = transformer(
465+
{
466+
"encoder_inputs": tokenized_input_sentence,
467+
"decoder_inputs": tokenized_target_sentence,
468+
}
469+
)
459470

460471
# ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
461472
sampled_token_index = ops.convert_to_numpy(

vignettes-src/examples/nlp/neural_machine_translation_with_transformer.Rmd

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: English-to-Spanish translation with a sequence-to-sequence Transformer
33
author: '[fchollet](https://twitter.com/fchollet), t-kalinowski'
44
date-created: 2021/05/26
5-
last-modified: 2023/02/25
5+
last-modified: 2024/11/18
66
description: Implementing a sequence-to-sequence Transformer and training it on a
77
machine translation task.
88
accelerator: GPU
@@ -199,11 +199,11 @@ using the source sentence and the target words from 1 to N.
199199

200200
As such, the training dataset will yield a tuple `(inputs, targets)`, where:
201201

202-
- `inputs` is a dictionary (named list) with the keys (names) `encoder_inputs` and `decoder_inputs`.
203-
`encoder_inputs` is the vectorized source sentence and `encoder_inputs` is the target sentence "so far",
204-
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
205-
- `target` is the target sentence offset by one step:
206-
it provides the next words in the target sentence -- what the model will try to predict.
202+
* `inputs` is a dictionary (named list) with the keys (names) `encoder_inputs` and `decoder_inputs`.
203+
`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
204+
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
205+
* `target` is the target sentence offset by one step:
206+
it provides the next words in the target sentence -- what the model will try to predict.
207207

208208
```{r}
209209
format_pair <- function(pair) {
@@ -347,29 +347,37 @@ layer_transformer_decoder <- Layer(
347347
repeats <- op_stack(c(batch_size, 1L, 1L))
348348
op_tile(mask[NULL, , ], repeats)
349349
},
350-
call = function(inputs, encoder_outputs, mask = NULL) {
351-
causal_mask <- self$get_causal_attention_mask(inputs)
352350
353-
if (is.null(mask))
354-
mask <- causal_mask
355-
else
356-
mask %<>% { op_minimum(op_cast(.[, NULL, ], "int32"),
357-
causal_mask) }
358-
359-
inputs %>%
360-
{ self$attention_1(query = ., value = ., key = .,
361-
attention_mask = causal_mask) + . } %>%
362-
self$layernorm_1() %>%
363-
364-
{ self$attention_2(query = .,
365-
value = encoder_outputs,
366-
key = encoder_outputs,
367-
attention_mask = mask) + . } %>%
368-
self$layernorm_2() %>%
369-
370-
{ self$dense_proj(.) + . } %>%
371-
self$layernorm_3()
351+
call = function(inputs, mask = NULL) {
352+
c(inputs_seq, encoder_outputs) %<-% inputs
353+
causal_mask <- self$get_causal_attention_mask(inputs_seq)
354+
355+
if (is.null(mask)) {
356+
inputs_padding_mask <- NULL
357+
encoder_outputs_padding_mask <- NULL
358+
} else {
359+
c(inputs_padding_mask, encoder_outputs_padding_mask) %<-% mask
360+
}
361+
362+
attention_output_1 <- self$attention_1(
363+
query = inputs_seq,
364+
value = inputs_seq,
365+
key = inputs_seq,
366+
attention_mask = causal_mask,
367+
query_mask = inputs_padding_mask
368+
)
369+
out_1 <- self$layernorm_1(inputs_seq + attention_output_1)
370+
371+
attention_output_2 <- self$attention_2(
372+
query = out_1,
373+
value = encoder_outputs,
374+
key = encoder_outputs,
375+
query_mask = inputs_padding_mask,
376+
key_mask = encoder_outputs_padding_mask
377+
)
378+
out_2 <- self$layernorm_2(out_1 + attention_output_2)
372379
380+
self$layernorm_3(out_2 + self$dense_proj(out_2))
373381
}
374382
)
375383
@@ -398,7 +406,6 @@ layer_positional_embedding <- Layer(
398406
},
399407
400408
compute_mask = function(inputs, mask = NULL) {
401-
if (is.null(mask)) return (NULL)
402409
inputs != 0L
403410
},
404411
@@ -437,21 +444,22 @@ transformer_decoder <- layer_transformer_decoder(NULL,
437444
438445
decoder_outputs <- decoder_inputs %>%
439446
layer_positional_embedding(sequence_length, vocab_size, embed_dim) %>%
440-
transformer_decoder(., encoded_seq_inputs) %>%
447+
{ transformer_decoder(list(., encoded_seq_inputs)) } %>%
441448
layer_dropout(0.5) %>%
442449
layer_dense(vocab_size, activation="softmax")
443450
444451
decoder <- keras_model(inputs = list(decoder_inputs, encoded_seq_inputs),
445452
outputs = decoder_outputs)
446453
447-
decoder_outputs = decoder(list(decoder_inputs, encoder_outputs))
454+
decoder_outputs <- decoder(list(decoder_inputs, encoder_outputs))
448455
449-
transformer <- keras_model(list(encoder_inputs, decoder_inputs),
450-
decoder_outputs,
451-
name = "transformer")
456+
transformer <- keras_model(
457+
inputs = list(encoder_inputs = encoder_inputs, decoder_inputs = decoder_inputs),
458+
outputs = decoder_outputs,
459+
name = "transformer"
460+
)
452461
```
453462

454-
455463
## Training our model
456464

457465
We'll use accuracy as a quick way to monitor training progress on the validation data.
@@ -466,15 +474,14 @@ epochs <- 1 # This should be at least 30 for convergence
466474
transformer
467475
transformer |> compile(
468476
"rmsprop",
469-
loss = "sparse_categorical_crossentropy",
477+
loss = loss_sparse_categorical_crossentropy(ignore_class = 0),
470478
metrics = "accuracy"
471479
)
472480
473481
transformer |> fit(train_ds, epochs = epochs,
474482
validation_data = val_ds)
475483
```
476484

477-
478485
## Decoding test sentences
479486

480487
Finally, let's demonstrate how to translate brand new English sentences.
@@ -500,8 +507,10 @@ tf_decode_sequence <- tf_function(function(input_sentence) {
500507
spa_vectorization(decoded_sentence)[,NA:-1]
501508
502509
next_token_predictions <-
503-
transformer(list(tokenized_input_sentence,
504-
tokenized_target_sentence))
510+
transformer(list(
511+
encoder_inputs = tokenized_input_sentence,
512+
decoder_inputs = tokenized_target_sentence
513+
))
505514
506515
sampled_token_index <- tf$argmax(next_token_predictions[0, i, ])
507516
sampled_token <- spa_vocab[sampled_token_index]
@@ -527,10 +536,8 @@ for (i in seq(20)) {
527536
cat(" Model Translation:", input_sentence %>% as_tensor() %>%
528537
tf_decode_sequence() %>% as.character(), "\n")
529538
}
530-
531539
```
532540

533-
534541
After 30 epochs, we get results such as:
535542

536543
```

0 commit comments

Comments
 (0)