@@ -42,6 +42,7 @@ def __init__(
42
42
# self.init_cache_engine().
43
43
self .cache_config = None
44
44
self .block_size = None
45
+ self .sliding_window = None
45
46
self .cache_engine = None
46
47
self .cache_events = None
47
48
self .gpu_cache = None
@@ -136,10 +137,13 @@ def profile_num_available_blocks(
136
137
def init_cache_engine (self , cache_config : CacheConfig ) -> None :
137
138
self .cache_config = cache_config
138
139
self .block_size = cache_config .block_size
140
+ self .sliding_window = cache_config .sliding_window
139
141
140
- max_seq_len = min (self .scheduler_config .max_model_len ,
141
- cache_config .sliding_window or float ("inf" ))
142
-
142
+ if self .sliding_window is None :
143
+ max_seq_len = self .scheduler_config .max_model_len
144
+ else :
145
+ max_seq_len = min (self .scheduler_config .max_model_len ,
146
+ self .sliding_window )
143
147
_check_if_can_support_max_seq_len (max_seq_len , self .block_size )
144
148
145
149
self .cache_engine = CacheEngine (self .cache_config , self .model_config ,
@@ -151,6 +155,8 @@ def _prepare_inputs(
151
155
self ,
152
156
seq_group_metadata_list : List [SequenceGroupMetadata ],
153
157
) -> Tuple [torch .Tensor , torch .Tensor , InputMetadata ]:
158
+ assert self .block_size is not None
159
+
154
160
seq_groups : List [Tuple [List [int ], SamplingParams ]] = []
155
161
input_tokens : List [int ] = []
156
162
input_positions : List [int ] = []
@@ -193,9 +199,6 @@ def _prepare_inputs(
193
199
slot = block_number * self .block_size + block_offset
194
200
slot_mapping .append (slot )
195
201
196
- sliding_window = getattr (self .model_config .hf_config , "sliding_window" ,
197
- float ("inf" ))
198
-
199
202
# Add generation tokens.
200
203
max_context_len = 0
201
204
max_num_blocks_per_seq = 0
@@ -216,8 +219,8 @@ def _prepare_inputs(
216
219
217
220
context_len = seq_data .get_len ()
218
221
position = context_len - 1
219
- if sliding_window :
220
- context_len = min (context_len , sliding_window )
222
+ if self . sliding_window is not None :
223
+ context_len = min (context_len , self . sliding_window )
221
224
input_positions .append (position )
222
225
223
226
block_table = seq_group_metadata .block_tables [seq_id ]
@@ -232,10 +235,9 @@ def _prepare_inputs(
232
235
slot = block_number * self .block_size + block_offset
233
236
slot_mapping .append (slot )
234
237
235
- if sliding_window :
236
- assert self .cache_config is not None
237
- sliding_window_blocks = (sliding_window //
238
- self .cache_config .block_size )
238
+ if self .sliding_window is not None :
239
+ sliding_window_blocks = (self .sliding_window //
240
+ self .block_size )
239
241
block_table = block_table [- sliding_window_blocks :]
240
242
generation_block_tables .append (block_table )
241
243
@@ -277,7 +279,7 @@ def _prepare_inputs(
277
279
context_lens = context_lens_tensor ,
278
280
max_context_len = max_context_len ,
279
281
block_tables = block_tables_tensor ,
280
- sliding_window = sliding_window ,
282
+ sliding_window = self . sliding_window ,
281
283
)
282
284
return tokens_tensor , positions_tensor , input_metadata
283
285
0 commit comments