Skip to content

Commit 9343f35

Browse files
authored
Port #301 and #313 from extension (#55)
HabanaAI/vllm-hpu-extension#301 and HabanaAI/vllm-hpu-extension#313 Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 079e659 commit 9343f35

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

vllm_gaudi/extension/bucketing/common.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,29 @@
1010
from vllm_gaudi.extension.runtime import get_config
1111

1212

13+
def calc_fallback_value(n: int, base_step: int):
14+
""" Calculate next bucket for yet unbucketized value"""
15+
if n <= 1:
16+
return n
17+
power = 1/3
18+
# The basic idea is that we first estimate bucket size based
19+
# on exponent of the number, so higher numbers will generate
20+
# bigger gaps between individual buckets, but it's not as steep
21+
# as exponential bucketing. Additionally this has a nice
22+
# property that generated values are guaranteed to be divisible
23+
# by base_step
24+
#
25+
# examples:
26+
# n=31, base_step=32
27+
# => bucket_size = ceil(31^1/3) * 32 = 4 * 32 = 128
28+
# => next_value = round_up(31, 128) = 128
29+
# n=4001, base_step=32
30+
# => bucket_size = ceil(4001^1/3) * 32 = 16 * 32 = 512
31+
# => next_value = round_up(4001, 512) = 4096
32+
bucket_size = math.ceil(math.pow(n, power)) * base_step
33+
return math.ceil(n / bucket_size) * bucket_size
34+
35+
1336
class HPUBucketingManager():
1437
_instance = None
1538
prompt_buckets: List[Tuple[int, int, int]] = []
@@ -31,6 +54,10 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size,
3154
self.max_model_len = max_model_len
3255
self.initialized = True
3356

57+
self.fallback_bs_base_step = 2
58+
self.fallback_seq_base_step = 32
59+
self.fallback_blocks_base_step = 32
60+
3461
def get_bucketing_strategy(self):
3562
strategy = None
3663
# TODO - we can use different strategies for decode and prompt
@@ -86,14 +113,23 @@ def log_generate_info(self, is_prompt):
86113
f"{list(buckets)}")
87114
logger().info(msg)
88115

116+
def generate_fallback_bucket(self, batch_size, seq_len, ctx):
117+
assert self.max_num_batched_tokens is not None
118+
new_batch_size = calc_fallback_value(batch_size, self.fallback_bs_base_step)
119+
new_seq_len = min(calc_fallback_value(seq_len, self.fallback_seq_base_step),
120+
self.max_num_batched_tokens)
121+
if self.num_hpu_blocks is None:
122+
new_ctx = 0
123+
else:
124+
new_ctx = min(calc_fallback_value(ctx, self.fallback_blocks_base_step),
125+
self.num_hpu_blocks)
126+
return (new_batch_size, new_seq_len, new_ctx)
127+
89128
def find_prompt_bucket(self, batch_size, seq_len, ctx=0):
90129
if self.initialized:
91130
found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, (batch_size, seq_len, ctx))
92131
if found_bucket is None:
93-
new_batch_size = 2 ** math.ceil(math.log2(batch_size))
94-
new_seq_len = math.ceil(seq_len / self.block_size) * self.block_size
95-
new_ctx = math.ceil(ctx / 2) * 2
96-
new_bucket = (new_batch_size, new_seq_len, new_ctx)
132+
new_bucket = self.generate_fallback_bucket(batch_size, seq_len, ctx)
97133
logger().warning(f"Prompt bucket for {batch_size, seq_len, ctx}"
98134
f" was not prepared. Adding new bucket: {new_bucket}")
99135
self.prompt_buckets.append(new_bucket)
@@ -106,9 +142,7 @@ def find_decode_bucket(self, batch_size, num_blocks):
106142
if self.initialized:
107143
found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks))
108144
if found_bucket is None:
109-
new_batch_size = 2 ** math.ceil(math.log2(batch_size))
110-
new_num_blocks = math.ceil(num_blocks / 2) * 2
111-
new_bucket = (new_batch_size, 1, new_num_blocks)
145+
new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks)
112146
logger().warning(f"Decode bucket for {batch_size, 1, num_blocks}"
113147
f" was not prepared. Adding new bucket: {new_bucket}")
114148
self.decode_buckets.append(new_bucket)

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def generate_prompt_buckets(bs_bucket_config,
139139
for bs in batch_size_buckets:
140140
for b in seq_bucket_config:
141141
max_blocks_range = (bmax - b) // block_size
142-
for i in range(0, max_blocks_range + 1):
142+
for i in range(0, max_blocks_range + 2):
143143
buckets_3d.append((bs, b, i))
144144
buckets = buckets_3d
145145
else:

0 commit comments

Comments
 (0)