Skip to content

Commit 545642f

Browse files
authored
Enable Runtime Selection of Attention Functions (Comfy-Org#9639)
* Looking into a @wrap_attn decorator to look for 'optimized_attention_override' entry in transformer_options * Created logging code for this branch so that it can be used to track down all the code paths where transformer_options would need to be added * Fix memory usage issue with inspect * Made WAN attention receive transformer_options, test node added to wan to test out attention override later * Added **kwargs to all attention functions so transformer_options could potentially be passed through * Make sure wrap_attn doesn't make itself recurse infinitely, attempt to load SageAttention and FlashAttention if not enabled so that they can be marked as available or not, create registry for available attention * Turn off attention logging for now, make AttentionOverrideTestNode have a dropdown with available attention (this is a test node only) * Make flux work with optimized_attention_override * Add logs to verify optimized_attention_override is passed all the way into attention function * Make Qwen work with optimized_attention_override * Made hidream work with optimized_attention_override * Made wan patches_replace work with optimized_attention_override * Made SD3 work with optimized_attention_override * Made HunyuanVideo work with optimized_attention_override * Made Mochi work with optimized_attention_override * Made LTX work with optimized_attention_override * Made StableAudio work with optimized_attention_override * Made optimized_attention_override work with ACE Step * Made Hunyuan3D work with optimized_attention_override * Make CosmosPredict2 work with optimized_attention_override * Made CosmosVideo work with optimized_attention_override * Made Omnigen 2 work with optimized_attention_override * Made StableCascade work with optimized_attention_override * Made AuraFlow work with optimized_attention_override * Made Lumina work with optimized_attention_override * Made Chroma work with optimized_attention_override * Made SVD work with optimized_attention_override * Fix WanI2VCrossAttention so that it expects to receive transformer_options * Fixed Wan2.1 Fun Camera transformer_options passthrough * Fixed WAN 2.1 VACE transformer_options passthrough * Add optimized to get_attention_function * Disable attention logs for now * Remove attention logging code * Remove _register_core_attention_functions, as we wouldn't want someone to call that, just in case * Satisfy ruff * Remove AttentionOverrideTest node, that's something to cook up for later
1 parent 81a2585 commit 545642f

File tree

26 files changed

+316
-179
lines changed

26 files changed

+316
-179
lines changed

comfy/ldm/ace/attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,15 @@ def forward(
133133
hidden_states: torch.Tensor,
134134
encoder_hidden_states: Optional[torch.Tensor] = None,
135135
attention_mask: Optional[torch.Tensor] = None,
136+
transformer_options={},
136137
**cross_attention_kwargs,
137138
) -> torch.Tensor:
138139
return self.processor(
139140
self,
140141
hidden_states,
141142
encoder_hidden_states=encoder_hidden_states,
142143
attention_mask=attention_mask,
144+
transformer_options=transformer_options,
143145
**cross_attention_kwargs,
144146
)
145147

@@ -366,6 +368,7 @@ def __call__(
366368
encoder_attention_mask: Optional[torch.FloatTensor] = None,
367369
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
368370
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
371+
transformer_options={},
369372
*args,
370373
**kwargs,
371374
) -> torch.Tensor:
@@ -433,7 +436,7 @@ def __call__(
433436

434437
# the output of sdp = (batch, num_heads, seq_len, head_dim)
435438
hidden_states = optimized_attention(
436-
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
439+
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
437440
).to(query.dtype)
438441

439442
# linear proj
@@ -697,6 +700,7 @@ def forward(
697700
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
698701
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
699702
temb: torch.FloatTensor = None,
703+
transformer_options={},
700704
):
701705

702706
N = hidden_states.shape[0]
@@ -720,6 +724,7 @@ def forward(
720724
encoder_attention_mask=encoder_attention_mask,
721725
rotary_freqs_cis=rotary_freqs_cis,
722726
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
727+
transformer_options=transformer_options,
723728
)
724729
else:
725730
attn_output, _ = self.attn(
@@ -729,6 +734,7 @@ def forward(
729734
encoder_attention_mask=None,
730735
rotary_freqs_cis=rotary_freqs_cis,
731736
rotary_freqs_cis_cross=None,
737+
transformer_options=transformer_options,
732738
)
733739

734740
if self.use_adaln_single:
@@ -743,6 +749,7 @@ def forward(
743749
encoder_attention_mask=encoder_attention_mask,
744750
rotary_freqs_cis=rotary_freqs_cis,
745751
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
752+
transformer_options=transformer_options,
746753
)
747754
hidden_states = attn_output + hidden_states
748755

comfy/ldm/ace/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def decode(
314314
output_length: int = 0,
315315
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
316316
controlnet_scale: Union[float, torch.Tensor] = 1.0,
317+
transformer_options={},
317318
):
318319
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
319320
temb = self.t_block(embedded_timestep)
@@ -339,6 +340,7 @@ def decode(
339340
rotary_freqs_cis=rotary_freqs_cis,
340341
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
341342
temb=temb,
343+
transformer_options=transformer_options,
342344
)
343345

344346
output = self.final_layer(hidden_states, embedded_timestep, output_length)
@@ -393,6 +395,7 @@ def _forward(
393395

394396
output_length = hidden_states.shape[-1]
395397

398+
transformer_options = kwargs.get("transformer_options", {})
396399
output = self.decode(
397400
hidden_states=hidden_states,
398401
attention_mask=attention_mask,
@@ -402,6 +405,7 @@ def _forward(
402405
output_length=output_length,
403406
block_controlnet_hidden_states=block_controlnet_hidden_states,
404407
controlnet_scale=controlnet_scale,
408+
transformer_options=transformer_options,
405409
)
406410

407411
return output

comfy/ldm/audio/dit.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def forward(
298298
mask = None,
299299
context_mask = None,
300300
rotary_pos_emb = None,
301-
causal = None
301+
causal = None,
302+
transformer_options={},
302303
):
303304
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
304305

@@ -363,7 +364,7 @@ def forward(
363364
heads_per_kv_head = h // kv_h
364365
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
365366

366-
out = optimized_attention(q, k, v, h, skip_reshape=True)
367+
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
367368
out = self.to_out(out)
368369

369370
if mask is not None:
@@ -488,7 +489,8 @@ def forward(
488489
global_cond=None,
489490
mask = None,
490491
context_mask = None,
491-
rotary_pos_emb = None
492+
rotary_pos_emb = None,
493+
transformer_options={}
492494
):
493495
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
494496

@@ -498,12 +500,12 @@ def forward(
498500
residual = x
499501
x = self.pre_norm(x)
500502
x = x * (1 + scale_self) + shift_self
501-
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
503+
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
502504
x = x * torch.sigmoid(1 - gate_self)
503505
x = x + residual
504506

505507
if context is not None:
506-
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
508+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
507509

508510
if self.conformer is not None:
509511
x = x + self.conformer(x)
@@ -517,10 +519,10 @@ def forward(
517519
x = x + residual
518520

519521
else:
520-
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
522+
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
521523

522524
if context is not None:
523-
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
525+
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
524526

525527
if self.conformer is not None:
526528
x = x + self.conformer(x)
@@ -606,7 +608,8 @@ def forward(
606608
return_info = False,
607609
**kwargs
608610
):
609-
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
611+
transformer_options = kwargs.get("transformer_options", {})
612+
patches_replace = transformer_options.get("patches_replace", {})
610613
batch, seq, device = *x.shape[:2], x.device
611614
context = kwargs["context"]
612615

@@ -645,13 +648,13 @@ def forward(
645648
if ("double_block", i) in blocks_replace:
646649
def block_wrap(args):
647650
out = {}
648-
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
651+
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
649652
return out
650653

651-
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
654+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
652655
x = out["img"]
653656
else:
654-
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
657+
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
655658
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
656659

657660
if return_info:

comfy/ldm/aura/mmdit.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera
8585
)
8686

8787
#@torch.compile()
88-
def forward(self, c):
88+
def forward(self, c, transformer_options={}):
8989

9090
bsz, seqlen1, _ = c.shape
9191

@@ -95,7 +95,7 @@ def forward(self, c):
9595
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
9696
q, k = self.q_norm1(q), self.k_norm1(k)
9797

98-
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
98+
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
9999
c = self.w1o(output)
100100
return c
101101

@@ -144,7 +144,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera
144144

145145

146146
#@torch.compile()
147-
def forward(self, c, x):
147+
def forward(self, c, x, transformer_options={}):
148148

149149
bsz, seqlen1, _ = c.shape
150150
bsz, seqlen2, _ = x.shape
@@ -168,7 +168,7 @@ def forward(self, c, x):
168168
torch.cat([cv, xv], dim=1),
169169
)
170170

171-
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
171+
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
172172

173173
c, x = output.split([seqlen1, seqlen2], dim=1)
174174
c = self.w1o(c)
@@ -207,7 +207,7 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None,
207207
self.is_last = is_last
208208

209209
#@torch.compile()
210-
def forward(self, c, x, global_cond, **kwargs):
210+
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
211211

212212
cres, xres = c, x
213213

@@ -225,7 +225,7 @@ def forward(self, c, x, global_cond, **kwargs):
225225
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
226226

227227
# attention
228-
c, x = self.attn(c, x)
228+
c, x = self.attn(c, x, transformer_options=transformer_options)
229229

230230

231231
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@@ -255,13 +255,13 @@ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, o
255255
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
256256

257257
#@torch.compile()
258-
def forward(self, cx, global_cond, **kwargs):
258+
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
259259
cxres = cx
260260
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
261261
global_cond
262262
).chunk(6, dim=1)
263263
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
264-
cx = self.attn(cx)
264+
cx = self.attn(cx, transformer_options=transformer_options)
265265
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
266266
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
267267
cx = gate_mlp.unsqueeze(1) * mlpout
@@ -473,13 +473,14 @@ def block_wrap(args):
473473
out = {}
474474
out["txt"], out["img"] = layer(args["txt"],
475475
args["img"],
476-
args["vec"])
476+
args["vec"],
477+
transformer_options=args["transformer_options"])
477478
return out
478-
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
479+
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
479480
c = out["txt"]
480481
x = out["img"]
481482
else:
482-
c, x = layer(c, x, global_cond, **kwargs)
483+
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
483484

484485
if len(self.single_layers) > 0:
485486
c_len = c.size(1)
@@ -488,13 +489,13 @@ def block_wrap(args):
488489
if ("single_block", i) in blocks_replace:
489490
def block_wrap(args):
490491
out = {}
491-
out["img"] = layer(args["img"], args["vec"])
492+
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
492493
return out
493494

494-
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
495+
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
495496
cx = out["img"]
496497
else:
497-
cx = layer(cx, global_cond, **kwargs)
498+
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
498499

499500
x = cx[:, c_len:]
500501

comfy/ldm/cascade/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No
3232

3333
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
3434

35-
def forward(self, q, k, v):
35+
def forward(self, q, k, v, transformer_options={}):
3636
q = self.to_q(q)
3737
k = self.to_k(k)
3838
v = self.to_v(v)
3939

40-
out = optimized_attention(q, k, v, self.heads)
40+
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
4141

4242
return self.out_proj(out)
4343

@@ -47,13 +47,13 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No
4747
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
4848
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
4949

50-
def forward(self, x, kv, self_attn=False):
50+
def forward(self, x, kv, self_attn=False, transformer_options={}):
5151
orig_shape = x.shape
5252
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
5353
if self_attn:
5454
kv = torch.cat([x, kv], dim=1)
5555
# x = self.attn(x, kv, kv, need_weights=False)[0]
56-
x = self.attn(x, kv, kv)
56+
x = self.attn(x, kv, kv, transformer_options=transformer_options)
5757
x = x.permute(0, 2, 1).view(*orig_shape)
5858
return x
5959

@@ -114,9 +114,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, de
114114
operations.Linear(c_cond, c, dtype=dtype, device=device)
115115
)
116116

117-
def forward(self, x, kv):
117+
def forward(self, x, kv, transformer_options={}):
118118
kv = self.kv_mapper(kv)
119-
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
119+
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
120120
return x
121121

122122

comfy/ldm/cascade/stage_b.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def gen_c_embeddings(self, clip):
173173
clip = self.clip_norm(clip)
174174
return clip
175175

176-
def _down_encode(self, x, r_embed, clip):
176+
def _down_encode(self, x, r_embed, clip, transformer_options={}):
177177
level_outputs = []
178178
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
179179
for down_block, downscaler, repmap in block_group:
@@ -187,7 +187,7 @@ def _down_encode(self, x, r_embed, clip):
187187
elif isinstance(block, AttnBlock) or (
188188
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
189189
AttnBlock)):
190-
x = block(x, clip)
190+
x = block(x, clip, transformer_options=transformer_options)
191191
elif isinstance(block, TimestepBlock) or (
192192
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
193193
TimestepBlock)):
@@ -199,7 +199,7 @@ def _down_encode(self, x, r_embed, clip):
199199
level_outputs.insert(0, x)
200200
return level_outputs
201201

202-
def _up_decode(self, level_outputs, r_embed, clip):
202+
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
203203
x = level_outputs[0]
204204
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
205205
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@@ -216,7 +216,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
216216
elif isinstance(block, AttnBlock) or (
217217
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
218218
AttnBlock)):
219-
x = block(x, clip)
219+
x = block(x, clip, transformer_options=transformer_options)
220220
elif isinstance(block, TimestepBlock) or (
221221
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
222222
TimestepBlock)):
@@ -228,7 +228,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
228228
x = upscaler(x)
229229
return x
230230

231-
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
231+
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
232232
if pixels is None:
233233
pixels = x.new_zeros(x.size(0), 3, 8, 8)
234234

@@ -245,8 +245,8 @@ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
245245
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
246246
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
247247
align_corners=True)
248-
level_outputs = self._down_encode(x, r_embed, clip)
249-
x = self._up_decode(level_outputs, r_embed, clip)
248+
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
249+
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
250250
return self.clf(x)
251251

252252
def update_weights_ema(self, src_model, beta=0.999):

0 commit comments

Comments
 (0)