@@ -18279,65 +18279,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1827918279 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1828018280 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
1828118281 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18282- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18282+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18283+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18284+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18285+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18286+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18287+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18288+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18289+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18290+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18291+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18292+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18293+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18294+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18295+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18296+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18297+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18298+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18299+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18300+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18301+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18302+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18303+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18304+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18305+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18306+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18307+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18308+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18309+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18310+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18311+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18312+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18313+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18314+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18315+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18316+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18317+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18318+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18319+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18320+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18321+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18322+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18323+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18324+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18325+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18326+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
1828318327
1828418328 // These operations perform a matrix multiplication and accumulation of
1828518329 // the form:
1828618330 // D = A * B + C
18287- // The return type always matches the type of matrix C.
18288- unsigned ArgForMatchingRetType;
18331+ // We need to specify one type for matrices AB and one for matrices CD.
18332+ // Sparse matrix operations can have different types for A and B as well as
18333+ // an additional type for sparsity index.
18334+ // Destination type should be put before types used for source operands.
18335+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18336+ // On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18337+ // There is no need for the variable opsel argument, so always set it to
18338+ // "false".
18339+ bool AppendFalseForOpselArg = false;
1828918340 unsigned BuiltinWMMAOp;
1829018341
1829118342 switch (BuiltinID) {
1829218343 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
1829318344 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18294- ArgForMatchingRetType = 2;
18345+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18346+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18347+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1829518348 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
1829618349 break;
1829718350 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
1829818351 case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18299- ArgForMatchingRetType = 2;
18352+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18353+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18354+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830018355 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
1830118356 break;
18357+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18358+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18359+ AppendFalseForOpselArg = true;
18360+ LLVM_FALLTHROUGH;
1830218361 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
1830318362 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18304- ArgForMatchingRetType = 2;
18363+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830518364 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
1830618365 break;
18366+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18367+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18368+ AppendFalseForOpselArg = true;
18369+ LLVM_FALLTHROUGH;
1830718370 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1830818371 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18309- ArgForMatchingRetType = 2;
18372+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831018373 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
1831118374 break;
1831218375 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
1831318376 case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18314- ArgForMatchingRetType = 2;
18377+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831518378 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
1831618379 break;
1831718380 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1831818381 case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18319- ArgForMatchingRetType = 2;
18382+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832018383 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
1832118384 break;
1832218385 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
1832318386 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18324- ArgForMatchingRetType = 4;
18387+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18388+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18389+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1832518390 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
1832618391 break;
1832718392 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1832818393 case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18329- ArgForMatchingRetType = 4;
18394+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18395+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18396+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833018397 BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
1833118398 break;
18399+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18400+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18401+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18402+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18403+ break;
18404+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18405+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18406+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18407+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18408+ break;
18409+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18410+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18411+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18412+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18413+ break;
18414+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18415+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18416+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18417+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18418+ break;
18419+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18420+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18421+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18422+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18423+ break;
18424+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18425+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18426+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18427+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18428+ break;
18429+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18430+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18431+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18432+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18433+ break;
18434+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18435+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18436+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18437+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18438+ break;
18439+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18440+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18441+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18442+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18443+ break;
18444+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18445+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18446+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18447+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18448+ break;
18449+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18450+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18451+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18452+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18453+ break;
18454+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18455+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18456+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18457+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18458+ break;
18459+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18460+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18461+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18462+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18463+ break;
18464+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18465+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18466+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18467+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18468+ break;
18469+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18470+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18471+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18472+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18473+ break;
18474+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18475+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18476+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18477+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18478+ break;
1833218479 }
1833318480
1833418481 SmallVector<Value *, 6> Args;
1833518482 for (int i = 0, e = E->getNumArgs(); i != e; ++i)
1833618483 Args.push_back(EmitScalarExpr(E->getArg(i)));
18484+ if (AppendFalseForOpselArg)
18485+ Args.push_back(Builder.getFalse());
1833718486
18338- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18339- {Args[ArgForMatchingRetType]->getType()});
18487+ SmallVector<llvm::Type *, 6> ArgTypes;
18488+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18489+ ArgTypes.push_back(Args[ArgIdx]->getType());
1834018490
18491+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
1834118492 return Builder.CreateCall(F, Args);
1834218493 }
1834318494
0 commit comments