@@ -15,14 +15,12 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
15
15
use_merged_prefill = get_config ().merged_prefill
16
16
prefix_caching = get_config ().prefix_caching
17
17
18
- max_prompt_seq = max_model_len
19
-
20
18
prompt_bs_bucket_cfg = read_bucket_settings (
21
19
'prompt' , 'bs' , min = 1 , step = 32 ,
22
20
max = max_num_prefill_seqs )
23
21
prompt_seq_bucket_cfg = read_bucket_settings (
24
22
'prompt' , 'seq' , min = block_size ,
25
- step = block_size , max = max_prompt_seq )
23
+ step = block_size , max = max_model_len )
26
24
27
25
if use_merged_prefill :
28
26
prev_prompt_bs_bucket_cfg = tuple (prompt_bs_bucket_cfg )
@@ -56,10 +54,8 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
56
54
57
55
def get_decode_buckets (self , max_num_seqs , block_size ,
58
56
max_num_batched_tokens , max_model_len ,
59
- num_max_blocks ):
57
+ max_blocks ):
60
58
prefix_caching = get_config ().prefix_caching
61
-
62
- max_blocks = num_max_blocks
63
59
64
60
decode_bs_bucket_cfg = read_bucket_settings (
65
61
'decode' , 'bs' , min = 1 , step = 32 ,
@@ -75,7 +71,7 @@ def get_decode_buckets(self, max_num_seqs, block_size,
75
71
76
72
decode_buckets = generate_decode_buckets (
77
73
decode_bs_bucket_cfg ,
78
- decode_block_bucket_cfg , num_max_blocks )
74
+ decode_block_bucket_cfg , max_blocks , max_model_len , block_size )
79
75
80
76
return sorted (decode_buckets )
81
77
@@ -190,23 +186,24 @@ def generate_prompt_buckets(bs_bucket_config,
190
186
191
187
192
188
def generate_decode_buckets (bs_bucket_config , blocks_bucket_config ,
193
- max_blocks ):
189
+ max_blocks , max_model_len , block_size ):
194
190
buckets = []
195
191
bs_buckets = warmup_range (bs_bucket_config )
196
192
use_contiguous_pa = get_config ().use_contiguous_pa
197
- if os .environ .get ('VLLM_DECODE_BLOCK_BUCKET_MAX' ) is None \
198
- and use_contiguous_pa :
199
- blocks_bucket_config [2 ] = max_blocks
200
193
block_buckets = warmup_range (blocks_bucket_config )
201
- if os .environ .get ('VLLM_DECODE_BLOCK_BUCKET_MAX' ) is None \
202
- and max_blocks not in block_buckets and use_contiguous_pa :
194
+ if max_blocks not in block_buckets and use_contiguous_pa :
203
195
block_buckets .append (max_blocks )
204
196
last_bucket = max_blocks
205
197
for bs in bs_buckets :
198
+ max_blocks_including_max_model_len = bs * math .ceil (max_model_len / block_size )
206
199
for blocks in block_buckets :
207
200
if bs > blocks :
208
201
# Skip a dummy case when bs > blocks, which cannot occur in real execution
209
202
continue
203
+ if not use_contiguous_pa and blocks > max_blocks_including_max_model_len :
204
+ # Skip case when user wants to have bigger blocks than max model len
205
+ # case cn only occur with contiguous PA
206
+ continue
210
207
if blocks >= last_bucket :
211
208
buckets .append ((bs , 1 , last_bucket ))
212
209
break
0 commit comments