@@ -1348,3 +1348,131 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13481348 tt.return
13491349 }
13501350}
1351+
1352+ // -----
1353+
1354+ #blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1355+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [4 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
1356+ #mma = #ttg.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = []}>
1357+ #shared = #ttg.swizzled_shared <{vec = 8 , perPhase = 2 , maxPhase = 4 , order = [1 , 0 ]}>
1358+ #shared1 = #ttg.swizzled_shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ]}>
1359+ #smem = #ttg.shared_memory
1360+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
1361+ tt.func @assume_matmul (%arg0: index , %arg1: index , %arg2: index , %arg3: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg4: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }) -> tensor <128 x128 xf32 , #mma > {
1362+ // expected-remark@+1 {{unsigned : [18446744073709551615, 18446744073709551615] signed : [-1, -1]}}
1363+ %c -1 = arith.constant -1 : index
1364+ // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}}
1365+ // expected-remark@+1 {{non-neg}}
1366+ %c1 = arith.constant 1 : index
1367+ // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}}
1368+ // expected-remark@+1 {{non-neg}}
1369+ %c0 = arith.constant 0 : index
1370+ // expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}}
1371+ // expected-remark@+1 {{non-neg}}
1372+ %c1_i32 = arith.constant 1 : i32
1373+ // expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}}
1374+ // expected-remark@+1 {{non-neg}}
1375+ %c0_i32 = arith.constant 0 : i32
1376+ // expected-remark@+1 {{unsigned : [1, 1] signed : [-1, -1]}}
1377+ %true = arith.constant true
1378+ %cst = arith.constant dense <4.000000e+00 > : tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
1379+ // expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}}
1380+ // expected-remark@+1 {{non-neg}}
1381+ %cst_0 = arith.constant dense <4 > : tensor <32 x128 xi32 , #blocked >
1382+ // expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}}
1383+ // expected-remark@+1 {{non-neg}}
1384+ %cst_1 = arith.constant dense <4 > : tensor <128 x32 xi32 , #blocked1 >
1385+ %cst_2 = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
1386+ %cst_3 = arith.constant dense <0.000000e+00 > : tensor <32 x128 xf16 , #blocked >
1387+ %0 = tt.splat %arg3 : !tt.ptr <f16 > -> tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >
1388+ // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1389+ // expected-remark@+1 {{non-neg}}
1390+ %1 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
1391+ // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1392+ // expected-remark@+1 {{non-neg}}
1393+ %2 = tt.expand_dims %1 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>> -> tensor <1 x32 xi32 , #blocked1 >
1394+ // expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1395+ // expected-remark@+1 {{non-neg}}
1396+ %3 = tt.broadcast %2 : tensor <1 x32 xi32 , #blocked1 > -> tensor <128 x32 xi32 , #blocked1 >
1397+ %4 = tt.addptr %0 , %3 : tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x32 xi32 , #blocked1 >
1398+ %5 = tt.splat %arg4 : !tt.ptr <f16 > -> tensor <32 x128 x!tt.ptr <f16 >, #blocked >
1399+ // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1400+ // expected-remark@+1 {{non-neg}}
1401+ %6 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
1402+ // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1403+ // expected-remark@+1 {{non-neg}}
1404+ %7 = tt.expand_dims %6 {axis = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x128 xi32 , #blocked >
1405+ // expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1406+ // expected-remark@+1 {{non-neg}}
1407+ %8 = tt.broadcast %7 : tensor <1 x128 xi32 , #blocked > -> tensor <32 x128 xi32 , #blocked >
1408+ %9 = tt.addptr %5 , %8 : tensor <32 x128 x!tt.ptr <f16 >, #blocked >, tensor <32 x128 xi32 , #blocked >
1409+ %10 = ttg.local_alloc : () -> !ttg.memdesc <1 x128 x32 xf16 , #shared , #smem , mutable >
1410+ %11 = ttg.local_alloc : () -> !ttg.memdesc <1 x32 x128 xf16 , #shared1 , #smem , mutable >
1411+ // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1412+ %12 = arith.cmpi slt , %arg0 , %arg1 : index
1413+ // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1414+ %13 = tt.splat %12 : i1 -> tensor <128 x32 xi1 , #blocked1 >
1415+ %14 = tt.load %4 , %13 {OpIdx = #amdgpu.OpIdx <0 >} : tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >
1416+ // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1417+ %15 = tt.splat %12 : i1 -> tensor <32 x128 xi1 , #blocked >
1418+ %16 = tt.load %9 , %15 , %cst_3 : tensor <32 x128 x!tt.ptr <f16 >, #blocked >
1419+ %17 = ttg.memdesc_subview %10 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x128 x32 xf16 , #shared , #smem , mutable > -> !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >
1420+ ttg.local_store %14 , %17 {OpIdx = #amdgpu.OpIdx <0 >} : tensor <128 x32 xf16 , #blocked1 > -> !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >
1421+ %18 = ttg.memdesc_subview %11 [%c0_i32 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x32 x128 xf16 , #shared1 , #smem , mutable > -> !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >
1422+ ttg.local_store %16 , %18 : tensor <32 x128 xf16 , #blocked > -> !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >
1423+ // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1424+ %19 = arith.subi %arg1 , %arg2 : index
1425+ %20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args (%arg6 = %4 , %arg7 = %9 , %arg8 = %cst_2 , %arg9 = %c0_i32 , %arg10 = %17 , %arg11 = %18 ) -> (tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >, tensor <32 x128 x!tt.ptr <f16 >, #blocked >, tensor <128 x128 xf32 , #mma >, i32 , !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >, !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >) {
1426+ %33 = tt.addptr %arg6 , %cst_1 : tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >, tensor <128 x32 xi32 , #blocked1 >
1427+ %34 = tt.addptr %arg7 , %cst_0 : tensor <32 x128 x!tt.ptr <f16 >, #blocked >, tensor <32 x128 xi32 , #blocked >
1428+ llvm.intr.assume %true : i1
1429+ %35 = tt.load %33 {OpIdx = #amdgpu.OpIdx <0 >} : tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >
1430+ %36 = ttg.local_load %arg10 : !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable > -> tensor <128 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
1431+ %37 = tt.load %34 : tensor <32 x128 x!tt.ptr <f16 >, #blocked >
1432+ %38 = ttg.local_load %arg11 : !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable > -> tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
1433+ %39 = arith.mulf %38 , %cst : tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
1434+ %40 = tt.dot %36 , %39 , %arg8 : tensor <128 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
1435+ // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}}
1436+ %41 = arith.addi %arg9 , %c1_i32 : i32
1437+ // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1438+ %42 = arith.cmpi slt , %41 , %c1_i32 : i32
1439+ // expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}}
1440+ %43 = arith.select %42 , %41 , %c0_i32 : i32
1441+ %44 = ttg.memdesc_subview %10 [%43 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x128 x32 xf16 , #shared , #smem , mutable > -> !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >
1442+ ttg.local_store %35 , %44 {OpIdx = #amdgpu.OpIdx <0 >} : tensor <128 x32 xf16 , #blocked1 > -> !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >
1443+ %45 = ttg.memdesc_subview %11 [%43 , %c0_i32 , %c0_i32 ] : !ttg.memdesc <1 x32 x128 xf16 , #shared1 , #smem , mutable > -> !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >
1444+ ttg.local_store %37 , %45 : tensor <32 x128 xf16 , #blocked > -> !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >
1445+ scf.yield %33 , %34 , %40 , %43 , %44 , %45 : tensor <128 x32 x!tt.ptr <f16 >, #blocked1 >, tensor <32 x128 x!tt.ptr <f16 >, #blocked >, tensor <128 x128 xf32 , #mma >, i32 , !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable >, !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable >
1446+ }
1447+ // expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1448+ %21 = arith.cmpi slt , %arg2 , %c0 : index
1449+ // expected-remark@+1 {{unsigned : [1, 18446744073709551615] signed : [-1, 1]}}
1450+ %22 = arith.select %21 , %c1 , %c -1 : index
1451+ // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1452+ %23 = arith.subi %arg1 , %arg0 : index
1453+ // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1454+ %24 = arith.addi %23 , %arg2 : index
1455+ // expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1456+ %25 = arith.addi %24 , %22 : index
1457+ // expected-remark@+2 {{unsigned : [1, 9223372036854775807] signed : [1, 9223372036854775807]}}
1458+ // expected-remark@+1 {{non-neg}}
1459+ %26 = arith.divsi %25 , %arg2 : index
1460+ %28 = ttg.local_load %20#4 : !ttg.memdesc <128 x32 xf16 , #shared , #smem , mutable > -> tensor <128 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
1461+ %29 = ttg.local_load %20#5 : !ttg.memdesc <32 x128 xf16 , #shared1 , #smem , mutable > -> tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
1462+ %30 = arith.mulf %29 , %cst : tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
1463+ // expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}}
1464+ // expected-remark@+1 {{result is true}}
1465+ %27 = arith.cmpi sge , %26 , %c1 : index
1466+ llvm.intr.assume %27 : i1
1467+ %31 = scf.if %27 -> (tensor <128 x128 xf32 , #mma >) {
1468+ %33 = tt.dot %28 , %30 , %20#2 : tensor <128 x32 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
1469+ scf.yield %33 : tensor <128 x128 xf32 , #mma >
1470+ } else {
1471+ scf.yield %20#2 : tensor <128 x128 xf32 , #mma >
1472+ }
1473+ %32 = arith.select %27 , %31 , %20#2 : tensor <128 x128 xf32 , #mma >
1474+ ttg.local_dealloc %10 : !ttg.memdesc <1 x128 x32 xf16 , #shared , #smem , mutable >
1475+ ttg.local_dealloc %11 : !ttg.memdesc <1 x32 x128 xf16 , #shared1 , #smem , mutable >
1476+ tt.return %32 : tensor <128 x128 xf32 , #mma >
1477+ }
1478+ }
0 commit comments