diff --git a/krita_plugin/krita_diff/krita_diff.py b/krita_plugin/krita_diff/krita_diff.py index c8b39baf9..8ec50261b 100644 --- a/krita_plugin/krita_diff/krita_diff.py +++ b/krita_plugin/krita_diff/krita_diff.py @@ -3,15 +3,24 @@ import urllib.parse import urllib.request import json +from urllib import request, parse +import requests from krita import * default_url = "http://127.0.0.1:8000" +# get an OS-sane tmp dir +default_tmp_dir = os.path.join(os.path.expanduser("~"), "tmp") samplers = ["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'] samplers_img2img = ["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'] upscalers = ["None", "Lanczos"] face_restorers = ["None", "CodeFormer", "GFPGAN"] +realesrgan_models = ['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'] + +MODE_IMG2IMG = 0 +MODE_INPAINT = 1 +MODE_SD_UPSCALE = 2 class Script(QObject): def __init__(self): @@ -36,6 +45,7 @@ def restore_defaults(self, if_empty=False): self.set_cfg('png_quality', -1, if_empty) self.set_cfg('fix_aspect_ratio', True, if_empty) self.set_cfg('only_full_img_tiling', True, if_empty) + self.set_cfg('tmp_dir', default_tmp_dir, if_empty) self.set_cfg('face_restorer_model', face_restorers.index("CodeFormer"), if_empty) self.set_cfg('codeformer_weight', 0.5, if_empty) @@ -136,7 +146,7 @@ def txt2img(self): def img2img(self, path, mask_path, mode): tiling = self.cfg('txt2img_tiling', bool) - if mode == 2 or (self.cfg("only_full_img_tiling", bool) and self.selection is not None): + if mode == MODE_SD_UPSCALE or (self.cfg("only_full_img_tiling", bool) and self.selection is not None): tiling = False params = { @@ -258,7 +268,16 @@ def image_to_ba(self, image): ptr.setsize(image.byteCount()) return QByteArray(ptr.asstring()) - def insert_img(self, layer_name, path, visible=True): + def insert_img(self, layer_name, qimage, visible=True): + layer = self.create_layer(layer_name) + ba = self.image_to_ba(qimage) + + if not visible: + layer.setVisible(False) + layer.setPixelData(ba, self.x, self.y, self.width, self.height) + print(f"inserted image to layer: {layer}") + + def insert_img_from_path(self, layer_name, path, visible=True): image = QImage() image.load(path, "PNG") ba = self.image_to_ba(image) @@ -273,43 +292,92 @@ def insert_img(self, layer_name, path, visible=True): def apply_txt2img(self): response = self.txt2img() outputs = response['outputs'] - print(f"Getting images: {outputs}") + + print(f'Fetching remote images from server. Images: {outputs}') for i, output in enumerate(outputs): - self.insert_img(f"txt2img {i + 1}: {os.path.basename(output)}", output, i + 1 == len(outputs)) - self.clear_temp_images(outputs) + url = self.cfg('base_url', str) + '/result' + + data = json.dumps({'file_name': output}).encode('utf-8') + req = urllib.request.Request(url, data=data, headers={'content-type': 'application/json'}) + response = urllib.request.urlopen(req) + + image = QImage() + image.loadFromData(response.read(), "PNG") + + layer = self.create_layer(f"txt2img {i + 1}: {output}") + layer.setPixelData(self.image_to_ba(image), self.x, self.y, self.width, self.height) + self.doc.refreshProjection() def apply_img2img(self, mode): - path = self.opt['new_img'] - mask_path = self.opt['new_img_mask'] - self.save_img(path) - if mode == 1: + path = self.cfg('tmp_dir', str) + '/img2img.png' + mask_path = self.cfg('tmp_dir', str) + '/img2img_mask.png' + path_filename = os.path.basename(path) + mask_filename = os.path.basename(mask_path) + + if mode == MODE_INPAINT: + print(f'Saving mask locally: {mask_path}') self.save_img(mask_path, is_mask=True) + self.node.setVisible(False) + self.doc.refreshProjection() + self.save_img(path) + + response = requests.request("POST", self.cfg('base_url', str) + '/saveimg', files=[('file',(path_filename,open(path,'rb'),'image/png'))]) + if response.status_code != 200: + print(f"Error while uploading image to server: {response.status_code}") + return - response = self.img2img(path, mask_path, mode) + if mode == MODE_INPAINT: + response = requests.request("POST", self.cfg('base_url', str) + '/saveimg', files=[('file',(mask_filename,open(mask_path,'rb'),'image/png'))]) + if response.status_code != 200: + print(f"Error while uploading image to server: {response.status_code}") + return + + response = self.img2img(path_filename, mask_filename, mode) outputs = response['outputs'] + print(f"Getting images: {outputs}") - layer_name_prefix = "inpaint" if mode == 1 else "sd upscale" if mode == 2 else "img2img" + layer_name_prefix = "inpaint" if mode == MODE_INPAINT else "sd upscale" if mode == MODE_SD_UPSCALE else "img2img" for i, output in enumerate(outputs): - self.insert_img(f"{layer_name_prefix} {i + 1}: {os.path.basename(output)}", output, i + 1 == len(outputs)) + req = urllib.request.Request(url = self.cfg('base_url', str) + '/result', + data=json.dumps({'file_name': os.path.basename(output)}).encode('utf-8'), + headers={'content-type': 'application/json'}) + response = urllib.request.urlopen(req) + qimage = QImage() + qimage.loadFromData(response.read(), "PNG") + + self.insert_img(f"{layer_name_prefix} {i + 1}: {os.path.basename(output)}", qimage, i + 1 == len(outputs)) - if mode == 1: - self.clear_temp_images([path, mask_path, *outputs]) + if mode == MODE_INPAINT: + self.clear_temp_images([path, mask_path]) else: - self.clear_temp_images([path, *outputs]) + self.clear_temp_images([path]) self.doc.refreshProjection() def apply_simple_upscale(self): - path = self.opt['new_img'] + path = self.cfg('tmp_dir', str) + '/simple_upscale.png' + path_filename = os.path.basename(path) + self.save_img(path) + response = requests.request("POST", self.cfg('base_url', str) + '/saveimg', + files=[('file',(path_filename,open(path,'rb'),'image/png'))]) + if response.status_code != 200: + print(f"Error while uploading image to server: {response.status_code}") + return - response = self.simple_upscale(path) + response = self.simple_upscale(path_filename) output = response['output'] print(f"Getting image: {output}") - - self.insert_img(f"upscale: {os.path.basename(output)}", output) - self.clear_temp_images([path, output]) + req = urllib.request.Request(url = self.cfg('base_url', str) + '/result', + data=json.dumps({'file_name': os.path.basename(output)}).encode('utf-8'), + headers={'content-type': 'application/json'}) + response = urllib.request.urlopen(req) + qimage = QImage() + qimage.loadFromData(response.read(), "PNG") + + self.insert_img(f"upscale: {path_filename}", qimage) + self.clear_temp_images([path]) self.doc.refreshProjection() def create_mask_layer_internal(self): @@ -345,14 +413,14 @@ def action_img2img(self): pass self.update_config() self.try_fix_aspect_ratio() - self.apply_img2img(mode=0) + self.apply_img2img(mode=MODE_IMG2IMG) self.create_mask_layer_workaround() def action_sd_upscale(self): if self.working: pass self.update_config() - self.apply_img2img(mode=2) + self.apply_img2img(mode=MODE_SD_UPSCALE) self.create_mask_layer_workaround() def action_inpaint(self): @@ -360,7 +428,7 @@ def action_inpaint(self): pass self.update_config() self.try_fix_aspect_ratio() - self.apply_img2img(mode=1) + self.apply_img2img(mode=MODE_INPAINT) def action_simple_upscale(self): if self.working: diff --git a/krita_plugin/krita_diff/krita_diff_ui.py b/krita_plugin/krita_diff/krita_diff_ui.py index 3c2bb8c4c..9610a4a5e 100644 --- a/krita_plugin/krita_diff/krita_diff_ui.py +++ b/krita_plugin/krita_diff/krita_diff_ui.py @@ -79,19 +79,13 @@ def create_interface(self): self.widget = QWidget(self) self.widget.setLayout(self.layout) - # TODO: Add necessary UI components to match up with upstream changes. def create_txt2img_interface(self): self.txt2img_prompt_label = QLabel("Prompt:") self.txt2img_prompt_text = QPlainTextEdit() self.txt2img_prompt_text.setPlaceholderText("krita_config.yaml value will be used") - self.txt2img_negative_prompt_label = QLabel("Negative Prompt:") - self.txt2img_negative_prompt_text = QLineEdit() - self.txt2img_negative_prompt_text.setPlaceholderText("krita_config.yaml value will be used") - self.txt2img_prompt_layout = QVBoxLayout() + self.txt2img_prompt_layout = QHBoxLayout() self.txt2img_prompt_layout.addWidget(self.txt2img_prompt_label) self.txt2img_prompt_layout.addWidget(self.txt2img_prompt_text) - self.txt2img_prompt_layout.addWidget(self.txt2img_negative_prompt_label) - self.txt2img_prompt_layout.addWidget(self.txt2img_negative_prompt_text) self.txt2img_sampler_name_label = QLabel("Sampler:") self.txt2img_sampler_name = QComboBox() @@ -158,7 +152,7 @@ def create_txt2img_interface(self): self.txt2img_seed_layout.addWidget(self.txt2img_seed_label) self.txt2img_seed_layout.addWidget(self.txt2img_seed) - self.txt2img_use_gfpgan = QCheckBox("Restore faces") + self.txt2img_use_gfpgan = QCheckBox("Enable GFPGAN (may fix faces)") self.txt2img_use_gfpgan.setTristate(False) self.txt2img_tiling = QCheckBox("Enable tiling mode") @@ -186,7 +180,6 @@ def create_txt2img_interface(self): def init_txt2img_interface(self): self.txt2img_prompt_text.setPlainText(script.cfg('txt2img_prompt', str)) - self.txt2img_negative_prompt_text.setText(script.cfg('txt2img_negative_prompt', str)) self.txt2img_sampler_name.setCurrentIndex(script.cfg('txt2img_sampler', int)) self.txt2img_steps.setValue(script.cfg('txt2img_steps', int)) self.txt2img_cfg_scale.setValue(script.cfg('txt2img_cfg_scale', float)) @@ -204,10 +197,6 @@ def connect_txt2img_interface(self): self.txt2img_prompt_text.textChanged.connect( lambda: script.set_cfg("txt2img_prompt", self.txt2img_prompt_text.toPlainText()) ) - self.txt2img_negative_prompt_text.textChanged.connect( - lambda: script.set_cfg("txt2img_negative_prompt", - re.sub(r'\n', ', ', self.txt2img_negative_prompt_text.text())) - ) self.txt2img_sampler_name.currentIndexChanged.connect( partial(script.set_cfg, "txt2img_sampler") ) @@ -246,14 +235,9 @@ def create_img2img_interface(self): self.img2img_prompt_label = QLabel("Prompt:") self.img2img_prompt_text = QPlainTextEdit() self.img2img_prompt_text.setPlaceholderText("krita_config.yaml value will be used") - self.img2img_negative_prompt_label = QLabel("Negative Prompt:") - self.img2img_negative_prompt_text = QLineEdit() - self.img2img_negative_prompt_text.setPlaceholderText("krita_config.yaml value will be used") - self.img2img_prompt_layout = QVBoxLayout() + self.img2img_prompt_layout = QHBoxLayout() self.img2img_prompt_layout.addWidget(self.img2img_prompt_label) self.img2img_prompt_layout.addWidget(self.img2img_prompt_text) - self.img2img_prompt_layout.addWidget(self.img2img_negative_prompt_label) - self.img2img_prompt_layout.addWidget(self.img2img_negative_prompt_text) self.img2img_sampler_name_label = QLabel("Sampler:") self.img2img_sampler_name = QComboBox() @@ -329,17 +313,12 @@ def create_img2img_interface(self): self.img2img_seed_layout.addWidget(self.img2img_seed_label) self.img2img_seed_layout.addWidget(self.img2img_seed) - self.img2img_checkboxes_layout = QHBoxLayout() self.img2img_tiling = QCheckBox("Enable tiling mode") self.img2img_tiling.setTristate(False) - self.img2img_invert_mask = QCheckBox("Invert mask") - self.img2img_invert_mask.setTristate(False) - self.img2img_checkboxes_layout.addWidget(self.img2img_tiling) - # self.img2img_checkboxes_layout.addWidget(self.img2img_invert_mask) - self.img2img_use_gfpgan = QCheckBox("Restore faces") + self.img2img_use_gfpgan = QCheckBox("Enable GFPGAN (may fix faces)") self.img2img_use_gfpgan.setTristate(False) - + self.img2img_upscaler_name_label = QLabel("Prescaler for SD upscale:") self.img2img_upscaler_name = QComboBox() self.img2img_upscaler_name.addItems(upscalers) @@ -367,8 +346,7 @@ def create_img2img_interface(self): self.img2img_layout.addLayout(self.img2img_size_layout) self.img2img_layout.addLayout(self.img2img_seed_layout) self.img2img_layout.addWidget(self.img2img_use_gfpgan) - - self.img2img_layout.addLayout(self.img2img_checkboxes_layout) + self.img2img_layout.addWidget(self.img2img_tiling) self.img2img_layout.addLayout(self.img2img_upscaler_name_layout) self.img2img_layout.addLayout(self.img2img_button_layout) self.img2img_layout.addStretch() @@ -378,7 +356,6 @@ def create_img2img_interface(self): def init_img2img_interface(self): self.img2img_prompt_text.setPlainText(script.cfg('img2img_prompt', str)) - self.img2img_negative_prompt_text.setText(script.cfg('img2img_negative_prompt', str)) self.img2img_sampler_name.setCurrentIndex(script.cfg('img2img_sampler', int)) self.img2img_steps.setValue(script.cfg('img2img_steps', int)) self.img2img_cfg_scale.setValue(script.cfg('img2img_cfg_scale', float)) @@ -392,8 +369,6 @@ def init_img2img_interface(self): Qt.CheckState.Checked if script.cfg('img2img_use_gfpgan', bool) else Qt.CheckState.Unchecked) self.img2img_tiling.setCheckState( Qt.CheckState.Checked if script.cfg('img2img_tiling', bool) else Qt.CheckState.Unchecked) - self.img2img_invert_mask.setCheckState( - Qt.CheckState.Checked if script.cfg('img2img_invert_mask', bool) else Qt.CheckState.Unchecked) self.img2img_upscaler_name.addItems(upscalers[self.img2img_upscaler_name.count():]) self.img2img_upscaler_name.setCurrentIndex(script.cfg('img2img_upscaler_name', int)) @@ -401,10 +376,6 @@ def connect_img2img_interface(self): self.img2img_prompt_text.textChanged.connect( lambda: script.set_cfg("img2img_prompt", self.img2img_prompt_text.toPlainText()) ) - self.img2img_negative_prompt_text.textChanged.connect( - lambda: script.set_cfg("img2img_negative_prompt", - re.sub(r'\n', ', ', self.img2img_negative_prompt_text.text().strip())) - ) self.img2img_sampler_name.currentIndexChanged.connect( partial(script.set_cfg, "img2img_sampler") ) @@ -438,9 +409,6 @@ def connect_img2img_interface(self): self.img2img_tiling.toggled.connect( partial(script.set_cfg, "img2img_tiling") ) - self.img2img_invert_mask.toggled.connect( - partial(script.set_cfg, "img2img_invert_mask") - ) self.img2img_upscaler_name.currentIndexChanged.connect( partial(script.set_cfg, "img2img_upscaler_name") ) @@ -504,7 +472,7 @@ def connect_upscale_interface(self): ) def create_config_interface(self): - self.config_base_url_label = QLabel("Backend url (only local now):") + self.config_base_url_label = QLabel("Backend url:") self.config_base_url = QLineEdit() self.config_base_url_reset = QPushButton("Default") self.config_base_url_layout = QHBoxLayout() @@ -522,21 +490,12 @@ def create_config_interface(self): self.config_only_full_img_tiling = QCheckBox("Allow tiling only with no selection (on full image)") self.config_only_full_img_tiling.setTristate(False) - self.config_face_restorer_model_label = QLabel("Face restorer model:") - self.config_face_restorer_model = QComboBox() - self.config_face_restorer_model.addItems(face_restorers) - self.config_face_restorer_model_layout = QHBoxLayout() - self.config_face_restorer_model_layout.addWidget(self.config_face_restorer_model_label) - self.config_face_restorer_model_layout.addWidget(self.config_face_restorer_model) - - self.config_codeformer_weight_label = QLabel("CodeFormer weight (0 - max effect, 1 - min effect)") - self.config_codeformer_weight = QDoubleSpinBox() - self.config_codeformer_weight.setMinimum(0.0) - self.config_codeformer_weight.setMaximum(1.0) - self.config_codeformer_weight.setSingleStep(0.01) - self.config_codeformer_weight_layout = QHBoxLayout() - self.config_codeformer_weight_layout.addWidget(self.config_codeformer_weight_label) - self.config_codeformer_weight_layout.addWidget(self.config_codeformer_weight) + self.config_tmp_dir_label = QLabel("Temporary directory:") + self.config_tmp_dir = QLineEdit() + self.config_tmp_dir_reset = QPushButton("Default") + self.config_tmp_dir_layout = QHBoxLayout() + self.config_tmp_dir_layout.addWidget(self.config_tmp_dir) + self.config_tmp_dir_layout.addWidget(self.config_tmp_dir_reset) self.config_restore_defaults = QPushButton("Restore Defaults") @@ -548,9 +507,9 @@ def create_config_interface(self): self.config_layout.addWidget(self.config_delete_temp_files) self.config_layout.addWidget(self.config_fix_aspect_ratio) self.config_layout.addWidget(self.config_only_full_img_tiling) - self.config_layout.addLayout(self.config_face_restorer_model_layout) - self.config_layout.addLayout(self.config_codeformer_weight_layout) self.config_layout.addWidget(self.config_restore_defaults) + self.config_layout.addWidget(self.config_tmp_dir_label) + self.config_layout.addLayout(self.config_tmp_dir_layout) self.config_layout.addStretch() self.config_widget = QWidget() @@ -568,8 +527,7 @@ def init_config_interface(self): Qt.CheckState.Checked if script.cfg('fix_aspect_ratio', bool) else Qt.CheckState.Unchecked) self.config_only_full_img_tiling.setCheckState( Qt.CheckState.Checked if script.cfg('only_full_img_tiling', bool) else Qt.CheckState.Unchecked) - self.config_face_restorer_model.setCurrentIndex(script.cfg('face_restorer_model', int)) - self.config_codeformer_weight.setValue(script.cfg('codeformer_weight', float)) + self.config_tmp_dir.setText(script.cfg('tmp_dir', str)) def connect_config_interface(self): self.config_base_url.textChanged.connect( @@ -593,11 +551,11 @@ def connect_config_interface(self): self.config_only_full_img_tiling.toggled.connect( partial(script.set_cfg, "only_full_img_tiling") ) - self.config_face_restorer_model.currentIndexChanged.connect( - partial(script.set_cfg, "face_restorer_model") + self.config_tmp_dir.textChanged.connect( + partial(script.set_cfg, "tmp_dir") ) - self.config_codeformer_weight.valueChanged.connect( - partial(script.set_cfg, "codeformer_weight") + self.config_tmp_dir_reset.released.connect( + lambda: self.config_tmp_dir.setText(default_tmp_dir) ) self.config_restore_defaults.released.connect( lambda: self.restore_defaults() diff --git a/krita_server.py b/krita_server.py index bd772957b..0d410ad50 100644 --- a/krita_server.py +++ b/krita_server.py @@ -1,10 +1,12 @@ import contextlib import threading import math +import shutil import time import yaml -import os from typing import Optional +from fastapi.responses import FileResponse +from fastapi import UploadFile import numpy as np from pydantic import BaseModel @@ -26,7 +28,7 @@ def load_config(): def save_img(image, sample_path, filename): path = os.path.join(sample_path, filename) image.save(path) - return os.path.abspath(path) + return os.path.basename(path) def fix_aspect_ratio(base_size, max_size, orig_width, orig_height): @@ -70,6 +72,8 @@ def collect_prompt(opts, key): return prompt raise Exception("wtf man, fix your prompts") +class ImageRequest(BaseModel): + file_name: str class Txt2ImgRequest(BaseModel): orig_width: int @@ -164,12 +168,21 @@ async def read_item(): "upscalers": [upscaler.name for upscaler in shared.sd_upscalers], **opt} +@app.post("/result") +async def get_result(req: ImageRequest): + print(f'get_result: {req.file_name}') + opt = load_config()['txt2img'] + print(f'sample_path: {opt["sample_path"]}') + path = os.path.join(opt['sample_path'], req.file_name) + print(f'loading {path}') + return FileResponse(path) @app.post("/txt2img") async def f_txt2img(req: Txt2ImgRequest): print(f"txt2img: {req}") opt = load_config()['txt2img'] + sample_path = opt['sample_path'] set_face_restorer(req.face_restorer or opt['face_restorer'], req.codeformer_weight or opt['codeformer_weight']) @@ -204,7 +217,6 @@ async def f_txt2img(req: Txt2ImgRequest): 0 ) - sample_path = opt['sample_path'] os.makedirs(sample_path, exist_ok=True) resized_images = [modules.images.resize_image(0, image, req.orig_width, req.orig_height) for image in output_images] outputs = [save_img(image, sample_path, filename=f"{int(time.time())}_{i}.png") @@ -212,12 +224,22 @@ async def f_txt2img(req: Txt2ImgRequest): print(f"finished: {outputs}\n{info}") return {"outputs": outputs, "info": info} +@app.post("/saveimg") +async def f_saveimg(file: UploadFile): + print(f'saveimg: {file.filename}') + opt = load_config()['plugin'] + path = os.path.join(opt['sample_path'], file.filename) + print(f'saving {path}') + with open(path, 'wb') as f: + shutil.copyfileobj(file.file, f) + return {"path": path} @app.post("/img2img") async def f_img2img(req: Img2ImgRequest): print(f"img2img: {req}") opt = load_config()['img2img'] + opt_plugin = load_config()['plugin'] set_face_restorer(req.face_restorer or opt['face_restorer'], req.codeformer_weight or opt['codeformer_weight']) @@ -229,11 +251,13 @@ async def f_img2img(req: Img2ImgRequest): mode = req.mode or opt['mode'] - image = Image.open(req.src_path) + path = os.path.join(opt_plugin['sample_path'], req.src_path) + image = Image.open(path) orig_width, orig_height = image.size if mode == 1: - mask = Image.open(req.mask_path).convert('L') + mask_path = os.path.join(opt_plugin['sample_path'], req.mask_path) + mask = Image.open(mask_path).convert('L') else: mask = None @@ -310,7 +334,9 @@ async def f_upscale(req: UpscaleRequest): print(f"upscale: {req}") opt = load_config()['upscale'] - image = Image.open(req.src_path).convert('RGB') + opt_plugin = load_config()['plugin'] + path = os.path.join(opt_plugin['sample_path'], req.src_path) + image = Image.open(path).convert('RGB') orig_width, orig_height = image.size upscaler_index = get_upscaler_index(req.upscaler_name or opt['upscaler_name'])