File tree Expand file tree Collapse file tree 2 files changed +18
-9
lines changed
model_executor/layers/quantization/utils Expand file tree Collapse file tree 2 files changed +18
-9
lines changed Original file line number Diff line number Diff line change 6
6
7
7
8
8
def calculate_tile_tokens_dim (num_tokens , top_k , num_experts ):
9
- from flashinfer import next_positive_power_of_2
10
-
11
- # Guess tokens per expert assuming perfect expert distribution first.
12
- num_tokens_per_expert = (num_tokens * top_k ) // num_experts
13
- # And pad the number to the next power of 2.
14
- tile_tokens_dim = next_positive_power_of_2 (num_tokens_per_expert )
15
- # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
16
- tile_tokens_dim = min (max (tile_tokens_dim , 8 ), 64 )
9
+
10
+ # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
11
+ # TODO: Revert this to dynamic calculation once a new version of FlashInfer
12
+ # with the necessary kernels is released.
13
+ tile_tokens_dim = 8
14
+
15
+ # from flashinfer import next_positive_power_of_2
16
+
17
+ # # Guess tokens per expert assuming perfect expert distribution first.
18
+ # num_tokens_per_expert = (num_tokens * top_k) // num_experts
19
+ # # And pad the number to the next power of 2.
20
+ # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
21
+ # # Cap to 8-64 tokens per CTA tile as it's the range supported by the
22
+ # # kernel.
23
+ # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
24
+
17
25
return tile_tokens_dim
18
26
19
27
Original file line number Diff line number Diff line change @@ -524,7 +524,8 @@ def build(self,
524
524
head_dim = self .kv_cache_spec .head_size
525
525
526
526
# currently prefill trtllm attention does not support fp8 kv cache
527
- prefill_use_trtllm = use_trtllm_attention (
527
+ prefill_use_trtllm = not cache_dtype .startswith ("fp8" ) \
528
+ and use_trtllm_attention (
528
529
num_prefill_tokens , max_seq_len , cache_dtype ,
529
530
num_qo_heads , num_kv_heads , head_dim )
530
531
decode_use_trtllm = use_trtllm_attention (
You can’t perform that action at this time.
0 commit comments