Skip to content

Commit b3f2fdd

Browse files
authored
[TPU][V1] Fix exponential padding when max-num-batched-tokens is not a power of 2 (#16596)
Signed-off-by: NickLucche <[email protected]>
1 parent aa29841 commit b3f2fdd

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

tests/v1/tpu/worker/test_tpu_model_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@ def test_get_paddings():
299299
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
300300
padding_gap)
301301
assert actual_paddings == expected_paddings
302+
# Exponential padding.
303+
max_token_size, padding_gap = 1024, 0
304+
expected_paddings = [16, 32, 64, 128, 256, 512, 1024]
305+
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
306+
padding_gap)
307+
assert actual_paddings == expected_paddings
308+
# Exponential padding with max_token_size not a power of two.
309+
max_token_size = 317
310+
expected_paddings = [16, 32, 64, 128, 256, 512]
311+
actual_paddings = _get_token_paddings(min_token_size, max_token_size,
312+
padding_gap)
313+
assert actual_paddings == expected_paddings
302314

303315

304316
def test_get_padded_token_len():

vllm/v1/worker/tpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int,
10401040

10411041
if padding_gap == 0:
10421042
logger.info("Using exponential token paddings:")
1043-
while num <= max_token_size:
1043+
while True:
10441044
logger.info(" %d", num)
10451045
paddings.append(num)
1046+
if num >= max_token_size:
1047+
break
10461048
num *= 2
10471049

10481050
else:

0 commit comments

Comments
 (0)