Skip to content

Commit bdf93b0

Browse files
committed
reknit all guides and examples
1 parent 43e7094 commit bdf93b0

35 files changed

+882
-812
lines changed

vignettes/custom_train_step_in_tensorflow.Rmd

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ model |> fit(x, y, epochs = 3)
127127

128128
```
129129
## Epoch 1/3
130-
## 32/32 - 1s - 29ms/step - mae: 1.4339 - loss: 3.2271
130+
## 32/32 - 1s - 23ms/step - mae: 1.4339 - loss: 3.2271
131131
## Epoch 2/3
132132
## 32/32 - 0s - 2ms/step - mae: 1.3605 - loss: 2.9034
133133
## Epoch 3/3
@@ -282,11 +282,11 @@ model |> fit(x, y, sample_weight = sw, epochs = 3)
282282

283283
```
284284
## Epoch 1/3
285-
## 32/32 - 1s - 28ms/step - mae: 1.3434 - loss: 0.1681
285+
## 32/32 - 1s - 23ms/step - mae: 1.3434 - loss: 0.1681
286286
## Epoch 2/3
287-
## 32/32 - 0s - 3ms/step - mae: 1.3364 - loss: 0.1394
287+
## 32/32 - 0s - 2ms/step - mae: 1.3364 - loss: 0.1394
288288
## Epoch 3/3
289-
## 32/32 - 0s - 3ms/step - mae: 1.3286 - loss: 0.1148
289+
## 32/32 - 0s - 4ms/step - mae: 1.3286 - loss: 0.1148
290290
```
291291

292292
## Providing your own evaluation step
@@ -332,7 +332,7 @@ model |> evaluate(x, y)
332332
```
333333

334334
```
335-
## 32/32 - 0s - 9ms/step - mae: 1.3871 - loss: 0.0000e+00
335+
## 32/32 - 0s - 10ms/step - mae: 1.3871 - loss: 0.0000e+00
336336
```
337337

338338
```
@@ -508,7 +508,7 @@ gan |> fit(
508508
```
509509

510510
```
511-
## 100/100 - 6s - 57ms/step - d_loss: 0.0000e+00 - g_loss: 0.0000e+00
511+
## 100/100 - 5s - 53ms/step - d_loss: 0.0000e+00 - g_loss: 0.0000e+00
512512
```
513513

514514
The ideas behind deep learning are simple, so why should their implementation be painful?

vignettes/distributed_training_with_tensorflow.Rmd

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,16 @@ get_compiled_model <- function() {
124124
model |> compile(
125125
optimizer = optimizer_adam(),
126126
loss = loss_sparse_categorical_crossentropy(from_logits = TRUE),
127-
metrics = list(metric_sparse_categorical_accuracy()),
128-
129-
# XLA compilation is temporarily disabled due to a bug
130-
# https://github.com/keras-team/keras/issues/19005
131-
jit_compile = FALSE
127+
metrics = list(metric_sparse_categorical_accuracy())
132128
)
133129
model
134130
}
135131

136132
get_dataset <- function(batch_size = 64) {
137133

138134
c(c(x_train, y_train), c(x_test, y_test)) %<-% dataset_mnist()
139-
x_train <- array_reshape(x_train, c(-1, 784))
140-
x_test <- array_reshape(x_test, c(-1, 784))
135+
x_train <- array_reshape(x_train, c(-1, 784)) / 255
136+
x_test <- array_reshape(x_test, c(-1, 784)) / 255
141137

142138
# Reserve 10,000 samples for validation.
143139
val_i <- sample.int(nrow(x_train), 10000)
@@ -146,18 +142,25 @@ get_dataset <- function(batch_size = 64) {
146142
x_train = x_train[-val_i,]
147143
y_train = y_train[-val_i]
148144

145+
y_train <- array_reshape(y_train, c(-1, 1))
146+
y_val <- array_reshape(y_val, c(-1, 1))
147+
y_test <- array_reshape(y_test, c(-1, 1))
148+
149149
# Prepare the training dataset.
150150
train_dataset <- list(x_train, y_train) |>
151+
lapply(np_array, "float32") |>
151152
tensor_slices_dataset() |>
152153
dataset_batch(batch_size)
153154

154155
# Prepare the validation dataset.
155156
val_dataset <- list(x_val, y_val) |>
157+
lapply(np_array, "float32") |>
156158
tensor_slices_dataset() |>
157159
dataset_batch(batch_size)
158160

159161
# Prepare the test dataset.
160162
test_dataset <- list(x_test, y_test) |>
163+
lapply(np_array, "float32") |>
161164
tensor_slices_dataset() |>
162165
dataset_batch(batch_size)
163166

@@ -193,18 +196,18 @@ with(strategy$scope(), {
193196

194197
```
195198
## Epoch 1/2
196-
## 782/782 - 4s - 6ms/step - loss: 2.1409 - sparse_categorical_accuracy: 0.8896 - val_loss: 0.7223 - val_sparse_categorical_accuracy: 0.9216
199+
## 782/782 - 7s - 9ms/step - loss: nan - sparse_categorical_accuracy: nan - val_loss: nan - val_sparse_categorical_accuracy: nan
197200
## Epoch 2/2
198-
## 782/782 - 3s - 4ms/step - loss: 0.4292 - sparse_categorical_accuracy: 0.9387 - val_loss: 0.3693 - val_sparse_categorical_accuracy: 0.9404
199-
## 157/157 - 0s - 2ms/step - loss: 0.3976 - sparse_categorical_accuracy: 0.9386
201+
## 782/782 - 5s - 7ms/step - loss: nan - sparse_categorical_accuracy: nan - val_loss: nan - val_sparse_categorical_accuracy: nan
202+
## 157/157 - 1s - 5ms/step - loss: nan - sparse_categorical_accuracy: nan
200203
```
201204

202205
```
203206
## $loss
204-
## [1] 0.3976028
207+
## [1] NaN
205208
##
206209
## $sparse_categorical_accuracy
207-
## [1] 0.9386
210+
## [1] NaN
208211
```
209212

210213
## Using callbacks to ensure fault tolerance
@@ -274,7 +277,7 @@ run_training(epochs = 1)
274277
```
275278

276279
```
277-
## 782/782 - 4s - 5ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9627 - val_loss: 0.2062 - val_sparse_categorical_accuracy: 0.9560
280+
## 782/782 - 5s - 7ms/step - loss: nan - sparse_categorical_accuracy: nan - val_loss: nan - val_sparse_categorical_accuracy: nan
278281
```
279282

280283
``` r
@@ -283,7 +286,7 @@ run_training(epochs = 1)
283286
```
284287

285288
```
286-
## 782/782 - 4s - 5ms/step - loss: 0.1227 - sparse_categorical_accuracy: 0.9673 - val_loss: 0.2007 - val_sparse_categorical_accuracy: 0.9602
289+
## 782/782 - 6s - 7ms/step - loss: nan - sparse_categorical_accuracy: nan - val_loss: nan - val_sparse_categorical_accuracy: nan
287290
```
288291

289292
## `tf$data` performance tips

vignettes/distribution.Rmd

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Sys.setenv("XLA_FLAGS" = "--xla_force_host_platform_device_count=8")
5656
library(keras3)
5757

5858
# The distribution API is only implemented for the JAX backend for now.
59-
use_backend("jax")
59+
use_backend("jax", FALSE)
6060
jax <- reticulate::import("jax")
6161

6262
library(tfdatasets, exclude = "shape") # For dataset input.
@@ -184,24 +184,24 @@ model |> fit(dataset, epochs = 3)
184184

185185
```
186186
## Epoch 1/3
187-
## 8/8 - 0s - 38ms/step - loss: 1.1533
187+
## 8/8 - 0s - 40ms/step - loss: 1.1536
188188
## Epoch 2/3
189-
## 8/8 - 0s - 5ms/step - loss: 1.0621
189+
## 8/8 - 0s - 5ms/step - loss: 1.0540
190190
## Epoch 3/3
191-
## 8/8 - 0s - 7ms/step - loss: 1.0163
191+
## 8/8 - 0s - 6ms/step - loss: 1.0072
192192
```
193193

194194
``` r
195195
model |> evaluate(dataset)
196196
```
197197

198198
```
199-
## 8/8 - 0s - 7ms/step - loss: 0.9673
199+
## 8/8 - 0s - 9ms/step - loss: 0.9620
200200
```
201201

202202
```
203203
## $loss
204-
## [1] 0.9673058
204+
## [1] 0.9620273
205205
```
206206

207207

@@ -269,7 +269,85 @@ outputs <- inputs |>
269269
name = "d2")
270270

271271
model <- keras_model(inputs = inputs, outputs = outputs)
272+
```
273+
274+
We can visualize how individual weights will be sharded
275+
276+
``` r
277+
d1 <- get_layer(model, "d1")
278+
d1$kernel$value |> jax$debug$visualize_array_sharding()
279+
```
280+
281+
```
282+
## ┌───────┬───────┬───────┬───────┐
283+
## │ │ │ │ │
284+
## │ │ │ │ │
285+
## │ │ │ │ │
286+
## │ │ │ │ │
287+
## │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
288+
## │ │ │ │ │
289+
## │ │ │ │ │
290+
## │ │ │ │ │
291+
## │ │ │ │ │
292+
## └───────┴───────┴───────┴───────┘
293+
```
294+
295+
``` r
296+
d2 <- get_layer(model, "d2")
297+
d2$kernel$value |> jax$debug$visualize_array_sharding()
298+
```
299+
300+
```
301+
## ┌───────────────────┐
302+
## │ │
303+
## │ │
304+
## │ │
305+
## │ │
306+
## │CPU 0,1,2,3,4,5,6,7│
307+
## │ │
308+
## │ │
309+
## │ │
310+
## │ │
311+
## └───────────────────┘
312+
```
313+
314+
``` r
315+
d2$bias$value |> jax$debug$visualize_array_sharding()
316+
```
272317

318+
```
319+
## ┌───────────────────┐
320+
## │CPU 0,1,2,3,4,5,6,7│
321+
## └───────────────────┘
322+
```
323+
324+
``` r
325+
x_batch <- dataset |>
326+
as_iterator() |> iter_next() |>
327+
_[[1]] |> op_convert_to_tensor()
328+
329+
output_array <- model(x_batch)
330+
output_array |> jax$debug$visualize_array_sharding()
331+
```
332+
333+
```
334+
## ┌─────────────┐
335+
## │ │
336+
## │ CPU 0,1,2,3 │
337+
## │ │
338+
## │ │
339+
## ├─────────────┤
340+
## │ │
341+
## │ CPU 4,5,6,7 │
342+
## │ │
343+
## │ │
344+
## └─────────────┘
345+
```
346+
347+
348+
349+
350+
``` r
273351
# The data will be sharded across the "data" dimension of the method, which
274352
# has 2 devices.
275353
model |> compile(loss = "mse")
@@ -278,24 +356,24 @@ model |> fit(dataset, epochs = 3)
278356

279357
```
280358
## Epoch 1/3
281-
## 8/8 - 0s - 42ms/step - loss: 1.1424
359+
## 8/8 - 0s - 46ms/step - loss: 1.1676
282360
## Epoch 2/3
283-
## 8/8 - 0s - 7ms/step - loss: 1.0528
361+
## 8/8 - 0s - 4ms/step - loss: 1.1134
284362
## Epoch 3/3
285-
## 8/8 - 0s - 7ms/step - loss: 1.0393
363+
## 8/8 - 0s - 5ms/step - loss: 1.1034
286364
```
287365

288366
``` r
289367
model |> evaluate(dataset)
290368
```
291369

292370
```
293-
## 8/8 - 0s - 9ms/step - loss: 1.0088
371+
## 8/8 - 0s - 8ms/step - loss: 1.0676
294372
```
295373

296374
```
297375
## $loss
298-
## [1] 1.008847
376+
## [1] 1.067567
299377
```
300378

301379

@@ -335,3 +413,5 @@ full_model_parallel_mesh <- keras$distribution$DeviceMesh(
335413
3. [TensorFlow Distributed training with DTensors](https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial)
336414
4. [TensorFlow DTensor concepts](https://www.tensorflow.org/guide/dtensor_overview)
337415
5. [Using DTensors with tf.keras](https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial)
416+
417+

vignettes/examples/index.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
title: Keras examples
33
output: rmarkdown::html_vignette
4-
date: 'Last Modified: 2023-11-30; Last Rendered: 2025-01-23'
4+
date: 'Last Modified: 2023-11-30; Last Rendered: 2025-05-02'
55
vignette: >
66
%\VignetteIndexEntry{Keras examples}
77
%\VignetteEngine{knitr::rmarkdown}

0 commit comments

Comments
 (0)