diff --git a/src/upscaler.py b/src/upscaler.py
new file mode 100644
index 0000000..fa74165
--- /dev/null
+++ b/src/upscaler.py
@@ -0,0 +1,249 @@
+from realesrgan import RealESRGANer
+from basicsr.archs.rrdbnet_arch import RRDBNet
+import torch
+import PIL
+import numpy as np
+import contextlib
+from io import StringIO
+from tqdm.auto import tqdm
+import signal
+import requests
+import urllib.request
+import urllib.parse
+import os
+import re
+
+def download_file(
+ link: str,
+ path: str,
+ block_size: int = 1024,
+ force_download: bool = False,
+ progress: bool = True,
+ interrupt_check: bool = True,
+) -> str:
+ def truncate_string(string: str, length: int):
+ length -= 5 if length - 5 > 0 else 0
+ curr_len = len(string)
+ new_len = len(string[: length // 2] + "(...)" + string[-length // 2 :])
+ if new_len > curr_len:
+ return string
+ else:
+ return string[: length // 2] + "(...)" + string[-length // 2 :]
+
+ def remove_char(string: str, chars: list):
+ for char in chars:
+ string = string.replace(char, "")
+ return string
+
+ # source: https://github.com/wkentaro/gdown/blob/main/gdown/download.py
+ def google_drive_parse_url(url: str):
+ parsed = urllib.parse.urlparse(url)
+ query = urllib.parse.parse_qs(parsed.query)
+ is_gdrive = parsed.hostname in ["drive.google.com", "docs.google.com"]
+ is_download_link = parsed.path.endswith("/uc")
+
+ if not is_gdrive:
+ return is_gdrive, is_download_link
+
+ file_id = None
+ if "id" in query:
+ file_ids = query["id"]
+ if len(file_ids) == 1:
+ file_id = file_ids[0]
+ else:
+ patterns = [r"^/file/d/(.*?)/view$", r"^/presentation/d/(.*?)/edit$"]
+ for pattern in patterns:
+ match = re.match(pattern, parsed.path)
+ if match:
+ file_id = match.groups()[0]
+ break
+
+ return file_id, is_download_link
+
+ # source: https://github.com/wkentaro/gdown/blob/main/gdown/download.py
+ def get_url_from_gdrive_confirmation(contents: str):
+ url = ""
+ for line in contents.splitlines():
+ m = re.search(r'href="(/uc\?export=download[^"]+)', line)
+ if m:
+ url = "https://docs.google.com" + m.groups()[0]
+ url = url.replace("&", "&")
+ break
+ m = re.search('id="download-form" action="(.+?)"', line)
+ if m:
+ url = m.groups()[0]
+ url = url.replace("&", "&")
+ break
+ m = re.search('"downloadUrl":"([^"]+)', line)
+ if m:
+ url = m.groups()[0]
+ url = url.replace("\\u003d", "=")
+ url = url.replace("\\u0026", "&")
+ break
+ m = re.search('
(.*)
', line)
+ if m:
+ error = m.groups()[0]
+ raise RuntimeError(error)
+ if not url:
+ raise RuntimeError("Cannot retrieve the link of the file. ")
+ return url
+
+ def interrupt(*args):
+ if os.path.isfile(filepath):
+ os.remove(filepath)
+ raise KeyboardInterrupt
+
+ # create folder if not exists
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ # check if link is google drive link
+ if not google_drive_parse_url(link)[0]:
+ response = requests.get(link, stream=True, allow_redirects=True)
+ else:
+ if not google_drive_parse_url(link)[1]:
+ # convert to direct link
+ file_id = google_drive_parse_url(link)[0]
+ link = f"https://drive.google.com/uc?id={file_id}"
+ # test if redirect is needed
+ response = requests.get(link, stream=True, allow_redirects=True)
+ if response.headers.get("Content-Disposition") is None:
+ page = urllib.request.urlopen(link)
+ link = get_url_from_gdrive_confirmation(str(page.read()))
+ response = requests.get(link, stream=True, allow_redirects=True)
+
+ if response.status_code == 404:
+ raise FileNotFoundError(f"File not found at {link}")
+
+ # get filename
+ content_disposition = response.headers.get("Content-Disposition")
+ if content_disposition:
+ filename = re.findall(r"filename=(.*?)(?:[;\n]|$)", content_disposition)[0]
+ else:
+ filename = os.path.basename(link)
+
+ filename = remove_char(
+ filename, ["/", "\\", ":", "*", "?", '"', "'", "<", ">", "|", ";"]
+ )
+ filename = filename.replace(" ", "_")
+
+ filepath = os.path.join(path, filename)
+
+ # download file
+ if os.path.isfile(filepath) and not force_download:
+ print(f"{filename} already exists. Skipping download.")
+ else:
+ text = f"Downloading {truncate_string(filename, 50)}"
+ with open(filepath, "wb") as file:
+ total_size = int(response.headers.get("content-length", 0))
+ with tqdm(
+ total=total_size,
+ unit="B",
+ unit_scale=True,
+ desc=text,
+ unit_divisor=1024,
+ disable=not progress,
+ ) as pb:
+ if interrupt_check:
+ signal.signal(signal.SIGINT, lambda signum, frame: interrupt())
+ for data in response.iter_content(block_size):
+ pb.update(len(data))
+ file.write(data)
+ del response
+ return filename
+
+def factorize(num: int, max_value: int) -> list[float]:
+ result = []
+ while num > max_value:
+ result.append(max_value)
+ num /= max_value
+ result.append(round(num, 4))
+ return result
+
+def upscale(
+ img_list: list[PIL.Image.Image],
+ model_name: str = "RealESRGAN_x4plus_anime_6B",
+ scale_factor: float = 4,
+ half_precision: bool = False,
+ tile: int = 0,
+ tile_pad: int = 10,
+ pre_pad: int = 0,
+) -> list[PIL.Image.Image]:
+ # check model
+ if model_name == "RealESRGAN_x4plus":
+ upscale_model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=4,
+ )
+ netscale = 4
+ file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
+ elif model_name == "RealESRNet_x4plus":
+ upscale_model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=4,
+ )
+ netscale = 4
+ file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
+ elif model_name == "RealESRGAN_x4plus_anime_6B":
+ upscale_model = RRDBNet(
+ num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4
+ )
+ netscale = 4
+ file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
+ elif model_name == "RealESRGAN_x2plus":
+ upscale_model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=2,
+ )
+ netscale = 2
+ file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
+ else:
+ raise NotImplementedError("Model name not supported")
+
+ # download model
+ model_path = download_file(
+ file_url, path="./upscaler-model", progress=False, interrupt_check=False
+ )
+
+ # declare the upscaler
+ upsampler = RealESRGANer(
+ scale=netscale,
+ model_path=os.path.join("./upscaler-model", model_path),
+ dni_weight=None,
+ model=upscale_model,
+ tile=tile,
+ tile_pad=tile_pad,
+ pre_pad=pre_pad,
+ half=half_precision,
+ gpu_id=None,
+ )
+
+ # upscale
+ torch.cuda.empty_cache()
+ upscaled_imgs = []
+ with tqdm(total=len(img_list)) as pb:
+ for i, img in enumerate(img_list):
+ img = np.array(img)
+ outscale_list = factorize(scale_factor, netscale)
+ with contextlib.redirect_stdout(StringIO()):
+ for outscale in outscale_list:
+ curr_img = upsampler.enhance(img, outscale=outscale)[0]
+ img = curr_img
+ upscaled_imgs.append(PIL.Image.fromarray(img))
+
+ pb.update(1)
+ torch.cuda.empty_cache()
+
+ return upscaled_imgs
diff --git a/stable_diffusion_interactive_notebook_enhanced.ipynb b/stable_diffusion_interactive_notebook_enhanced.ipynb
new file mode 100644
index 0000000..9aed231
--- /dev/null
+++ b/stable_diffusion_interactive_notebook_enhanced.ipynb
@@ -0,0 +1,455 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wILzWiWRfTX8"
+ },
+ "source": [
+ "# Stable Diffusion Interactive Notebook 📓 🤖\n",
+ "\n",
+ "A widgets-based interactive notebook for Google Colab that lets users generate AI images from prompts (Text2Image) using [Stable Diffusion (by Stability AI, Runway & CompVis)](https://en.wikipedia.org/wiki/Stable_Diffusion).\n",
+ "\n",
+ "This notebook aims to be an alternative to WebUIs while offering a simple and lightweight GUI for anyone to get started with Stable Diffusion.\n",
+ "\n",
+ "Uses Stable Diffusion, [HuggingFace](https://huggingface.co/) Diffusers and [Jupyter widgets](https://github.com/jupyter-widgets/ipywidgets).\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Made with ❤️ by Nemo1166, origin by [redromnon](https://github.com/redromnon/stable-diffusion-interactive-notebook)\n",
+ "\n",
+ "To-do:\n",
+ "- Model settings\n",
+ " - [x] Add more sampler (scheduler)\n",
+ " - [x] Add support for loading model from single file (\\*.ckpt / \\*.safetensors)\n",
+ " - [ ] Add VAE selection\n",
+ " - [ ] Add support for loading LoRA(s)\n",
+ "- Text2image\n",
+ " - [x] Add CLIP skip\n",
+ "- Others\n",
+ " - [ ] Add Hires.fix (txt2img -> upscale -> img2img)\n",
+ " - [ ] Colab form-format input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vCR176NNfn0o"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 👇 Installing dependencies { display-mode: \"form\" }\n",
+ "#@markdown ---\n",
+ "#@markdown Make sure to select **GPU** as the runtime type:
\n",
+ "#@markdown *Runtime->Change Runtime Type->Under Hardware accelerator, select GPU*\n",
+ "#@markdown\n",
+ "#@markdown ---\n",
+ "\n",
+ "%pip -q install diffusers transformers accelerate scipy safetensors xformers mediapy ipywidgets==8.1.1 omegaconf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CV_UTS40oD1k"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 👇 Selecting Model { form-width: \"20%\", display-mode: \"form\" }\n",
+ "#@markdown ---\n",
+ "#@markdown - **Select Model** - A list of Stable Diffusion models to choose from.\n",
+ "#@markdown - **Link** - Load model from Stable-diffusion single file (support .ckpt or .safetensors).\n",
+ "#@markdown - **Selector** - Choose a pre-trained diffusers model.\n",
+ "#@markdown - **Safety Checker** - Enable/Disable uncensored content\n",
+ "#@markdown\n",
+ "#@markdown ---\n",
+ "\n",
+ "from diffusers import StableDiffusionPipeline\n",
+ "from diffusers.models import AutoencoderKL\n",
+ "import torch\n",
+ "import ipywidgets as widgets\n",
+ "import importlib\n",
+ "\n",
+ "#Enable third party widget support\n",
+ "from google.colab import output\n",
+ "output.enable_custom_widget_manager()\n",
+ "\n",
+ "#Pipe\n",
+ "pipe = None\n",
+ "\n",
+ "# Model source selector\n",
+ "model_src_switcher = widgets.Button(\n",
+ " description=\"Switch model source\",\n",
+ " button_style=\"info\"\n",
+ ")\n",
+ "\n",
+ "# Single model file src\n",
+ "model_src = widgets.Textarea(\n",
+ " value=\"https://huggingface.co/nemo1166/cetusmix-wf2/blob/main/cetusMix_Whalefall2.safetensors\",\n",
+ " placeholder=\"HTTP link to single file\",\n",
+ " description=\"Model link:\",\n",
+ " rows=1,\n",
+ " layout=widgets.Layout(width=\"600px\", visibility='hidden')\n",
+ ")\n",
+ "\n",
+ "#Models\n",
+ "select_model = widgets.Dropdown(\n",
+ " options=[\n",
+ " (\"Stable Diffusion 2.1 Base\" , \"stabilityai/stable-diffusion-2-1-base\"),\n",
+ " (\"Stable Diffusion 2.1\" , \"stabilityai/stable-diffusion-2-1\"),\n",
+ " (\"Stable Diffusion 1.5\", \"runwayml/stable-diffusion-v1-5\"),\n",
+ " (\"Dreamlike Photoreal 2.0\" , \"dreamlike-art/dreamlike-photoreal-2.0\"),\n",
+ " (\"OpenJourney v4\" , \"prompthero/openjourney-v4\")\n",
+ " ],\n",
+ " description=\"Select Model:\"\n",
+ ")\n",
+ "\n",
+ "#Safety Checker\n",
+ "safety_check = widgets.Checkbox(\n",
+ " value=True,\n",
+ " description=\"Enable Safety Check\",\n",
+ " layout=widgets.Layout(margin=\"0px 0px 0px -85px\")\n",
+ ")\n",
+ "\n",
+ "#Output\n",
+ "out = widgets.Output()\n",
+ "\n",
+ "#Apply Settings\n",
+ "apply_btn = widgets.Button(\n",
+ " description=\"Apply\",\n",
+ " button_style=\"info\"\n",
+ ")\n",
+ "\n",
+ "# Switch model loading method\n",
+ "use_SFM = False\n",
+ "def switch_model_src(p):\n",
+ " global use_SFM\n",
+ " out.clear_output()\n",
+ "\n",
+ " with out:\n",
+ " use_SFM = not use_SFM\n",
+ " if use_SFM:\n",
+ " select_model.layout.visibility = 'hidden'\n",
+ " model_src.layout.visibility = 'visible'\n",
+ " print(\"Mode: Loading model from single-file stable-diffusion model.\")\n",
+ " else:\n",
+ " select_model.layout.visibility = 'visible'\n",
+ " model_src.layout.visibility = 'hidden'\n",
+ " print(\"Mode: Loading model from pre-trained diffusers.\")\n",
+ "\n",
+ "#Run pipeline\n",
+ "def pipeline(p):\n",
+ "\n",
+ " global pipe\n",
+ " global use_SFM\n",
+ "\n",
+ " out.clear_output()\n",
+ " apply_btn.disabled = True\n",
+ "\n",
+ " with out:\n",
+ "\n",
+ " print(\"Running, please wait...\")\n",
+ "\n",
+ " if use_SFM:\n",
+ " try:\n",
+ " pipe = StableDiffusionPipeline.from_single_file(\n",
+ " model_src.value,\n",
+ " torch_dtype=torch.float16,\n",
+ " vae=AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\n",
+ " ).to(\"cuda\")\n",
+ " except:\n",
+ " raise Exception(\"Invalid model link.\")\n",
+ " else:\n",
+ " pipe = StableDiffusionPipeline.from_pretrained(\n",
+ " select_model.value,\n",
+ " torch_dtype=torch.float16,\n",
+ " vae=AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\n",
+ " ).to(\"cuda\")\n",
+ "\n",
+ " if not safety_check.value:\n",
+ " pipe.safety_checker = None\n",
+ "\n",
+ " # Optimization trickcs\n",
+ " pipe.enable_xformers_memory_efficient_attention()\n",
+ " pipe.enable_vae_tiling()\n",
+ "\n",
+ " print(\"Finished!\")\n",
+ "\n",
+ " apply_btn.disabled = False\n",
+ "\n",
+ "\n",
+ "#Display\n",
+ "apply_btn.on_click(pipeline)\n",
+ "model_src_switcher.on_click(switch_model_src)\n",
+ "\n",
+ "widgets.VBox(\n",
+ " [\n",
+ " widgets.HTML(value=\"Configure Pipeline
\"), model_src_switcher, model_src,\n",
+ " select_model, safety_check, apply_btn, out\n",
+ " ]\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-BjkIMgOnplH"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 👇 Model customisation { display-mode: \"form\" }\n",
+ "#@markdown ---\n",
+ "#@markdown - **Select Sampler** - A list of schedulers to choose from. Default is EulerAncestralScheduler.\n",
+ "#@markdown\n",
+ "#@markdown ---\n",
+ "\n",
+ "#Schedulers\n",
+ "select_sampler = widgets.RadioButtons(\n",
+ " options=[\n",
+ " 'Euler a', 'Euler', 'UniPC', 'DDIM', 'DDPM', 'PNDM', 'LMS',\n",
+ " 'DPM++2M Karras', 'DPM++2S', 'DPM2', 'DPM2 a',\n",
+ " 'Heun', 'UniPC',\n",
+ " ],\n",
+ " value='Euler a',\n",
+ " description=\"Select Schedular:\",\n",
+ " orientation='horizontal',\n",
+ " layout=widgets.Layout()\n",
+ ")\n",
+ "select_sampler.style.description_width = \"auto\"\n",
+ "\n",
+ "# button\n",
+ "submit = widgets.Button(\n",
+ " description=\"Submit\",\n",
+ " button_style=\"info\"\n",
+ ")\n",
+ "\n",
+ "#Output\n",
+ "customization = widgets.Output()\n",
+ "\n",
+ "#Get scheduler\n",
+ "def set_scheduler(name):\n",
+ " submit.disabled = True\n",
+ " global pipe\n",
+ " customization.clear_output()\n",
+ "\n",
+ " SCHEDULER_MAP ={\n",
+ " 'PNDM':'PNDMScheduler',\n",
+ " 'LMS':'LMSDiscretescheduler',\n",
+ " 'DDIM':'DDIMScheduler',\n",
+ " 'DDPM':'DDPMScheduler',\n",
+ " 'DPM++2M Karras':'DPMSolverMultistepScheduler',\n",
+ " 'DPM++2S':'DPMSolversinglestepscheduler',\n",
+ " 'DPM2':'KDPM2DiscreteScheduler',\n",
+ " 'DPM2 a':'KDPM2AncestralDiscretescheduler',\n",
+ " 'Euler':'EulerDiscreteScheduler',\n",
+ " 'Euler a':'EulerAncestralDiscreteScheduler',\n",
+ " 'Heun':'HeunDiscreteScheduler',\n",
+ " 'UniPC':'UniPCMultistepScheduler',\n",
+ " }\n",
+ "\n",
+ " with customization:\n",
+ " try:\n",
+ " selected_scheduler = SCHEDULER_MAP[select_sampler.value]\n",
+ " print(f\"Using Sampler: {select_sampler.value} ({selected_scheduler}).\")\n",
+ " exec(f\"from diffusers import {selected_scheduler}\")\n",
+ " pipe.scheduler = eval(f\"{selected_scheduler}.from_config(pipe.scheduler.config)\")\n",
+ " print(\"Done.\")\n",
+ " submit.disabled = False\n",
+ " except:\n",
+ " raise Exception(\"Sampler unavailable.\")\n",
+ "\n",
+ "submit.on_click(set_scheduler)\n",
+ "\n",
+ "widgets.VBox(\n",
+ " [\n",
+ " widgets.HTML(value=\"Customize Pipeline
\"),\n",
+ " select_sampler, submit, customization\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "atmx0PNQ78Wa"
+ },
+ "outputs": [],
+ "source": [
+ "#@title 👇 Generating Images { form-width: \"20%\", display-mode: \"form\" }\n",
+ "#@markdown ---\n",
+ "#@markdown - **Prompt** - Description of the image\n",
+ "#@markdown - **Negative Prompt** - Things you don't want to see or ignore in the image\n",
+ "#@markdown - **Steps** - Number of denoising steps. Higher steps may lead to better results but takes longer time to generate the image. Default is `30`.\n",
+ "#@markdown - **CFG** - Guidance scale ranging from `0` to `20`. Lower values allow the AI to be more creative and less strict at following the prompt. Default is `7.5`.\n",
+ "#@markdown - **Seed** - A random value that controls image generation. The same seed and prompt produce the same images. Set `-1` for using random seed values.\n",
+ "#@markdown ---\n",
+ "import ipywidgets as widgets, mediapy, random\n",
+ "import IPython.display\n",
+ "\n",
+ "\n",
+ "#PARAMETER WIDGETS\n",
+ "width = \"300px\"\n",
+ "\n",
+ "prompt = widgets.Textarea(\n",
+ " value=\"best quality, highly detailed, masterpiece, ultra-detailed, illustration\",\n",
+ " placeholder=\"Enter prompt\",\n",
+ " #description=\"Prompt:\",\n",
+ " rows=5,\n",
+ " layout=widgets.Layout(width=\"600px\")\n",
+ ")\n",
+ "\n",
+ "neg_prompt = widgets.Textarea(\n",
+ " value=\"lowres,bad anatomy,bad hands,text,error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,username,missing arms\",\n",
+ " placeholder=\"Enter negative prompt\",\n",
+ " #description=\"Negative Prompt:\",\n",
+ " rows=5,\n",
+ " layout=widgets.Layout(width=\"600px\")\n",
+ ")\n",
+ "\n",
+ "num_images = widgets.IntText(\n",
+ " value=1,\n",
+ " description=\"Images:\",\n",
+ " layout=widgets.Layout(width=width),\n",
+ ")\n",
+ "\n",
+ "clip_skip = widgets.IntSlider(\n",
+ " value=2,\n",
+ " min=0,\n",
+ " max=11,\n",
+ " step=1,\n",
+ " description=\"CLIP skip:\",\n",
+ " orientation='horizontal',\n",
+ " readout=True,\n",
+ " readout_format='d',\n",
+ " layout=widgets.Layout(width=width),\n",
+ ")\n",
+ "\n",
+ "steps = widgets.IntText(\n",
+ " value=20,\n",
+ " description=\"Steps:\",\n",
+ " layout=widgets.Layout(width=width)\n",
+ ")\n",
+ "\n",
+ "CFG = widgets.FloatText(\n",
+ " value=7.5,\n",
+ " description=\"CFG:\",\n",
+ " layout=widgets.Layout(width=width)\n",
+ ")\n",
+ "\n",
+ "img_height = widgets.IntSlider(\n",
+ " min=0,\n",
+ " max=1024,\n",
+ " step=8,\n",
+ " value=512,\n",
+ " description=\"Height:\",\n",
+ " orientation='horizontal',\n",
+ " readout=True,\n",
+ " readout_format='d',\n",
+ " layout=widgets.Layout(width=width)\n",
+ ")\n",
+ "\n",
+ "img_width = widgets.IntSlider(\n",
+ " min=0,\n",
+ " max=1024,\n",
+ " step=8,\n",
+ " value=512,\n",
+ " description=\"Width:\",\n",
+ " orientation='horizontal',\n",
+ " readout=True,\n",
+ " readout_format='d',\n",
+ " layout=widgets.Layout(width=width)\n",
+ ")\n",
+ "\n",
+ "random_seed = widgets.IntText(\n",
+ " value=-1,\n",
+ " description=\"Seed:\",\n",
+ " layout=widgets.Layout(width=width),\n",
+ " disabled=False\n",
+ ")\n",
+ "\n",
+ "generate = widgets.Button(\n",
+ " description=\"Generate\",\n",
+ " disabled=False,\n",
+ " button_style=\"primary\"\n",
+ ")\n",
+ "\n",
+ "display_imgs = widgets.Output()\n",
+ "\n",
+ "\n",
+ "#RUN\n",
+ "def generate_img(i):\n",
+ "\n",
+ " #Clear output\n",
+ " display_imgs.clear_output()\n",
+ " generate.disabled = True\n",
+ "\n",
+ " #Calculate seed\n",
+ " seed = random.randint(0, 2147483647) if random_seed.value == -1 else random_seed.value\n",
+ "\n",
+ " with display_imgs:\n",
+ "\n",
+ " print(\"Running...\")\n",
+ "\n",
+ " images = pipe(\n",
+ " prompt.value,\n",
+ " height = img_height.value,\n",
+ " width = img_width.value,\n",
+ " num_inference_steps = steps.value,\n",
+ " guidance_scale = CFG.value,\n",
+ " num_images_per_prompt = num_images.value,\n",
+ " negative_prompt = neg_prompt.value,\n",
+ " generator = torch.Generator(\"cuda\").manual_seed(seed),\n",
+ " clip_skip = clip_skip.value,\n",
+ " ).images\n",
+ " mediapy.show_images(images)\n",
+ "\n",
+ " print(f\"Seed:\\n{seed}\")\n",
+ "\n",
+ " generate.disabled = False\n",
+ "\n",
+ "#Display\n",
+ "generate.on_click(generate_img)\n",
+ "\n",
+ "widgets.VBox(\n",
+ " [\n",
+ " widgets.AppLayout(\n",
+ " header=widgets.HTML(\n",
+ " value=\"Stable Diffusion
\",\n",
+ " ),\n",
+ " left_sidebar=widgets.VBox(\n",
+ " [num_images, steps, CFG, img_height, img_width, clip_skip, random_seed]\n",
+ " ),\n",
+ " center=widgets.VBox(\n",
+ " [prompt, neg_prompt, generate]\n",
+ " ),\n",
+ " right_sidebar=None,\n",
+ " footer=None\n",
+ " ),\n",
+ " display_imgs\n",
+ " ]\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}