10
10
from vllm_gaudi .extension .runtime import get_config
11
11
12
12
13
+ def calc_fallback_value (n : int , base_step : int ):
14
+ """ Calculate next bucket for yet unbucketized value"""
15
+ if n <= 1 :
16
+ return n
17
+ power = 1 / 3
18
+ # The basic idea is that we first estimate bucket size based
19
+ # on exponent of the number, so higher numbers will generate
20
+ # bigger gaps between individual buckets, but it's not as steep
21
+ # as exponential bucketing. Additionally this has a nice
22
+ # property that generated values are guaranteed to be divisible
23
+ # by base_step
24
+ #
25
+ # examples:
26
+ # n=31, base_step=32
27
+ # => bucket_size = ceil(31^1/3) * 32 = 4 * 32 = 128
28
+ # => next_value = round_up(31, 128) = 128
29
+ # n=4001, base_step=32
30
+ # => bucket_size = ceil(4001^1/3) * 32 = 16 * 32 = 512
31
+ # => next_value = round_up(4001, 512) = 4096
32
+ bucket_size = math .ceil (math .pow (n , power )) * base_step
33
+ return math .ceil (n / bucket_size ) * bucket_size
34
+
35
+
13
36
class HPUBucketingManager ():
14
37
_instance = None
15
38
prompt_buckets : List [Tuple [int , int , int ]] = []
@@ -31,6 +54,10 @@ def initialize(self, max_num_seqs, max_num_prefill_seqs, block_size,
31
54
self .max_model_len = max_model_len
32
55
self .initialized = True
33
56
57
+ self .fallback_bs_base_step = 2
58
+ self .fallback_seq_base_step = 32
59
+ self .fallback_blocks_base_step = 32
60
+
34
61
def get_bucketing_strategy (self ):
35
62
strategy = None
36
63
# TODO - we can use different strategies for decode and prompt
@@ -86,14 +113,23 @@ def log_generate_info(self, is_prompt):
86
113
f"{ list (buckets )} " )
87
114
logger ().info (msg )
88
115
116
+ def generate_fallback_bucket (self , batch_size , seq_len , ctx ):
117
+ assert self .max_num_batched_tokens is not None
118
+ new_batch_size = calc_fallback_value (batch_size , self .fallback_bs_base_step )
119
+ new_seq_len = min (calc_fallback_value (seq_len , self .fallback_seq_base_step ),
120
+ self .max_num_batched_tokens )
121
+ if self .num_hpu_blocks is None :
122
+ new_ctx = 0
123
+ else :
124
+ new_ctx = min (calc_fallback_value (ctx , self .fallback_blocks_base_step ),
125
+ self .num_hpu_blocks )
126
+ return (new_batch_size , new_seq_len , new_ctx )
127
+
89
128
def find_prompt_bucket (self , batch_size , seq_len , ctx = 0 ):
90
129
if self .initialized :
91
130
found_bucket = find_equal_or_closest_greater_config (self .prompt_buckets , (batch_size , seq_len , ctx ))
92
131
if found_bucket is None :
93
- new_batch_size = 2 ** math .ceil (math .log2 (batch_size ))
94
- new_seq_len = math .ceil (seq_len / self .block_size ) * self .block_size
95
- new_ctx = math .ceil (ctx / 2 ) * 2
96
- new_bucket = (new_batch_size , new_seq_len , new_ctx )
132
+ new_bucket = self .generate_fallback_bucket (batch_size , seq_len , ctx )
97
133
logger ().warning (f"Prompt bucket for { batch_size , seq_len , ctx } "
98
134
f" was not prepared. Adding new bucket: { new_bucket } " )
99
135
self .prompt_buckets .append (new_bucket )
@@ -106,9 +142,7 @@ def find_decode_bucket(self, batch_size, num_blocks):
106
142
if self .initialized :
107
143
found_bucket = find_equal_or_closest_greater_config (self .decode_buckets , (batch_size , 1 , num_blocks ))
108
144
if found_bucket is None :
109
- new_batch_size = 2 ** math .ceil (math .log2 (batch_size ))
110
- new_num_blocks = math .ceil (num_blocks / 2 ) * 2
111
- new_bucket = (new_batch_size , 1 , new_num_blocks )
145
+ new_bucket = self .generate_fallback_bucket (batch_size , 1 , num_blocks )
112
146
logger ().warning (f"Decode bucket for { batch_size , 1 , num_blocks } "
113
147
f" was not prepared. Adding new bucket: { new_bucket } " )
114
148
self .decode_buckets .append (new_bucket )
0 commit comments