11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- import functools
54from collections .abc import Iterable
5+ from functools import lru_cache
66from math import prod
77from typing import Any
88
@@ -249,7 +249,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
249249
250250 return vid_freqs , txt_freqs
251251
252- @functools . cache
252+ @lru_cache ( maxsize = 16 )
253253 def _compute_video_freqs (self , frame , height , width , idx = 0 ):
254254 seq_lens = frame * height * width
255255 freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -268,7 +268,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
268268 freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
269269 return freqs .clone ().contiguous ()
270270
271- @functools . cache
271+ @lru_cache ( maxsize = 16 )
272272 def _compute_condition_freqs (self , frame , height , width ):
273273 seq_lens = frame * height * width
274274 freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -311,7 +311,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
311311 ],
312312 dim = 1 ,
313313 )
314- self .rope_cache = {}
315314
316315 # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
317316 self .scale_rope = scale_rope
@@ -349,14 +348,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
349348 max_vid_index = 0
350349 for idx , fhw in enumerate (video_fhw ):
351350 frame , height , width = fhw
352- rope_key = f"{ idx } _{ height } _{ width } "
353-
354- if not torch .compiler .is_compiling ():
355- if rope_key not in self .rope_cache :
356- self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width , idx )
357- video_freq = self .rope_cache [rope_key ]
358- else :
359- video_freq = self ._compute_video_freqs (frame , height , width , idx )
351+ video_freq = self ._compute_video_freqs (frame , height , width , idx )
360352 video_freq = video_freq .to (device )
361353 vid_freqs .append (video_freq )
362354
@@ -371,7 +363,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
371363
372364 return vid_freqs , txt_freqs
373365
374- @functools . cache
366+ @lru_cache ( maxsize = 16 )
375367 def _compute_video_freqs (self , frame , height , width , idx = 0 ):
376368 seq_lens = frame * height * width
377369 freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
0 commit comments