@@ -552,9 +552,9 @@ static void insertViewForTTNNDRAMTensor(Value operand,
552552 fakeShardedShape, metalTensor.getElementType (), viewOutputLayout);
553553
554554 builder.setInsertionPointAfter (castOp);
555- auto viewOp = builder. create < d2m::ViewLayoutOp>(
556- castOp. getLoc (), viewOutputTensor, castOp.getResult (),
557- AffineMapAttr::get (reblockMap));
555+ auto viewOp = d2m::ViewLayoutOp::create (builder, castOp. getLoc (),
556+ viewOutputTensor, castOp.getResult (),
557+ AffineMapAttr::get (reblockMap));
558558 castOp.getResult ().replaceAllUsesExcept (viewOp.getResult (), viewOp);
559559}
560560
@@ -580,18 +580,20 @@ static void optimizeTTNNMetalLayoutCastOpGrid(
580580
581581 builder.setInsertionPointAfter (castOp);
582582
583- auto newViewLayoutOp = builder.create <d2m::ViewLayoutOp>(
584- castOp.getLoc (), newTensorType, castOp.getResult (), gridRemapping);
583+ auto newViewLayoutOp =
584+ d2m::ViewLayoutOp::create (builder, castOp.getLoc (), newTensorType,
585+ castOp.getResult (), gridRemapping);
585586
586587 // Reblock it back to original shape to preserve IR correctness.
587588 auto viewOutputType = utils::reblockTensor (
588589 newTensorType, outputLayout.getGridShape (outputType));
589590 auto reblockMap = ttmlir::utils::calculateReblockMap (
590591 newTensorType.getShape (), viewOutputType.getShape (),
591592 builder.getContext ());
592- auto revertingView = builder.create <d2m::ViewLayoutOp>(
593- castOp.getLoc (), viewOutputType, newViewLayoutOp.getResult (), reblockMap,
594- /* reinterpretLayout=*/ false );
593+ auto revertingView =
594+ d2m::ViewLayoutOp::create (builder, castOp.getLoc (), viewOutputType,
595+ newViewLayoutOp.getResult (), reblockMap,
596+ /* reinterpretLayout=*/ false );
595597
596598 castOp.getResult ().replaceAllUsesExcept (revertingView.getResult (),
597599 newViewLayoutOp);
@@ -1244,223 +1246,6 @@ recreateGenericOp(d2m::GenericOp genericOp,
12441246 }
12451247}
12461248
1247- static bool hasTTNNOperands (d2m::GenericOp genericOp) {
1248- for (Value operand : genericOp.getInputsAndOutputs ()) {
1249- if (operand.getDefiningOp <ttir::TTNNMetalLayoutCastOp>()) {
1250- return true ;
1251- }
1252- // Check if view operand's input is the result of a TTNNMetalLayoutCastOp.
1253- if (auto view = operand.getDefiningOp <d2m::ViewLayoutOp>();
1254- view && view.getInput ().getDefiningOp <ttir::TTNNMetalLayoutCastOp>()) {
1255- return true ;
1256- }
1257- }
1258- return false ;
1259- }
1260-
1261- // Computes the expected TTNN generic output grid shape for the given tensor.
1262- static llvm::SmallVector<llvm::SmallVector<int64_t >>
1263- computeTTNNGenericGridShapes (GenericOp genericOp,
1264- ArrayRef<int64_t > targetSquareGridShape) {
1265-
1266- auto optimalOperandGrids = llvm::SmallVector<llvm::SmallVector<int64_t >>(
1267- genericOp.getInputsAndOutputs ().size ());
1268-
1269- // Determine dim size constraints based on L1 operands. L1 operands are
1270- // assumed fixed and already legal; DRAM operand streams are aligned to match
1271- // L1 shapes.
1272- auto maybeConstrainedDims = genericOp.computeGridDimConstraints (
1273- [&](ttcore::MetalLayoutAttr baseMetalLayout, bool isOutputOperand) {
1274- return baseMetalLayout.getMemorySpace () ==
1275- ttcore::MemorySpace::DeviceL1;
1276- });
1277- // this should be guaranteed if GenericOp verification is working.
1278- TT_assertv (maybeConstrainedDims.has_value (),
1279- " GenericOp dim constraints are cannot be satisfied." );
1280- auto constrainedDims = maybeConstrainedDims.value ();
1281-
1282- auto indexingMaps = genericOp.getIndexingMapsValue ();
1283- auto getConstrainedDims = [&](int64_t operandIdx) {
1284- return indexingMaps[operandIdx].compose (constrainedDims);
1285- };
1286- auto allDimsConstrained = [&](int64_t operandIdx) {
1287- return llvm::all_of (getConstrainedDims (operandIdx),
1288- [](int64_t dim) { return dim != 0 ; });
1289- };
1290-
1291- // Set all grid shapes according to constraints
1292- OpBuilder builder (genericOp->getContext ());
1293- for (auto [operandIdx, operand] :
1294- llvm::enumerate (genericOp.getInputsAndOutputs ())) {
1295-
1296- auto constrainedDims = getConstrainedDims (operandIdx);
1297- // if all dims are constrained, use the constrained dims.
1298- if (allDimsConstrained (operandIdx)) {
1299- optimalOperandGrids[operandIdx] = getConstrainedDims (operandIdx);
1300- } else {
1301- // if not all dims are constrained, shard to an optimal grid.
1302- auto metalTensorType =
1303- mlir::cast<mlir::RankedTensorType>(operand.getType ());
1304- auto baseMetalLayout =
1305- mlir::cast<ttcore::MetalLayoutAttr>(metalTensorType.getEncoding ());
1306- auto constrainedDims = getConstrainedDims (operandIdx);
1307-
1308- // Compute constrained target grid shape as min of targetSquareGridShape
1309- // and constrainedDims (if constrainedDim > 0), else use
1310- // targetSquareGridShape.
1311- llvm::SmallVector<int64_t > constrainedTargetGridShape =
1312- llvm::to_vector (targetSquareGridShape);
1313- if (constrainedDims.size () == targetSquareGridShape.size ()) {
1314- for (size_t i = 0 ; i < targetSquareGridShape.size (); ++i) {
1315- if (constrainedDims[i] > 0 ) {
1316- constrainedTargetGridShape[i] =
1317- std::min (constrainedDims[i], targetSquareGridShape[i]);
1318- }
1319- }
1320- }
1321-
1322- llvm::SmallVector<int64_t > physicalShape;
1323- // If operand is DRAM interleaved operand that is the result of a
1324- // ttnn->metal cast, we must generate a view of the underlying ttnn tensor
1325- // _without_ padding, as the underlying tensor also is unpadded.
1326- bool isNonPaddableTTNNDRAMOperand =
1327- operand.getDefiningOp <ttir::TTNNMetalLayoutCastOp>() &&
1328- baseMetalLayout.getMemorySpace () == ttcore::MemorySpace::DeviceDRAM &&
1329- baseMetalLayout.getMemoryLayout () ==
1330- ttcore::TensorMemoryLayout::Interleaved;
1331- if (isNonPaddableTTNNDRAMOperand) {
1332- llvm::SmallVector<int64_t > tileShape;
1333- if (auto tileType = mlir::dyn_cast<ttcore::TileType>(
1334- metalTensorType.getElementType ())) {
1335- tileShape = llvm::to_vector (tileType.getShape ());
1336- } else {
1337- tileShape = llvm::to_vector (ttcore::TileType::getDefaultShape ());
1338- }
1339- physicalShape = baseMetalLayout.getPhysicalShape (tileShape);
1340- } else {
1341- physicalShape =
1342- computePhysicalShape (baseMetalLayout, metalTensorType,
1343- constrainedTargetGridShape, builder);
1344- }
1345-
1346- optimalOperandGrids[operandIdx] = computeOptimalGrid (
1347- metalTensorType, physicalShape, constrainedTargetGridShape);
1348- }
1349- }
1350-
1351- return optimalOperandGrids;
1352- }
1353-
1354- // Finds and erases all unit reblocking views inserted by TTIRToD2M,
1355- // passing each view's input as the new operands.
1356- static void eraseUnitGridReblockingViews (d2m::GenericOp genericOp) {
1357- // Use vector here to avoid invalidating iterator with erasures.
1358- auto operands = llvm::to_vector (genericOp.getInputsAndOutputs ());
1359- for (Value operand : operands) {
1360- if (auto viewOp = operand.getDefiningOp <d2m::ViewLayoutOp>()) {
1361- auto originalOperand = viewOp.getInput ();
1362- viewOp.getResult ().replaceAllUsesWith (originalOperand);
1363- viewOp.erase ();
1364- }
1365- }
1366- }
1367-
1368- // TTNN DRAM interleaved tensors are represented as having a 1x1 grid. This
1369- // leads to the genericOp having a worker grid of 1x1 since it must match the
1370- // output tensor grid. This is obviously not optimal. We match genericOps that
1371- // have TTNN DRAM interleaved tensors as operands and:
1372- // 1. Compute the "optimal" grid for the tensor as if it were a regular Metal
1373- // sharded tensor.
1374- // 2. Insert a view layout op to represent the tensor with the "optimal" grid.
1375- // 3. Update the genericOp to use the view output as an operand.
1376- //
1377- // Note the cast op is NOT erased as it represents the canonical layout mapping
1378- // between TTNN and Metal layouts.
1379- //
1380- // For a given TTNN DRAM interleaved tensor, we end up with the following
1381- // representations:
1382- // 1. The canonical translation of the TTNN tensor to a Metal tensor, having
1383- // a metal layout, DRAM memory space, and a 1x1 grid.
1384- //
1385- // 2. The "reblocked" version of tensor 1, having a metal layout, DRAM memory
1386- // space, an inferred grid, and an index map to index into the original
1387- // tensor.
1388- //
1389- // A view layout op is used here so that the Allocator pass retains
1390- // ownership of stream insertion and buffer count selection.
1391- static llvm::SmallVector<llvm::SmallVector<int64_t >>
1392- insertTTNNDRAMViews (d2m::GenericOp genericOp,
1393- ArrayRef<int64_t > targetSquareGridShape) {
1394-
1395- eraseUnitGridReblockingViews (genericOp);
1396-
1397- auto optimalOperandGrids =
1398- computeTTNNGenericGridShapes (genericOp, targetSquareGridShape);
1399-
1400- OpBuilder builder (genericOp->getContext ());
1401- for (auto [operandIdx, operand] :
1402- llvm::enumerate (genericOp.getInputsAndOutputs ())) {
1403- auto metalTensor = mlir::cast<mlir::RankedTensorType>(operand.getType ());
1404- auto baseMetalLayout =
1405- mlir::cast<ttcore::MetalLayoutAttr>(metalTensor.getEncoding ());
1406- if (baseMetalLayout.getMemorySpace () != ttcore::MemorySpace::DeviceDRAM) {
1407- continue ;
1408- }
1409-
1410- // Do not "restream" metal -> ttnn -> metal sequences. This happens when the
1411- // output of a generic is the input to another generic. The output is
1412- // already streamed, but the cast back to ttnn silently erases the index
1413- // map. Instead, we just forward the already streamed metal tensor to the
1414- // current generic.
1415- auto castOp = operand.getDefiningOp <ttir::TTNNMetalLayoutCastOp>();
1416- TT_assertv (
1417- castOp,
1418- " If one d2m.generic operand is from TTNN, they must all be from TTNN." );
1419- auto producerCastOp =
1420- castOp.getInput ().getDefiningOp <ttir::TTNNMetalLayoutCastOp>();
1421- if (producerCastOp) {
1422- castOp.getResult ().replaceAllUsesExcept (producerCastOp.getInput (),
1423- producerCastOp);
1424- continue ;
1425- }
1426-
1427- // TTNN DRAM interleaved tensors are represented as having a 1x1 grid.
1428- llvm::SmallVector<int64_t > unitGridShape{1 , 1 };
1429- llvm::SmallVector<int64_t > unShardedShapeWithGrid =
1430- baseMetalLayout.getDeviceShape (unitGridShape,
1431- ttcore::TileType::getDefaultShape ());
1432-
1433- llvm::SmallVector<int64_t > fakeShardedShape =
1434- baseMetalLayout.getDeviceShape (optimalOperandGrids[operandIdx],
1435- ttcore::TileType::getDefaultShape ());
1436-
1437- auto reblockMap = ttmlir::utils::calculateReblockMap (
1438- unShardedShapeWithGrid, fakeShardedShape, builder.getContext ());
1439- auto viewOutputLayout = ttcore::MetalLayoutAttr::get (
1440- builder.getContext (), baseMetalLayout.getLogicalShape (),
1441- baseMetalLayout.getOobVal (), ttcore::MemorySpace::DeviceDRAM,
1442- ttcore::TensorMemoryLayout::Interleaved,
1443- baseMetalLayout.getCollapsedIntervals (),
1444- baseMetalLayout.getDimAlignments ());
1445-
1446- auto viewOutputTensor = mlir::RankedTensorType::get (
1447- fakeShardedShape, metalTensor.getElementType (), viewOutputLayout);
1448-
1449- builder.setInsertionPointAfter (castOp);
1450- auto viewOp =
1451- d2m::ViewLayoutOp::create (builder, castOp.getLoc (), viewOutputTensor,
1452- castOp.getResult (), reblockMap);
1453- castOp.getResult ().replaceAllUsesExcept (viewOp.getResult (), viewOp);
1454- }
1455-
1456- TT_assertv (llvm::all_of (optimalOperandGrids,
1457- [](const llvm::SmallVector<int64_t > &grid) {
1458- return !grid.empty ();
1459- }),
1460- " Optimal grids must be populated for all operands." );
1461- return optimalOperandGrids;
1462- }
1463-
14641249// Assign optimized grids to all ToLayoutOps feeding into a GenericOp by
14651250// computing the optimal grid per tensor independently, mirroring the old
14661251// TTIRToD2M behavior.
0 commit comments