|
512 | 512 | " @tf.function(\n",
|
513 | 513 | " input_signature=(\n",
|
514 | 514 | " tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),\n",
|
| 515 | + " tf.TensorSpec(shape=[2], dtype=tf.int32),\n", |
515 | 516 | " tf.TensorSpec(shape=[], dtype=tf.int32),)\n",
|
516 | 517 | " )\n",
|
517 |
| - " def __call__(self, img, tile_size=512):\n", |
| 518 | + " def __call__(self, img, img_size, tile_size=512):\n", |
518 | 519 | " shift, img_rolled = random_roll(img, tile_size)\n",
|
519 | 520 | "\n",
|
520 | 521 | " # Initialize the image gradients to zero.\n",
|
521 | 522 | " gradients = tf.zeros_like(img_rolled)\n",
|
522 | 523 | " \n",
|
523 | 524 | " # Skip the last tile, unless there's only one tile.\n",
|
524 |
| - " xs = tf.range(0, img_rolled.shape[0], tile_size)[:-1]\n", |
| 525 | + " xs = tf.range(0, img_size[1], tile_size)[:-1]\n", |
525 | 526 | " if not tf.cast(len(xs), bool):\n",
|
526 | 527 | " xs = tf.constant([0])\n",
|
527 |
| - " ys = tf.range(0, img_rolled.shape[1], tile_size)[:-1]\n", |
| 528 | + " ys = tf.range(0, img_size[0], tile_size)[:-1]\n", |
528 | 529 | " if not tf.cast(len(ys), bool):\n",
|
529 | 530 | " ys = tf.constant([0])\n",
|
530 | 531 | "\n",
|
|
537 | 538 | " tape.watch(img_rolled)\n",
|
538 | 539 | "\n",
|
539 | 540 | " # Extract a tile out of the image.\n",
|
540 |
| - " img_tile = img_rolled[x:x+tile_size, y:y+tile_size]\n", |
| 541 | + " img_tile = img_rolled[y:y+tile_size, x:x+tile_size]\n", |
541 | 542 | " loss = calc_loss(img_tile, self.model)\n",
|
542 | 543 | "\n",
|
543 | 544 | " # Update the image gradients for this tile.\n",
|
|
591 | 592 | " for octave in octaves:\n",
|
592 | 593 | " # Scale the image based on the octave\n",
|
593 | 594 | " new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)\n",
|
594 |
| - " img = tf.image.resize(img, tf.cast(new_size, tf.int32))\n", |
| 595 | + " new_size = tf.cast(new_size, tf.int32)\n", |
| 596 | + " img = tf.image.resize(img, new_size)\n", |
595 | 597 | "\n",
|
596 | 598 | " for step in range(steps_per_octave):\n",
|
597 |
| - " gradients = get_tiled_gradients(img)\n", |
| 599 | + " gradients = get_tiled_gradients(img, new_size)\n", |
598 | 600 | " img = img + gradients*step_size\n",
|
599 | 601 | " img = tf.clip_by_value(img, -1, 1)\n",
|
600 | 602 | "\n",
|
|
0 commit comments