Skip to content

Commit 4ae17bf

Browse files
authored
Revert "Use Cache Hinting for fused_moe kernel (#15511)" (#15645)
Signed-off-by: Wes Medford <[email protected]>
1 parent 8a49eea commit 4ae17bf

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,7 @@ def fused_moe_kernel_gptq_awq(
189189
mask=token_mask[:, None] &
190190
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
191191
other=0.0)
192-
b = tl.load(
193-
b_ptrs,
194-
cache_modifier=".cg",
195-
eviction_policy="evict_last",
196-
)
192+
b = tl.load(b_ptrs)
197193
if use_int4_w4a16:
198194
b = (b >> b_shifter) & 0xF
199195

@@ -395,13 +391,9 @@ def fused_moe_kernel(
395391
mask=token_mask[:, None] &
396392
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
397393
other=0.0)
398-
b = tl.load(
399-
b_ptrs,
400-
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
401-
other=0.0,
402-
cache_modifier=".cg",
403-
eviction_policy="evict_last",
404-
)
394+
b = tl.load(b_ptrs,
395+
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
396+
other=0.0)
405397
# We accumulate along the K dimension.
406398
if use_int8_w8a16:
407399
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)

0 commit comments

Comments
 (0)