Skip to content

Commit ca0dcf8

Browse files
authored
Reduce Peak WAN inference VRAM usage - part II (Comfy-Org#10062)
* flux: math: Use _addcmul to avoid expensive VRAM intermediate The rope process can be the VRAM peak and this intermediate for the addition result before releasing the original can OOM. addcmul_ it. * wan: Delete the self attention before cross attention This saves VRAM when the cross attention and FFN are in play as the VRAM peak.
1 parent d4176c4 commit ca0dcf8

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

comfy/ldm/flux/math.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
3737

3838
def apply_rope1(x: Tensor, freqs_cis: Tensor):
3939
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
40-
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
40+
41+
x_out = freqs_cis[..., 0] * x_[..., 0]
42+
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
43+
4144
return x_out.reshape(*x.shape).type_as(x)
4245

4346
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):

comfy/ldm/wan/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def forward(
237237
freqs, transformer_options=transformer_options)
238238

239239
x = torch.addcmul(x, y, repeat_e(e[2], x))
240+
del y
240241

241242
# cross-attention & ffn
242243
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)

0 commit comments

Comments
 (0)