Skip to content

Commit d53bd97

Browse files
[Backport to llvm_release_180] Add FP4/FP8 operand support for SubgroupMatrixMultiplyAccumulateINTEL (KhronosGroup#3609) (KhronosGroup#3630)
Extend SubgroupMatrixMultiplyAccumulateINTEL to support packed 4-bit and 8-bit floating-point matrix operands by implementing extensions: - SPV_INTEL_subgroup_matrix_multiply_accumulate_float4 - SPV_INTEL_subgroup_matrix_multiply_accumulate_float8 These extensions add operand flags that interpret packed integer data as FP4/FP8 without requiring actual FP4/FP8 type support added by SPV_INTEL_float4 or SPV_EXT_float8. FP4 operands: `MatrixAPackedFloat4E2M1INTEL` (0x40000) / `MatrixBPackedFloat4E2M1INTEL` (0x80000) FP8 operands: `MatrixAPackedFloat8E4M3INTEL` (0x4000) / `MatrixBPackedFloat8E4M3INTEL` (0x8000) `MatrixAPackedFloat8E5M2INTEL` (0x10000) / `MatrixBPackedFloat8E5M2INTEL` (0x20000) Specs: https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate_float4.asciidoc https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate_float8.asciidoc Co-authored-by: Viktoria Maximova <viktoria.maksimova@intel.com>
1 parent d83d445 commit d53bd97

File tree

5 files changed

+187
-0
lines changed

5 files changed

+187
-0
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ EXT(SPV_INTEL_maximum_registers)
7777
EXT(SPV_INTEL_bindless_images)
7878
EXT(SPV_INTEL_2d_block_io)
7979
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
80+
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate_float4)
81+
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate_float8)
8082
EXT(SPV_KHR_bfloat16)
8183
EXT(SPV_INTEL_bfloat16_arithmetic)
8284
EXT(SPV_INTEL_ternary_bitwise_function)

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4155,6 +4155,66 @@ class SPIRVSubgroupMatrixMultiplyAccumulateINTELInst
41554155
std::optional<ExtensionID> getRequiredExtension() const override {
41564156
return ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate;
41574157
}
4158+
4159+
protected:
4160+
void validate() const override {
4161+
SPIRVInstTemplateBase::validate();
4162+
4163+
// Check if FP4 or FP8 matrix operands are used
4164+
// Operands parameter is the last operand (index 4)
4165+
auto *NonConstThis =
4166+
const_cast<SPIRVSubgroupMatrixMultiplyAccumulateINTELInst *>(this);
4167+
if (NonConstThis->getOperands().size() > 4) {
4168+
const SPIRVConstant *OperandsConst =
4169+
static_cast<const SPIRVConstant *>(NonConstThis->getOperand(4));
4170+
uint64_t OperandsMask = OperandsConst->getZExtIntValue();
4171+
4172+
// FP4 operand bits
4173+
constexpr uint64_t FP4Mask =
4174+
spv::internal::
4175+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat4E2M1INTELMask |
4176+
spv::internal::
4177+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat4E2M1INTELMask;
4178+
4179+
// FP8 operand bits
4180+
constexpr uint64_t FP8Mask =
4181+
spv::internal::
4182+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E4M3INTELMask |
4183+
spv::internal::
4184+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E4M3INTELMask |
4185+
spv::internal::
4186+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E5M2INTELMask |
4187+
spv::internal::
4188+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E5M2INTELMask;
4189+
4190+
if ((OperandsMask & FP4Mask) != 0) {
4191+
getModule()->getErrorLog().checkError(
4192+
getModule()->isAllowedToUseExtension(
4193+
ExtensionID::
4194+
SPV_INTEL_subgroup_matrix_multiply_accumulate_float4),
4195+
SPIRVEC_RequiresExtension,
4196+
"SPV_INTEL_subgroup_matrix_multiply_accumulate_float4\n"
4197+
"SubgroupMatrixMultiplyAccumulateINTEL with FP4 operand flags "
4198+
"requires this extension");
4199+
getModule()->addExtension(
4200+
ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate_float4);
4201+
}
4202+
4203+
if ((OperandsMask & FP8Mask) != 0) {
4204+
getModule()->getErrorLog().checkError(
4205+
getModule()->isAllowedToUseExtension(
4206+
ExtensionID::
4207+
SPV_INTEL_subgroup_matrix_multiply_accumulate_float8),
4208+
SPIRVEC_RequiresExtension,
4209+
"SPV_INTEL_subgroup_matrix_multiply_accumulate_float8\n"
4210+
"SubgroupMatrixMultiplyAccumulateINTEL with FP8 operand flags "
4211+
"requires this extension");
4212+
getModule()->addExtension(
4213+
ExtensionID::SPV_INTEL_subgroup_matrix_multiply_accumulate_float8);
4214+
}
4215+
}
4216+
}
4217+
41584218
SPIRVCapVec getRequiredCapability() const override {
41594219
return getVec(CapabilitySubgroupMatrixMultiplyAccumulateINTEL);
41604220
}

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,17 @@ enum InternalBuiltIn {
171171
IBuiltInDeviceBarrierValidINTEL = 6186,
172172
};
173173

174+
enum InternalMatrixMultiplyAccumulateOperandsMask {
175+
// FP8 matrix operands
176+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E4M3INTELMask = 0x4000,
177+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E4M3INTELMask = 0x8000,
178+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat8E5M2INTELMask = 0x10000,
179+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat8E5M2INTELMask = 0x20000,
180+
// FP4 matrix operands
181+
IMatrixMultiplyAccumulateOperandsMatrixAPackedFloat4E2M1INTELMask = 0x40000,
182+
IMatrixMultiplyAccumulateOperandsMatrixBPackedFloat4E2M1INTELMask = 0x80000,
183+
};
184+
174185
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
175186
_SPIRV_OP(Capability, JointMatrixINTEL)
176187
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; This test checks that SubgroupMatrixMultiplyAccumulateINTEL with FP4 operand flags
2+
; requires the SPV_INTEL_subgroup_matrix_multiply_accumulate_float4 extension.
3+
4+
; RUN: llvm-as %s -o %t.bc
5+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate,+SPV_INTEL_subgroup_matrix_multiply_accumulate_float4
6+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
7+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
8+
9+
; RUN: not llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
11+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
12+
; CHECK-ERROR: SPV_INTEL_subgroup_matrix_multiply_accumulate_float4
13+
14+
; CHECK-SPIRV-DAG: Capability SubgroupMatrixMultiplyAccumulateINTEL
15+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate"
16+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate_float4"
17+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 262144
18+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 524288
19+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 786432
20+
21+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
22+
target triple = "spir64-unknown-unknown"
23+
24+
; Test MatrixAPackedFloat4E2M1INTEL operand (0x40000 = 262144)
25+
define spir_func <4 x float> @test_fp4_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
26+
entry:
27+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 262144)
28+
ret <4 x float> %result
29+
}
30+
31+
; Test MatrixBPackedFloat4E2M1INTEL operand (0x80000 = 524288)
32+
define spir_func <4 x float> @test_fp4_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
33+
entry:
34+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 524288)
35+
ret <4 x float> %result
36+
}
37+
38+
; Test both FP4 operands (0xC0000 = 786432)
39+
define spir_func <4 x float> @test_fp4_matrix_both(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
40+
entry:
41+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 786432)
42+
ret <4 x float> %result
43+
}
44+
45+
declare spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32, <4 x i8>, <8 x i8>, <4 x float>, i32)
46+
47+
!opencl.spir.version = !{!0}
48+
!spirv.Source = !{!1}
49+
!llvm.ident = !{!2}
50+
51+
!0 = !{i32 1, i32 0}
52+
!1 = !{i32 4, i32 100000}
53+
!2 = !{!"clang version 17.0.0"}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
; This test checks that SubgroupMatrixMultiplyAccumulateINTEL with FP8 operand flags
2+
; requires the SPV_INTEL_subgroup_matrix_multiply_accumulate_float8 extension.
3+
4+
; RUN: llvm-as %s -o %t.bc
5+
; RUN: llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate,+SPV_INTEL_subgroup_matrix_multiply_accumulate_float8
6+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
7+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
8+
9+
; RUN: not llvm-spirv %t.bc -o %t.spv --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
10+
11+
; CHECK-ERROR: RequiresExtension: Feature requires the following SPIR-V extension:
12+
; CHECK-ERROR: SPV_INTEL_subgroup_matrix_multiply_accumulate_float8
13+
14+
; CHECK-SPIRV-DAG: Capability SubgroupMatrixMultiplyAccumulateINTEL
15+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate"
16+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_subgroup_matrix_multiply_accumulate_float8"
17+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 16384
18+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 32768
19+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 65536
20+
; CHECK-SPIRV-DAG: SubgroupMatrixMultiplyAccumulateINTEL {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} {{[0-9]+}} 131072
21+
22+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
23+
target triple = "spir64-unknown-unknown"
24+
25+
; Test MatrixAPackedFloat8E4M3INTEL operand (0x4000 = 16384)
26+
define spir_func <4 x float> @test_fp8_e4m3_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
27+
entry:
28+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 16384)
29+
ret <4 x float> %result
30+
}
31+
32+
; Test MatrixBPackedFloat8E4M3INTEL operand (0x8000 = 32768)
33+
define spir_func <4 x float> @test_fp8_e4m3_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
34+
entry:
35+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 32768)
36+
ret <4 x float> %result
37+
}
38+
39+
; Test MatrixAPackedFloat8E5M2INTEL operand (0x10000 = 65536)
40+
define spir_func <4 x float> @test_fp8_e5m2_matrix_a(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
41+
entry:
42+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 65536)
43+
ret <4 x float> %result
44+
}
45+
46+
; Test MatrixBPackedFloat8E5M2INTEL operand (0x20000 = 131072)
47+
define spir_func <4 x float> @test_fp8_e5m2_matrix_b(<4 x float> %c, <4 x i8> %a, <8 x i8> %b) {
48+
entry:
49+
%result = call spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32 8, <4 x i8> %a, <8 x i8> %b, <4 x float> %c, i32 131072)
50+
ret <4 x float> %result
51+
}
52+
53+
declare spir_func <4 x float> @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_hDv8_hDv4_fi(i32, <4 x i8>, <8 x i8>, <4 x float>, i32)
54+
55+
!opencl.spir.version = !{!0}
56+
!spirv.Source = !{!1}
57+
!llvm.ident = !{!2}
58+
59+
!0 = !{i32 1, i32 0}
60+
!1 = !{i32 4, i32 100000}
61+
!2 = !{!"clang version 17.0.0"}

0 commit comments

Comments
 (0)