Skip to content

Commit f331b00

Browse files
authored
[UT] Fix test args for bucketing tests (#105)
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent c204d32 commit f331b00

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

tests/unit_tests/test_bucketing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def test_generate_decode_buckets():
6363
bs_bucket_config = [1, 32, 128]
6464
blocks_bucket_config = [128, 128, 2048]
6565
max_blocks = 1024
66+
max_model_len = 131072
67+
block_size = 128
6668
buckets = linear.generate_decode_buckets(bs_bucket_config,
67-
blocks_bucket_config, max_blocks)
69+
blocks_bucket_config, max_blocks,
70+
max_model_len, block_size)
6871
assert len(buckets) == 72
6972
assert all(blocks <= max_blocks for _, _, blocks in buckets)

vllm_gaudi/extension/bucketing/linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import operator
33
import os
4+
import math
45
from dataclasses import dataclass, field
56
from typing import List, Tuple
67

0 commit comments

Comments
 (0)