From e13a827bf87301d4030cd2cf8e2814dd35174dfc Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Tue, 31 Jan 2023 11:46:07 +0900 Subject: [PATCH 01/13] update aliases --- paint_with_words/helper/aliases.py | 11 ++++++++++- tests/helper/aliases_test.py | 16 +++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/paint_with_words/helper/aliases.py b/paint_with_words/helper/aliases.py index ef8826c..2085a8c 100644 --- a/paint_with_words/helper/aliases.py +++ b/paint_with_words/helper/aliases.py @@ -1,9 +1,18 @@ -from typing import Dict, Tuple, Union +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union import torch +import torch as th from paint_with_words.weight_functions import WeightFunction RGB = Tuple[int, int, int] PaintWithWordsHiddenStates = Dict[str, Union[torch.Tensor, WeightFunction]] + + +@dataclass +class SeparatedImageContext(object): + word: str + token_ids: List[int] + color_map_th: th.Tensor diff --git a/tests/helper/aliases_test.py b/tests/helper/aliases_test.py index 7515f2c..a6edbf4 100644 --- a/tests/helper/aliases_test.py +++ b/tests/helper/aliases_test.py @@ -2,7 +2,11 @@ import torch as th -from paint_with_words.helper.aliases import RGB, PaintWithWordsHiddenStates +from paint_with_words.helper.aliases import ( + RGB, + PaintWithWordsHiddenStates, + SeparatedImageContext, +) from paint_with_words.weight_functions import WeightFunction @@ -12,3 +16,13 @@ def test_rgb(): def test_paint_with_words_hidden_states(): assert PaintWithWordsHiddenStates == Dict[str, Union[th.Tensor, WeightFunction]] + + +def test_separated_image_context(): + + separated_image_context = SeparatedImageContext( + word="cat", + token_ids=[2368], + color_map_th=th.zeros((512, 512)), + ) + assert isinstance(separated_image_context, SeparatedImageContext) From 419d0f25bfd80302c9234ad77a7a7795bc4d8773 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Tue, 31 Jan 2023 18:06:01 +0900 Subject: [PATCH 02/13] [WIP] update files --- paint_with_words/helper/__Init__.py | 2 + paint_with_words/helper/aliases.py | 1 + paint_with_words/helper/attention.py | 12 ++ paint_with_words/helper/images.py | 138 ++++++++++++- .../pipelines/paint_with_words_pipeline.py | 192 +++++------------- tests/helper/attention_test.py | 34 ++++ tests/helper/images_test.py | 111 +++++++++- .../paint_with_words_pipeline_test.py | 41 ++++ 8 files changed, 385 insertions(+), 146 deletions(-) create mode 100644 paint_with_words/helper/attention.py create mode 100644 tests/helper/attention_test.py diff --git a/paint_with_words/helper/__Init__.py b/paint_with_words/helper/__Init__.py index e69de29..139597f 100644 --- a/paint_with_words/helper/__Init__.py +++ b/paint_with_words/helper/__Init__.py @@ -0,0 +1,2 @@ + + diff --git a/paint_with_words/helper/aliases.py b/paint_with_words/helper/aliases.py index 2085a8c..987bf17 100644 --- a/paint_with_words/helper/aliases.py +++ b/paint_with_words/helper/aliases.py @@ -9,6 +9,7 @@ RGB = Tuple[int, int, int] PaintWithWordsHiddenStates = Dict[str, Union[torch.Tensor, WeightFunction]] +ColorContext = Dict[RGB, str] @dataclass diff --git a/paint_with_words/helper/attention.py b/paint_with_words/helper/attention.py new file mode 100644 index 0000000..19034aa --- /dev/null +++ b/paint_with_words/helper/attention.py @@ -0,0 +1,12 @@ +from diffusers.models import UNet2DConditionModel + +from paint_with_words.models.attention import paint_with_words_forward + + +def replace_cross_attention( + unet: UNet2DConditionModel, cross_attention_name: str = "CrossAttention" +) -> None: + + for m in unet.modules(): + if m.__class__.__name__ == cross_attention_name: + m.__class__.__call__ = paint_with_words_forward diff --git a/paint_with_words/helper/images.py b/paint_with_words/helper/images.py index df167bf..f1d261d 100644 --- a/paint_with_words/helper/images.py +++ b/paint_with_words/helper/images.py @@ -1,9 +1,27 @@ -from typing import Tuple +import logging +import os +from typing import Dict, List, Tuple, Union +import numpy as np import torch as th import torch.nn.functional as F +from PIL import Image from PIL.Image import Image as PilImage from PIL.Image import Resampling +from transformers.tokenization_utils import PreTrainedTokenizer + +from paint_with_words.helper.aliases import RGB, SeparatedImageContext + +logger = logging.getLogger(__name__) + + +def load_image(image: Union[str, os.PathLike, PilImage]) -> PilImage: + if isinstance(image, str) or isinstance(image, os.PathLike): + image = Image.open(image) + + if image.mode != "RGB": + image = image.convert("RGB") + return image def get_resize_size(img: PilImage) -> Tuple[int, int]: @@ -34,3 +52,121 @@ def flatten_image_importance(img_th: th.Tensor, ratio: int) -> th.Tensor: ret = ret.squeeze() return ret + + +def separate_image_context( + tokenizer: PreTrainedTokenizer, + img: PilImage, + color_context: Dict[RGB, str], + device: str, +) -> List[SeparatedImageContext]: + + assert img.width % 32 == 0 and img.height % 32 == 0, img.size + + separated_image_and_context: List[SeparatedImageContext] = [] + + for rgb_color, word_with_weight in color_context.items(): + + # e.g., + # rgb_color: (0, 0, 0) + # word_with_weight: cat,1.0 + + # cat,1.0 -> ["cat", "1.0"] + word_and_weight = word_with_weight.split(",") + # ["cat", "1.0"] -> 1.0 + word_weight = float(word_and_weight[-1]) + # ["cat", "1.0"] -> cat + word = ",".join(word_and_weight[:-1]) + + logger.info( + f"input = {word_with_weight}; word = {word}; weight = {word_weight}" + ) + + word_input = tokenizer( + word, + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=False, + ) + word_as_tokens = word_input["input_ids"] + + img_where_color_np = (np.array(img) == rgb_color).all(axis=-1) + if not img_where_color_np.sum() > 0: + logger.warning( + f"Warning : not a single color {rgb_color} not found in image" + ) + + img_where_color_th = th.tensor( + img_where_color_np, + dtype=th.float32, + device=device, + ) + img_where_color_th = img_where_color_th * word_weight + + breakpoint() + image_context = SeparatedImageContext( + word=word, + token_ids=word_as_tokens, + color_map_th=img_where_color_th, + ) + separated_image_and_context.append(image_context) + + if len(separated_image_and_context) == 0: + image_context = SeparatedImageContext( + word="", + token_ids=[-1], + color_map_th=th.zeros((img.width, img.height), dtype=th.float32), + ) + separated_image_and_context.append(image_context) + + return separated_image_and_context + + +def calculate_tokens_image_attention_weight( + tokenizer: PreTrainedTokenizer, + input_prompt: str, + separated_image_context_list: List[SeparatedImageContext], + ratio: int, + device: str, +) -> th.Tensor: + + prompt_token_ids = tokenizer( + input_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + w, h = separated_image_context_list[0].color_map_th.shape + w_r, h_r = w // ratio, h // ratio + + ret_tensor = th.zeros( + (w_r * h_r, len(prompt_token_ids)), dtype=th.float32, device=device + ) + + for separated_image_context in separated_image_context_list: + is_in = False + context_token_ids = separated_image_context.token_ids + context_image_map = separated_image_context.color_map_th + + for i, token_id in enumerate(prompt_token_ids): + if prompt_token_ids[i : i + len(context_token_ids)] == context_token_ids: + is_in = True + + # shape: (w * 1/ratio, h * 1/ratio) + img_importance = flatten_image_importance( + img_th=context_image_map, ratio=ratio + ) + # shape: ((w * 1/ratio) * (h * 1/ratio), 1) + img_importance = img_importance.view(-1, 1) + # shape: ((w * 1/ratio) * (h * 1/ratio), len(context_token_ids)) + img_importance = img_importance.repeat(1, len(context_token_ids)) + + ret_tensor[:, i : i + len(context_token_ids)] += img_importance + + if not is_in: + logger.warning( + f"Warning ratio {ratio} : tokens {context_token_ids} not found in text" + ) + + return ret_tensor diff --git a/paint_with_words/pipelines/paint_with_words_pipeline.py b/paint_with_words/pipelines/paint_with_words_pipeline.py index eb511cf..6ccd183 100644 --- a/paint_with_words/pipelines/paint_with_words_pipeline.py +++ b/paint_with_words/pipelines/paint_with_words_pipeline.py @@ -1,9 +1,6 @@ import logging -import os -from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union -import numpy as np import torch as th from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -19,17 +16,18 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from PIL import Image from PIL.Image import Image as PilImage from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from paint_with_words.helper.aliases import RGB +from paint_with_words.helper.aliases import RGB, ColorContext, SeparatedImageContext +from paint_with_words.helper.attention import replace_cross_attention from paint_with_words.helper.images import ( - flatten_image_importance, + calculate_tokens_image_attention_weight, get_resize_size, + load_image, resize_image, + separate_image_context, ) -from paint_with_words.models.attention import paint_with_words_forward from paint_with_words.weight_functions import ( PaintWithWordsWeightFunction, UnconditionedWeightFunction, @@ -39,13 +37,6 @@ logger = logging.getLogger(__name__) -@dataclass -class SeparatedImageContext(object): - word: str - token_ids: List[int] - color_map_th: th.Tensor - - class PaintWithWordsPipeline(StableDiffusionPipeline): def __init__( self, @@ -86,71 +77,20 @@ def __init__( def replace_cross_attention( self, cross_attention_name: str = "CrossAttention" ) -> None: + replace_cross_attention( + unet=self.unet, + cross_attention_name=cross_attention_name, + ) - for m in self.unet.modules(): - if m.__class__.__name__ == cross_attention_name: - m.__class__.__call__ = paint_with_words_forward - - def separate_image_context(self, img: PilImage, color_context: Dict[RGB, str]): - - assert img.width % 32 == 0 and img.height % 32 == 0, img.size - - separated_image_and_context: List[SeparatedImageContext] = [] - - for rgb_color, word_with_weight in color_context.items(): - - # e.g., - # rgb_color: (0, 0, 0) - # word_with_weight: cat,1.0 - - # cat,1.0 -> ["cat", "1.0"] - word_and_weight = word_with_weight.split(",") - # ["cat", "1.0"] -> 1.0 - word_weight = float(word_and_weight[-1]) - # ["cat", "1.0"] -> cat - word = ",".join(word_and_weight[:-1]) - - logger.info( - f"input = {word_with_weight}; word = {word}; weight = {word_weight}" - ) - - word_input = self.tokenizer( - word, - max_length=self.tokenizer.model_max_length, - truncation=True, - add_special_tokens=False, - ) - word_as_tokens = word_input["input_ids"] - - img_where_color_np = (np.array(img) == rgb_color).all(axis=-1) - if not img_where_color_np.sum() > 0: - logger.warning( - f"Warning : not a single color {rgb_color} not found in image" - ) - - img_where_color_th = th.tensor( - img_where_color_np, - dtype=th.float32, - device=self.device, - ) - img_where_color_th = img_where_color_th * word_weight - - image_context = SeparatedImageContext( - word=word, - token_ids=word_as_tokens, - color_map_th=img_where_color_th, - ) - separated_image_and_context.append(image_context) - - if len(separated_image_and_context) == 0: - image_context = SeparatedImageContext( - word="", - token_ids=[-1], - color_map_th=th.zeros((img.width, img.height), dtype=th.float32), - ) - separated_image_and_context.append(image_context) - - return separated_image_and_context + def separate_image_context( + self, img: PilImage, color_context: ColorContext + ) -> List[SeparatedImageContext]: + return separate_image_context( + tokenizer=self.tokenizer, + img=img, + color_context=color_context, + device=self.device, + ) def calculate_tokens_image_attention_weight( self, @@ -159,65 +99,20 @@ def calculate_tokens_image_attention_weight( ratio: int, ) -> th.Tensor: - prompt_token_ids = self.tokenizer( - input_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - ).input_ids - - w, h = separated_image_context_list[0].color_map_th.shape - w_r, h_r = w // ratio, h // ratio - - ret_tensor = th.zeros( - (w_r * h_r, len(prompt_token_ids)), dtype=th.float32, device=self.device + return calculate_tokens_image_attention_weight( + tokenizer=self.tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=ratio, + device=self.device, ) - for separated_image_context in separated_image_context_list: - is_in = False - context_token_ids = separated_image_context.token_ids - context_image_map = separated_image_context.color_map_th - - for i, token_id in enumerate(prompt_token_ids): - if ( - prompt_token_ids[i : i + len(context_token_ids)] - == context_token_ids - ): - is_in = True - - # shape: (w * 1/ratio, h * 1/ratio) - img_importance = flatten_image_importance( - img_th=context_image_map, ratio=ratio - ) - # shape: ((w * 1/ratio) * (h * 1/ratio), 1) - img_importance = img_importance.view(-1, 1) - # shape: ((w * 1/ratio) * (h * 1/ratio), len(context_token_ids)) - img_importance = img_importance.repeat(1, len(context_token_ids)) - - ret_tensor[:, i : i + len(context_token_ids)] += img_importance - - if not is_in: - logger.warning( - f"Warning ratio {ratio} : tokens {context_token_ids} not found in text" - ) - - return ret_tensor - - def load_image(self, image: Union[str, os.PathLike, PilImage]) -> PilImage: - if isinstance(image, str) or isinstance(image, os.PathLike): - image = Image.open(image) - - if image.mode != "RGB": - image = image.convert("RGB") - - return image - @th.no_grad() def __call__( self, - prompt: str, - color_context: Dict[RGB, str], - color_map_image: PilImage, + prompts: Union[str, List[str]], + color_contexts: Union[ColorContext, List[ColorContext]], + color_map_images: Union[PilImage, List[PilImage]], weight_function: WeightFunction = PaintWithWordsWeightFunction(), num_inference_steps: int = 50, guidance_scale: float = 7.5, @@ -232,46 +127,55 @@ def __call__( callback_steps: Optional[int] = 1, ) -> StableDiffusionPipelineOutput: - assert isinstance(prompt, str), type(prompt) + if not isinstance(prompts, list): + prompts = [prompts] + if not isinstance(color_contexts, list): + color_contexts = [color_contexts] + if not isinstance(color_map_images, list): + color_map_images = [color_map_images] + assert guidance_scale > 1.0, guidance_scale # 0. Default height and width to unet and resize the color map image - color_map_image = self.load_image(image=color_map_image) - width, height = get_resize_size(img=color_map_image) - color_map_image = resize_image(img=color_map_image, w=width, h=height) + color_map_images = [load_image(image=image) for image in color_map_images] + sizes = [get_resize_size(img=img) for img in color_map_images] + color_map_images = [ + resize_image(img=img, w=w, h=h) + for img, (w, h) in zip(color_map_images, sizes) + ] separated_image_context_list = self.separate_image_context( - img=color_map_image, color_context=color_context + img=color_map_images, color_context=color_contexts ) cross_attention_weight_8 = self.calculate_tokens_image_attention_weight( - input_prompt=prompt, + input_prompt=prompts, separated_image_context_list=separated_image_context_list, ratio=8, ) cross_attention_weight_16 = self.calculate_tokens_image_attention_weight( - input_prompt=prompt, + input_prompt=prompts, separated_image_context_list=separated_image_context_list, ratio=16, ) cross_attention_weight_32 = self.calculate_tokens_image_attention_weight( - input_prompt=prompt, + input_prompt=prompts, separated_image_context_list=separated_image_context_list, ratio=32, ) cross_attention_weight_64 = self.calculate_tokens_image_attention_weight( - input_prompt=prompt, + input_prompt=prompts, separated_image_context_list=separated_image_context_list, ratio=64, ) # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + self.check_inputs(prompts, height, width, callback_steps) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 if isinstance(prompts, str) else len(prompts) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -280,7 +184,7 @@ def __call__( # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, + prompts, device, num_images_per_prompt, do_classifier_free_guidance, diff --git a/tests/helper/attention_test.py b/tests/helper/attention_test.py new file mode 100644 index 0000000..987871e --- /dev/null +++ b/tests/helper/attention_test.py @@ -0,0 +1,34 @@ +import copy + +import pytest +from diffusers.models import UNet2DConditionModel + +from paint_with_words.helper.attention import replace_cross_attention + + +@pytest.fixture +def model_name() -> str: + return "CompVis/stable-diffusion-v1-4" + + +def test_replace_cross_attention( + model_name: str, cross_attention_name: str = "CrossAttention" +): + unet_original = UNet2DConditionModel.from_pretrained( + model_name, subfolder="unet", revision="fp16" + ) + + unet_proposed = copy.deepcopy(unet_original) + # unet_proposed = UNet2DConditionModel.from_pretrained( + # model_name, subfolder="unet", revision="fp16" + # ) + unet_proposed = replace_cross_attention(unet=unet_proposed) + + for m_orig, m_prop in zip(unet_original.modules(), unet_proposed.modules()): + cond1 = m_orig.__class__.__name__ == cross_attention_name + cond2 = m_prop.__class__.__name__ == cross_attention_name + + if cond1 and cond2: + breakpoint() + assert m_orig.__class__.__call__.__name__ == "_call_impl" + assert m_prop.__class__.__call__.__name__ == "paint_with_words_forward" diff --git a/tests/helper/images_test.py b/tests/helper/images_test.py index dd0b5c1..8788a17 100644 --- a/tests/helper/images_test.py +++ b/tests/helper/images_test.py @@ -1,7 +1,18 @@ +from typing import Dict + import pytest +import torch as th from PIL import Image +from transformers import CLIPTokenizer -from paint_with_words.helper.images import get_resize_size, resize_image +from paint_with_words.helper.aliases import RGB, SeparatedImageContext +from paint_with_words.helper.images import ( + calculate_tokens_image_attention_weight, + get_resize_size, + load_image, + resize_image, + separate_image_context, +) @pytest.fixture @@ -41,3 +52,101 @@ def test_resize_image(): w=129, h=129, ) + + +def test_separate_image_context( + model_name: str, color_context: Dict[RGB, str], color_map_image_path: str +): + + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") + + color_map_image = load_image(color_map_image_path) + + ret_list = separate_image_context( + tokenizer=tokenizer, + img=color_map_image, + color_context=color_context, + device="cpu", + ) + + for ret in ret_list: + assert isinstance(ret, SeparatedImageContext) + assert isinstance(ret.word, str) + assert isinstance(ret.token_ids, list) + assert isinstance(ret.color_map_th, th.Tensor) + + token_ids = tokenizer( + ret.word, + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=False, + ).input_ids + assert ret.token_ids == token_ids + + +def test_calculate_tokens_image_attention_weight( + model_name: str, + color_context: Dict[RGB, str], + color_map_image_path: str, + input_prompt: str, +): + + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") + + color_map_image = load_image(color_map_image_path) + w, h = color_map_image.size + + separated_image_context_list = separate_image_context( + tokenizer=tokenizer, + img=color_map_image, + color_context=color_context, + device="cpu", + ) + + cross_attention_weight_8 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=8, + device="cpu", + ) + assert cross_attention_weight_8.size() == ( + int((w * 1 / 8) * (h * 1 / 8)), + tokenizer.model_max_length, + ) + + cross_attention_weight_16 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=16, + device="cpu", + ) + assert cross_attention_weight_16.size() == ( + int((w * 1 / 16) * (h * 1 / 16)), + tokenizer.model_max_length, + ) + + cross_attention_weight_32 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=32, + device="cpu", + ) + assert cross_attention_weight_32.size() == ( + int((w * 1 / 32) * (h * 1 / 32)), + tokenizer.model_max_length, + ) + + cross_attention_weight_64 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=64, + device="cpu", + ) + assert cross_attention_weight_64.size() == ( + int((w * 1 / 64) * (h * 1 / 64)), + tokenizer.model_max_length, + ) diff --git a/tests/pipelines/paint_with_words_pipeline_test.py b/tests/pipelines/paint_with_words_pipeline_test.py index 9199f5e..18cdbc9 100644 --- a/tests/pipelines/paint_with_words_pipeline_test.py +++ b/tests/pipelines/paint_with_words_pipeline_test.py @@ -243,3 +243,44 @@ def test_calculate_tokens_image_attention_weight( int((w * 1 / 64) * (h * 1 / 64)), pipe.tokenizer.model_max_length, ) + + +def test_batch_pipeline(model_name: str): + + # load pre-trained weight with paint with words pipeline + pipe = PaintWithWordsPipeline.from_pretrained( + model_name, + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.safety_checker = None # disable the safety checker + pipe.to(gpu_device) + + # check the scheduler is LMSDiscreteScheduler + assert isinstance(pipe.scheduler, LMSDiscreteScheduler), type(pipe.scheduler) + + # generate latents with seed-fixed generator + generator = torch.manual_seed(0) + latents = torch.randn((1, 4, 64, 64), generator=generator) + latents = latents.repeat(2, 1, 1, 1) # shape: (1, 4, 64, 64) -> (2, 4, 64, 64) + + color_map_image_1 = EXAMPLE_SETTING_1["color_map_image_path"] + color_map_image_2 = EXAMPLE_SETTING_2["color_map_image_path"] + + with th.autocast("cuda"): + images = pipe( + prompt=[ + EXAMPLE_SETTING_1["input_prompt"], + EXAMPLE_SETTING_1["input_prompt"], + ], + color_context=[ + EXAMPLE_SETTING_1["color_context"], + EXAMPLE_SETTING_2["color_context"], + ], + color_map_image=[ + color_map_image_1, + color_map_image_2, + ], + latents=latents, + num_inference_steps=30, + ) From c01e2064bba22ccb608630c8217badf2a18fb979 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Wed, 1 Feb 2023 21:14:38 +0900 Subject: [PATCH 03/13] remove breakpoint --- paint_with_words/helper/images.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paint_with_words/helper/images.py b/paint_with_words/helper/images.py index f1d261d..c4661ed 100644 --- a/paint_with_words/helper/images.py +++ b/paint_with_words/helper/images.py @@ -103,7 +103,6 @@ def separate_image_context( ) img_where_color_th = img_where_color_th * word_weight - breakpoint() image_context = SeparatedImageContext( word=word, token_ids=word_as_tokens, From 046d0bd43ed14db9179964c61e67d63388e6cbef Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Wed, 1 Feb 2023 21:14:49 +0900 Subject: [PATCH 04/13] fix typo --- paint_with_words/helper/{__Init__.py => __init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename paint_with_words/helper/{__Init__.py => __init__.py} (100%) diff --git a/paint_with_words/helper/__Init__.py b/paint_with_words/helper/__init__.py similarity index 100% rename from paint_with_words/helper/__Init__.py rename to paint_with_words/helper/__init__.py From ed0d235eb5f4880f357c22357d8bf177a0bbb30f Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Wed, 1 Feb 2023 21:15:16 +0900 Subject: [PATCH 05/13] [WIP] temp commit --- .../pipelines/paint_with_words_pipeline.py | 155 ++++++++++++------ .../paint_with_words_pipeline_test.py | 27 ++- 2 files changed, 121 insertions(+), 61 deletions(-) diff --git a/paint_with_words/pipelines/paint_with_words_pipeline.py b/paint_with_words/pipelines/paint_with_words_pipeline.py index 6ccd183..b77276b 100644 --- a/paint_with_words/pipelines/paint_with_words_pipeline.py +++ b/paint_with_words/pipelines/paint_with_words_pipeline.py @@ -1,4 +1,5 @@ import logging +import os from typing import Callable, Dict, List, Optional, Union import torch as th @@ -107,13 +108,91 @@ def calculate_tokens_image_attention_weight( device=self.device, ) + def calculate_cross_attention_weight( + self, + prompt: str, + color_map_image: Union[str, os.PathLike, PilImage], + color_context: ColorContext, + ): + + # 0. Default height and width to unet and resize the color map image + color_map_image = load_image(image=color_map_image) + width, height = get_resize_size(img=color_map_image) + assert width == 512 and height == 512 + color_map_image = resize_image(img=color_map_image, w=width, h=height) + + separated_image_context_list = self.separate_image_context( + img=color_map_image, color_context=color_context + ) + + cross_attention_weight_8 = self.calculate_tokens_image_attention_weight( + input_prompt=prompt, + separated_image_context_list=separated_image_context_list, + ratio=8, + ) + + cross_attention_weight_16 = self.calculate_tokens_image_attention_weight( + input_prompt=prompt, + separated_image_context_list=separated_image_context_list, + ratio=16, + ) + + cross_attention_weight_32 = self.calculate_tokens_image_attention_weight( + input_prompt=prompt, + separated_image_context_list=separated_image_context_list, + ratio=32, + ) + + cross_attention_weight_64 = self.calculate_tokens_image_attention_weight( + input_prompt=prompt, + separated_image_context_list=separated_image_context_list, + ratio=64, + ) + + breakpoint() + + return { + f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, + f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, + f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, + f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, + } + + def batch_calculate_cross_attention_weight( + self, + prompts: List[str], + color_map_images: Union[List[str], List[os.PathLike], List[PilImage]], + color_contexts: List[ColorContext], + ): + assert len(prompts) == len(color_map_images) == len(color_contexts) + it = zip(prompts, color_map_images, color_contexts) + + cross_attention_weight_dict = {} + + for i, (prompt, color_map_image, color_context) in enumerate(it): + output_dict = self.calculate_cross_attention_weight( + prompt=prompt, + color_map_image=color_map_image, + color_context=color_context, + ) + if i == 0: + cross_attention_weight_dict.update(output_dict) + else: + breakpoint() + + return cross_attention_weight_dict + @th.no_grad() def __call__( self, prompts: Union[str, List[str]], color_contexts: Union[ColorContext, List[ColorContext]], - color_map_images: Union[PilImage, List[PilImage]], + color_map_images: Union[ + PilImage, List[PilImage], str, List[str], os.PathLike, List[os.PathLike] + ], weight_function: WeightFunction = PaintWithWordsWeightFunction(), + height: int = 512, + width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -136,41 +215,6 @@ def __call__( assert guidance_scale > 1.0, guidance_scale - # 0. Default height and width to unet and resize the color map image - color_map_images = [load_image(image=image) for image in color_map_images] - sizes = [get_resize_size(img=img) for img in color_map_images] - color_map_images = [ - resize_image(img=img, w=w, h=h) - for img, (w, h) in zip(color_map_images, sizes) - ] - - separated_image_context_list = self.separate_image_context( - img=color_map_images, color_context=color_contexts - ) - cross_attention_weight_8 = self.calculate_tokens_image_attention_weight( - input_prompt=prompts, - separated_image_context_list=separated_image_context_list, - ratio=8, - ) - - cross_attention_weight_16 = self.calculate_tokens_image_attention_weight( - input_prompt=prompts, - separated_image_context_list=separated_image_context_list, - ratio=16, - ) - - cross_attention_weight_32 = self.calculate_tokens_image_attention_weight( - input_prompt=prompts, - separated_image_context_list=separated_image_context_list, - ratio=32, - ) - - cross_attention_weight_64 = self.calculate_tokens_image_attention_weight( - input_prompt=prompts, - separated_image_context_list=separated_image_context_list, - ratio=64, - ) - # 1. Check inputs. Raise error if not correct self.check_inputs(prompts, height, width, callback_steps) @@ -191,8 +235,11 @@ def __call__( negative_prompt, ) # Ensure classifier free guidance is performed and - # the batch size of the text embedding is 2 (conditional + unconditional) - assert do_classifier_free_guidance and text_embeddings.size(dim=0) == 2 + # the batch size of the text embedding is batch_size * 2 (conditional + unconditional) + assert ( + do_classifier_free_guidance + and text_embeddings.size(dim=0) == batch_size * 2 + ) uncond_embeddings, cond_embeddings = text_embeddings.chunk(2) # 4. Prepare timesteps @@ -215,6 +262,12 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + cross_attention_weights = self.batch_calculate_cross_attention_weight( + prompts=prompts, + color_map_images=color_map_images, + color_contexts=color_contexts, + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -231,19 +284,27 @@ def __call__( width // self.vae_scale_factor, ) + encoder_hidden_states = { + "context_tensor": cond_embeddings, + "sigma": sigma, + "weight_function": weight_function, + } + encoder_hidden_states.update(cross_attention_weights) + # predict the noise residual noise_pred_text = self.unet( latent_model_input, t, - encoder_hidden_states={ - "context_tensor": cond_embeddings, - f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, - f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, - f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, - f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, - "sigma": sigma, - "weight_function": weight_function, - }, + # encoder_hidden_states={ + # "context_tensor": cond_embeddings, + # f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, + # f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, + # f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, + # f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, + # "sigma": sigma, + # "weight_function": weight_function, + # }, + encoder_hidden_states=encoder_hidden_states, ).sample latent_model_input = self.scheduler.scale_model_input(latents, t) diff --git a/tests/pipelines/paint_with_words_pipeline_test.py b/tests/pipelines/paint_with_words_pipeline_test.py index 18cdbc9..487591a 100644 --- a/tests/pipelines/paint_with_words_pipeline_test.py +++ b/tests/pipelines/paint_with_words_pipeline_test.py @@ -1,7 +1,6 @@ from typing import Dict import pytest -import torch import torch as th from diffusers.schedulers import LMSDiscreteScheduler from PIL import Image @@ -113,7 +112,7 @@ def test_pipeline( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) pipe.safety_checker = None # disable the safety checker pipe.to(gpu_device) @@ -122,8 +121,8 @@ def test_pipeline( assert isinstance(pipe.scheduler, LMSDiscreteScheduler), type(pipe.scheduler) # generate latents with seed-fixed generator - generator = torch.manual_seed(0) - latents = torch.randn((1, 4, 64, 64), generator=generator) + generator = th.manual_seed(0) + latents = th.randn((1, 4, 64, 64), generator=generator) # load color map image color_map_image = Image.open(color_map_image_path) @@ -156,7 +155,7 @@ def test_separate_image_context( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) color_map_image = pipe.load_image(color_map_image_path) @@ -169,7 +168,7 @@ def test_separate_image_context( assert isinstance(ret, SeparatedImageContext) assert isinstance(ret.word, str) assert isinstance(ret.token_ids, list) - assert isinstance(ret.color_map_th, torch.Tensor) + assert isinstance(ret.color_map_th, th.Tensor) token_ids = pipe.tokenizer( ret.word, @@ -194,7 +193,7 @@ def test_calculate_tokens_image_attention_weight( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) color_map_image = pipe.load_image(color_map_image_path) @@ -245,13 +244,13 @@ def test_calculate_tokens_image_attention_weight( ) -def test_batch_pipeline(model_name: str): +def test_batch_pipeline(model_name: str, gpu_device: str): # load pre-trained weight with paint with words pipeline pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) pipe.safety_checker = None # disable the safety checker pipe.to(gpu_device) @@ -260,8 +259,8 @@ def test_batch_pipeline(model_name: str): assert isinstance(pipe.scheduler, LMSDiscreteScheduler), type(pipe.scheduler) # generate latents with seed-fixed generator - generator = torch.manual_seed(0) - latents = torch.randn((1, 4, 64, 64), generator=generator) + generator = th.manual_seed(0) + latents = th.randn((1, 4, 64, 64), generator=generator) latents = latents.repeat(2, 1, 1, 1) # shape: (1, 4, 64, 64) -> (2, 4, 64, 64) color_map_image_1 = EXAMPLE_SETTING_1["color_map_image_path"] @@ -269,15 +268,15 @@ def test_batch_pipeline(model_name: str): with th.autocast("cuda"): images = pipe( - prompt=[ + prompts=[ EXAMPLE_SETTING_1["input_prompt"], EXAMPLE_SETTING_1["input_prompt"], ], - color_context=[ + color_contexts=[ EXAMPLE_SETTING_1["color_context"], EXAMPLE_SETTING_2["color_context"], ], - color_map_image=[ + color_map_images=[ color_map_image_1, color_map_image_2, ], From c92b3c4879f5c71d9f13a6485784bfd8adce24cf Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:08:21 +0900 Subject: [PATCH 06/13] update --- paint_with_words/helper/images.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paint_with_words/helper/images.py b/paint_with_words/helper/images.py index 83364c5..196f89b 100644 --- a/paint_with_words/helper/images.py +++ b/paint_with_words/helper/images.py @@ -59,13 +59,11 @@ def separate_image_context( color_context: Dict[RGB, str], device: str, ) -> List[SeparatedImageContext]: - assert img.width % 32 == 0 and img.height % 32 == 0, img.size separated_image_and_context: List[SeparatedImageContext] = [] for rgb_color, word_with_weight in color_context.items(): - # e.g., # rgb_color: (0, 0, 0) # word_with_weight: cat,1.0 @@ -127,7 +125,6 @@ def calculate_tokens_image_attention_weight( ratio: int, device: str, ) -> th.Tensor: - prompt_token_ids = tokenizer( input_prompt, padding="max_length", @@ -167,4 +164,8 @@ def calculate_tokens_image_attention_weight( f"Warning ratio {ratio} : tokens {context_token_ids} not found in text" ) + # add dimension for the batch + # shape: (w_r * h_r, len(prompt_token_ids)) -> (1, w_r * h_r, len(prompt_token_ids)) + ret_tensor = ret_tensor.unsqueeze(dim=0) + return ret_tensor From 927019f56b436133d3444b2a2bdd7fdcf4afb5bb Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:08:30 +0900 Subject: [PATCH 07/13] update --- paint_with_words/models/attention.py | 31 ++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/paint_with_words/models/attention.py b/paint_with_words/models/attention.py index 6f19662..915cc0e 100644 --- a/paint_with_words/models/attention.py +++ b/paint_with_words/models/attention.py @@ -24,8 +24,6 @@ def paint_with_words_forward( else: context_tensor = hidden_states - # batch_size, sequence_length, _ = hidden_states.shape - query = self.to_q(hidden_states) key = self.to_k(context_tensor) @@ -37,6 +35,7 @@ def paint_with_words_forward( key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) + # shape: (batch_size * self.heads, 64 * 64, 77) attention_scores = th.matmul(query, key.transpose(-1, -2)) attention_size_of_img = attention_scores.shape[-2] @@ -53,6 +52,34 @@ def paint_with_words_forward( else: cross_attention_weight = 0.0 + if not isinstance(cross_attention_weight, float): + # shape: (batch_size, 64 * 64, 77) -> (batch_size * self.heads, 64 * 64, 77) + # + # example: + # >>> x = torch.arange(20).reshape(2, 10) + # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]) + # + # >>> x.repeat_interleave(2, dim=0) + # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]) + # + cross_attention_weight = cross_attention_weight.repeat_interleave( + self.heads, dim=0 + ) + + # Example: + # shape (attention_scores): (16, 4096, 77) + # scores1: (8, 4096, 77), scores2: (8, 4096, 77) + # shape (cross_attention_weights): (2, 4096, 77) + # weights1: (1, 4096, 77), weights2: (1, 4096, 77) + # + # We want to calculate the following: + # scores1 + weights1 + # scores2 + weights2 + attention_scores = (attention_scores + cross_attention_weight) * self.scale attention_probs = attention_scores.softmax(dim=-1) From 552e2ae5e37ef16ff27be9be127d0c64f972233a Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:18:43 +0900 Subject: [PATCH 08/13] update --- .../pipelines/paint_with_words_pipeline.py | 38 ++++++++++--------- .../paint_with_words_pipeline_test.py | 34 ++++++++--------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/paint_with_words/pipelines/paint_with_words_pipeline.py b/paint_with_words/pipelines/paint_with_words_pipeline.py index 26a1901..0df4e3d 100644 --- a/paint_with_words/pipelines/paint_with_words_pipeline.py +++ b/paint_with_words/pipelines/paint_with_words_pipeline.py @@ -20,7 +20,7 @@ from PIL.Image import Image as PilImage from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from paint_with_words.helper.aliases import RGB, ColorContext, SeparatedImageContext +from paint_with_words.helper.aliases import ColorContext, SeparatedImageContext from paint_with_words.helper.attention import replace_cross_attention from paint_with_words.helper.images import ( calculate_tokens_image_attention_weight, @@ -147,8 +147,6 @@ def calculate_cross_attention_weight( ratio=64, ) - breakpoint() - return { f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, @@ -176,7 +174,19 @@ def batch_calculate_cross_attention_weight( if i == 0: cross_attention_weight_dict.update(output_dict) else: - breakpoint() + for weight_key in output_dict.keys(): + tensor_tuple = ( + cross_attention_weight_dict[weight_key], + output_dict[weight_key], + ) + cross_attention_weight_dict[weight_key] = th.cat( + tensor_tuple, dim=0 + ) + + for w_k, w_v in cross_attention_weight_dict.items(): + assert w_v.size(dim=0) == len( + prompts + ), f"Invalid batch dim at {w_k}: {w_v.size(dim=0)} != {len(prompts)}" return cross_attention_weight_dict @@ -259,6 +269,7 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7. Calculate weights for the cross attention cross_attention_weights = self.batch_calculate_cross_attention_weight( prompts=prompts, color_map_images=color_map_images, @@ -274,7 +285,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latents, t) assert latent_model_input.size() == ( - 1, + batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, @@ -289,25 +300,16 @@ def __call__( # predict the noise residual noise_pred_text = self.unet( - latent_model_input, - t, - # encoder_hidden_states={ - # "context_tensor": cond_embeddings, - # f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, - # f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, - # f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, - # f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, - # "sigma": sigma, - # "weight_function": weight_function, - # }, + sample=latent_model_input, + timestep=t, encoder_hidden_states=encoder_hidden_states, ).sample latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_uncond = self.unet( - latent_model_input, - t, + sample=latent_model_input, + timestep=t, encoder_hidden_states={ "context_tensor": uncond_embeddings, f"cross_attention_weight_{height * width // (8*8)}": 0.0, diff --git a/tests/pipelines/paint_with_words_pipeline_test.py b/tests/pipelines/paint_with_words_pipeline_test.py index 8fa0c16..72ce2f7 100644 --- a/tests/pipelines/paint_with_words_pipeline_test.py +++ b/tests/pipelines/paint_with_words_pipeline_test.py @@ -129,9 +129,9 @@ def test_pipeline( # generate image using the pipeline with th.autocast("cuda"): image = pipe( - prompt=input_prompt, - color_context=color_context, - color_map_image=color_map_image, + prompts=input_prompt, + color_contexts=color_context, + color_map_images=color_map_image, latents=latents, num_inference_steps=30, ).images[0] @@ -247,8 +247,11 @@ def test_calculate_tokens_image_attention_weight( ) +@pytest.mark.skipif( + not th.cuda.is_available(), + reason="No GPUs available for testing.", +) def test_batch_pipeline(model_name: str, gpu_device: str): - # load pre-trained weight with paint with words pipeline pipe = PaintWithWordsPipeline.from_pretrained( model_name, @@ -266,23 +269,20 @@ def test_batch_pipeline(model_name: str, gpu_device: str): latents = th.randn((1, 4, 64, 64), generator=generator) latents = latents.repeat(2, 1, 1, 1) # shape: (1, 4, 64, 64) -> (2, 4, 64, 64) - color_map_image_1 = EXAMPLE_SETTING_1["color_map_image_path"] - color_map_image_2 = EXAMPLE_SETTING_2["color_map_image_path"] + batch_examples = [EXAMPLE_SETTING_1, EXAMPLE_SETTING_3] with th.autocast("cuda"): - images = pipe( - prompts=[ - EXAMPLE_SETTING_1["input_prompt"], - EXAMPLE_SETTING_1["input_prompt"], - ], - color_contexts=[ - EXAMPLE_SETTING_1["color_context"], - EXAMPLE_SETTING_2["color_context"], - ], + pipe_output = pipe( + prompts=[example["input_prompt"] for example in batch_examples], + color_contexts=[example["color_context"] for example in batch_examples], color_map_images=[ - color_map_image_1, - color_map_image_2, + example["color_map_image_path"] for example in batch_examples ], latents=latents, num_inference_steps=30, ) + images = pipe_output.images + + for image, example in zip(images, batch_examples): + content_dir, image_filename = example["output_image_path"].split("/") + image.save(f"{content_dir}/batch_{image_filename}") From 7cbfe975ad2c31d704a62b40733cbe9f3eab747f Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:25:16 +0900 Subject: [PATCH 09/13] apply formatter --- paint_with_words/helper/__init__.py | 1 - paint_with_words/helper/attention.py | 1 - tests/helper/aliases_test.py | 1 - tests/helper/images_test.py | 2 -- 4 files changed, 5 deletions(-) diff --git a/paint_with_words/helper/__init__.py b/paint_with_words/helper/__init__.py index 139597f..8b13789 100644 --- a/paint_with_words/helper/__init__.py +++ b/paint_with_words/helper/__init__.py @@ -1,2 +1 @@ - diff --git a/paint_with_words/helper/attention.py b/paint_with_words/helper/attention.py index 19034aa..08b3203 100644 --- a/paint_with_words/helper/attention.py +++ b/paint_with_words/helper/attention.py @@ -6,7 +6,6 @@ def replace_cross_attention( unet: UNet2DConditionModel, cross_attention_name: str = "CrossAttention" ) -> None: - for m in unet.modules(): if m.__class__.__name__ == cross_attention_name: m.__class__.__call__ = paint_with_words_forward diff --git a/tests/helper/aliases_test.py b/tests/helper/aliases_test.py index a6edbf4..af5e9f6 100644 --- a/tests/helper/aliases_test.py +++ b/tests/helper/aliases_test.py @@ -19,7 +19,6 @@ def test_paint_with_words_hidden_states(): def test_separated_image_context(): - separated_image_context = SeparatedImageContext( word="cat", token_ids=[2368], diff --git a/tests/helper/images_test.py b/tests/helper/images_test.py index ea5e528..90741fe 100644 --- a/tests/helper/images_test.py +++ b/tests/helper/images_test.py @@ -55,7 +55,6 @@ def test_resize_image(): def test_separate_image_context( model_name: str, color_context: Dict[RGB, str], color_map_image_path: str ): - tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") color_map_image = load_image(color_map_image_path) @@ -88,7 +87,6 @@ def test_calculate_tokens_image_attention_weight( color_map_image_path: str, input_prompt: str, ): - tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") color_map_image = load_image(color_map_image_path) From 57f0bc1b7a3e2109064a7a2bf3909555f85016ce Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:25:53 +0900 Subject: [PATCH 10/13] update --- tests/pipelines/paint_with_words_pipeline_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/paint_with_words_pipeline_test.py b/tests/pipelines/paint_with_words_pipeline_test.py index 72ce2f7..fb7dea3 100644 --- a/tests/pipelines/paint_with_words_pipeline_test.py +++ b/tests/pipelines/paint_with_words_pipeline_test.py @@ -284,5 +284,5 @@ def test_batch_pipeline(model_name: str, gpu_device: str): images = pipe_output.images for image, example in zip(images, batch_examples): - content_dir, image_filename = example["output_image_path"].split("/") + content_dir, image_filename = example["output_image_path"].split("/") # type: ignore image.save(f"{content_dir}/batch_{image_filename}") From aefb37d6ce0735bac45fd85cd79910098d071037 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:29:09 +0900 Subject: [PATCH 11/13] fix for formatting --- paint_with_words/helper/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paint_with_words/helper/__init__.py b/paint_with_words/helper/__init__.py index 8b13789..e69de29 100644 --- a/paint_with_words/helper/__init__.py +++ b/paint_with_words/helper/__init__.py @@ -1 +0,0 @@ - From 3f0d8b3ae76911986a305679eb66af745475f017 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:29:32 +0900 Subject: [PATCH 12/13] remove unused import --- paint_with_words/pipelines/paint_with_words_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paint_with_words/pipelines/paint_with_words_pipeline.py b/paint_with_words/pipelines/paint_with_words_pipeline.py index 0df4e3d..41aa876 100644 --- a/paint_with_words/pipelines/paint_with_words_pipeline.py +++ b/paint_with_words/pipelines/paint_with_words_pipeline.py @@ -1,6 +1,6 @@ import logging import os -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union import torch as th from diffusers.models import AutoencoderKL, UNet2DConditionModel From 42b379aa05a8c4c4b78ff7be9157454ac8eb0211 Mon Sep 17 00:00:00 2001 From: Shunsuke KITADA Date: Thu, 2 Feb 2023 19:34:02 +0900 Subject: [PATCH 13/13] remove file --- tests/helper/attention_test.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 tests/helper/attention_test.py diff --git a/tests/helper/attention_test.py b/tests/helper/attention_test.py deleted file mode 100644 index 987871e..0000000 --- a/tests/helper/attention_test.py +++ /dev/null @@ -1,34 +0,0 @@ -import copy - -import pytest -from diffusers.models import UNet2DConditionModel - -from paint_with_words.helper.attention import replace_cross_attention - - -@pytest.fixture -def model_name() -> str: - return "CompVis/stable-diffusion-v1-4" - - -def test_replace_cross_attention( - model_name: str, cross_attention_name: str = "CrossAttention" -): - unet_original = UNet2DConditionModel.from_pretrained( - model_name, subfolder="unet", revision="fp16" - ) - - unet_proposed = copy.deepcopy(unet_original) - # unet_proposed = UNet2DConditionModel.from_pretrained( - # model_name, subfolder="unet", revision="fp16" - # ) - unet_proposed = replace_cross_attention(unet=unet_proposed) - - for m_orig, m_prop in zip(unet_original.modules(), unet_proposed.modules()): - cond1 = m_orig.__class__.__name__ == cross_attention_name - cond2 = m_prop.__class__.__name__ == cross_attention_name - - if cond1 and cond2: - breakpoint() - assert m_orig.__class__.__call__.__name__ == "_call_impl" - assert m_prop.__class__.__call__.__name__ == "paint_with_words_forward"