@@ -97,19 +97,6 @@ static DISubprogram *getSubprogram(DIScope *Scope) {
9797 return cast<DILocalScope>(Scope)->getSubprogram ();
9898}
9999
100- // / Erase \p V from \p BB and move \II forward to avoid invalidating
101- // / iterators.
102- static void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
103- BasicBlock &BB) {
104- auto *Inst = cast<Instruction>(V);
105- // Still used, don't erase.
106- if (!Inst->use_empty ())
107- return ;
108- if (II != BB.rend () && Inst == &*II)
109- ++II;
110- Inst->eraseFromParent ();
111- }
112-
113100// / Return true if V is a splat of a value (which is used when multiplying a
114101// / matrix with a scalar).
115102static bool isSplat (Value *V) {
@@ -259,7 +246,7 @@ static bool isUniformShape(Value *V) {
259246// / Return the ShapeInfo for the result of \p I, it it can be determined.
260247static std::optional<ShapeInfo>
261248computeShapeInfoForInst (Instruction *I,
262- const ValueMap <Value *, ShapeInfo> &ShapeMap) {
249+ const DenseMap <Value *, ShapeInfo> &ShapeMap) {
263250 Value *M;
264251 Value *N;
265252 Value *K;
@@ -492,10 +479,16 @@ class LowerMatrixIntrinsics {
492479 // / the result value of the instruction, with the only exceptions being store
493480 // / instructions and the matrix_column_major_store intrinsics. For those, the
494481 // / shape information indicates that those instructions should be lowered
495- // / using shape information as well. A ValueMap is used so that when
496- // / sub-passes like optimizeTransposes performs RAUW the map stays
497- // / up-to-date.
498- ValueMap<Value *, ShapeInfo> ShapeMap;
482+ // / using shape information as well. Note that extra care is needed when
483+ // / erasing or RAUW'ing a value that is present in ShapeMap. If the
484+ // / replacement is also a matrix operation, use
485+ // / updateShapeAndReplaceAllUsesWith to make sure the replacement is added to
486+ // / ShapeMap. We don't use ValueMap, as there are also cases where we do not
487+ // / want to add shape information for a replacement instruction. When directly
488+ // / erasing a value with an entry in ShapeMap, use
489+ // / eraseFromParentAndRemoveFromShapeMap to make sure ShapeMap is also updated
490+ // / accordingly.
491+ DenseMap<Value *, ShapeInfo> ShapeMap;
499492
500493 // / List of instructions to remove. While lowering, we are not replacing all
501494 // / users of a lowered instruction, if shape information is available and
@@ -759,6 +752,30 @@ class LowerMatrixIntrinsics {
759752 return Operation (T0, Shape0.t (), T1, Shape1.t ());
760753 }
761754
755+ // / Erase \p Inst from both ShapeMap (if an entry exists) and erase \p Inst
756+ // / itself.
757+ void eraseFromParentAndRemoveFromShapeMap (Instruction *Inst) {
758+ auto Iter = ShapeMap.find (Inst);
759+ if (Iter != ShapeMap.end ())
760+ ShapeMap.erase (Iter);
761+ Inst->eraseFromParent ();
762+ }
763+
764+ // / Erase \p V from \p BB and move \II forward to avoid invalidating
765+ // / iterators.
766+ void eraseFromParentAndMove (Value *V, BasicBlock::reverse_iterator &II,
767+ BasicBlock &BB) {
768+ auto *Inst = cast<Instruction>(V);
769+ // Still used, don't erase.
770+ if (!Inst->use_empty ())
771+ return ;
772+ if (II != BB.rend () && Inst == &*II)
773+ ++II;
774+ eraseFromParentAndRemoveFromShapeMap (Inst);
775+ }
776+
777+ // / Add a new entry to ShapeMap for \p New with \p Old's shape info, erase the
778+ // / entry for \p Old and replace all uses of \p Old with \p New.
762779 void updateShapeAndReplaceAllUsesWith (Instruction &Old, Value *New) {
763780 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
764781 // with New. We should only add New it it supportsShapeInfo so we insert
@@ -872,13 +889,13 @@ class LowerMatrixIntrinsics {
872889
873890 void liftTranspose (Instruction &I) {
874891 // Erase dead Instructions after lifting transposes from binops.
875- auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
892+ auto CleanupBinOp = [this ](Instruction &T, Value *A, Value *B) {
876893 if (T.use_empty ())
877- T. eraseFromParent ( );
894+ eraseFromParentAndRemoveFromShapeMap (&T );
878895 if (A->use_empty ())
879- cast<Instruction>(A)-> eraseFromParent ( );
896+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(A));
880897 if (A != B && B->use_empty ())
881- cast<Instruction>(B)-> eraseFromParent ( );
898+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(B));
882899 };
883900
884901 Value *A, *B, *AT, *BT;
@@ -1476,7 +1493,7 @@ class LowerMatrixIntrinsics {
14761493 m_Value (Arg)))) {
14771494 auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg);
14781495 Op->replaceAllUsesWith (NewLoad);
1479- cast<Instruction>(Op)-> eraseFromParent ( );
1496+ eraseFromParentAndRemoveFromShapeMap ( cast<Instruction>(Op));
14801497 return ;
14811498 } else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
14821499 m_Value (Arg)))) {
@@ -1845,15 +1862,15 @@ class LowerMatrixIntrinsics {
18451862 // Mark eliminated instructions as fused and remove them.
18461863 FusedInsts.insert (Store);
18471864 FusedInsts.insert (MatMul);
1848- Store-> eraseFromParent ( );
1849- MatMul-> eraseFromParent ( );
1865+ eraseFromParentAndRemoveFromShapeMap (Store );
1866+ eraseFromParentAndRemoveFromShapeMap (MatMul );
18501867 if (LoadOp0->hasNUses (0 )) {
18511868 FusedInsts.insert (LoadOp0);
1852- LoadOp0-> eraseFromParent ( );
1869+ eraseFromParentAndRemoveFromShapeMap (LoadOp0 );
18531870 }
18541871 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses (0 )) {
18551872 FusedInsts.insert (LoadOp1);
1856- LoadOp1-> eraseFromParent ( );
1873+ eraseFromParentAndRemoveFromShapeMap (LoadOp1 );
18571874 }
18581875 }
18591876
0 commit comments