@@ -872,9 +872,11 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
872872}
873873
874874llvm.func @rocdl.wmma (%arg0 : vector <8 xf32 >, %arg1 : vector <16 x f16 >, %arg2 : vector <16 x i16 >, %arg3 : vector <8 x i32 >,
875- %arg4 : vector <2 xi32 >, %arg5 : vector <4 xi32 >, %arg6 : vector <4 xf32 >, %arg7 : vector <8 xf16 >, %arg8 : vector <8 xi16 >) -> vector <8 xf32 > {
875+ %arg4 : vector <2 xi32 >, %arg5 : vector <4 xi32 >, %arg6 : vector <4 xf32 >, %arg7 : vector <8 xf16 >, %arg8 : vector <8 xi16 >,
876+ %arg9 : vector <32 xf16 >, %arg10 : vector <16 xf32 >, %arg11 : vector <4 xf32 >, %arg12 : vector <32 xf32 >, %arg13 : vector <64 xf32 >,
877+ %arg14 : vector <64 xi32 >, %arg15 : vector <64 xf16 >, %arg16 : vector <16 xbf16 >, %arg17 : vector <32 xbf16 >) -> vector <8 xf32 > {
876878 %zero = llvm.mlir.constant (false ) : i1
877-
879+ %zero_i16 = llvm.mlir.constant ( 0 : i16 ) : i16
878880 // ---- Wave32 -----
879881
880882 // f16 -> f32
@@ -905,6 +907,83 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
905907 // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
906908 %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero , %arg4 , %zero , %arg4 , %arg3 , %zero : (i1 , vector <2 xi32 >, i1 , vector <2 xi32 >, vector <8 xi32 >, i1 ) -> vector <8 xi32 >
907909
910+ // f32 -> f32
911+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 {{.*}}, <16 x float> %{{.*}}, i1 {{.*}}, <16 x float> %{{.*}}, i16 0, <4 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
912+ %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero , %arg10 , %zero , %arg10 , %zero_i16 , %arg11 , %zero , %zero : (i1 , vector <16 xf32 >, i1 , vector <16 xf32 >, i16 , vector <4 xf32 >, i1 , i1 ) -> vector <4 xf32 >
913+
914+ // f16 -> f32
915+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
916+ %r2.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
917+
918+ // bf16 -> f32
919+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
920+ %r3.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xbf16 >, i1 , vector <16 xbf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xf32 >
921+
922+ // f16 -> f16
923+ // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 {{.*}}, <16 x half> %{{.*}}, i1 {{.*}}, <16 x half> %{{.*}}, i16 0, <32 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
924+ %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero , %arg1 , %zero , %arg1 , %zero_i16 , %arg9 , %zero , %zero : (i1 , vector <16 xf16 >, i1 , vector <16 xf16 >, i16 , vector <32 xf16 >, i1 , i1 ) -> vector <32 xf16 >
925+
926+ // bf16 -> bf16
927+ // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v32bf16.v16bf16(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x bfloat> %{{.*}}, i1 {{.*}}, i1 {{.*}})
928+ %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg17 , %zero , %zero : (i1 , vector <16 xbf16 >, i1 , vector <16 xbf16 >, i16 , vector <32 xbf16 >, i1 , i1 ) -> vector <32 xbf16 >
929+
930+ // bf16 -> bf16 / f32
931+ // CHECK: call <32 x bfloat> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v32bf16.v16bf16.v32f32(i1 {{.*}}, <16 x bfloat> %{{.*}}, i1 {{.*}}, <16 x bfloat> %{{.*}}, i16 0, <32 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
932+ %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero , %arg16 , %zero , %arg16 , %zero_i16 , %arg12 , %zero , %zero : (i1 , vector <16 xbf16 >, i1 , vector <16 xbf16 >, i16 , vector <32 xf32 >, i1 , i1 ) -> vector <32 xbf16 >
933+
934+ // f8/bf8 -> f16/f32
935+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
936+ %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
937+
938+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
939+ %r8.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
940+
941+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
942+ %r9.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
943+
944+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
945+ %r10.gfx1250 = rocdl.wmma.f32.16x16x64.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
946+
947+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
948+ %r11.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
949+
950+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
951+ %r12.gfx1250 = rocdl.wmma.f16.16x16x64.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
952+
953+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
954+ %r13.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
955+
956+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x64.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
957+ %r14.gfx1250 = rocdl.wmma.f16.16x16x64.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
958+
959+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
960+ %r15.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
961+
962+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.fp8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
963+ %r16.gfx1250 = rocdl.wmma.f32.16x16x128.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
964+
965+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.fp8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
966+ %r17.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
967+
968+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x128.bf8.bf8.v64f32.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x float> %{{.*}}, i1 {{.*}}, i1 {{.*}})
969+ %r18.gfx1250 = rocdl.wmma.f32.16x16x128.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg13 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf32 >, i1 , i1 ) -> vector <64 xf32 >
970+
971+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
972+ %r19.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
973+
974+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.fp8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
975+ %r20.gfx1250 = rocdl.wmma.f16.16x16x128.fp8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
976+
977+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.fp8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
978+ %r21.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_fp8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
979+
980+ // CHECK: call <64 x half> @llvm.amdgcn.wmma.f16.16x16x128.bf8.bf8.v64f16.v4i32(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i16 0, <64 x half> %{{.*}}, i1 {{.*}}, i1 {{.*}})
981+ %r22.gfx1250 = rocdl.wmma.f16.16x16x128.bf8_bf8 %arg5 , %arg5 , %zero_i16 , %arg15 , %zero , %zero : (vector <4 xi32 >, vector <4 xi32 >, i16 , vector <64 xf16 >, i1 , i1 ) -> vector <64 xf16 >
982+
983+ // iu8 -> i32
984+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <64 x i32> %{{.*}}, i1 {{.*}}, i1 {{.*}})
985+ %r23.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero , %arg5 , %zero , %arg5 , %arg14 , %zero , %zero : (i1 , vector <4 xi32 >, i1 , vector <4 xi32 >, vector <64 xi32 >, i1 , i1 ) -> vector <64 xi32 >
986+
908987 // ---- Wave64 -----
909988
910989 // f16 -> f32
0 commit comments