Skip to content

Commit 4c01c9d

Browse files
committed
refactor pt. 2
1 parent 90d29c7 commit 4c01c9d

File tree

2 files changed

+92
-117
lines changed

2 files changed

+92
-117
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 85 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from ...utils import logging
2727
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2828
from ..modeling_outputs import Transformer2DModelOutput
29-
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
3029

3130

3231
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -81,6 +80,53 @@ def forward(
8180
return hidden_states, prompt_embeds, negative_prompt_embeds
8281

8382

83+
class CogView4AdaLayerNormZero(nn.Module):
84+
def __init__(self, embedding_dim: int, dim: int) -> None:
85+
super().__init__()
86+
87+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
88+
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
89+
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
90+
91+
def forward(
92+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
93+
) -> Tuple[torch.Tensor, torch.Tensor]:
94+
norm_hidden_states = self.norm(hidden_states)
95+
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
96+
97+
emb = self.linear(temb)
98+
(
99+
shift_msa,
100+
c_shift_msa,
101+
scale_msa,
102+
c_scale_msa,
103+
gate_msa,
104+
c_gate_msa,
105+
shift_mlp,
106+
c_shift_mlp,
107+
scale_mlp,
108+
c_scale_mlp,
109+
gate_mlp,
110+
c_gate_mlp,
111+
) = emb.chunk(12, dim=1)
112+
113+
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
114+
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
115+
116+
return (
117+
hidden_states,
118+
gate_msa,
119+
shift_mlp,
120+
scale_mlp,
121+
gate_mlp,
122+
encoder_hidden_states,
123+
c_gate_msa,
124+
c_shift_mlp,
125+
c_scale_mlp,
126+
c_gate_mlp,
127+
)
128+
129+
84130
class CogView4AttnProcessor:
85131
"""
86132
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -89,7 +135,7 @@ class CogView4AttnProcessor:
89135

90136
def __init__(self):
91137
if not hasattr(F, "scaled_dot_product_attention"):
92-
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
138+
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
93139

94140
def __call__(
95141
self,
@@ -153,10 +199,8 @@ def __init__(
153199
) -> None:
154200
super().__init__()
155201

156-
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
157-
self.adaln = self.norm1.linear
158-
self.layernorm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
159-
202+
# 1. Attention
203+
self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim)
160204
self.attn1 = Attention(
161205
query_dim=dim,
162206
heads=num_attention_heads,
@@ -169,97 +213,52 @@ def __init__(
169213
processor=CogView4AttnProcessor(),
170214
)
171215

216+
# 2. Feedforward
217+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
218+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
172219
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
173220

174-
def multi_modulate(self, hidden_states, encoder_hidden_states, factors) -> torch.Tensor:
175-
_, _, h = factors[0].shape
176-
shift_factor, scale_factor = factors[0].view(-1, h), factors[1].view(-1, h)
177-
178-
shift_factor_hidden_states, shift_factor_encoder_hidden_states = shift_factor.chunk(2, dim=0)
179-
scale_factor_hidden_states, scale_factor_encoder_hidden_states = scale_factor.chunk(2, dim=0)
180-
shift_factor_hidden_states = shift_factor_hidden_states.unsqueeze(1)
181-
scale_factor_hidden_states = scale_factor_hidden_states.unsqueeze(1)
182-
hidden_states = torch.addcmul(shift_factor_hidden_states, hidden_states, (1 + scale_factor_hidden_states))
183-
184-
shift_factor_encoder_hidden_states = shift_factor_encoder_hidden_states.unsqueeze(1)
185-
scale_factor_encoder_hidden_states = scale_factor_encoder_hidden_states.unsqueeze(1)
186-
encoder_hidden_states = torch.addcmul(
187-
shift_factor_encoder_hidden_states, encoder_hidden_states, (1 + scale_factor_encoder_hidden_states)
188-
)
189-
190-
return hidden_states, encoder_hidden_states
191-
192-
def multi_gate(self, hidden_states, encoder_hidden_states, factor):
193-
_, _, hidden_dim = hidden_states.shape
194-
gate_factor = factor.view(-1, hidden_dim)
195-
gate_factor_hidden_states, gate_factor_encoder_hidden_states = gate_factor.chunk(2, dim=0)
196-
gate_factor_hidden_states = gate_factor_hidden_states.unsqueeze(1)
197-
gate_factor_encoder_hidden_states = gate_factor_encoder_hidden_states.unsqueeze(1)
198-
hidden_states = gate_factor_hidden_states * hidden_states
199-
encoder_hidden_states = gate_factor_encoder_hidden_states * encoder_hidden_states
200-
201-
return hidden_states, encoder_hidden_states
202-
203221
def forward(
204222
self,
205223
hidden_states: torch.Tensor,
206224
encoder_hidden_states: torch.Tensor,
207225
temb: Optional[torch.Tensor] = None,
208226
image_rotary_emb: Optional[torch.Tensor] = None,
209227
) -> torch.Tensor:
210-
batch_size, encoder_hidden_states_len, hidden_dim = encoder_hidden_states.shape
211-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
212-
residual = hidden_states
213-
layernorm_factor = (
214-
self.adaln(temb)
215-
.view(
216-
temb.shape[0],
217-
6,
218-
2,
219-
hidden_states.shape[-1],
220-
)
221-
.permute(1, 2, 0, 3)
222-
.contiguous()
223-
)
224-
hidden_states = self.layernorm(hidden_states)
225-
hidden_states, encoder_hidden_states = self.multi_modulate(
226-
hidden_states=hidden_states[:, encoder_hidden_states_len:],
227-
encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
228-
factors=(layernorm_factor[0], layernorm_factor[1]),
229-
)
230-
hidden_states, encoder_hidden_states = self.attn1(
231-
hidden_states=hidden_states,
232-
encoder_hidden_states=encoder_hidden_states,
228+
# 1. Timestep conditioning
229+
(
230+
norm_hidden_states,
231+
gate_msa,
232+
shift_mlp,
233+
scale_mlp,
234+
gate_mlp,
235+
norm_encoder_hidden_states,
236+
c_gate_msa,
237+
c_shift_mlp,
238+
c_scale_mlp,
239+
c_gate_mlp,
240+
) = self.norm1(hidden_states, encoder_hidden_states, temb)
241+
242+
# 2. Attention
243+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
244+
hidden_states=norm_hidden_states,
245+
encoder_hidden_states=norm_encoder_hidden_states,
233246
image_rotary_emb=image_rotary_emb,
234247
)
235-
hidden_states, encoder_hidden_states = self.multi_gate(
236-
hidden_states=hidden_states,
237-
encoder_hidden_states=encoder_hidden_states,
238-
factor=layernorm_factor[2],
239-
)
240-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
241-
hidden_states += residual
242-
243-
residual = hidden_states
244-
hidden_states = self.layernorm(hidden_states)
245-
hidden_states, encoder_hidden_states = self.multi_modulate(
246-
hidden_states=hidden_states[:, encoder_hidden_states_len:],
247-
encoder_hidden_states=hidden_states[:, :encoder_hidden_states_len],
248-
factors=(layernorm_factor[3], layernorm_factor[4]),
249-
)
250-
hidden_states = self.ff(hidden_states)
251-
encoder_hidden_states = self.ff(encoder_hidden_states)
252-
hidden_states, encoder_hidden_states = self.multi_gate(
253-
hidden_states=hidden_states,
254-
encoder_hidden_states=encoder_hidden_states,
255-
factor=layernorm_factor[5],
256-
)
257-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
258-
hidden_states += residual
259-
hidden_states, encoder_hidden_states = (
260-
hidden_states[:, encoder_hidden_states_len:],
261-
hidden_states[:, :encoder_hidden_states_len],
262-
)
248+
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
249+
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
250+
251+
# 3. Feedforward
252+
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
253+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
254+
1 + c_scale_mlp.unsqueeze(1)
255+
) + c_shift_mlp.unsqueeze(1)
256+
257+
ff_output = self.ff(norm_hidden_states)
258+
ff_output_context = self.ff(norm_encoder_hidden_states)
259+
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
260+
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
261+
263262
return hidden_states, encoder_hidden_states
264263

265264

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,7 @@ class CogView4Pipeline(DiffusionPipeline):
7878

7979
_optional_components = []
8080
model_cpu_offload_seq = "text_encoder->transformer->vae"
81-
82-
_callback_tensor_inputs = [
83-
"latents",
84-
"prompt_embeds",
85-
"negative_prompt_embeds",
86-
]
81+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
8782

8883
def __init__(
8984
self,
@@ -159,9 +154,9 @@ def encode_prompt(
159154
num_images_per_prompt: int = 1,
160155
prompt_embeds: Optional[torch.Tensor] = None,
161156
negative_prompt_embeds: Optional[torch.Tensor] = None,
162-
max_sequence_length: int = 1024,
163157
device: Optional[torch.device] = None,
164158
dtype: Optional[torch.dtype] = None,
159+
max_sequence_length: int = 1024,
165160
):
166161
r"""
167162
Encodes the prompt into text encoder hidden states.
@@ -184,12 +179,12 @@ def encode_prompt(
184179
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
185180
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
186181
argument.
187-
max_sequence_length (`int`, defaults to `1024`):
188-
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
189182
device: (`torch.device`, *optional*):
190183
torch device
191184
dtype: (`torch.dtype`, *optional*):
192185
torch dtype
186+
max_sequence_length (`int`, defaults to `1024`):
187+
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
193188
"""
194189
device = device or self._execution_device
195190

@@ -200,24 +195,10 @@ def encode_prompt(
200195
batch_size = prompt_embeds.shape[0]
201196

202197
if prompt_embeds is None:
203-
prompt_embeds = self._get_glm_embeds(
204-
prompt=prompt,
205-
num_images_per_prompt=num_images_per_prompt,
206-
max_sequence_length=max_sequence_length,
207-
device=device,
208-
dtype=dtype,
209-
)
210-
211-
if do_classifier_free_guidance and negative_prompt is None:
212-
negative_prompt_embeds = self._get_glm_embeds(
213-
prompt="",
214-
num_images_per_prompt=num_images_per_prompt,
215-
max_sequence_length=max_sequence_length,
216-
device=device,
217-
dtype=dtype,
218-
)
198+
prompt_embeds = self._get_glm_embeds(prompt, num_images_per_prompt, max_sequence_length, device, dtype)
219199

220200
if do_classifier_free_guidance and negative_prompt_embeds is None:
201+
negative_prompt = negative_prompt or ""
221202
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
222203

223204
if prompt is not None and type(prompt) is not type(negative_prompt):
@@ -233,11 +214,7 @@ def encode_prompt(
233214
)
234215

235216
negative_prompt_embeds = self._get_glm_embeds(
236-
prompt=negative_prompt,
237-
num_images_per_prompt=num_images_per_prompt,
238-
max_sequence_length=max_sequence_length,
239-
device=device,
240-
dtype=dtype,
217+
negative_prompt, num_images_per_prompt, max_sequence_length, device, dtype
241218
)
242219

243220
return prompt_embeds, negative_prompt_embeds
@@ -347,7 +324,6 @@ def __call__(
347324
timesteps: Optional[List[int]] = None,
348325
guidance_scale: float = 5.0,
349326
num_images_per_prompt: int = 1,
350-
eta: float = 0.0,
351327
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
352328
latents: Optional[torch.FloatTensor] = None,
353329
prompt_embeds: Optional[torch.FloatTensor] = None,

0 commit comments

Comments
 (0)