Skip to content

Commit 8d4e9af

Browse files
committed
retether vignettes
1 parent 6ff11d2 commit 8d4e9af

File tree

7 files changed

+12
-15
lines changed

7 files changed

+12
-15
lines changed

.tether/vignettes-src/distribution.Rmd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,7 @@ layout_map["d1/bias"] = ("model",)
188188
# You can also set the layout for the layer output like
189189
layout_map["d2/output"] = ("data", None)
190190

191-
model_parallel = keras.distribution.ModelParallel(
192-
mesh_2d, layout_map, batch_dim_name="data"
193-
)
191+
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")
194192

195193
keras.distribution.set_distribution(model_parallel)
196194

.tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: Customizing what happens in `fit()` with PyTorch
33
author: '[fchollet](https://twitter.com/fchollet)'
44
date-created: 2023/06/27
5-
last-modified: 2023/06/27
5+
last-modified: 2024/08/01
66
description: Overriding the training step of the Model class with PyTorch.
77
accelerator: GPU
88
output: rmarkdown::html_vignette
@@ -390,7 +390,7 @@ class GAN(keras.Model):
390390

391391
def train_step(self, real_images):
392392
device = "cuda" if torch.cuda.is_available() else "cpu"
393-
if isinstance(real_images, tuple):
393+
if isinstance(real_images, tuple) or isinstance(real_images, list):
394394
real_images = real_images[0]
395395
# Sample random points in the latent space
396396
batch_size = real_images.shape[0]

.tether/vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
153153
```
154154

155155
Once you have such a function, you can get the gradient function by
156-
specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss
156+
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
157157
computation function returns more outputs than just the loss. Note that the loss
158158
should always be the first output.
159159

.tether/vignettes-src/writing_your_own_callbacks.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
293293
# The epoch the training stops at.
294294
self.stopped_epoch = 0
295295
# Initialize the best as infinity.
296-
self.best = np.Inf
296+
self.best = np.inf
297297

298298
def on_epoch_end(self, epoch, logs=None):
299299
current = logs.get("loss")

vignettes-src/distribution.Rmd

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ mesh <- keras$distribution$DeviceMesh(
9393
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
9494
# devices on the mesh.
9595
layout_2d <- keras$distribution$TensorLayout(
96-
axes = c("model", "data"),
96+
axes = c("model", "data"),
9797
device_mesh = mesh
9898
)
9999
@@ -131,8 +131,8 @@ data_parallel <- keras$distribution$DataParallel(devices = devices)
131131
132132
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
133133
mesh_1d <- keras$distribution$DeviceMesh(
134-
shape = shape(8),
135-
axis_names = list("data"),
134+
shape = shape(8),
135+
axis_names = list("data"),
136136
devices = devices
137137
)
138138
data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d)
@@ -213,8 +213,7 @@ layout_map["d1/bias"] <- tuple("model")
213213
layout_map["d2/output"] <- tuple("data", NULL)
214214
215215
model_parallel <- keras$distribution$ModelParallel(
216-
layout_map = layout_map,
217-
batch_dim_name = "data"
216+
layout_map, batch_dim_name = "data"
218217
)
219218
220219
keras$distribution$set_distribution(model_parallel)

vignettes-src/parked/_custom_train_step_in_torch.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
title: Customizing what happens in `fit()` with PyTorch
33
author: '[fchollet](https://twitter.com/fchollet)'
44
date-created: 2023/06/27
5-
last-modified: 2023/06/27
5+
last-modified: 2024/08/01
66
description: Overriding the training step of the Model class with PyTorch.
77
accelerator: GPU
88
output: rmarkdown::html_vignette
@@ -390,7 +390,7 @@ class GAN(keras.Model):
390390

391391
def train_step(self, real_images):
392392
device = "cuda" if torch.cuda.is_available() else "cpu"
393-
if isinstance(real_images, tuple):
393+
if isinstance(real_images, tuple) or isinstance(real_images, list):
394394
real_images = real_images[0]
395395
# Sample random points in the latent space
396396
batch_size = real_images.shape[0]

vignettes-src/parked/_writing_a_custom_training_loop_in_jax.Rmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
153153
```
154154

155155
Once you have such a function, you can get the gradient function by
156-
specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss
156+
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
157157
computation function returns more outputs than just the loss. Note that the loss
158158
should always be the first output.
159159

0 commit comments

Comments
 (0)