diff --git a/src/upscaler.py b/src/upscaler.py new file mode 100644 index 0000000..fa74165 --- /dev/null +++ b/src/upscaler.py @@ -0,0 +1,249 @@ +from realesrgan import RealESRGANer +from basicsr.archs.rrdbnet_arch import RRDBNet +import torch +import PIL +import numpy as np +import contextlib +from io import StringIO +from tqdm.auto import tqdm +import signal +import requests +import urllib.request +import urllib.parse +import os +import re + +def download_file( + link: str, + path: str, + block_size: int = 1024, + force_download: bool = False, + progress: bool = True, + interrupt_check: bool = True, +) -> str: + def truncate_string(string: str, length: int): + length -= 5 if length - 5 > 0 else 0 + curr_len = len(string) + new_len = len(string[: length // 2] + "(...)" + string[-length // 2 :]) + if new_len > curr_len: + return string + else: + return string[: length // 2] + "(...)" + string[-length // 2 :] + + def remove_char(string: str, chars: list): + for char in chars: + string = string.replace(char, "") + return string + + # source: https://github.com/wkentaro/gdown/blob/main/gdown/download.py + def google_drive_parse_url(url: str): + parsed = urllib.parse.urlparse(url) + query = urllib.parse.parse_qs(parsed.query) + is_gdrive = parsed.hostname in ["drive.google.com", "docs.google.com"] + is_download_link = parsed.path.endswith("/uc") + + if not is_gdrive: + return is_gdrive, is_download_link + + file_id = None + if "id" in query: + file_ids = query["id"] + if len(file_ids) == 1: + file_id = file_ids[0] + else: + patterns = [r"^/file/d/(.*?)/view$", r"^/presentation/d/(.*?)/edit$"] + for pattern in patterns: + match = re.match(pattern, parsed.path) + if match: + file_id = match.groups()[0] + break + + return file_id, is_download_link + + # source: https://github.com/wkentaro/gdown/blob/main/gdown/download.py + def get_url_from_gdrive_confirmation(contents: str): + url = "" + for line in contents.splitlines(): + m = re.search(r'href="(/uc\?export=download[^"]+)', line) + if m: + url = "https://docs.google.com" + m.groups()[0] + url = url.replace("&", "&") + break + m = re.search('id="download-form" action="(.+?)"', line) + if m: + url = m.groups()[0] + url = url.replace("&", "&") + break + m = re.search('"downloadUrl":"([^"]+)', line) + if m: + url = m.groups()[0] + url = url.replace("\\u003d", "=") + url = url.replace("\\u0026", "&") + break + m = re.search('

(.*)

', line) + if m: + error = m.groups()[0] + raise RuntimeError(error) + if not url: + raise RuntimeError("Cannot retrieve the link of the file. ") + return url + + def interrupt(*args): + if os.path.isfile(filepath): + os.remove(filepath) + raise KeyboardInterrupt + + # create folder if not exists + if not os.path.exists(path): + os.makedirs(path) + + # check if link is google drive link + if not google_drive_parse_url(link)[0]: + response = requests.get(link, stream=True, allow_redirects=True) + else: + if not google_drive_parse_url(link)[1]: + # convert to direct link + file_id = google_drive_parse_url(link)[0] + link = f"https://drive.google.com/uc?id={file_id}" + # test if redirect is needed + response = requests.get(link, stream=True, allow_redirects=True) + if response.headers.get("Content-Disposition") is None: + page = urllib.request.urlopen(link) + link = get_url_from_gdrive_confirmation(str(page.read())) + response = requests.get(link, stream=True, allow_redirects=True) + + if response.status_code == 404: + raise FileNotFoundError(f"File not found at {link}") + + # get filename + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename = re.findall(r"filename=(.*?)(?:[;\n]|$)", content_disposition)[0] + else: + filename = os.path.basename(link) + + filename = remove_char( + filename, ["/", "\\", ":", "*", "?", '"', "'", "<", ">", "|", ";"] + ) + filename = filename.replace(" ", "_") + + filepath = os.path.join(path, filename) + + # download file + if os.path.isfile(filepath) and not force_download: + print(f"{filename} already exists. Skipping download.") + else: + text = f"Downloading {truncate_string(filename, 50)}" + with open(filepath, "wb") as file: + total_size = int(response.headers.get("content-length", 0)) + with tqdm( + total=total_size, + unit="B", + unit_scale=True, + desc=text, + unit_divisor=1024, + disable=not progress, + ) as pb: + if interrupt_check: + signal.signal(signal.SIGINT, lambda signum, frame: interrupt()) + for data in response.iter_content(block_size): + pb.update(len(data)) + file.write(data) + del response + return filename + +def factorize(num: int, max_value: int) -> list[float]: + result = [] + while num > max_value: + result.append(max_value) + num /= max_value + result.append(round(num, 4)) + return result + +def upscale( + img_list: list[PIL.Image.Image], + model_name: str = "RealESRGAN_x4plus_anime_6B", + scale_factor: float = 4, + half_precision: bool = False, + tile: int = 0, + tile_pad: int = 10, + pre_pad: int = 0, +) -> list[PIL.Image.Image]: + # check model + if model_name == "RealESRGAN_x4plus": + upscale_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + netscale = 4 + file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" + elif model_name == "RealESRNet_x4plus": + upscale_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + netscale = 4 + file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth" + elif model_name == "RealESRGAN_x4plus_anime_6B": + upscale_model = RRDBNet( + num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4 + ) + netscale = 4 + file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" + elif model_name == "RealESRGAN_x2plus": + upscale_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + netscale = 2 + file_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" + else: + raise NotImplementedError("Model name not supported") + + # download model + model_path = download_file( + file_url, path="./upscaler-model", progress=False, interrupt_check=False + ) + + # declare the upscaler + upsampler = RealESRGANer( + scale=netscale, + model_path=os.path.join("./upscaler-model", model_path), + dni_weight=None, + model=upscale_model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=half_precision, + gpu_id=None, + ) + + # upscale + torch.cuda.empty_cache() + upscaled_imgs = [] + with tqdm(total=len(img_list)) as pb: + for i, img in enumerate(img_list): + img = np.array(img) + outscale_list = factorize(scale_factor, netscale) + with contextlib.redirect_stdout(StringIO()): + for outscale in outscale_list: + curr_img = upsampler.enhance(img, outscale=outscale)[0] + img = curr_img + upscaled_imgs.append(PIL.Image.fromarray(img)) + + pb.update(1) + torch.cuda.empty_cache() + + return upscaled_imgs diff --git a/stable_diffusion_interactive_notebook_enhanced.ipynb b/stable_diffusion_interactive_notebook_enhanced.ipynb new file mode 100644 index 0000000..9aed231 --- /dev/null +++ b/stable_diffusion_interactive_notebook_enhanced.ipynb @@ -0,0 +1,455 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wILzWiWRfTX8" + }, + "source": [ + "# Stable Diffusion Interactive Notebook 📓 🤖\n", + "\n", + "A widgets-based interactive notebook for Google Colab that lets users generate AI images from prompts (Text2Image) using [Stable Diffusion (by Stability AI, Runway & CompVis)](https://en.wikipedia.org/wiki/Stable_Diffusion).\n", + "\n", + "This notebook aims to be an alternative to WebUIs while offering a simple and lightweight GUI for anyone to get started with Stable Diffusion.\n", + "\n", + "Uses Stable Diffusion, [HuggingFace](https://huggingface.co/) Diffusers and [Jupyter widgets](https://github.com/jupyter-widgets/ipywidgets).\n", + "\n", + "
\n", + "\n", + "Made with ❤️ by Nemo1166, origin by [redromnon](https://github.com/redromnon/stable-diffusion-interactive-notebook)\n", + "\n", + "To-do:\n", + "- Model settings\n", + " - [x] Add more sampler (scheduler)\n", + " - [x] Add support for loading model from single file (\\*.ckpt / \\*.safetensors)\n", + " - [ ] Add VAE selection\n", + " - [ ] Add support for loading LoRA(s)\n", + "- Text2image\n", + " - [x] Add CLIP skip\n", + "- Others\n", + " - [ ] Add Hires.fix (txt2img -> upscale -> img2img)\n", + " - [ ] Colab form-format input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vCR176NNfn0o" + }, + "outputs": [], + "source": [ + "#@title 👇 Installing dependencies { display-mode: \"form\" }\n", + "#@markdown ---\n", + "#@markdown Make sure to select **GPU** as the runtime type:
\n", + "#@markdown *Runtime->Change Runtime Type->Under Hardware accelerator, select GPU*\n", + "#@markdown\n", + "#@markdown ---\n", + "\n", + "%pip -q install diffusers transformers accelerate scipy safetensors xformers mediapy ipywidgets==8.1.1 omegaconf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CV_UTS40oD1k" + }, + "outputs": [], + "source": [ + "#@title 👇 Selecting Model { form-width: \"20%\", display-mode: \"form\" }\n", + "#@markdown ---\n", + "#@markdown - **Select Model** - A list of Stable Diffusion models to choose from.\n", + "#@markdown - **Link** - Load model from Stable-diffusion single file (support .ckpt or .safetensors).\n", + "#@markdown - **Selector** - Choose a pre-trained diffusers model.\n", + "#@markdown - **Safety Checker** - Enable/Disable uncensored content\n", + "#@markdown\n", + "#@markdown ---\n", + "\n", + "from diffusers import StableDiffusionPipeline\n", + "from diffusers.models import AutoencoderKL\n", + "import torch\n", + "import ipywidgets as widgets\n", + "import importlib\n", + "\n", + "#Enable third party widget support\n", + "from google.colab import output\n", + "output.enable_custom_widget_manager()\n", + "\n", + "#Pipe\n", + "pipe = None\n", + "\n", + "# Model source selector\n", + "model_src_switcher = widgets.Button(\n", + " description=\"Switch model source\",\n", + " button_style=\"info\"\n", + ")\n", + "\n", + "# Single model file src\n", + "model_src = widgets.Textarea(\n", + " value=\"https://huggingface.co/nemo1166/cetusmix-wf2/blob/main/cetusMix_Whalefall2.safetensors\",\n", + " placeholder=\"HTTP link to single file\",\n", + " description=\"Model link:\",\n", + " rows=1,\n", + " layout=widgets.Layout(width=\"600px\", visibility='hidden')\n", + ")\n", + "\n", + "#Models\n", + "select_model = widgets.Dropdown(\n", + " options=[\n", + " (\"Stable Diffusion 2.1 Base\" , \"stabilityai/stable-diffusion-2-1-base\"),\n", + " (\"Stable Diffusion 2.1\" , \"stabilityai/stable-diffusion-2-1\"),\n", + " (\"Stable Diffusion 1.5\", \"runwayml/stable-diffusion-v1-5\"),\n", + " (\"Dreamlike Photoreal 2.0\" , \"dreamlike-art/dreamlike-photoreal-2.0\"),\n", + " (\"OpenJourney v4\" , \"prompthero/openjourney-v4\")\n", + " ],\n", + " description=\"Select Model:\"\n", + ")\n", + "\n", + "#Safety Checker\n", + "safety_check = widgets.Checkbox(\n", + " value=True,\n", + " description=\"Enable Safety Check\",\n", + " layout=widgets.Layout(margin=\"0px 0px 0px -85px\")\n", + ")\n", + "\n", + "#Output\n", + "out = widgets.Output()\n", + "\n", + "#Apply Settings\n", + "apply_btn = widgets.Button(\n", + " description=\"Apply\",\n", + " button_style=\"info\"\n", + ")\n", + "\n", + "# Switch model loading method\n", + "use_SFM = False\n", + "def switch_model_src(p):\n", + " global use_SFM\n", + " out.clear_output()\n", + "\n", + " with out:\n", + " use_SFM = not use_SFM\n", + " if use_SFM:\n", + " select_model.layout.visibility = 'hidden'\n", + " model_src.layout.visibility = 'visible'\n", + " print(\"Mode: Loading model from single-file stable-diffusion model.\")\n", + " else:\n", + " select_model.layout.visibility = 'visible'\n", + " model_src.layout.visibility = 'hidden'\n", + " print(\"Mode: Loading model from pre-trained diffusers.\")\n", + "\n", + "#Run pipeline\n", + "def pipeline(p):\n", + "\n", + " global pipe\n", + " global use_SFM\n", + "\n", + " out.clear_output()\n", + " apply_btn.disabled = True\n", + "\n", + " with out:\n", + "\n", + " print(\"Running, please wait...\")\n", + "\n", + " if use_SFM:\n", + " try:\n", + " pipe = StableDiffusionPipeline.from_single_file(\n", + " model_src.value,\n", + " torch_dtype=torch.float16,\n", + " vae=AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\n", + " ).to(\"cuda\")\n", + " except:\n", + " raise Exception(\"Invalid model link.\")\n", + " else:\n", + " pipe = StableDiffusionPipeline.from_pretrained(\n", + " select_model.value,\n", + " torch_dtype=torch.float16,\n", + " vae=AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\n", + " ).to(\"cuda\")\n", + "\n", + " if not safety_check.value:\n", + " pipe.safety_checker = None\n", + "\n", + " # Optimization trickcs\n", + " pipe.enable_xformers_memory_efficient_attention()\n", + " pipe.enable_vae_tiling()\n", + "\n", + " print(\"Finished!\")\n", + "\n", + " apply_btn.disabled = False\n", + "\n", + "\n", + "#Display\n", + "apply_btn.on_click(pipeline)\n", + "model_src_switcher.on_click(switch_model_src)\n", + "\n", + "widgets.VBox(\n", + " [\n", + " widgets.HTML(value=\"

Configure Pipeline

\"), model_src_switcher, model_src,\n", + " select_model, safety_check, apply_btn, out\n", + " ]\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-BjkIMgOnplH" + }, + "outputs": [], + "source": [ + "#@title 👇 Model customisation { display-mode: \"form\" }\n", + "#@markdown ---\n", + "#@markdown - **Select Sampler** - A list of schedulers to choose from. Default is EulerAncestralScheduler.\n", + "#@markdown\n", + "#@markdown ---\n", + "\n", + "#Schedulers\n", + "select_sampler = widgets.RadioButtons(\n", + " options=[\n", + " 'Euler a', 'Euler', 'UniPC', 'DDIM', 'DDPM', 'PNDM', 'LMS',\n", + " 'DPM++2M Karras', 'DPM++2S', 'DPM2', 'DPM2 a',\n", + " 'Heun', 'UniPC',\n", + " ],\n", + " value='Euler a',\n", + " description=\"Select Schedular:\",\n", + " orientation='horizontal',\n", + " layout=widgets.Layout()\n", + ")\n", + "select_sampler.style.description_width = \"auto\"\n", + "\n", + "# button\n", + "submit = widgets.Button(\n", + " description=\"Submit\",\n", + " button_style=\"info\"\n", + ")\n", + "\n", + "#Output\n", + "customization = widgets.Output()\n", + "\n", + "#Get scheduler\n", + "def set_scheduler(name):\n", + " submit.disabled = True\n", + " global pipe\n", + " customization.clear_output()\n", + "\n", + " SCHEDULER_MAP ={\n", + " 'PNDM':'PNDMScheduler',\n", + " 'LMS':'LMSDiscretescheduler',\n", + " 'DDIM':'DDIMScheduler',\n", + " 'DDPM':'DDPMScheduler',\n", + " 'DPM++2M Karras':'DPMSolverMultistepScheduler',\n", + " 'DPM++2S':'DPMSolversinglestepscheduler',\n", + " 'DPM2':'KDPM2DiscreteScheduler',\n", + " 'DPM2 a':'KDPM2AncestralDiscretescheduler',\n", + " 'Euler':'EulerDiscreteScheduler',\n", + " 'Euler a':'EulerAncestralDiscreteScheduler',\n", + " 'Heun':'HeunDiscreteScheduler',\n", + " 'UniPC':'UniPCMultistepScheduler',\n", + " }\n", + "\n", + " with customization:\n", + " try:\n", + " selected_scheduler = SCHEDULER_MAP[select_sampler.value]\n", + " print(f\"Using Sampler: {select_sampler.value} ({selected_scheduler}).\")\n", + " exec(f\"from diffusers import {selected_scheduler}\")\n", + " pipe.scheduler = eval(f\"{selected_scheduler}.from_config(pipe.scheduler.config)\")\n", + " print(\"Done.\")\n", + " submit.disabled = False\n", + " except:\n", + " raise Exception(\"Sampler unavailable.\")\n", + "\n", + "submit.on_click(set_scheduler)\n", + "\n", + "widgets.VBox(\n", + " [\n", + " widgets.HTML(value=\"

Customize Pipeline

\"),\n", + " select_sampler, submit, customization\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "atmx0PNQ78Wa" + }, + "outputs": [], + "source": [ + "#@title 👇 Generating Images { form-width: \"20%\", display-mode: \"form\" }\n", + "#@markdown ---\n", + "#@markdown - **Prompt** - Description of the image\n", + "#@markdown - **Negative Prompt** - Things you don't want to see or ignore in the image\n", + "#@markdown - **Steps** - Number of denoising steps. Higher steps may lead to better results but takes longer time to generate the image. Default is `30`.\n", + "#@markdown - **CFG** - Guidance scale ranging from `0` to `20`. Lower values allow the AI to be more creative and less strict at following the prompt. Default is `7.5`.\n", + "#@markdown - **Seed** - A random value that controls image generation. The same seed and prompt produce the same images. Set `-1` for using random seed values.\n", + "#@markdown ---\n", + "import ipywidgets as widgets, mediapy, random\n", + "import IPython.display\n", + "\n", + "\n", + "#PARAMETER WIDGETS\n", + "width = \"300px\"\n", + "\n", + "prompt = widgets.Textarea(\n", + " value=\"best quality, highly detailed, masterpiece, ultra-detailed, illustration\",\n", + " placeholder=\"Enter prompt\",\n", + " #description=\"Prompt:\",\n", + " rows=5,\n", + " layout=widgets.Layout(width=\"600px\")\n", + ")\n", + "\n", + "neg_prompt = widgets.Textarea(\n", + " value=\"lowres,bad anatomy,bad hands,text,error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,username,missing arms\",\n", + " placeholder=\"Enter negative prompt\",\n", + " #description=\"Negative Prompt:\",\n", + " rows=5,\n", + " layout=widgets.Layout(width=\"600px\")\n", + ")\n", + "\n", + "num_images = widgets.IntText(\n", + " value=1,\n", + " description=\"Images:\",\n", + " layout=widgets.Layout(width=width),\n", + ")\n", + "\n", + "clip_skip = widgets.IntSlider(\n", + " value=2,\n", + " min=0,\n", + " max=11,\n", + " step=1,\n", + " description=\"CLIP skip:\",\n", + " orientation='horizontal',\n", + " readout=True,\n", + " readout_format='d',\n", + " layout=widgets.Layout(width=width),\n", + ")\n", + "\n", + "steps = widgets.IntText(\n", + " value=20,\n", + " description=\"Steps:\",\n", + " layout=widgets.Layout(width=width)\n", + ")\n", + "\n", + "CFG = widgets.FloatText(\n", + " value=7.5,\n", + " description=\"CFG:\",\n", + " layout=widgets.Layout(width=width)\n", + ")\n", + "\n", + "img_height = widgets.IntSlider(\n", + " min=0,\n", + " max=1024,\n", + " step=8,\n", + " value=512,\n", + " description=\"Height:\",\n", + " orientation='horizontal',\n", + " readout=True,\n", + " readout_format='d',\n", + " layout=widgets.Layout(width=width)\n", + ")\n", + "\n", + "img_width = widgets.IntSlider(\n", + " min=0,\n", + " max=1024,\n", + " step=8,\n", + " value=512,\n", + " description=\"Width:\",\n", + " orientation='horizontal',\n", + " readout=True,\n", + " readout_format='d',\n", + " layout=widgets.Layout(width=width)\n", + ")\n", + "\n", + "random_seed = widgets.IntText(\n", + " value=-1,\n", + " description=\"Seed:\",\n", + " layout=widgets.Layout(width=width),\n", + " disabled=False\n", + ")\n", + "\n", + "generate = widgets.Button(\n", + " description=\"Generate\",\n", + " disabled=False,\n", + " button_style=\"primary\"\n", + ")\n", + "\n", + "display_imgs = widgets.Output()\n", + "\n", + "\n", + "#RUN\n", + "def generate_img(i):\n", + "\n", + " #Clear output\n", + " display_imgs.clear_output()\n", + " generate.disabled = True\n", + "\n", + " #Calculate seed\n", + " seed = random.randint(0, 2147483647) if random_seed.value == -1 else random_seed.value\n", + "\n", + " with display_imgs:\n", + "\n", + " print(\"Running...\")\n", + "\n", + " images = pipe(\n", + " prompt.value,\n", + " height = img_height.value,\n", + " width = img_width.value,\n", + " num_inference_steps = steps.value,\n", + " guidance_scale = CFG.value,\n", + " num_images_per_prompt = num_images.value,\n", + " negative_prompt = neg_prompt.value,\n", + " generator = torch.Generator(\"cuda\").manual_seed(seed),\n", + " clip_skip = clip_skip.value,\n", + " ).images\n", + " mediapy.show_images(images)\n", + "\n", + " print(f\"Seed:\\n{seed}\")\n", + "\n", + " generate.disabled = False\n", + "\n", + "#Display\n", + "generate.on_click(generate_img)\n", + "\n", + "widgets.VBox(\n", + " [\n", + " widgets.AppLayout(\n", + " header=widgets.HTML(\n", + " value=\"

Stable Diffusion

\",\n", + " ),\n", + " left_sidebar=widgets.VBox(\n", + " [num_images, steps, CFG, img_height, img_width, clip_skip, random_seed]\n", + " ),\n", + " center=widgets.VBox(\n", + " [prompt, neg_prompt, generate]\n", + " ),\n", + " right_sidebar=None,\n", + " footer=None\n", + " ),\n", + " display_imgs\n", + " ]\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}