Skip to content

Commit c204d32

Browse files
authored
Warmup fix - for non contiguous PA runs, don't take more context blocks than possible (#97)
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent fe7cd43 commit c204d32

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

vllm_gaudi/extension/bucketing/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def generate_decode_buckets(self):
9898
block_size = self.block_size,
9999
max_num_batched_tokens = self.max_num_batched_tokens,
100100
max_model_len = self.max_model_len,
101-
num_max_blocks = self.num_hpu_blocks)
101+
max_blocks = self.num_hpu_blocks)
102102
self.log_generate_info(False)
103103
else:
104104
logger().info("Bucketing is off - skipping decode buckets generation")

vllm_gaudi/extension/bucketing/exponential.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
6262

6363
def get_decode_buckets(self, max_num_seqs, block_size,
6464
max_num_batched_tokens, max_model_len,
65-
num_max_blocks):
65+
max_blocks):
6666
self.check_for_user_flags('decode')
6767
prefix_caching = get_config().prefix_caching
68+
use_contiguous_pa = get_config().use_contiguous_pa
6869

6970
# cfgs shape: [min, step, max, limit]
7071
decode_bs_limit = math.ceil(math.log2(max_num_seqs)) + 1
7172
decode_bs_bucket_cfg = [1, 2, max_num_seqs, decode_bs_limit]
72-
max_decode_block_limit = math.ceil(math.log2(num_max_blocks)) + 1
73-
max_decode_blocks = min((max_model_len // block_size * max_num_seqs), num_max_blocks)
73+
max_decode_block_limit = math.ceil(math.log2(max_blocks)) + 1
74+
max_decode_blocks = max_blocks if use_contiguous_pa else \
75+
min((max_model_len // block_size * max_num_seqs), max_blocks)
7476
decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit]
7577

7678
msg = ("Decode bucket config (min, step, max_warmup, limit) "
@@ -80,7 +82,7 @@ def get_decode_buckets(self, max_num_seqs, block_size,
8082

8183
decode_buckets = generate_decode_buckets(
8284
decode_bs_bucket_cfg, decode_block_bucket_cfg,
83-
num_max_blocks, max_model_len, block_size)
85+
max_blocks, max_model_len, block_size)
8486

8587
return sorted(decode_buckets)
8688

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
1515
use_merged_prefill = get_config().merged_prefill
1616
prefix_caching = get_config().prefix_caching
1717

18-
max_prompt_seq = max_model_len
19-
2018
prompt_bs_bucket_cfg = read_bucket_settings(
2119
'prompt', 'bs', min=1, step=32,
2220
max=max_num_prefill_seqs)
2321
prompt_seq_bucket_cfg = read_bucket_settings(
2422
'prompt', 'seq', min=block_size,
25-
step=block_size, max=max_prompt_seq)
23+
step=block_size, max=max_model_len)
2624

2725
if use_merged_prefill:
2826
prev_prompt_bs_bucket_cfg = tuple(prompt_bs_bucket_cfg)
@@ -56,10 +54,8 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
5654

5755
def get_decode_buckets(self, max_num_seqs, block_size,
5856
max_num_batched_tokens, max_model_len,
59-
num_max_blocks):
57+
max_blocks):
6058
prefix_caching = get_config().prefix_caching
61-
62-
max_blocks = num_max_blocks
6359

6460
decode_bs_bucket_cfg = read_bucket_settings(
6561
'decode', 'bs', min=1, step=32,
@@ -75,7 +71,7 @@ def get_decode_buckets(self, max_num_seqs, block_size,
7571

7672
decode_buckets = generate_decode_buckets(
7773
decode_bs_bucket_cfg,
78-
decode_block_bucket_cfg, num_max_blocks)
74+
decode_block_bucket_cfg, max_blocks, max_model_len, block_size)
7975

8076
return sorted(decode_buckets)
8177

@@ -190,23 +186,24 @@ def generate_prompt_buckets(bs_bucket_config,
190186

191187

192188
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
193-
max_blocks):
189+
max_blocks, max_model_len, block_size):
194190
buckets = []
195191
bs_buckets = warmup_range(bs_bucket_config)
196192
use_contiguous_pa = get_config().use_contiguous_pa
197-
if os.environ.get('VLLM_DECODE_BLOCK_BUCKET_MAX') is None\
198-
and use_contiguous_pa:
199-
blocks_bucket_config[2] = max_blocks
200193
block_buckets = warmup_range(blocks_bucket_config)
201-
if os.environ.get('VLLM_DECODE_BLOCK_BUCKET_MAX') is None\
202-
and max_blocks not in block_buckets and use_contiguous_pa:
194+
if max_blocks not in block_buckets and use_contiguous_pa:
203195
block_buckets.append(max_blocks)
204196
last_bucket = max_blocks
205197
for bs in bs_buckets:
198+
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
206199
for blocks in block_buckets:
207200
if bs > blocks:
208201
# Skip a dummy case when bs > blocks, which cannot occur in real execution
209202
continue
203+
if not use_contiguous_pa and blocks > max_blocks_including_max_model_len:
204+
# Skip case when user wants to have bigger blocks than max model len
205+
# case cn only occur with contiguous PA
206+
continue
210207
if blocks >= last_bucket:
211208
buckets.append((bs, 1, last_bucket))
212209
break

0 commit comments

Comments
 (0)