Skip to content

Commit 86f0de6

Browse files
bingyizh233chsigg
andauthored
Add the python + lit test cases for Ampere small-tile-size mixed precision gemm (bf16 x s8) (#5337)
In the closed PR #4768, I have written the python + lit tests cases for Ampere small-tile-size mixed precision gemm (bf16 x s8). While the compilation crash is solved by another PR, the test cases can be added. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Christian Sigg <[email protected]>
1 parent a4f1854 commit 86f0de6

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

python/test/regression/test_cast_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import triton.language as tl
1414
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip
1515

16-
input_dtypes = ["float16", "float32", "float64"]
16+
input_dtypes = ["bfloat16", "float16", "float32", "float64"]
1717
if is_cuda():
1818
input_dtypes += ["int8", "float8_e5m2"]
1919
cc = torch.cuda.get_device_capability(0)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,3 +1999,32 @@ tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: te
19991999
}
20002000

20012001
}
2002+
2003+
// -----
2004+
2005+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
2006+
2007+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
2008+
2009+
tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
2010+
// CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1
2011+
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
2012+
%2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
2013+
tt.return
2014+
}
2015+
2016+
}
2017+
2018+
// -----
2019+
2020+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
2021+
2022+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
2023+
tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
2024+
// CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0
2025+
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
2026+
%2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
2027+
tt.return
2028+
}
2029+
2030+
}

0 commit comments

Comments
 (0)