Skip to content

Commit 15df1ce

Browse files
committed
[AMD] DCE/canonicalize true epilogue conditionals
1 parent 7af8cad commit 15df1ce

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

test/TritonGPU/amd/amd-range-analysis.mlir

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,131 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13481348
tt.return
13491349
}
13501350
}
1351+
1352+
// -----
1353+
1354+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
1355+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
1356+
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = []}>
1357+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
1358+
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
1359+
#smem = #ttg.shared_memory
1360+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1361+
tt.func @assume_matmul(%arg0: index, %arg1: index, %arg2: index, %arg3: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #mma> {
1362+
// expected-remark@+1 {{unsigned : [18446744073709551615, 18446744073709551615] signed : [-1, -1]}}
1363+
%c-1 = arith.constant -1 : index
1364+
// expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}}
1365+
// expected-remark@+1 {{non-neg}}
1366+
%c1 = arith.constant 1 : index
1367+
// expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}}
1368+
// expected-remark@+1 {{non-neg}}
1369+
%c0 = arith.constant 0 : index
1370+
// expected-remark@+2 {{unsigned : [1, 1] signed : [1, 1]}}
1371+
// expected-remark@+1 {{non-neg}}
1372+
%c1_i32 = arith.constant 1 : i32
1373+
// expected-remark@+2 {{unsigned : [0, 0] signed : [0, 0]}}
1374+
// expected-remark@+1 {{non-neg}}
1375+
%c0_i32 = arith.constant 0 : i32
1376+
// expected-remark@+1 {{unsigned : [1, 1] signed : [-1, -1]}}
1377+
%true = arith.constant true
1378+
%cst = arith.constant dense<4.000000e+00> : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
1379+
// expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}}
1380+
// expected-remark@+1 {{non-neg}}
1381+
%cst_0 = arith.constant dense<4> : tensor<32x128xi32, #blocked>
1382+
// expected-remark@+2 {{unsigned : [4, 4] signed : [4, 4]}}
1383+
// expected-remark@+1 {{non-neg}}
1384+
%cst_1 = arith.constant dense<4> : tensor<128x32xi32, #blocked1>
1385+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
1386+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<32x128xf16, #blocked>
1387+
%0 = tt.splat %arg3 : !tt.ptr<f16> -> tensor<128x32x!tt.ptr<f16>, #blocked1>
1388+
// expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1389+
// expected-remark@+1 {{non-neg}}
1390+
%1 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
1391+
// expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1392+
// expected-remark@+1 {{non-neg}}
1393+
%2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
1394+
// expected-remark@+2 {{unsigned : [0, 32] signed : [0, 32]}}
1395+
// expected-remark@+1 {{non-neg}}
1396+
%3 = tt.broadcast %2 : tensor<1x32xi32, #blocked1> -> tensor<128x32xi32, #blocked1>
1397+
%4 = tt.addptr %0, %3 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
1398+
%5 = tt.splat %arg4 : !tt.ptr<f16> -> tensor<32x128x!tt.ptr<f16>, #blocked>
1399+
// expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1400+
// expected-remark@+1 {{non-neg}}
1401+
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
1402+
// expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1403+
// expected-remark@+1 {{non-neg}}
1404+
%7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked>
1405+
// expected-remark@+2 {{unsigned : [0, 128] signed : [0, 128]}}
1406+
// expected-remark@+1 {{non-neg}}
1407+
%8 = tt.broadcast %7 : tensor<1x128xi32, #blocked> -> tensor<32x128xi32, #blocked>
1408+
%9 = tt.addptr %5, %8 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
1409+
%10 = ttg.local_alloc : () -> !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
1410+
%11 = ttg.local_alloc : () -> !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
1411+
// expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1412+
%12 = arith.cmpi slt, %arg0, %arg1 : index
1413+
// expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1414+
%13 = tt.splat %12 : i1 -> tensor<128x32xi1, #blocked1>
1415+
%14 = tt.load %4, %13 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
1416+
// expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1417+
%15 = tt.splat %12 : i1 -> tensor<32x128xi1, #blocked>
1418+
%16 = tt.load %9, %15, %cst_3 : tensor<32x128x!tt.ptr<f16>, #blocked>
1419+
%17 = ttg.memdesc_subview %10[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
1420+
ttg.local_store %14, %17 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
1421+
%18 = ttg.memdesc_subview %11[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
1422+
ttg.local_store %16, %18 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
1423+
// expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1424+
%19 = arith.subi %arg1, %arg2 : index
1425+
%20:6 = scf.for %arg5 = %arg0 to %19 step %arg2 iter_args(%arg6 = %4, %arg7 = %9, %arg8 = %cst_2, %arg9 = %c0_i32, %arg10 = %17, %arg11 = %18) -> (tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>) {
1426+
%33 = tt.addptr %arg6, %cst_1 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<128x32xi32, #blocked1>
1427+
%34 = tt.addptr %arg7, %cst_0 : tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<32x128xi32, #blocked>
1428+
llvm.intr.assume %true : i1
1429+
%35 = tt.load %33 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32x!tt.ptr<f16>, #blocked1>
1430+
%36 = ttg.local_load %arg10 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
1431+
%37 = tt.load %34 : tensor<32x128x!tt.ptr<f16>, #blocked>
1432+
%38 = ttg.local_load %arg11 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
1433+
%39 = arith.mulf %38, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
1434+
%40 = tt.dot %36, %39, %arg8 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
1435+
// expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}}
1436+
%41 = arith.addi %arg9, %c1_i32 : i32
1437+
// expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1438+
%42 = arith.cmpi slt, %41, %c1_i32 : i32
1439+
// expected-remark@+1 {{unsigned : [0, 4294967295] signed : [-2147483648, 2147483647]}}
1440+
%43 = arith.select %42, %41, %c0_i32 : i32
1441+
%44 = ttg.memdesc_subview %10[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
1442+
ttg.local_store %35, %44 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x32xf16, #blocked1> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable>
1443+
%45 = ttg.memdesc_subview %11[%43, %c0_i32, %c0_i32] : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
1444+
ttg.local_store %37, %45 : tensor<32x128xf16, #blocked> -> !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
1445+
scf.yield %33, %34, %40, %43, %44, %45 : tensor<128x32x!tt.ptr<f16>, #blocked1>, tensor<32x128x!tt.ptr<f16>, #blocked>, tensor<128x128xf32, #mma>, i32, !ttg.memdesc<128x32xf16, #shared, #smem, mutable>, !ttg.memdesc<32x128xf16, #shared1, #smem, mutable>
1446+
}
1447+
// expected-remark@+1 {{unsigned : [0, 1] signed : [-1, 0]}}
1448+
%21 = arith.cmpi slt, %arg2, %c0 : index
1449+
// expected-remark@+1 {{unsigned : [1, 18446744073709551615] signed : [-1, 1]}}
1450+
%22 = arith.select %21, %c1, %c-1 : index
1451+
// expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1452+
%23 = arith.subi %arg1, %arg0 : index
1453+
// expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1454+
%24 = arith.addi %23, %arg2 : index
1455+
// expected-remark@+1 {{unsigned : [0, 18446744073709551615] signed : [-9223372036854775808, 9223372036854775807]}}
1456+
%25 = arith.addi %24, %22 : index
1457+
// expected-remark@+2 {{unsigned : [1, 9223372036854775807] signed : [1, 9223372036854775807]}}
1458+
// expected-remark@+1 {{non-neg}}
1459+
%26 = arith.divsi %25, %arg2 : index
1460+
%28 = ttg.local_load %20#4 : !ttg.memdesc<128x32xf16, #shared, #smem, mutable> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
1461+
%29 = ttg.local_load %20#5 : !ttg.memdesc<32x128xf16, #shared1, #smem, mutable> -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
1462+
%30 = arith.mulf %29, %cst : tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
1463+
// expected-remark@+2 {{unsigned : [1, 1] signed : [-1, -1]}}
1464+
// expected-remark@+1 {{result is true}}
1465+
%27 = arith.cmpi sge, %26, %c1 : index
1466+
llvm.intr.assume %27 : i1
1467+
%31 = scf.if %27 -> (tensor<128x128xf32, #mma>) {
1468+
%33 = tt.dot %28, %30, %20#2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
1469+
scf.yield %33 : tensor<128x128xf32, #mma>
1470+
} else {
1471+
scf.yield %20#2 : tensor<128x128xf32, #mma>
1472+
}
1473+
%32 = arith.select %27, %31, %20#2 : tensor<128x128xf32, #mma>
1474+
ttg.local_dealloc %10 : !ttg.memdesc<1x128x32xf16, #shared, #smem, mutable>
1475+
ttg.local_dealloc %11 : !ttg.memdesc<1x32x128xf16, #shared1, #smem, mutable>
1476+
tt.return %32 : tensor<128x128xf32, #mma>
1477+
}
1478+
}

0 commit comments

Comments
 (0)