diff --git a/paint_with_words/helper/__Init__.py b/paint_with_words/helper/__init__.py similarity index 100% rename from paint_with_words/helper/__Init__.py rename to paint_with_words/helper/__init__.py diff --git a/paint_with_words/helper/aliases.py b/paint_with_words/helper/aliases.py index ef8826c..987bf17 100644 --- a/paint_with_words/helper/aliases.py +++ b/paint_with_words/helper/aliases.py @@ -1,9 +1,19 @@ -from typing import Dict, Tuple, Union +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union import torch +import torch as th from paint_with_words.weight_functions import WeightFunction RGB = Tuple[int, int, int] PaintWithWordsHiddenStates = Dict[str, Union[torch.Tensor, WeightFunction]] +ColorContext = Dict[RGB, str] + + +@dataclass +class SeparatedImageContext(object): + word: str + token_ids: List[int] + color_map_th: th.Tensor diff --git a/paint_with_words/helper/attention.py b/paint_with_words/helper/attention.py new file mode 100644 index 0000000..08b3203 --- /dev/null +++ b/paint_with_words/helper/attention.py @@ -0,0 +1,11 @@ +from diffusers.models import UNet2DConditionModel + +from paint_with_words.models.attention import paint_with_words_forward + + +def replace_cross_attention( + unet: UNet2DConditionModel, cross_attention_name: str = "CrossAttention" +) -> None: + for m in unet.modules(): + if m.__class__.__name__ == cross_attention_name: + m.__class__.__call__ = paint_with_words_forward diff --git a/paint_with_words/helper/images.py b/paint_with_words/helper/images.py index fac0956..196f89b 100644 --- a/paint_with_words/helper/images.py +++ b/paint_with_words/helper/images.py @@ -1,9 +1,27 @@ -from typing import Tuple +import logging +import os +from typing import Dict, List, Tuple, Union +import numpy as np import torch as th import torch.nn.functional as F +from PIL import Image from PIL.Image import Image as PilImage from PIL.Image import Resampling +from transformers.tokenization_utils import PreTrainedTokenizer + +from paint_with_words.helper.aliases import RGB, SeparatedImageContext + +logger = logging.getLogger(__name__) + + +def load_image(image: Union[str, os.PathLike, PilImage]) -> PilImage: + if isinstance(image, str) or isinstance(image, os.PathLike): + image = Image.open(image) + + if image.mode != "RGB": + image = image.convert("RGB") + return image def get_resize_size(img: PilImage) -> Tuple[int, int]: @@ -33,3 +51,121 @@ def flatten_image_importance(img_th: th.Tensor, ratio: int) -> th.Tensor: ret = ret.squeeze() return ret + + +def separate_image_context( + tokenizer: PreTrainedTokenizer, + img: PilImage, + color_context: Dict[RGB, str], + device: str, +) -> List[SeparatedImageContext]: + assert img.width % 32 == 0 and img.height % 32 == 0, img.size + + separated_image_and_context: List[SeparatedImageContext] = [] + + for rgb_color, word_with_weight in color_context.items(): + # e.g., + # rgb_color: (0, 0, 0) + # word_with_weight: cat,1.0 + + # cat,1.0 -> ["cat", "1.0"] + word_and_weight = word_with_weight.split(",") + # ["cat", "1.0"] -> 1.0 + word_weight = float(word_and_weight[-1]) + # ["cat", "1.0"] -> cat + word = ",".join(word_and_weight[:-1]) + + logger.info( + f"input = {word_with_weight}; word = {word}; weight = {word_weight}" + ) + + word_input = tokenizer( + word, + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=False, + ) + word_as_tokens = word_input["input_ids"] + + img_where_color_np = (np.array(img) == rgb_color).all(axis=-1) + if not img_where_color_np.sum() > 0: + logger.warning( + f"Warning : not a single color {rgb_color} not found in image" + ) + + img_where_color_th = th.tensor( + img_where_color_np, + dtype=th.float32, + device=device, + ) + img_where_color_th = img_where_color_th * word_weight + + image_context = SeparatedImageContext( + word=word, + token_ids=word_as_tokens, + color_map_th=img_where_color_th, + ) + separated_image_and_context.append(image_context) + + if len(separated_image_and_context) == 0: + image_context = SeparatedImageContext( + word="", + token_ids=[-1], + color_map_th=th.zeros((img.width, img.height), dtype=th.float32), + ) + separated_image_and_context.append(image_context) + + return separated_image_and_context + + +def calculate_tokens_image_attention_weight( + tokenizer: PreTrainedTokenizer, + input_prompt: str, + separated_image_context_list: List[SeparatedImageContext], + ratio: int, + device: str, +) -> th.Tensor: + prompt_token_ids = tokenizer( + input_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + w, h = separated_image_context_list[0].color_map_th.shape + w_r, h_r = w // ratio, h // ratio + + ret_tensor = th.zeros( + (w_r * h_r, len(prompt_token_ids)), dtype=th.float32, device=device + ) + + for separated_image_context in separated_image_context_list: + is_in = False + context_token_ids = separated_image_context.token_ids + context_image_map = separated_image_context.color_map_th + + for i, token_id in enumerate(prompt_token_ids): + if prompt_token_ids[i : i + len(context_token_ids)] == context_token_ids: + is_in = True + + # shape: (w * 1/ratio, h * 1/ratio) + img_importance = flatten_image_importance( + img_th=context_image_map, ratio=ratio + ) + # shape: ((w * 1/ratio) * (h * 1/ratio), 1) + img_importance = img_importance.view(-1, 1) + # shape: ((w * 1/ratio) * (h * 1/ratio), len(context_token_ids)) + img_importance = img_importance.repeat(1, len(context_token_ids)) + + ret_tensor[:, i : i + len(context_token_ids)] += img_importance + + if not is_in: + logger.warning( + f"Warning ratio {ratio} : tokens {context_token_ids} not found in text" + ) + + # add dimension for the batch + # shape: (w_r * h_r, len(prompt_token_ids)) -> (1, w_r * h_r, len(prompt_token_ids)) + ret_tensor = ret_tensor.unsqueeze(dim=0) + + return ret_tensor diff --git a/paint_with_words/models/attention.py b/paint_with_words/models/attention.py index 6f19662..915cc0e 100644 --- a/paint_with_words/models/attention.py +++ b/paint_with_words/models/attention.py @@ -24,8 +24,6 @@ def paint_with_words_forward( else: context_tensor = hidden_states - # batch_size, sequence_length, _ = hidden_states.shape - query = self.to_q(hidden_states) key = self.to_k(context_tensor) @@ -37,6 +35,7 @@ def paint_with_words_forward( key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) + # shape: (batch_size * self.heads, 64 * 64, 77) attention_scores = th.matmul(query, key.transpose(-1, -2)) attention_size_of_img = attention_scores.shape[-2] @@ -53,6 +52,34 @@ def paint_with_words_forward( else: cross_attention_weight = 0.0 + if not isinstance(cross_attention_weight, float): + # shape: (batch_size, 64 * 64, 77) -> (batch_size * self.heads, 64 * 64, 77) + # + # example: + # >>> x = torch.arange(20).reshape(2, 10) + # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]) + # + # >>> x.repeat_interleave(2, dim=0) + # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]) + # + cross_attention_weight = cross_attention_weight.repeat_interleave( + self.heads, dim=0 + ) + + # Example: + # shape (attention_scores): (16, 4096, 77) + # scores1: (8, 4096, 77), scores2: (8, 4096, 77) + # shape (cross_attention_weights): (2, 4096, 77) + # weights1: (1, 4096, 77), weights2: (1, 4096, 77) + # + # We want to calculate the following: + # scores1 + weights1 + # scores2 + weights2 + attention_scores = (attention_scores + cross_attention_weight) * self.scale attention_probs = attention_scores.softmax(dim=-1) diff --git a/paint_with_words/pipelines/paint_with_words_pipeline.py b/paint_with_words/pipelines/paint_with_words_pipeline.py index fdb8d7d..41aa876 100644 --- a/paint_with_words/pipelines/paint_with_words_pipeline.py +++ b/paint_with_words/pipelines/paint_with_words_pipeline.py @@ -1,9 +1,7 @@ import logging import os -from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, List, Optional, Union -import numpy as np import torch as th from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( @@ -19,17 +17,18 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from PIL import Image from PIL.Image import Image as PilImage from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from paint_with_words.helper.aliases import RGB +from paint_with_words.helper.aliases import ColorContext, SeparatedImageContext +from paint_with_words.helper.attention import replace_cross_attention from paint_with_words.helper.images import ( - flatten_image_importance, + calculate_tokens_image_attention_weight, get_resize_size, + load_image, resize_image, + separate_image_context, ) -from paint_with_words.models.attention import paint_with_words_forward from paint_with_words.weight_functions import ( PaintWithWordsWeightFunction, UnconditionedWeightFunction, @@ -39,13 +38,6 @@ logger = logging.getLogger(__name__) -@dataclass -class SeparatedImageContext(object): - word: str - token_ids: List[int] - color_map_th: th.Tensor - - class PaintWithWordsPipeline(StableDiffusionPipeline): def __init__( self, @@ -86,68 +78,20 @@ def __init__( def replace_cross_attention( self, cross_attention_name: str = "CrossAttention" ) -> None: - for m in self.unet.modules(): - if m.__class__.__name__ == cross_attention_name: - m.__class__.__call__ = paint_with_words_forward - - def separate_image_context(self, img: PilImage, color_context: Dict[RGB, str]): - assert img.width % 32 == 0 and img.height % 32 == 0, img.size - - separated_image_and_context: List[SeparatedImageContext] = [] - - for rgb_color, word_with_weight in color_context.items(): - # e.g., - # rgb_color: (0, 0, 0) - # word_with_weight: cat,1.0 - - # cat,1.0 -> ["cat", "1.0"] - word_and_weight = word_with_weight.split(",") - # ["cat", "1.0"] -> 1.0 - word_weight = float(word_and_weight[-1]) - # ["cat", "1.0"] -> cat - word = ",".join(word_and_weight[:-1]) - - logger.info( - f"input = {word_with_weight}; word = {word}; weight = {word_weight}" - ) - - word_input = self.tokenizer( - word, - max_length=self.tokenizer.model_max_length, - truncation=True, - add_special_tokens=False, - ) - word_as_tokens = word_input["input_ids"] - - img_where_color_np = (np.array(img) == rgb_color).all(axis=-1) - if not img_where_color_np.sum() > 0: - logger.warning( - f"Warning : not a single color {rgb_color} not found in image" - ) - - img_where_color_th = th.tensor( - img_where_color_np, - dtype=th.float32, - device=self.device, - ) - img_where_color_th = img_where_color_th * word_weight - - image_context = SeparatedImageContext( - word=word, - token_ids=word_as_tokens, - color_map_th=img_where_color_th, - ) - separated_image_and_context.append(image_context) - - if len(separated_image_and_context) == 0: - image_context = SeparatedImageContext( - word="", - token_ids=[-1], - color_map_th=th.zeros((img.width, img.height), dtype=th.float32), - ) - separated_image_and_context.append(image_context) + replace_cross_attention( + unet=self.unet, + cross_attention_name=cross_attention_name, + ) - return separated_image_and_context + def separate_image_context( + self, img: PilImage, color_context: ColorContext + ) -> List[SeparatedImageContext]: + return separate_image_context( + tokenizer=self.tokenizer, + img=img, + color_context=color_context, + device=self.device, + ) def calculate_tokens_image_attention_weight( self, @@ -155,93 +99,30 @@ def calculate_tokens_image_attention_weight( separated_image_context_list: List[SeparatedImageContext], ratio: int, ) -> th.Tensor: - prompt_token_ids = self.tokenizer( - input_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - ).input_ids - - w, h = separated_image_context_list[0].color_map_th.shape - w_r, h_r = w // ratio, h // ratio - - ret_tensor = th.zeros( - (w_r * h_r, len(prompt_token_ids)), dtype=th.float32, device=self.device + return calculate_tokens_image_attention_weight( + tokenizer=self.tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=ratio, + device=self.device, ) - for separated_image_context in separated_image_context_list: - is_in = False - context_token_ids = separated_image_context.token_ids - context_image_map = separated_image_context.color_map_th - - for i, token_id in enumerate(prompt_token_ids): - if ( - prompt_token_ids[i : i + len(context_token_ids)] - == context_token_ids - ): - is_in = True - - # shape: (w * 1/ratio, h * 1/ratio) - img_importance = flatten_image_importance( - img_th=context_image_map, ratio=ratio - ) - # shape: ((w * 1/ratio) * (h * 1/ratio), 1) - img_importance = img_importance.view(-1, 1) - # shape: ((w * 1/ratio) * (h * 1/ratio), len(context_token_ids)) - img_importance = img_importance.repeat(1, len(context_token_ids)) - - ret_tensor[:, i : i + len(context_token_ids)] += img_importance - - if not is_in: - logger.warning( - f"Warning ratio {ratio} : tokens {context_token_ids} not found in text" - ) - - # add dimension for the batch - # shape: (w_r * h_r, len(prompt_token_ids)) -> (1, w_r * h_r, len(prompt_token_ids)) - ret_tensor = ret_tensor.unsqueeze(dim=0) - - return ret_tensor - - def load_image(self, image: Union[str, os.PathLike, PilImage]) -> PilImage: - if isinstance(image, str) or isinstance(image, os.PathLike): - image = Image.open(image) - - if image.mode != "RGB": - image = image.convert("RGB") - - return image - - @th.no_grad() - def __call__( + def calculate_cross_attention_weight( self, prompt: str, - color_context: Dict[RGB, str], - color_map_image: PilImage, - weight_function: WeightFunction = PaintWithWordsWeightFunction(), - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: int = 1, - eta: float = 0, - generator: Optional[th.Generator] = None, - latents: Optional[th.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, th.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - ) -> StableDiffusionPipelineOutput: - assert isinstance(prompt, str), type(prompt) - assert guidance_scale > 1.0, guidance_scale - + color_map_image: Union[str, os.PathLike, PilImage], + color_context: ColorContext, + ): # 0. Default height and width to unet and resize the color map image - color_map_image = self.load_image(image=color_map_image) + color_map_image = load_image(image=color_map_image) width, height = get_resize_size(img=color_map_image) + assert width == 512 and height == 512 color_map_image = resize_image(img=color_map_image, w=width, h=height) separated_image_context_list = self.separate_image_context( img=color_map_image, color_context=color_context ) + cross_attention_weight_8 = self.calculate_tokens_image_attention_weight( input_prompt=prompt, separated_image_context_list=separated_image_context_list, @@ -266,11 +147,86 @@ def __call__( ratio=64, ) + return { + f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, + f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, + f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, + f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, + } + + def batch_calculate_cross_attention_weight( + self, + prompts: List[str], + color_map_images: Union[List[str], List[os.PathLike], List[PilImage]], + color_contexts: List[ColorContext], + ): + assert len(prompts) == len(color_map_images) == len(color_contexts) + it = zip(prompts, color_map_images, color_contexts) + + cross_attention_weight_dict = {} + + for i, (prompt, color_map_image, color_context) in enumerate(it): + output_dict = self.calculate_cross_attention_weight( + prompt=prompt, + color_map_image=color_map_image, + color_context=color_context, + ) + if i == 0: + cross_attention_weight_dict.update(output_dict) + else: + for weight_key in output_dict.keys(): + tensor_tuple = ( + cross_attention_weight_dict[weight_key], + output_dict[weight_key], + ) + cross_attention_weight_dict[weight_key] = th.cat( + tensor_tuple, dim=0 + ) + + for w_k, w_v in cross_attention_weight_dict.items(): + assert w_v.size(dim=0) == len( + prompts + ), f"Invalid batch dim at {w_k}: {w_v.size(dim=0)} != {len(prompts)}" + + return cross_attention_weight_dict + + @th.no_grad() + def __call__( + self, + prompts: Union[str, List[str]], + color_contexts: Union[ColorContext, List[ColorContext]], + color_map_images: Union[ + PilImage, List[PilImage], str, List[str], os.PathLike, List[os.PathLike] + ], + weight_function: WeightFunction = PaintWithWordsWeightFunction(), + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0, + generator: Optional[th.Generator] = None, + latents: Optional[th.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, th.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ) -> StableDiffusionPipelineOutput: + if not isinstance(prompts, list): + prompts = [prompts] + if not isinstance(color_contexts, list): + color_contexts = [color_contexts] + if not isinstance(color_map_images, list): + color_map_images = [color_map_images] + + assert guidance_scale > 1.0, guidance_scale + # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + self.check_inputs(prompts, height, width, callback_steps) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + batch_size = 1 if isinstance(prompts, str) else len(prompts) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -279,15 +235,18 @@ def __call__( # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, + prompts, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, ) # Ensure classifier free guidance is performed and - # the batch size of the text embedding is 2 (conditional + unconditional) - assert do_classifier_free_guidance and text_embeddings.size(dim=0) == 2 + # the batch size of the text embedding is batch_size * 2 (conditional + unconditional) + assert ( + do_classifier_free_guidance + and text_embeddings.size(dim=0) == batch_size * 2 + ) uncond_embeddings, cond_embeddings = text_embeddings.chunk(2) # 4. Prepare timesteps @@ -310,6 +269,13 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 7. Calculate weights for the cross attention + cross_attention_weights = self.batch_calculate_cross_attention_weight( + prompts=prompts, + color_map_images=color_map_images, + color_contexts=color_contexts, + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -319,32 +285,31 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latents, t) assert latent_model_input.size() == ( - 1, + batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) + encoder_hidden_states = { + "context_tensor": cond_embeddings, + "sigma": sigma, + "weight_function": weight_function, + } + encoder_hidden_states.update(cross_attention_weights) + # predict the noise residual noise_pred_text = self.unet( - latent_model_input, - t, - encoder_hidden_states={ - "context_tensor": cond_embeddings, - f"cross_attention_weight_{height * width // (8*8)}": cross_attention_weight_8, - f"cross_attention_weight_{height * width // (16*16)}": cross_attention_weight_16, - f"cross_attention_weight_{height * width // (32*32)}": cross_attention_weight_32, - f"cross_attention_weight_{height * width // (64*64)}": cross_attention_weight_64, - "sigma": sigma, - "weight_function": weight_function, - }, + sample=latent_model_input, + timestep=t, + encoder_hidden_states=encoder_hidden_states, ).sample latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_uncond = self.unet( - latent_model_input, - t, + sample=latent_model_input, + timestep=t, encoder_hidden_states={ "context_tensor": uncond_embeddings, f"cross_attention_weight_{height * width // (8*8)}": 0.0, diff --git a/tests/helper/aliases_test.py b/tests/helper/aliases_test.py index 7515f2c..af5e9f6 100644 --- a/tests/helper/aliases_test.py +++ b/tests/helper/aliases_test.py @@ -2,7 +2,11 @@ import torch as th -from paint_with_words.helper.aliases import RGB, PaintWithWordsHiddenStates +from paint_with_words.helper.aliases import ( + RGB, + PaintWithWordsHiddenStates, + SeparatedImageContext, +) from paint_with_words.weight_functions import WeightFunction @@ -12,3 +16,12 @@ def test_rgb(): def test_paint_with_words_hidden_states(): assert PaintWithWordsHiddenStates == Dict[str, Union[th.Tensor, WeightFunction]] + + +def test_separated_image_context(): + separated_image_context = SeparatedImageContext( + word="cat", + token_ids=[2368], + color_map_th=th.zeros((512, 512)), + ) + assert isinstance(separated_image_context, SeparatedImageContext) diff --git a/tests/helper/images_test.py b/tests/helper/images_test.py index d21323f..90741fe 100644 --- a/tests/helper/images_test.py +++ b/tests/helper/images_test.py @@ -1,7 +1,18 @@ +from typing import Dict + import pytest +import torch as th from PIL import Image +from transformers import CLIPTokenizer -from paint_with_words.helper.images import get_resize_size, resize_image +from paint_with_words.helper.aliases import RGB, SeparatedImageContext +from paint_with_words.helper.images import ( + calculate_tokens_image_attention_weight, + get_resize_size, + load_image, + resize_image, + separate_image_context, +) @pytest.fixture @@ -39,3 +50,99 @@ def test_resize_image(): w=129, h=129, ) + + +def test_separate_image_context( + model_name: str, color_context: Dict[RGB, str], color_map_image_path: str +): + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") + + color_map_image = load_image(color_map_image_path) + + ret_list = separate_image_context( + tokenizer=tokenizer, + img=color_map_image, + color_context=color_context, + device="cpu", + ) + + for ret in ret_list: + assert isinstance(ret, SeparatedImageContext) + assert isinstance(ret.word, str) + assert isinstance(ret.token_ids, list) + assert isinstance(ret.color_map_th, th.Tensor) + + token_ids = tokenizer( + ret.word, + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=False, + ).input_ids + assert ret.token_ids == token_ids + + +def test_calculate_tokens_image_attention_weight( + model_name: str, + color_context: Dict[RGB, str], + color_map_image_path: str, + input_prompt: str, +): + tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") + + color_map_image = load_image(color_map_image_path) + w, h = color_map_image.size + + separated_image_context_list = separate_image_context( + tokenizer=tokenizer, + img=color_map_image, + color_context=color_context, + device="cpu", + ) + + cross_attention_weight_8 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=8, + device="cpu", + ) + assert cross_attention_weight_8.size() == ( + int((w * 1 / 8) * (h * 1 / 8)), + tokenizer.model_max_length, + ) + + cross_attention_weight_16 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=16, + device="cpu", + ) + assert cross_attention_weight_16.size() == ( + int((w * 1 / 16) * (h * 1 / 16)), + tokenizer.model_max_length, + ) + + cross_attention_weight_32 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=32, + device="cpu", + ) + assert cross_attention_weight_32.size() == ( + int((w * 1 / 32) * (h * 1 / 32)), + tokenizer.model_max_length, + ) + + cross_attention_weight_64 = calculate_tokens_image_attention_weight( + tokenizer=tokenizer, + input_prompt=input_prompt, + separated_image_context_list=separated_image_context_list, + ratio=64, + device="cpu", + ) + assert cross_attention_weight_64.size() == ( + int((w * 1 / 64) * (h * 1 / 64)), + tokenizer.model_max_length, + ) diff --git a/tests/pipelines/paint_with_words_pipeline_test.py b/tests/pipelines/paint_with_words_pipeline_test.py index 41df741..fb7dea3 100644 --- a/tests/pipelines/paint_with_words_pipeline_test.py +++ b/tests/pipelines/paint_with_words_pipeline_test.py @@ -1,7 +1,6 @@ from typing import Dict import pytest -import torch import torch as th from diffusers.schedulers import LMSDiscreteScheduler from PIL import Image @@ -112,7 +111,7 @@ def test_pipeline( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) pipe.safety_checker = None # disable the safety checker pipe.to(gpu_device) @@ -121,8 +120,8 @@ def test_pipeline( assert isinstance(pipe.scheduler, LMSDiscreteScheduler), type(pipe.scheduler) # generate latents with seed-fixed generator - generator = torch.manual_seed(0) - latents = torch.randn((1, 4, 64, 64), generator=generator) + generator = th.manual_seed(0) + latents = th.randn((1, 4, 64, 64), generator=generator) # load color map image color_map_image = Image.open(color_map_image_path) @@ -130,9 +129,9 @@ def test_pipeline( # generate image using the pipeline with th.autocast("cuda"): image = pipe( - prompt=input_prompt, - color_context=color_context, - color_map_image=color_map_image, + prompts=input_prompt, + color_contexts=color_context, + color_map_images=color_map_image, latents=latents, num_inference_steps=30, ).images[0] @@ -155,7 +154,7 @@ def test_separate_image_context( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) color_map_image = pipe.load_image(color_map_image_path) @@ -168,7 +167,7 @@ def test_separate_image_context( assert isinstance(ret, SeparatedImageContext) assert isinstance(ret.word, str) assert isinstance(ret.token_ids, list) - assert isinstance(ret.color_map_th, torch.Tensor) + assert isinstance(ret.color_map_th, th.Tensor) token_ids = pipe.tokenizer( ret.word, @@ -193,7 +192,7 @@ def test_calculate_tokens_image_attention_weight( pipe = PaintWithWordsPipeline.from_pretrained( model_name, revision="fp16", - torch_dtype=torch.float16, + torch_dtype=th.float16, ) color_map_image = pipe.load_image(color_map_image_path) @@ -246,3 +245,44 @@ def test_calculate_tokens_image_attention_weight( int((w * 1 / 64) * (h * 1 / 64)), pipe.tokenizer.model_max_length, ) + + +@pytest.mark.skipif( + not th.cuda.is_available(), + reason="No GPUs available for testing.", +) +def test_batch_pipeline(model_name: str, gpu_device: str): + # load pre-trained weight with paint with words pipeline + pipe = PaintWithWordsPipeline.from_pretrained( + model_name, + revision="fp16", + torch_dtype=th.float16, + ) + pipe.safety_checker = None # disable the safety checker + pipe.to(gpu_device) + + # check the scheduler is LMSDiscreteScheduler + assert isinstance(pipe.scheduler, LMSDiscreteScheduler), type(pipe.scheduler) + + # generate latents with seed-fixed generator + generator = th.manual_seed(0) + latents = th.randn((1, 4, 64, 64), generator=generator) + latents = latents.repeat(2, 1, 1, 1) # shape: (1, 4, 64, 64) -> (2, 4, 64, 64) + + batch_examples = [EXAMPLE_SETTING_1, EXAMPLE_SETTING_3] + + with th.autocast("cuda"): + pipe_output = pipe( + prompts=[example["input_prompt"] for example in batch_examples], + color_contexts=[example["color_context"] for example in batch_examples], + color_map_images=[ + example["color_map_image_path"] for example in batch_examples + ], + latents=latents, + num_inference_steps=30, + ) + images = pipe_output.images + + for image, example in zip(images, batch_examples): + content_dir, image_filename = example["output_image_path"].split("/") # type: ignore + image.save(f"{content_dir}/batch_{image_filename}")