Skip to content

Commit 68c9343

Browse files
authored
Merge pull request #1505 from rstudio/update-vignettes-v3.9.2
Update-vignettes-v3.9.2
2 parents 130ec29 + 2d1385d commit 68c9343

File tree

819 files changed

+8125
-5136
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

819 files changed

+8125
-5136
lines changed

.tether/vignettes-src/examples/nlp/neural_machine_translation_with_transformer.Rmd

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: English-to-Spanish translation with a sequence-to-sequence Transformer
33
author: '[fchollet](https://twitter.com/fchollet)'
44
date-created: 2021/05/26
5-
last-modified: 2023/02/25
5+
last-modified: 2024/11/18
66
description: Implementing a sequence-to-sequence Transformer and training it on a
77
machine translation task.
88
accelerator: GPU
@@ -174,7 +174,7 @@ using the source sentence and the target words 0 to N.
174174
As such, the training dataset will yield a tuple `(inputs, targets)`, where:
175175

176176
- `inputs` is a dictionary with the keys `encoder_inputs` and `decoder_inputs`.
177-
`encoder_inputs` is the vectorized source sentence and `encoder_inputs` is the target sentence "so far",
177+
`encoder_inputs` is the vectorized source sentence and `decoder_inputs` is the target sentence "so far",
178178
that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.
179179
- `target` is the target sentence offset by one step:
180180
it provides the next words in the target sentence -- what the model will try to predict.
@@ -304,10 +304,7 @@ class PositionalEmbedding(layers.Layer):
304304
return embedded_tokens + embedded_positions
305305

306306
def compute_mask(self, inputs, mask=None):
307-
if mask is None:
308-
return None
309-
else:
310-
return ops.not_equal(inputs, 0)
307+
return ops.not_equal(inputs, 0)
311308

312309
def get_config(self):
313310
config = super().get_config()
@@ -344,24 +341,30 @@ class TransformerDecoder(layers.Layer):
344341
self.layernorm_3 = layers.LayerNormalization()
345342
self.supports_masking = True
346343

347-
def call(self, inputs, encoder_outputs, mask=None):
344+
def call(self, inputs, mask=None):
345+
inputs, encoder_outputs = inputs
348346
causal_mask = self.get_causal_attention_mask(inputs)
349-
if mask is not None:
350-
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
351-
padding_mask = ops.minimum(padding_mask, causal_mask)
347+
348+
if mask is None:
349+
inputs_padding_mask, encoder_outputs_padding_mask = None, None
352350
else:
353-
padding_mask = None
351+
inputs_padding_mask, encoder_outputs_padding_mask = mask
354352

355353
attention_output_1 = self.attention_1(
356-
query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
354+
query=inputs,
355+
value=inputs,
356+
key=inputs,
357+
attention_mask=causal_mask,
358+
query_mask=inputs_padding_mask,
357359
)
358360
out_1 = self.layernorm_1(inputs + attention_output_1)
359361

360362
attention_output_2 = self.attention_2(
361363
query=out_1,
362364
value=encoder_outputs,
363365
key=encoder_outputs,
364-
attention_mask=padding_mask,
366+
query_mask=inputs_padding_mask,
367+
key_mask=encoder_outputs_padding_mask,
365368
)
366369
out_2 = self.layernorm_2(out_1 + attention_output_2)
367370

@@ -408,14 +411,15 @@ encoder = keras.Model(encoder_inputs, encoder_outputs)
408411
decoder_inputs = keras.Input(shape=(None,), dtype="int64", name="decoder_inputs")
409412
encoded_seq_inputs = keras.Input(shape=(None, embed_dim), name="decoder_state_inputs")
410413
x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(decoder_inputs)
411-
x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
414+
x = TransformerDecoder(embed_dim, latent_dim, num_heads)([x, encoder_outputs])
412415
x = layers.Dropout(0.5)(x)
413416
decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
414417
decoder = keras.Model([decoder_inputs, encoded_seq_inputs], decoder_outputs)
415418

416-
decoder_outputs = decoder([decoder_inputs, encoder_outputs])
417419
transformer = keras.Model(
418-
[encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
420+
{"encoder_inputs": encoder_inputs, "decoder_inputs": decoder_inputs},
421+
decoder_outputs,
422+
name="transformer",
419423
)
420424
```
421425

@@ -432,7 +436,9 @@ epochs = 1 # This should be at least 30 for convergence
432436

433437
transformer.summary()
434438
transformer.compile(
435-
"rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
439+
"rmsprop",
440+
loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=0),
441+
metrics=["accuracy"],
436442
)
437443
transformer.fit(train_ds, epochs=epochs, validation_data=val_ds)
438444
```
@@ -455,7 +461,12 @@ def decode_sequence(input_sentence):
455461
decoded_sentence = "[start]"
456462
for i in range(max_decoded_sentence_length):
457463
tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
458-
predictions = transformer([tokenized_input_sentence, tokenized_target_sentence])
464+
predictions = transformer(
465+
{
466+
"encoder_inputs": tokenized_input_sentence,
467+
"decoder_inputs": tokenized_target_sentence,
468+
}
469+
)
459470

460471
# ops.argmax(predictions[0, i, :]) is not a concrete value for jax here
461472
sampled_token_index = ops.convert_to_numpy(

.tether/vignettes-src/parked/_distributed_training_with_jax.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ optimizer.build(model.trainable_variables)
174174
# Keras provides a pure functional forward pass: model.stateless_call
175175
def compute_loss(trainable_variables, non_trainable_variables, x, y):
176176
y_pred, updated_non_trainable_variables = model.stateless_call(
177-
trainable_variables, non_trainable_variables, x
177+
trainable_variables, non_trainable_variables, x, training=True
178178
)
179179
loss_value = loss(y, y_pred)
180180
return loss_value, updated_non_trainable_variables

.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ variables.
175175
```python
176176
def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):
177177
y_pred, non_trainable_variables = model.stateless_call(
178-
trainable_variables, non_trainable_variables, x
178+
trainable_variables, non_trainable_variables, x, training=True
179179
)
180180
loss = loss_fn(y, y_pred)
181181
return loss, non_trainable_variables

R/install.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ use_backend <- function(backend, gpu = NA) {
226226
},
227227

228228
Linux_jax = {
229-
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
229+
py_require(c("tensorflow", "tensorflow[and-cuda]", "jax[cuda12]", "jax[cpu]"), action = "remove")
230230

231231
if (is.na(gpu))
232232
gpu <- has_gpu()

docs/dev/LICENSE-text.html

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/dev/articles/custom_train_step_in_tensorflow.html

Lines changed: 13 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)