Skip to content

Commit a411196

Browse files
committed
minor fixes
1 parent dd57faa commit a411196

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

vignettes/examples/nmt_attention.R

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ library(tibble)
2121
# Assumes you've downloaded and unzipped one of the bilingual datasets offered at
2222
# http://www.manythings.org/anki/ and put it into a directory "data"
2323
# This example translates English to Dutch.
24+
download_data = function(){
25+
if(!dir.exists('data')) {
26+
dir.create('data')
27+
}
28+
if(!file.exists('data/nld-eng.zip')) {
29+
download.file('http://www.manythings.org/anki/nld-eng.zip',
30+
destfile = file.path("data", basename('nld-eng.zip')))
31+
unzip('data/nld-eng.zip', exdir = 'data')
32+
}
33+
}
34+
download_data()
2435

2536
filepath <- file.path("data", "nld.txt")
2637

@@ -290,7 +301,7 @@ evaluate <-
290301
attention_matrix[t,] <- attention_weights %>% as.double()
291302

292303
pred_idx <-
293-
tf$compat$v1$multinomial(k_exp(preds), num_samples = 1L)[1, 1] %>% as.double()
304+
tf$random$categorical(k_exp(preds), num_samples = 1L)[1, 1] %>% as.double()
294305
pred_word <- index2word(pred_idx, target_index)
295306

296307
if (pred_word == '<stop>') {
@@ -387,7 +398,7 @@ for (epoch in seq_len(n_epochs)) {
387398
": ",
388399
(loss / k_cast_to_floatx(dim(y)[2])) %>% as.double() %>% round(4),
389400
"\n"
390-
) %>% print()
401+
) %>% cat()
391402

392403
variables <- c(encoder$variables, decoder$variables)
393404
gradients <- tape$gradient(loss, variables)
@@ -402,7 +413,7 @@ for (epoch in seq_len(n_epochs)) {
402413
": ",
403414
(total_loss / k_cast_to_floatx(buffer_size)) %>% as.double() %>% round(4),
404415
"\n"
405-
) %>% print()
416+
) %>% cat()
406417

407418
walk(train_sentences[1:5], function(pair)
408419
translate(pair[1]))

0 commit comments

Comments
 (0)