@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797 return cast<DILocalScope>(Scope)->getSubprogram ();
9898}
9999
100- // / Erase \p V from \p BB and move \II forward to avoid invalidating
101- // / iterators.
102- static void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
103- BasicBlock &BB) {
104- auto *Inst = cast<Instruction>(V);
105- // Still used, don't erase.
106- if (!Inst->use_empty ())
107- return ;
108- if (II != BB.rend () && Inst == &*II)
109- ++II;
110- Inst->eraseFromParent ();
111- }
112-
113100// / Return true if V is a splat of a value (which is used when multiplying a
114101// / matrix with a scalar).
115102static bool isSplat (Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259246// / Return the ShapeInfo for the result of \p I, it it can be determined.
260247static std::optional<ShapeInfo>
261248computeShapeInfoForInst (Instruction *I,
262- const ValueMap <Value *, ShapeInfo> &ShapeMap) {
249+ const DenseMap <Value *, ShapeInfo> &ShapeMap) {
263250 Value *M;
264251 Value *N;
265252 Value *K;
@@ -492,10 +479,16 @@ class LowerMatrixIntrinsics {
492479 // / the result value of the instruction, with the only exceptions being store
493480 // / instructions and the matrix_column_major_store intrinsics. For those, the
494481 // / shape information indicates that those instructions should be lowered
495- // / using shape information as well. A ValueMap is used so that when
496- // / sub-passes like optimizeTransposes performs RAUW the map stays
497- // / up-to-date.
498- ValueMap<Value *, ShapeInfo> ShapeMap;
482+ // / using shape information as well. Note that extra care is needed when
483+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
484+ // / replacement is also a matrix operation, use
485+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
486+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
487+ // / want to add shape information for a replacement instruction. When directly
488+ // / erasing a value with an entry in ShapeMap, use
489+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
490+ // / accordingly.
491+ DenseMap<Value *, ShapeInfo> ShapeMap;
499492
500493 // / List of instructions to remove. While lowering, we are not replacing all
501494 // / users of a lowered instruction, if shape information is available and
@@ -759,6 +752,30 @@ class LowerMatrixIntrinsics {
759752 return Operation (T0, Shape0.t (), T1, Shape1.t ());
760753 }
761754
755+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
756+ // / itself.
757+ void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
758+ auto Iter = ShapeMap.find (Inst);
759+ if (Iter != ShapeMap.end ())
760+ ShapeMap.erase (Iter);
761+ Inst->eraseFromParent ();
762+ }
763+
764+ // / Erase \p V from \p BB and move \II forward to avoid invalidating
765+ // / iterators.
766+ void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
767+ BasicBlock &BB) {
768+ auto *Inst = cast<Instruction>(V);
769+ // Still used, don't erase.
770+ if (!Inst->use_empty ())
771+ return ;
772+ if (II != BB.rend () && Inst == &*II)
773+ ++II;
774+ eraseFromParentAndRemoveFromShapeMap (Inst);
775+ }
776+
777+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
778+ // / entry for \p Old and replace all uses of \p Old with \p New.
762779 void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
763780 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
764781 // with New. We should only add New it it supportsShapeInfo so we insert
@@ -872,13 +889,13 @@ class LowerMatrixIntrinsics {
872889
873890 void liftTranspose (Instruction &I) {
874891 // Erase dead Instructions after lifting transposes from binops.
875- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
892+ auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
876893 if (T.use_empty ())
877- T. eraseFromParent ( );
894+ eraseFromParentAndRemoveFromShapeMap (&T );
878895 if (A->use_empty ())
879- cast<Instruction>(A)-> eraseFromParent ( );
896+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(A));
880897 if (A != B && B->use_empty ())
881- cast<Instruction>(B)-> eraseFromParent ( );
898+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(B));
882899 };
883900
884901 Value *A, *B, *AT, *BT;
@@ -908,8 +925,7 @@ class LowerMatrixIntrinsics {
908925 match (B, m_Intrinsic<Intrinsic::matrix_transpose>(
909926 m_Value (BT), m_ConstantInt (), m_ConstantInt ()))) {
910927 IRBuilder<> Builder (&I);
911- auto *Add = cast<Instruction>(Builder.CreateFAdd (AT, BT, " mfadd" ));
912- setShapeInfo (Add, {R, C});
928+ auto *Add = Builder.CreateFAdd (AT, BT, " mfadd" );
913929 MatrixBuilder MBuilder (Builder);
914930 Instruction *NewInst = MBuilder.CreateMatrixTranspose (
915931 Add, R->getZExtValue (), C->getZExtValue (), " mfadd_t" );
@@ -918,9 +934,13 @@ class LowerMatrixIntrinsics {
918934 computeShapeInfoForInst (&I, ShapeMap) &&
919935 " Shape of new instruction doesn't match original shape." );
920936 CleanupBinOp (I, A, B);
921- assert (computeShapeInfoForInst (Add, ShapeMap).value_or (ShapeMap[Add]) ==
922- ShapeMap[Add] &&
923- " Shape of updated addition doesn't match cached shape." );
937+ if (auto *AddI = dyn_cast<Instruction>(Add)) {
938+ setShapeInfo (AddI, {R, C});
939+ assert (
940+ computeShapeInfoForInst (AddI, ShapeMap).value_or (ShapeMap[AddI]) ==
941+ ShapeMap[AddI] &&
942+ " Shape of updated addition doesn't match cached shape." );
943+ }
924944 }
925945 }
926946
@@ -1014,7 +1034,8 @@ class LowerMatrixIntrinsics {
10141034
10151035 // Third, try to fuse candidates.
10161036 for (CallInst *CI : MaybeFusableInsts)
1017- LowerMatrixMultiplyFused (CI, FusedInsts, LifetimeEnds);
1037+ if (!FusedInsts.contains (CI))
1038+ LowerMatrixMultiplyFused (CI, FusedInsts, LifetimeEnds);
10181039
10191040 Changed = !FusedInsts.empty ();
10201041
@@ -1475,7 +1496,7 @@ class LowerMatrixIntrinsics {
14751496 m_Value (Arg)))) {
14761497 auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg);
14771498 Op->replaceAllUsesWith (NewLoad);
1478- cast<Instruction>(Op)-> eraseFromParent ( );
1499+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(Op));
14791500 return ;
14801501 } else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14811502 m_Value (Arg)))) {
@@ -1844,15 +1865,15 @@ class LowerMatrixIntrinsics {
18441865 // Mark eliminated instructions as fused and remove them.
18451866 FusedInsts.insert (Store);
18461867 FusedInsts.insert (MatMul);
1847- Store-> eraseFromParent ( );
1848- MatMul-> eraseFromParent ( );
1868+ eraseFromParentAndRemoveFromShapeMap (Store );
1869+ eraseFromParentAndRemoveFromShapeMap (MatMul );
18491870 if (LoadOp0->hasNUses (0 )) {
18501871 FusedInsts.insert (LoadOp0);
1851- LoadOp0-> eraseFromParent ( );
1872+ eraseFromParentAndRemoveFromShapeMap (LoadOp0 );
18521873 }
18531874 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses (0 )) {
18541875 FusedInsts.insert (LoadOp1);
1855- LoadOp1-> eraseFromParent ( );
1876+ eraseFromParentAndRemoveFromShapeMap (LoadOp1 );
18561877 }
18571878 }
18581879
0 commit comments