Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm_gaudi.extension.runtime import get_config


def calc_fallback_value(n: int, base_step: int):
def calc_fallback_value(n: int, base_step: int, warmup_max: int = None) -> int:
""" Calculate next bucket for yet unbucketized value"""
if n <= 1:
return n
Expand All @@ -30,6 +30,9 @@ def calc_fallback_value(n: int, base_step: int):
# => bucket_size = ceil(4001^1/3) * 32 = 16 * 32 = 512
# => next_value = round_up(4001, 512) = 4096
bucket_size = math.ceil(math.pow(n, power)) * base_step
num_blocks = math.ceil(n / bucket_size) * bucket_size
if warmup_max is not None and num_blocks > warmup_max and warmup_max >= n:
bucket_size = warmup_max
return math.ceil(n / bucket_size) * bucket_size


Expand Down Expand Up @@ -121,7 +124,8 @@ def generate_fallback_bucket(self, batch_size, seq_len, ctx):
if self.num_hpu_blocks is None:
new_ctx = 0
else:
new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step),
decode_block_max = self.decode_buckets[-1][2] if len(self.decode_buckets) > 0 else self.decode_block_max
new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step, decode_block_max),
self.num_hpu_blocks)
return (new_batch_size, new_seq_len, new_ctx)

Expand Down
2 changes: 2 additions & 0 deletions vllm_gaudi/extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,15 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
last_bucket = max_blocks
for bs in bs_buckets:
max_blocks_including_max_model_len = bs * math.ceil(max_model_len / block_size)
print(f"{bs=} {max_model_len=} {block_size=} {max_blocks_including_max_model_len=}, {block_buckets=}")
for blocks in block_buckets:
if bs > blocks:
# Skip a dummy case when bs > blocks, which cannot occur in real execution
continue
if not use_contiguous_pa and blocks > max_blocks_including_max_model_len:
# Skip case when user wants to have bigger blocks than max model len
# case cn only occur with contiguous PA
buckets.append((bs, 1, max_blocks_including_max_model_len))
continue
if blocks >= last_bucket:
buckets.append((bs, 1, last_bucket))
Expand Down
Loading