@@ -50,13 +50,11 @@ We'll be working with an English-to-Spanish translation dataset
5050provided by [ Anki] ( https://www.manythings.org/anki/ ) . Let's download it:
5151
5252``` {r}
53- zipfile <- get_file("spa-eng.zip", origin =
54- "http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip")
53+ zip_path <-
54+ "http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip" |>
55+ get_file(origin = _, extract = TRUE)
5556
56- zip::zip_list(zipfile) # See what's in the zipfile
57- zip::unzip(zipfile, exdir = ".") # unzip into the current directory
58-
59- text_file <- fs::path("./spa-eng/spa.txt")
57+ text_path <- fs::path(zip_path, "spa-eng/spa.txt")
6058```
6159
6260## Parsing the data
@@ -209,36 +207,31 @@ it provides the next words in the target sentence -- what the model will try to
209207
210208``` {r}
211209format_pair <- function(pair) {
212- # the vectorization layers requrie batched inputs,
213- # reshape scalar string tensor to add a batch dim
214- pair %<>% lapply(op_expand_dims, 1)
215-
216- # vectorize
217- eng <- eng_vectorization(pair$english)
218- spa <- spa_vectorization(pair$spanish)
219-
220- # drop the batch dim
221- eng %<>% tf$ensure_shape(shape(1, sequence_length)) %>% op_squeeze(1)
222- spa %<>% tf$ensure_shape(shape(1, sequence_length+1)) %>% op_squeeze(1)
223-
224- inputs <- list(encoder_inputs = eng,
225- decoder_inputs = spa[NA:-2])
226- targets <- spa[2:NA]
227- list(inputs, targets)
228- }
210+ eng <- pair$english |> eng_vectorization()
211+ spa <- pair$spanish |> spa_vectorization()
212+
213+ spa_feature <- spa@r[NA:-2] # <1>
214+ spa_target <- spa@r[2:NA] # <2>
229215
216+ features <- list(encoder_inputs = eng, decoder_inputs = spa_feature)
217+ labels <- spa_target
218+ sample_weight <- labels != 0
219+
220+ tuple(features, labels, sample_weight)
221+ }
230222
231223batch_size <- 64
232224
233225library(tfdatasets, exclude = "shape")
234226make_dataset <- function(pairs) {
235- tensor_slices_dataset(pairs) %>%
236- dataset_map(format_pair, num_parallel_calls = 4) %>%
237- dataset_cache() %>%
238- dataset_shuffle(2048) %>%
239- dataset_batch(batch_size) %>%
240- dataset_prefetch(2 )
227+ tensor_slices_dataset(pairs) |>
228+ dataset_map(format_pair, num_parallel_calls = 4) |>
229+ dataset_cache() |>
230+ dataset_shuffle(2048) |>
231+ dataset_batch(batch_size) |>
232+ dataset_prefetch(16 )
241233}
234+
242235train_ds <- make_dataset(train_pairs)
243236val_ds <- make_dataset(val_pairs)
244237```
@@ -248,7 +241,7 @@ Let's take a quick look at the sequence shapes
248241(we have batches of 64 pairs, and all sequences are 20 steps long):
249242
250243``` {r}
251- c(inputs, targets) %<-% iter_next(as_iterator(train_ds))
244+ c(inputs, targets, weights ) %<-% iter_next(as_iterator(train_ds))
252245str(inputs)
253246str(targets)
254247```
@@ -346,7 +339,7 @@ layer_transformer_decoder <- Layer(
346339 get_causal_attention_mask = function(inputs) {
347340 c(batch_size, sequence_length, encoding_length) %<-% op_shape(inputs)
348341
349- x <- op_arange(sequence_length)
342+ x <- op_arange(0L, sequence_length, include_end = FALSE )
350343 i <- x[, NULL]
351344 j <- x[NULL, ]
352345 mask <- op_cast(i >= j, "int32")
@@ -398,7 +391,7 @@ layer_positional_embedding <- Layer(
398391
399392 call = function(inputs) {
400393 c(., len) %<-% op_shape(inputs) # (batch_size, seq_len)
401- positions <- op_arange(0, len, dtype = "int32")
394+ positions <- op_arange(0, len, dtype = "int32", include_end = FALSE )
402395 embedded_tokens <- self$token_embeddings(inputs)
403396 embedded_positions <- self$position_embeddings(positions)
404397 embedded_tokens + embedded_positions
@@ -476,6 +469,7 @@ transformer |> compile(
476469 loss = "sparse_categorical_crossentropy",
477470 metrics = "accuracy"
478471)
472+
479473transformer |> fit(train_ds, epochs = epochs,
480474 validation_data = val_ds)
481475```
@@ -544,3 +538,7 @@ English: I'm sure everything will be fine.
544538Correct Translation: [start] estoy segura de que todo irá bien. [end]
545539 Model Translation: [start] estoy seguro de que todo va bien [end]
546540```
541+ ``` {r}
542+
543+ ```
544+
0 commit comments