Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion paint_with_words/helper/aliases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import Dict, Tuple, Union
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union

import torch
import torch as th

from paint_with_words.weight_functions import WeightFunction

RGB = Tuple[int, int, int]

PaintWithWordsHiddenStates = Dict[str, Union[torch.Tensor, WeightFunction]]
ColorContext = Dict[RGB, str]


@dataclass
class SeparatedImageContext(object):
word: str
token_ids: List[int]
color_map_th: th.Tensor
11 changes: 11 additions & 0 deletions paint_with_words/helper/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from diffusers.models import UNet2DConditionModel

from paint_with_words.models.attention import paint_with_words_forward


def replace_cross_attention(
unet: UNet2DConditionModel, cross_attention_name: str = "CrossAttention"
) -> None:
for m in unet.modules():
if m.__class__.__name__ == cross_attention_name:
m.__class__.__call__ = paint_with_words_forward
138 changes: 137 additions & 1 deletion paint_with_words/helper/images.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
from typing import Tuple
import logging
import os
from typing import Dict, List, Tuple, Union

import numpy as np
import torch as th
import torch.nn.functional as F
from PIL import Image
from PIL.Image import Image as PilImage
from PIL.Image import Resampling
from transformers.tokenization_utils import PreTrainedTokenizer

from paint_with_words.helper.aliases import RGB, SeparatedImageContext

logger = logging.getLogger(__name__)


def load_image(image: Union[str, os.PathLike, PilImage]) -> PilImage:
if isinstance(image, str) or isinstance(image, os.PathLike):
image = Image.open(image)

if image.mode != "RGB":
image = image.convert("RGB")
return image


def get_resize_size(img: PilImage) -> Tuple[int, int]:
Expand Down Expand Up @@ -33,3 +51,121 @@ def flatten_image_importance(img_th: th.Tensor, ratio: int) -> th.Tensor:
ret = ret.squeeze()

return ret


def separate_image_context(
tokenizer: PreTrainedTokenizer,
img: PilImage,
color_context: Dict[RGB, str],
device: str,
) -> List[SeparatedImageContext]:
assert img.width % 32 == 0 and img.height % 32 == 0, img.size

separated_image_and_context: List[SeparatedImageContext] = []

for rgb_color, word_with_weight in color_context.items():
# e.g.,
# rgb_color: (0, 0, 0)
# word_with_weight: cat,1.0

# cat,1.0 -> ["cat", "1.0"]
word_and_weight = word_with_weight.split(",")
# ["cat", "1.0"] -> 1.0
word_weight = float(word_and_weight[-1])
# ["cat", "1.0"] -> cat
word = ",".join(word_and_weight[:-1])

logger.info(
f"input = {word_with_weight}; word = {word}; weight = {word_weight}"
)

word_input = tokenizer(
word,
max_length=tokenizer.model_max_length,
truncation=True,
add_special_tokens=False,
)
word_as_tokens = word_input["input_ids"]

img_where_color_np = (np.array(img) == rgb_color).all(axis=-1)
if not img_where_color_np.sum() > 0:
logger.warning(
f"Warning : not a single color {rgb_color} not found in image"
)

img_where_color_th = th.tensor(
img_where_color_np,
dtype=th.float32,
device=device,
)
img_where_color_th = img_where_color_th * word_weight

image_context = SeparatedImageContext(
word=word,
token_ids=word_as_tokens,
color_map_th=img_where_color_th,
)
separated_image_and_context.append(image_context)

if len(separated_image_and_context) == 0:
image_context = SeparatedImageContext(
word="",
token_ids=[-1],
color_map_th=th.zeros((img.width, img.height), dtype=th.float32),
)
separated_image_and_context.append(image_context)

return separated_image_and_context


def calculate_tokens_image_attention_weight(
tokenizer: PreTrainedTokenizer,
input_prompt: str,
separated_image_context_list: List[SeparatedImageContext],
ratio: int,
device: str,
) -> th.Tensor:
prompt_token_ids = tokenizer(
input_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids

w, h = separated_image_context_list[0].color_map_th.shape
w_r, h_r = w // ratio, h // ratio

ret_tensor = th.zeros(
(w_r * h_r, len(prompt_token_ids)), dtype=th.float32, device=device
)

for separated_image_context in separated_image_context_list:
is_in = False
context_token_ids = separated_image_context.token_ids
context_image_map = separated_image_context.color_map_th

for i, token_id in enumerate(prompt_token_ids):
if prompt_token_ids[i : i + len(context_token_ids)] == context_token_ids:
is_in = True

# shape: (w * 1/ratio, h * 1/ratio)
img_importance = flatten_image_importance(
img_th=context_image_map, ratio=ratio
)
# shape: ((w * 1/ratio) * (h * 1/ratio), 1)
img_importance = img_importance.view(-1, 1)
# shape: ((w * 1/ratio) * (h * 1/ratio), len(context_token_ids))
img_importance = img_importance.repeat(1, len(context_token_ids))

ret_tensor[:, i : i + len(context_token_ids)] += img_importance

if not is_in:
logger.warning(
f"Warning ratio {ratio} : tokens {context_token_ids} not found in text"
)

# add dimension for the batch
# shape: (w_r * h_r, len(prompt_token_ids)) -> (1, w_r * h_r, len(prompt_token_ids))
ret_tensor = ret_tensor.unsqueeze(dim=0)

return ret_tensor
31 changes: 29 additions & 2 deletions paint_with_words/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def paint_with_words_forward(
else:
context_tensor = hidden_states

# batch_size, sequence_length, _ = hidden_states.shape

query = self.to_q(hidden_states)

key = self.to_k(context_tensor)
Expand All @@ -37,6 +35,7 @@ def paint_with_words_forward(
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

# shape: (batch_size * self.heads, 64 * 64, 77)
attention_scores = th.matmul(query, key.transpose(-1, -2))

attention_size_of_img = attention_scores.shape[-2]
Expand All @@ -53,6 +52,34 @@ def paint_with_words_forward(
else:
cross_attention_weight = 0.0

if not isinstance(cross_attention_weight, float):
# shape: (batch_size, 64 * 64, 77) -> (batch_size * self.heads, 64 * 64, 77)
#
# example:
# >>> x = torch.arange(20).reshape(2, 10)
# tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])
#
# >>> x.repeat_interleave(2, dim=0)
# tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])
#
cross_attention_weight = cross_attention_weight.repeat_interleave(
self.heads, dim=0
)

# Example:
# shape (attention_scores): (16, 4096, 77)
# scores1: (8, 4096, 77), scores2: (8, 4096, 77)
# shape (cross_attention_weights): (2, 4096, 77)
# weights1: (1, 4096, 77), weights2: (1, 4096, 77)
#
# We want to calculate the following:
# scores1 + weights1
# scores2 + weights2

attention_scores = (attention_scores + cross_attention_weight) * self.scale

attention_probs = attention_scores.softmax(dim=-1)
Expand Down
Loading