We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c204d32 commit f331b00Copy full SHA for f331b00
tests/unit_tests/test_bucketing.py
@@ -63,7 +63,10 @@ def test_generate_decode_buckets():
63
bs_bucket_config = [1, 32, 128]
64
blocks_bucket_config = [128, 128, 2048]
65
max_blocks = 1024
66
+ max_model_len = 131072
67
+ block_size = 128
68
buckets = linear.generate_decode_buckets(bs_bucket_config,
- blocks_bucket_config, max_blocks)
69
+ blocks_bucket_config, max_blocks,
70
+ max_model_len, block_size)
71
assert len(buckets) == 72
72
assert all(blocks <= max_blocks for _, _, blocks in buckets)
vllm_gaudi/extension/bucketing/linear.py
@@ -1,6 +1,7 @@
1
import itertools
2
import operator
3
import os
4
+import math
5
from dataclasses import dataclass, field
6
from typing import List, Tuple
7
0 commit comments