Skip to content

Commit f39b472

Browse files
committed
[mlir][arith][tosa] Use extended mul in 32-bit tosa.apply_scale
To not introduce 64-bit types that may be difficult to handle for some targets. Reviewed By: rsuderman, antiagainst Differential Revision: https://reviews.llvm.org/D139777
1 parent 11b9c79 commit f39b472

File tree

2 files changed

+5
-19
lines changed

2 files changed

+5
-19
lines changed

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
127127

128128
Type resultTy = op.getType();
129129
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
130-
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
131130

132131
Value value = op.getValue();
133132
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
@@ -144,20 +143,13 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
144143
Value two32 = getConstantValue(loc, i32Ty, 2, rewriter);
145144
Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter);
146145
Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter);
147-
Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter);
148146

149147
// Compute the multiplication in 64-bits then select the high / low parts.
150-
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value32);
151-
Value multiplier64 =
152-
rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
153-
Value multiply64 =
154-
rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
155-
156148
// Grab out the high/low of the computation
157-
Value high64 =
158-
rewriter.create<arith::ShRUIOp>(loc, multiply64, thirtyTwo64);
159-
Value high32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, high64);
160-
Value low32 = rewriter.create<arith::MulIOp>(loc, value32, multiplier32);
149+
auto value64 =
150+
rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
151+
Value low32 = value64.getLow();
152+
Value high32 = value64.getHigh();
161153

162154
// Determine the direction and amount to shift the high bits.
163155
Value shiftOver32 = rewriter.create<arith::CmpIOp>(

mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,9 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
2121
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
2222
// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32
2323
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
24-
// CHECK-DAG: %[[C32L:.+]] = arith.constant 32 : i64
2524

2625
// Compute the high-low values of the matmul in 64-bits.
27-
// CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i32 to i64
28-
// CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
29-
// CHECK-DAG: %[[MUL64:.+]] = arith.muli %[[V64]], %[[M64]]
30-
// CHECK-DAG: %[[HI64:.+]] = arith.shrui %[[MUL64]], %[[C32L]]
31-
// CHECK-DAG: %[[HI:.+]] = arith.trunci %[[HI64]] : i64 to i32
32-
// CHECK-DAG: %[[LOW:.+]] = arith.muli %arg0, %arg1
26+
// CHECK-DAG: %[[LOW:.+]], %[[HI:.+]] = arith.mulsi_extended %arg0, %arg1
3327

3428
// Determine whether the high bits need to shift left or right and by how much.
3529
// CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]]

0 commit comments

Comments
 (0)