@@ -1704,18 +1704,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17041704 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
17051705 LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
17061706 PatternRewriter &rewriter) const override {
1707- auto warpOpYield = cast<gpu::YieldOp>(
1707+ auto yield = cast<gpu::YieldOp>(
17081708 warpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1709- // Only pick up `ForOp` if it is the last op in the region.
1710- Operation *lastNode = warpOpYield ->getPrevNode ();
1709+ // Only pick up forOp if it is the last op in the region.
1710+ Operation *lastNode = yield ->getPrevNode ();
17111711 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
17121712 if (!forOp)
17131713 return failure ();
1714- // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715- // Those Values need to be returned by the new warp op.
1714+ // Collect Values that come from the warp op but are outside the forOp.
1715+ // Those Value needs to be returned by the original warpOp and passed to
1716+ // the new op.
17161717 llvm::SmallSetVector<Value, 32 > escapingValues;
1717- SmallVector<Type> escapingValueInputTypes ;
1718- SmallVector<Type> escapingValueDistTypes ;
1718+ SmallVector<Type> inputTypes ;
1719+ SmallVector<Type> distTypes ;
17191720 mlir::visitUsedValuesDefinedAbove (
17201721 forOp.getBodyRegion (), [&](OpOperand *operand) {
17211722 Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
@@ -1727,153 +1728,81 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17271728 AffineMap map = distributionMapFn (operand->get ());
17281729 distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
17291730 }
1730- escapingValueInputTypes .push_back (operand->get ().getType ());
1731- escapingValueDistTypes .push_back (distType);
1731+ inputTypes .push_back (operand->get ().getType ());
1732+ distTypes .push_back (distType);
17321733 }
17331734 });
17341735
1735- if (llvm::is_contained (escapingValueDistTypes , Type{}))
1736+ if (llvm::is_contained (distTypes , Type{}))
17361737 return failure ();
1737- // `WarpOp` can yield two types of values:
1738- // 1. Values that are not results of the `ForOp`:
1739- // These values must also be yielded by the new `WarpOp`. Also, we need
1740- // to record the index mapping for these values to replace them later.
1741- // 2. Values that are results of the `ForOp`:
1742- // In this case, we record the index mapping between the `WarpOp` result
1743- // index and matching `ForOp` result index.
1744- SmallVector<Value> nonForYieldedValues;
1745- SmallVector<unsigned > nonForResultIndices;
1746- llvm::SmallDenseMap<unsigned , unsigned > forResultMapping;
1747- for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1748- // Yielded value is not a result of the forOp.
1749- if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ()) {
1750- nonForYieldedValues.push_back (yieldOperand.get ());
1751- nonForResultIndices.push_back (yieldOperand.getOperandNumber ());
1738+
1739+ SmallVector<size_t > newRetIndices;
1740+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1741+ rewriter, warpOp, escapingValues.getArrayRef (), distTypes,
1742+ newRetIndices);
1743+ yield = cast<gpu::YieldOp>(
1744+ newWarpOp.getBodyRegion ().getBlocks ().begin ()->getTerminator ());
1745+
1746+ SmallVector<Value> newOperands;
1747+ SmallVector<unsigned > resultIdx;
1748+ // Collect all the outputs coming from the forOp.
1749+ for (OpOperand &yieldOperand : yield->getOpOperands ()) {
1750+ if (yieldOperand.get ().getDefiningOp () != forOp.getOperation ())
17521751 continue ;
1753- }
1754- OpResult forResult = cast<OpResult>(yieldOperand.get ());
1755- forResultMapping[yieldOperand.getOperandNumber ()] =
1756- forResult.getResultNumber ();
1752+ auto forResult = cast<OpResult>(yieldOperand.get ());
1753+ newOperands.push_back (
1754+ newWarpOp.getResult (yieldOperand.getOperandNumber ()));
1755+ yieldOperand.set (forOp.getInitArgs ()[forResult.getResultNumber ()]);
1756+ resultIdx.push_back (yieldOperand.getOperandNumber ());
17571757 }
17581758
1759- // Newly created `WarpOp` will yield values in following order:
1760- // 1. All init args of the `ForOp`.
1761- // 2. All escaping values.
1762- // 3. All non-`ForOp` yielded values.
1763- SmallVector<Value> newWarpOpYieldValues;
1764- SmallVector<Type> newWarpOpDistTypes;
1765- for (auto [i, initArg] : llvm::enumerate (forOp.getInitArgs ())) {
1766- newWarpOpYieldValues.push_back (initArg);
1767- // Compute the distributed type for this init arg.
1768- Type distType = initArg.getType ();
1769- if (auto vecType = dyn_cast<VectorType>(distType)) {
1770- AffineMap map = distributionMapFn (initArg);
1771- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1772- }
1773- newWarpOpDistTypes.push_back (distType);
1774- }
1775- // Insert escaping values and their distributed types.
1776- newWarpOpYieldValues.insert (newWarpOpYieldValues.end (),
1777- escapingValues.begin (), escapingValues.end ());
1778- newWarpOpDistTypes.insert (newWarpOpDistTypes.end (),
1779- escapingValueDistTypes.begin (),
1780- escapingValueDistTypes.end ());
1781- // Next, we insert all non-`ForOp` yielded values and their distributed
1782- // types. We also create a mapping between the non-`ForOp` yielded value
1783- // index and the corresponding new `WarpOp` yield value index (needed to
1784- // update users later).
1785- llvm::SmallDenseMap<unsigned , unsigned > nonForResultMapping;
1786- for (auto [i, v] :
1787- llvm::zip_equal (nonForResultIndices, nonForYieldedValues)) {
1788- nonForResultMapping[i] = newWarpOpYieldValues.size ();
1789- newWarpOpYieldValues.push_back (v);
1790- newWarpOpDistTypes.push_back (warpOp.getResult (i).getType ());
1791- }
1792- // Create the new `WarpOp` with the updated yield values and types.
1793- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1794- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1795-
1796- // Next, we create a new `ForOp` with the init args yielded by the new
1797- // `WarpOp`.
1798- const unsigned escapingValuesStartIdx =
1799- forOp.getInitArgs ().size (); // `ForOp` init args are positioned before
1800- // escaping values in the new `WarpOp`.
1801- SmallVector<Value> newForOpOperands;
1802- for (size_t i = 0 ; i < escapingValuesStartIdx; ++i)
1803- newForOpOperands.push_back (newWarpOp.getResult (i));
1804-
1805- // Create a new `ForOp` outside the new `WarpOp` region.
18061759 OpBuilder::InsertionGuard g (rewriter);
18071760 rewriter.setInsertionPointAfter (newWarpOp);
1761+
1762+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
1763+ // region inside.
18081764 auto newForOp = rewriter.create <scf::ForOp>(
18091765 forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1810- forOp.getStep (), newForOpOperands);
1811- // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1812- // newly created `ForOp`. This `WarpOp` will contain all ops that were
1813- // contained within the original `ForOp` body.
1766+ forOp.getStep (), newOperands);
18141767 rewriter.setInsertionPointToStart (newForOp.getBody ());
18151768
1816- SmallVector<Value> innerWarpInput (newForOp.getRegionIterArgs ().begin (),
1817- newForOp.getRegionIterArgs ().end ());
1818- SmallVector<Type> innerWarpInputType (forOp.getResultTypes ().begin (),
1819- forOp.getResultTypes ().end ());
1820- // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1821- // arguments. We keep track of the mapping between these values and their
1822- // argument index in the inner `WarpOp` (to replace users later).
1769+ SmallVector<Value> warpInput (newForOp.getRegionIterArgs ().begin (),
1770+ newForOp.getRegionIterArgs ().end ());
1771+ SmallVector<Type> warpInputType (forOp.getResultTypes ().begin (),
1772+ forOp.getResultTypes ().end ());
18231773 llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1824- for (size_t i = escapingValuesStartIdx;
1825- i < escapingValuesStartIdx + escapingValues.size (); ++i) {
1826- innerWarpInput.push_back (newWarpOp.getResult (i));
1827- argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1828- innerWarpInputType.size ();
1829- innerWarpInputType.push_back (
1830- escapingValueInputTypes[i - escapingValuesStartIdx]);
1774+ for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1775+ warpInput.push_back (newWarpOp.getResult (retIdx));
1776+ argIndexMapping[escapingValues[i]] = warpInputType.size ();
1777+ warpInputType.push_back (inputTypes[i]);
18311778 }
1832- // Create the inner `WarpOp` with the new input values and types.
18331779 auto innerWarp = rewriter.create <WarpExecuteOnLane0Op>(
18341780 newWarpOp.getLoc (), newForOp.getResultTypes (), newWarpOp.getLaneid (),
1835- newWarpOp.getWarpSize (), innerWarpInput, innerWarpInputType );
1781+ newWarpOp.getWarpSize (), warpInput, warpInputType );
18361782
1837- // Inline the `ForOp` body into the inner `WarpOp` body.
18381783 SmallVector<Value> argMapping;
18391784 argMapping.push_back (newForOp.getInductionVar ());
1840- for (Value args : innerWarp.getBody ()->getArguments ())
1785+ for (Value args : innerWarp.getBody ()->getArguments ()) {
18411786 argMapping.push_back (args);
1842-
1787+ }
18431788 argMapping.resize (forOp.getBody ()->getNumArguments ());
18441789 SmallVector<Value> yieldOperands;
18451790 for (Value operand : forOp.getBody ()->getTerminator ()->getOperands ())
18461791 yieldOperands.push_back (operand);
1847-
18481792 rewriter.eraseOp (forOp.getBody ()->getTerminator ());
18491793 rewriter.mergeBlocks (forOp.getBody (), innerWarp.getBody (), argMapping);
1850-
1851- // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1852- // original `ForOp` results.
18531794 rewriter.setInsertionPointToEnd (innerWarp.getBody ());
18541795 rewriter.create <gpu::YieldOp>(innerWarp.getLoc (), yieldOperands);
18551796 rewriter.setInsertionPointAfter (innerWarp);
1856- // Insert a scf.yield op at the end of the new `ForOp` body that yields
1857- // the inner `WarpOp` results.
18581797 if (!innerWarp.getResults ().empty ())
18591798 rewriter.create <scf::YieldOp>(forOp.getLoc (), innerWarp.getResults ());
1860-
1861- // Update the users of original `WarpOp` results that were coming from the
1862- // original `ForOp` to the corresponding new `ForOp` result.
1863- for (auto [origIdx, newIdx] : forResultMapping)
1864- rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1865- newForOp.getResult (newIdx), newForOp);
1866- // Similarly, update any users of the `WarpOp` results that were not
1867- // results of the `ForOp`.
1868- for (auto [origIdx, newIdx] : nonForResultMapping)
1869- rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1870- newWarpOp.getResult (newIdx));
1871- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1872- // at this point.
18731799 rewriter.eraseOp (forOp);
1874- rewriter.eraseOp (warpOp);
1875- // Update any users of escaping values that were forwarded to the
1876- // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1800+ // Replace the warpOp result coming from the original ForOp.
1801+ for (const auto &res : llvm::enumerate (resultIdx)) {
1802+ rewriter.replaceAllUsesWith (newWarpOp.getResult (res.value ()),
1803+ newForOp.getResult (res.index ()));
1804+ newForOp->setOperand (res.index () + 3 , newWarpOp.getResult (res.value ()));
1805+ }
18771806 newForOp.walk ([&](Operation *op) {
18781807 for (OpOperand &operand : op->getOpOperands ()) {
18791808 auto it = argIndexMapping.find (operand.get ());
@@ -1883,7 +1812,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
18831812 }
18841813 });
18851814
1886- // Finally, hoist out any now uniform code from the inner `WarpOp` .
1815+ // Finally, hoist out any now uniform code from the inner warp op .
18871816 mlir::vector::moveScalarUniformCode (innerWarp);
18881817 return success ();
18891818 }
0 commit comments