Commit ec6b435
authored
[Kernels] Force mxfp4->bf16 conversion to use mul.bf16x2 for scaling (#8967)
LLVM doesn't auto-vectorize this very well, and ends up with a mix of
vector and scalar muls. I think the cost heuristics gets tripped up by
the scale broadcasting which requires unpacking and duplicating the
scales, for which we generate ptx like
```
mov.b32 {%rs0, %rs1}, %packed_scales
mov.b32 %r1, {%rs0, %rs0}
mov.b32 %r2, {%rs1, %rs1}
```
However, ptxas can fuse this into the multiply e.g.
```
HMUL2.BF16_V2 R90, R90, R100.H0_H0
HMUL2.BF16_V2 R91, R91, R100.H1_H1
```
where the movs have become the register modifier in the instruction.
This gives a modest 1% speedup on non-persistent bf16xmxfp4 MoE.1 parent 81526ff commit ec6b435
File tree
1 file changed
+17
-1
lines changed- python/triton_kernels/triton_kernels/tensor_details/layout_details
1 file changed
+17
-1
lines changedLines changed: 17 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
291 | 291 | | |
292 | 292 | | |
293 | 293 | | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
294 | 310 | | |
295 | 311 | | |
296 | 312 | | |
| |||
345 | 361 | | |
346 | 362 | | |
347 | 363 | | |
348 | | - | |
| 364 | + | |
349 | 365 | | |
0 commit comments