@@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics {
11341134      if  (FusedInsts.count (Inst))
11351135        continue ;
11361136
1137-       IRBuilder<> Builder (Inst);
1138- 
11391137      const  ShapeInfo &SI = ShapeMap.at (Inst);
11401138
11411139      Value *Op1;
11421140      Value *Op2;
1141+       MatrixTy Result;
11431142      if  (auto  *BinOp = dyn_cast<BinaryOperator>(Inst))
1144-         VisitBinaryOperator (BinOp, SI);
1143+         Result =  VisitBinaryOperator (BinOp, SI);
11451144      else  if  (auto  *Cast = dyn_cast<CastInst>(Inst))
1146-         VisitCastInstruction (Cast, SI);
1145+         Result =  VisitCastInstruction (Cast, SI);
11471146      else  if  (auto  *UnOp = dyn_cast<UnaryOperator>(Inst))
1148-         VisitUnaryOperator (UnOp, SI);
1149-       else  if  (IntrinsicInst  *Intr = dyn_cast<IntrinsicInst>(Inst))
1150-         VisitIntrinsicInst (Intr, SI);
1147+         Result =  VisitUnaryOperator (UnOp, SI);
1148+       else  if  (auto  *Intr = dyn_cast<IntrinsicInst>(Inst))
1149+         Result =  VisitIntrinsicInst (Intr, SI);
11511150      else  if  (match (Inst, m_Load (m_Value (Op1))))
1152-         VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder );
1151+         Result =  VisitLoad (cast<LoadInst>(Inst), SI, Op1);
11531152      else  if  (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1154-         VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder );
1153+         Result =  VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2);
11551154      else 
11561155        continue ;
1156+ 
1157+       IRBuilder<> Builder (Inst);
1158+       finalizeLowering (Inst, Result, Builder);
11571159      Changed = true ;
11581160    }
11591161
@@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics {
11931195  }
11941196
11951197  // / Replace intrinsic calls.
1196-   void  VisitIntrinsicInst (IntrinsicInst *Inst, const  ShapeInfo &Shape) {
1197-     switch  (Inst->getIntrinsicID ()) {
1198+   MatrixTy VisitIntrinsicInst (IntrinsicInst *Inst, const  ShapeInfo &SI) {
1199+     assert (Inst->getCalledFunction () &&
1200+            Inst->getCalledFunction ()->isIntrinsic ());
1201+ 
1202+     switch  (Inst->getCalledFunction ()->getIntrinsicID ()) {
11981203    case  Intrinsic::matrix_multiply:
1199-       LowerMultiply (Inst);
1200-       return ;
1204+       return  LowerMultiply (Inst);
12011205    case  Intrinsic::matrix_transpose:
1202-       LowerTranspose (Inst);
1203-       return ;
1206+       return  LowerTranspose (Inst);
12041207    case  Intrinsic::matrix_column_major_load:
1205-       LowerColumnMajorLoad (Inst);
1206-       return ;
1208+       return  LowerColumnMajorLoad (Inst);
12071209    case  Intrinsic::matrix_column_major_store:
1208-       LowerColumnMajorStore (Inst);
1209-       return ;
1210+       return  LowerColumnMajorStore (Inst);
12101211    case  Intrinsic::abs:
12111212    case  Intrinsic::fabs: {
12121213      IRBuilder<> Builder (Inst);
12131214      MatrixTy Result;
1214-       MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape , Builder);
1215+       MatrixTy M = getMatrix (Inst->getOperand (0 ), SI , Builder);
12151216      Builder.setFastMathFlags (getFastMathFlags (Inst));
12161217
12171218      for  (auto  &Vector : M.vectors ()) {
@@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics {
12291230        }
12301231      }
12311232
1232-       finalizeLowering (Inst,
1233-                        Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234-                                                Result.getNumVectors ()),
1235-                        Builder);
1236-       return ;
1233+       return  Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234+                                      Result.getNumVectors ());
12371235    }
12381236    default :
1239-       llvm_unreachable (
1240-           " only intrinsics supporting shape info should be seen here"  );
1237+       break ;
12411238    }
1239+     llvm_unreachable (
1240+         " only intrinsics supporting shape info should be seen here"  );
12421241  }
12431242
12441243  // / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics {
13041303  }
13051304
13061305  // / Lower a load instruction with shape information.
1307-   void  LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride ,
1308-                  bool  IsVolatile, ShapeInfo Shape) {
1306+   MatrixTy  LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align,
1307+                      Value *Stride,  bool  IsVolatile, ShapeInfo Shape) {
13091308    IRBuilder<> Builder (Inst);
1310-     finalizeLowering (Inst,
1311-                      loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile,
1312-                                 Shape, Builder),
1313-                      Builder);
1309+     return  loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile, Shape,
1310+                       Builder);
13141311  }
13151312
13161313  // / Lowers llvm.matrix.column.major.load.
13171314  // /
13181315  // / The intrinsic loads a matrix from memory using a stride between columns.
1319-   void  LowerColumnMajorLoad (CallInst *Inst) {
1316+   MatrixTy  LowerColumnMajorLoad (CallInst *Inst) {
13201317    assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13211318           " Intrinsic only supports column-major layout!"  );
13221319    Value *Ptr = Inst->getArgOperand (0 );
13231320    Value *Stride = Inst->getArgOperand (1 );
1324-     LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1325-               cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1326-               {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
1321+     return   LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1322+                       cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1323+                       {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
13271324  }
13281325
13291326  // / Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics {
13661363  }
13671364
13681365  // / Lower a store instruction with shape information.
1369-   void  LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1370-                   Value *Stride, bool  IsVolatile, ShapeInfo Shape) {
1366+   MatrixTy LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr,
1367+                       MaybeAlign A, Value *Stride, bool  IsVolatile,
1368+                       ShapeInfo Shape) {
13711369    IRBuilder<> Builder (Inst);
13721370    auto  StoreVal = getMatrix (Matrix, Shape, Builder);
1373-     finalizeLowering (Inst,
1374-                      storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride,
1375-                                  IsVolatile, Builder),
1376-                      Builder);
1371+     return  storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride, IsVolatile,
1372+                        Builder);
13771373  }
13781374
13791375  // / Lowers llvm.matrix.column.major.store.
13801376  // /
13811377  // / The intrinsic store a matrix back memory using a stride between columns.
1382-   void  LowerColumnMajorStore (CallInst *Inst) {
1378+   MatrixTy  LowerColumnMajorStore (CallInst *Inst) {
13831379    assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13841380           " Intrinsic only supports column-major layout!"  );
13851381    Value *Matrix = Inst->getArgOperand (0 );
13861382    Value *Ptr = Inst->getArgOperand (1 );
13871383    Value *Stride = Inst->getArgOperand (2 );
1388-     LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1389-                cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1390-                {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
1384+     return   LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1385+                        cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1386+                        {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
13911387  }
13921388
13931389  //  Set elements I..I+NumElts-1 to Block
@@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics {
21622158  }
21632159
21642160  // / Lowers llvm.matrix.multiply.
2165-   void  LowerMultiply (CallInst *MatMul) {
2161+   MatrixTy  LowerMultiply (CallInst *MatMul) {
21662162    IRBuilder<> Builder (MatMul);
21672163    auto  *EltType = cast<FixedVectorType>(MatMul->getType ())->getElementType ();
21682164    ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
@@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics {
21842180
21852181    emitMatrixMultiply (Result, Lhs, Rhs, Builder, false , false ,
21862182                       getFastMathFlags (MatMul));
2187-     finalizeLowering (MatMul,  Result, Builder) ;
2183+     return  Result;
21882184  }
21892185
21902186  // / Lowers llvm.matrix.transpose.
2191-   void  LowerTranspose (CallInst *Inst) {
2187+   MatrixTy  LowerTranspose (CallInst *Inst) {
21922188    MatrixTy Result;
21932189    IRBuilder<> Builder (Inst);
21942190    Value *InputVal = Inst->getArgOperand (0 );
@@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics {
22182214    //  TODO: Improve estimate of operations needed for transposes. Currently we
22192215    //  just count the insertelement/extractelement instructions, but do not
22202216    //  account for later simplifications/combines.
2221-     finalizeLowering (
2222-         Inst,
2223-         Result.addNumComputeOps (2  * ArgShape.NumRows  * ArgShape.NumColumns )
2224-             .addNumExposedTransposes (1 ),
2225-         Builder);
2217+     return  Result.addNumComputeOps (2  * ArgShape.NumRows  * ArgShape.NumColumns )
2218+         .addNumExposedTransposes (1 );
22262219  }
22272220
22282221  // / Lower load instructions.
2229-   void  VisitLoad (LoadInst *Inst, const  ShapeInfo &SI, Value *Ptr, 
2230-                   IRBuilder<> & Builder) { 
2231-     LowerLoad (Inst, Ptr, Inst->getAlign (), Builder. getInt64 (SI. getStride () ),
2232-               Inst->isVolatile (), SI);
2222+   MatrixTy  VisitLoad (LoadInst *Inst, const  ShapeInfo &SI, Value *Ptr) { 
2223+     IRBuilder<> Builder (Inst); 
2224+     return   LowerLoad (Inst, Ptr, Inst->getAlign (),
2225+                      Builder. getInt64 (SI. getStride ()),  Inst->isVolatile (), SI);
22332226  }
22342227
2235-   void  VisitStore (StoreInst *Inst, const  ShapeInfo &SI, Value *StoredVal,
2236-                   Value *Ptr, IRBuilder<> &Builder) {
2237-     LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2238-                Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2228+   MatrixTy VisitStore (StoreInst *Inst, const  ShapeInfo &SI, Value *StoredVal,
2229+                       Value *Ptr) {
2230+     IRBuilder<> Builder (Inst);
2231+     return  LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2232+                       Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
22392233  }
22402234
22412235  // / Lower binary operators.
2242-   void  VisitBinaryOperator (BinaryOperator *Inst, const  ShapeInfo &SI) {
2236+   MatrixTy  VisitBinaryOperator (BinaryOperator *Inst, const  ShapeInfo &SI) {
22432237    Value *Lhs = Inst->getOperand (0 );
22442238    Value *Rhs = Inst->getOperand (1 );
22452239
@@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics {
22582252      Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
22592253                                           B.getVector (I)));
22602254
2261-     finalizeLowering (Inst,
2262-                      Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2263-                                              Result.getNumVectors ()),
2264-                      Builder);
2255+     return  Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2256+                                    Result.getNumVectors ());
22652257  }
22662258
22672259  // / Lower unary operators.
2268-   void  VisitUnaryOperator (UnaryOperator *Inst, const  ShapeInfo &SI) {
2260+   MatrixTy  VisitUnaryOperator (UnaryOperator *Inst, const  ShapeInfo &SI) {
22692261    Value *Op = Inst->getOperand (0 );
22702262
22712263    IRBuilder<> Builder (Inst);
@@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics {
22882280    for  (unsigned  I = 0 ; I < SI.getNumVectors (); ++I)
22892281      Result.addVector (BuildVectorOp (M.getVector (I)));
22902282
2291-     finalizeLowering (Inst,
2292-                      Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2293-                                              Result.getNumVectors ()),
2294-                      Builder);
2283+     return  Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2284+                                    Result.getNumVectors ());
22952285  }
22962286
22972287  // / Lower cast instructions.
2298-   void  VisitCastInstruction (CastInst *Inst, const  ShapeInfo &Shape) {
2288+   MatrixTy  VisitCastInstruction (CastInst *Inst, const  ShapeInfo &Shape) {
22992289    Value *Op = Inst->getOperand (0 );
23002290
23012291    IRBuilder<> Builder (Inst);
@@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics {
23122302    for  (auto  &Vector : M.vectors ())
23132303      Result.addVector (Builder.CreateCast (Inst->getOpcode (), Vector, NewVTy));
23142304
2315-     finalizeLowering (Inst,
2316-                      Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2317-                                              Result.getNumVectors ()),
2318-                      Builder);
2305+     return  Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2306+                                    Result.getNumVectors ());
23192307  }
23202308
23212309  // / Helper to linearize a matrix expression tree into a string. Currently
0 commit comments