Skip to content

Commit 8400f82

Browse files
authored
Change warmup scenario for execute dummy scenario (#54)
Change warmup scenario to execute dummy scenario. This way we more accurately simulate the real behaviour of vllm inference by executing the precise run that is happening in real inference but with dummy config that we want to warm-up during warm-up process. No need for some artificially create an inference scenario, as right now we are utilizing real execution flow --------- Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 0492c55 commit 8400f82

File tree

2 files changed

+59
-213
lines changed

2 files changed

+59
-213
lines changed

vllm_gaudi/extension/bucketing/exponential.py

Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def get_decode_buckets(self, max_num_seqs, block_size,
6565
num_max_blocks):
6666
self.check_for_user_flags('decode')
6767
prefix_caching = get_config().prefix_caching
68-
max_blocks = num_max_blocks
6968

7069
# cfgs shape: [min, step, max, limit]
7170
decode_bs_limit = math.ceil(math.log2(max_num_seqs)) + 1
7271
decode_bs_bucket_cfg = [1, 2, max_num_seqs, decode_bs_limit]
73-
max_decode_block_limit = math.ceil(math.log2(max_blocks)) + 1
74-
decode_block_bucket_cfg = [block_size, block_size, max_blocks, max_decode_block_limit]
72+
max_decode_block_limit = math.ceil(math.log2(num_max_blocks)) + 1
73+
max_decode_blocks = min((max_model_len // block_size * max_num_seqs), num_max_blocks)
74+
decode_block_bucket_cfg = [1, max_num_seqs, max_decode_blocks, max_decode_block_limit]
7575

7676
msg = ("Decode bucket config (min, step, max_warmup, limit) "
7777
f"bs:{decode_bs_bucket_cfg}, "
@@ -163,54 +163,37 @@ def generate_prompt_buckets(bs_bucket_config,
163163

164164

165165
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
166-
max_blocks, max_model_len, block_size,
167-
skip_invalid=False):
166+
max_blocks, max_model_len, block_size):
168167
buckets = []
169168
bs_buckets = warmup_range_with_limit(bs_bucket_config)
170-
tmp_blocks_bucket_config = blocks_bucket_config
171-
tmp_blocks_bucket_config = (*tmp_blocks_bucket_config[:2], max_blocks, tmp_blocks_bucket_config[-1])
172-
block_buckets = warmup_range_with_limit(tmp_blocks_bucket_config)
173-
last_bucket = max_blocks
169+
block_buckets = warmup_range_with_limit(blocks_bucket_config)
174170
valid_blocks = set()
175-
if not skip_invalid:
176-
#NOTE(kzawora): this case will generate all possible combinations of
177-
# exponentially-spaced bs and blocks, even if combination is
178-
# invalid (exceeds max_model_len). Unfortunately, this is necessary
179-
# to handle scenario where bucket dimensions are determined by
180-
# get_padded_decode_num_blocks or get_padded_decode_batch_size,
181-
# since they don't include information about the other dimension.
182-
# This will need to be refactored at some point in the model runner,
183-
# but for now, we are dealing with this.
184-
valid_blocks = set((bs, 1, x) for x in sorted(block_buckets) for bs in bs_buckets)
185-
else:
186-
#NOTE(kzawora): this case will generate only valid combinations of
187-
# exponentially-spaced bs and blocks, where the product of bs and blocks
188-
# is less than or equal to max_model_len. To handle corner cases
189-
# (e.g. longer context due to fragmentation), we're adding an additional
190-
# bucket with max_blocks for each batch size.
191-
# For this to work properly, bucket dimensions need be requested as
192-
# a combination of (batch_size, num_blocks), not separately.
193-
for bs in bs_buckets:
194-
max_blocks_per_bs = min(bs * math.ceil(max_model_len / block_size), last_bucket)
195-
upper_bucket_bound = next(x for x in sorted(block_buckets) if x >= max_blocks_per_bs)
196-
valid_blocks = set((bs, 1, x) for x in sorted(block_buckets) if x <= upper_bucket_bound)
197-
198-
buckets.extend(list(valid_blocks))
199-
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
200-
201-
202-
def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=False, fill=True):
171+
#NOTE(kzawora): generate only valid combinations of
172+
# exponentially-spaced bs and blocks, where the product of bs and blocks
173+
# is less than or equal to max_model_len. To handle corner cases
174+
# (e.g. longer context due to fragmentation), we're adding an additional
175+
# bucket with max_blocks for each batch size.
176+
# For this to work properly, bucket dimensions need be requested as
177+
# a combination of (batch_size, num_blocks), not separately.
178+
for bs in bs_buckets:
179+
max_blocks_per_bs = min(bs * math.ceil(max_model_len / block_size), max_blocks)
180+
try:
181+
upper_bucket_bound = max(x for x in sorted(block_buckets) if x <= max_blocks_per_bs)
182+
except ValueError:
183+
continue
184+
valid_blocks = set((bs, 1, x) for x in sorted(block_buckets) if x <= upper_bucket_bound \
185+
and bs <= x)
186+
buckets.extend(valid_blocks)
187+
return list(buckets)
188+
189+
190+
def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=False):
203191
"""
204192
NOTE(kzawora): we'll use exponential spacing for buckets in which scaled
205193
power will return bmin for first bucket iteration, and bmax for last
206194
iteration, with elements between determined by the exponent, and base being
207-
unchanged. Note that after padding to bstep, duplicates may occur.
208-
Handling of duplicates is configured by fill parameter.
209-
If fill is False, duplicates are removed and less buckets are returned.
210-
211-
If fill is True, duplicates are resolved by selecting the closest (larger
212-
or smaller) bucket. If duplicate resolution is not possible, less buckets
213-
are returned. In that case, buckets are guaranteed to be linearly spaced.
195+
unchanged. Note that after padding to bstep, duplicates may occur, and
196+
then shall be removed.
214197
Example (bmin=128, bstep=128, bmax=2048, num_buckets=10):
215198
There are 16 possible buckets (2048/128), and we'll attempt to select 10 of
216199
them with exponential spacing.
@@ -226,37 +209,13 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
226209
scaled_powers_unpadded = [bmin*base^0(==bmin), bmin*base^1, bmin*base^2, ..., bmin*base^9(==bmax)]
227210
scaled_powers_unpadded = [128.00, 174.18, 237.02, 322.54, 438.91, 597.26, 812.75, 1105.98, 1505.01, 2048.00]
228211
229-
if fill is False:
212+
We then remove duplicate buckets:
230213
scaled_powers_padded = [ 128, 256, 256, 384, 512, 640, 896, 1152, 1536, 2048]
231214
^_______^
232215
duplicates
233216
buckets = [ 128, 256, 384, 512, 640, 896, 1152, 1536, 2048]
234217
^
235218
duplicate bucket removed
236-
len(buckets) = 9, num_buckets = 10
237-
if fill is True:
238-
buckets = [ 128, 256, 384, 512, 640, 768, 896, 1152, 1536, 2048]
239-
^_______^_______^_______^
240-
closest unused buckets selected
241-
^_______^_______^
242-
these become duplicates once previous duplicates are resolved
243-
244-
In this case we'll have four duplicated buckets:
245-
174.18 -> 256, optimal bucket,
246-
237.02 -> (256) -> 384, taking closest available bucket,
247-
as optimal bucket 256 was already captured by 174.18,
248-
322.54 -> (384) -> 512, taking closest available bucket,
249-
as optimal bucket 384 was already captured by 237.02,
250-
438.91 -> (512) -> 640, taking closest available bucket,
251-
as optimal bucket 512 was already captured by 322.54,
252-
597.26 -> (640) -> 768, taking closest available bucket,
253-
as optimal bucket 640 was already captured by 438.91,
254-
812.75 -> 896, optimal bucket
255-
len(buckets) = 10, num_buckets = 10
256-
In this case, the end result has the same buckets as fill=False,
257-
but with additional bucket 768 added.
258-
The difference is more pronounced for larger ranges and larger number
259-
of buckets.
260219
""" # noqa: E501
261220

262221
bmin, bstep, bmax, num_buckets = config
@@ -281,15 +240,7 @@ def warmup_range_with_limit(config: Tuple[int, int, int, int], long_context=Fals
281240
bucket = bmax
282241
else:
283242
bucket = math.ceil(power_unpadded / bstep) * bstep
284-
if fill and bucket in buckets:
285-
available_buckets = linear_buckets.difference(buckets)
286-
if len(available_buckets) == 0:
287-
break # there are no more unique buckets, let's exit now
288-
new_bucket = min(available_buckets,
289-
key=lambda x: abs(x - power_unpadded))
290-
buckets.add(new_bucket)
291-
else:
292-
buckets.add(bucket)
243+
buckets.add(bucket)
293244

294245
if long_context:
295246
#tmp_step = bmax / num_buckets

0 commit comments

Comments
 (0)