Skip to content

Commit 82e0ad2

Browse files
Merge pull request #1995 from dreavjr:deepdream-correct-shape
PiperOrigin-RevId: 421487418
2 parents c577477 + 9ae740a commit 82e0ad2

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

site/en/tutorials/generative/deepdream.ipynb

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,19 +512,20 @@
512512
" @tf.function(\n",
513513
" input_signature=(\n",
514514
" tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),\n",
515+
" tf.TensorSpec(shape=[2], dtype=tf.int32),\n",
515516
" tf.TensorSpec(shape=[], dtype=tf.int32),)\n",
516517
" )\n",
517-
" def __call__(self, img, tile_size=512):\n",
518+
" def __call__(self, img, img_size, tile_size=512):\n",
518519
" shift, img_rolled = random_roll(img, tile_size)\n",
519520
"\n",
520521
" # Initialize the image gradients to zero.\n",
521522
" gradients = tf.zeros_like(img_rolled)\n",
522523
" \n",
523524
" # Skip the last tile, unless there's only one tile.\n",
524-
" xs = tf.range(0, img_rolled.shape[0], tile_size)[:-1]\n",
525+
" xs = tf.range(0, img_size[1], tile_size)[:-1]\n",
525526
" if not tf.cast(len(xs), bool):\n",
526527
" xs = tf.constant([0])\n",
527-
" ys = tf.range(0, img_rolled.shape[1], tile_size)[:-1]\n",
528+
" ys = tf.range(0, img_size[0], tile_size)[:-1]\n",
528529
" if not tf.cast(len(ys), bool):\n",
529530
" ys = tf.constant([0])\n",
530531
"\n",
@@ -537,7 +538,7 @@
537538
" tape.watch(img_rolled)\n",
538539
"\n",
539540
" # Extract a tile out of the image.\n",
540-
" img_tile = img_rolled[x:x+tile_size, y:y+tile_size]\n",
541+
" img_tile = img_rolled[y:y+tile_size, x:x+tile_size]\n",
541542
" loss = calc_loss(img_tile, self.model)\n",
542543
"\n",
543544
" # Update the image gradients for this tile.\n",
@@ -591,10 +592,11 @@
591592
" for octave in octaves:\n",
592593
" # Scale the image based on the octave\n",
593594
" new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)\n",
594-
" img = tf.image.resize(img, tf.cast(new_size, tf.int32))\n",
595+
" new_size = tf.cast(new_size, tf.int32)\n",
596+
" img = tf.image.resize(img, new_size)\n",
595597
"\n",
596598
" for step in range(steps_per_octave):\n",
597-
" gradients = get_tiled_gradients(img)\n",
599+
" gradients = get_tiled_gradients(img, new_size)\n",
598600
" img = img + gradients*step_size\n",
599601
" img = tf.clip_by_value(img, -1, 1)\n",
600602
"\n",

0 commit comments

Comments
 (0)