@@ -462,18 +462,28 @@ struct FatPointers {
462462
463463 friend bool operator ==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
464464 return lhs.canNarrow == rhs.canNarrow &&
465- lhs.attributes == rhs.attributes &&
466- lhs.smallTensorBase == rhs.smallTensorBase ;
465+ lhs.isSmallTensor == rhs.isSmallTensor &&
466+ lhs.attributes == rhs.attributes ;
467467 }
468468
469469 friend bool operator !=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
470470 return !(lhs == rhs);
471471 }
472472
473+ static FatPtrAttrs merge (const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) {
474+ FatPtrAttrs result;
475+ result.canNarrow = lhs.canNarrow && rhs.canNarrow ;
476+ result.isSmallTensor = lhs.isSmallTensor && rhs.isSmallTensor ;
477+ for (const auto &attr : lhs.attributes ) {
478+ auto it = rhs.attributes .find (attr.first );
479+ if (it != rhs.attributes .end () && it->second == attr.second )
480+ result.attributes [attr.first ] = attr.second ;
481+ }
482+ return result;
483+ }
484+
473485 llvm::DenseMap<StringRef, Attribute> attributes;
474- // If the fat-pointer points to somewhere in a small-tensor, keep track the
475- // base of the tensor.
476- Value smallTensorBase;
486+ bool isSmallTensor = false ;
477487 bool canNarrow = false ;
478488 };
479489
@@ -745,7 +755,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
745755 RewriterBase::InsertionGuard guard (rewriter);
746756 rewriter.setInsertionPoint (addPtrOp);
747757
748- if (fatPtrs.at ({fatPtrBase, fatPtrOffset}).smallTensorBase )
758+ if (fatPtrs.at ({fatPtrBase, fatPtrOffset}).isSmallTensor )
749759 return rewriteSmallTensorPtr (addPtrOp, adaptor, rewriter);
750760
751761 // Query all discardable attributes that we want to preserve
@@ -861,7 +871,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
861871 const auto &oldAttr = fatPtrs.at ({fatPtrBase, fatPtrOffset});
862872
863873 LDBG (" smal-tensor addPtr: " << addPtrOp);
864- LDBG (" - with tensor-base : " << oldAttr.smallTensorBase );
874+ LDBG (" - isSmallTensor : " << oldAttr.isSmallTensor );
865875 LDBG (" - with originl offset: " << origOffset);
866876 LDBG (" - fatPtr base: " << fatPtrBase);
867877 LDBG (" - fatPtr offst: " << fatPtrOffset);
@@ -1362,17 +1372,6 @@ class ConvertArithSelectOp
13621372 // select of base and offset
13631373 ValueRange fatPtrFalse = adaptor.getFalseValue ();
13641374 ValueRange fatPtrTrue = adaptor.getTrueValue ();
1365- // Simple case of a scalar select: update the base pointer
1366- if (!isa<RankedTensorType>(selectOp.getType ())) {
1367- auto newSelectOp = arith::SelectOp::create (
1368- rewriter, selectOp.getLoc (), selectOp.getType (),
1369- selectOp.getCondition (), fatPtrTrue[0 ], selectOp.getFalseValue ());
1370- rewriter.replaceOpWithMultiple (selectOp, {{newSelectOp, fatPtrTrue[1 ]}});
1371- fatPtrs[{newSelectOp, /* fatPtrOffset*/ fatPtrTrue[1 ]}] =
1372- fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]});
1373- return success ();
1374- }
1375-
13761375 // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF)
13771376 auto newBase = arith::SelectOp::create (rewriter, selectOp.getLoc (),
13781377 selectOp.getCondition (),
@@ -1381,12 +1380,10 @@ class ConvertArithSelectOp
13811380 selectOp.getCondition (),
13821381 fatPtrTrue[1 ], fatPtrFalse[1 ]);
13831382
1384- assert ((fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]}) ==
1385- fatPtrs.at ({fatPtrFalse[0 ], fatPtrFalse[1 ]})) &&
1386- " expected can narrow to be the same for both fatPtrT and fatPtrF" );
1387-
13881383 rewriter.replaceOpWithMultiple (selectOp, {{newBase, newOffset}});
1389- fatPtrs[{newBase, newOffset}] = fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]});
1384+ fatPtrs[{newBase, newOffset}] = FatPointers::FatPtrAttrs::merge (
1385+ fatPtrs.at ({fatPtrTrue[0 ], fatPtrTrue[1 ]}),
1386+ fatPtrs.at ({fatPtrFalse[0 ], fatPtrFalse[1 ]}));
13901387
13911388 return success ();
13921389 }
@@ -1434,14 +1431,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14341431 assert (i < ifOp.thenYield ().getNumOperands () &&
14351432 i + 1 < ifOp.thenYield ().getNumOperands () &&
14361433 " expected idx to be within bounds of IfOp's results" );
1437- Value thenFatPtrBase = ifOp.thenYield ().getOperand (i);
1438- Value thenFatPtrOffset = ifOp.thenYield ().getOperand (i + 1 );
1439- Value elseFatPtrBase = ifOp.elseYield ().getOperand (i);
1440- Value elseFatPtrOffset = ifOp.elseYield ().getOperand (i + 1 );
1441- assert ((fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset}) ==
1442- fatPtrs.at ({elseFatPtrBase, elseFatPtrOffset})) &&
1443- " expected then fat ptr canNarrow and else fat ptr canNarrow "
1444- " to be equal" );
14451434 }
14461435 }
14471436 }
@@ -1467,8 +1456,17 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
14671456 for (int64_t idx : yieldPtrOffsets) {
14681457 Value thenFatPtrBase = newIfOp.thenYield ().getOperand (idx);
14691458 Value thenFatPtrOffset = newIfOp.thenYield ().getOperand (idx + 1 );
1470- fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1471- fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset});
1459+ const auto &thenAttrs = fatPtrs.at ({thenFatPtrBase, thenFatPtrOffset});
1460+ if (withElseRegion) {
1461+ Value elseFatPtrBase = newIfOp.elseYield ().getOperand (idx);
1462+ Value elseFatPtrOffset = newIfOp.elseYield ().getOperand (idx + 1 );
1463+ const auto &elseAttrs = fatPtrs.at ({elseFatPtrBase, elseFatPtrOffset});
1464+ fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1465+ FatPointers::FatPtrAttrs::merge (thenAttrs, elseAttrs);
1466+ } else {
1467+ fatPtrs[{newIfOp.getResult (idx), newIfOp.getResult (idx + 1 )}] =
1468+ thenAttrs;
1469+ }
14721470 }
14731471
14741472 ResultRange results = newIfOp.getResults ();
@@ -1708,7 +1706,7 @@ struct InitFuncPtrArgs : OpRewritePattern<tt::FuncOp> {
17081706 rewriter.replaceAllUsesExcept (arg, dummyCast.getResult (0 ), dummyCast);
17091707 fatPtrs[{arg, zeroOffset}].canNarrow = true ;
17101708 if (bitness != 64 )
1711- fatPtrs[{arg, zeroOffset}].smallTensorBase = arg ;
1709+ fatPtrs[{arg, zeroOffset}].isSmallTensor = true ;
17121710 }
17131711
17141712 newOp->setDiscardableAttr (kInitFuncArgsRewritten , rewriter.getUnitAttr ());
0 commit comments