Skip to content

Commit 4d8e290

Browse files
authored
[Bugfix] fix qwen image oom (#1168)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent 3bc3c95 commit 4d8e290

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import functools
54
from collections.abc import Iterable
5+
from functools import lru_cache
66
from math import prod
77
from typing import Any
88

@@ -249,7 +249,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
249249

250250
return vid_freqs, txt_freqs
251251

252-
@functools.cache
252+
@lru_cache(maxsize=16)
253253
def _compute_video_freqs(self, frame, height, width, idx=0):
254254
seq_lens = frame * height * width
255255
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -268,7 +268,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
268268
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
269269
return freqs.clone().contiguous()
270270

271-
@functools.cache
271+
@lru_cache(maxsize=16)
272272
def _compute_condition_freqs(self, frame, height, width):
273273
seq_lens = frame * height * width
274274
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -311,7 +311,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
311311
],
312312
dim=1,
313313
)
314-
self.rope_cache = {}
315314

316315
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
317316
self.scale_rope = scale_rope
@@ -349,14 +348,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
349348
max_vid_index = 0
350349
for idx, fhw in enumerate(video_fhw):
351350
frame, height, width = fhw
352-
rope_key = f"{idx}_{height}_{width}"
353-
354-
if not torch.compiler.is_compiling():
355-
if rope_key not in self.rope_cache:
356-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
357-
video_freq = self.rope_cache[rope_key]
358-
else:
359-
video_freq = self._compute_video_freqs(frame, height, width, idx)
351+
video_freq = self._compute_video_freqs(frame, height, width, idx)
360352
video_freq = video_freq.to(device)
361353
vid_freqs.append(video_freq)
362354

@@ -371,7 +363,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
371363

372364
return vid_freqs, txt_freqs
373365

374-
@functools.cache
366+
@lru_cache(maxsize=16)
375367
def _compute_video_freqs(self, frame, height, width, idx=0):
376368
seq_lens = frame * height * width
377369
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)

0 commit comments

Comments
 (0)