Skip to content

Commit b2e0a2b

Browse files
committed
rebase
1 parent a193b28 commit b2e0a2b

File tree

2 files changed

+11
-226
lines changed

2 files changed

+11
-226
lines changed

lib/Dialect/D2M/Transforms/GridSelection.cpp

Lines changed: 10 additions & 225 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,9 @@ static void insertViewForTTNNDRAMTensor(Value operand,
552552
fakeShardedShape, metalTensor.getElementType(), viewOutputLayout);
553553

554554
builder.setInsertionPointAfter(castOp);
555-
auto viewOp = builder.create<d2m::ViewLayoutOp>(
556-
castOp.getLoc(), viewOutputTensor, castOp.getResult(),
557-
AffineMapAttr::get(reblockMap));
555+
auto viewOp = d2m::ViewLayoutOp::create(builder, castOp.getLoc(),
556+
viewOutputTensor, castOp.getResult(),
557+
AffineMapAttr::get(reblockMap));
558558
castOp.getResult().replaceAllUsesExcept(viewOp.getResult(), viewOp);
559559
}
560560

@@ -580,18 +580,20 @@ static void optimizeTTNNMetalLayoutCastOpGrid(
580580

581581
builder.setInsertionPointAfter(castOp);
582582

583-
auto newViewLayoutOp = builder.create<d2m::ViewLayoutOp>(
584-
castOp.getLoc(), newTensorType, castOp.getResult(), gridRemapping);
583+
auto newViewLayoutOp =
584+
d2m::ViewLayoutOp::create(builder, castOp.getLoc(), newTensorType,
585+
castOp.getResult(), gridRemapping);
585586

586587
// Reblock it back to original shape to preserve IR correctness.
587588
auto viewOutputType = utils::reblockTensor(
588589
newTensorType, outputLayout.getGridShape(outputType));
589590
auto reblockMap = ttmlir::utils::calculateReblockMap(
590591
newTensorType.getShape(), viewOutputType.getShape(),
591592
builder.getContext());
592-
auto revertingView = builder.create<d2m::ViewLayoutOp>(
593-
castOp.getLoc(), viewOutputType, newViewLayoutOp.getResult(), reblockMap,
594-
/*reinterpretLayout=*/false);
593+
auto revertingView =
594+
d2m::ViewLayoutOp::create(builder, castOp.getLoc(), viewOutputType,
595+
newViewLayoutOp.getResult(), reblockMap,
596+
/*reinterpretLayout=*/false);
595597

596598
castOp.getResult().replaceAllUsesExcept(revertingView.getResult(),
597599
newViewLayoutOp);
@@ -1244,223 +1246,6 @@ recreateGenericOp(d2m::GenericOp genericOp,
12441246
}
12451247
}
12461248

1247-
static bool hasTTNNOperands(d2m::GenericOp genericOp) {
1248-
for (Value operand : genericOp.getInputsAndOutputs()) {
1249-
if (operand.getDefiningOp<ttir::TTNNMetalLayoutCastOp>()) {
1250-
return true;
1251-
}
1252-
// Check if view operand's input is the result of a TTNNMetalLayoutCastOp.
1253-
if (auto view = operand.getDefiningOp<d2m::ViewLayoutOp>();
1254-
view && view.getInput().getDefiningOp<ttir::TTNNMetalLayoutCastOp>()) {
1255-
return true;
1256-
}
1257-
}
1258-
return false;
1259-
}
1260-
1261-
// Computes the expected TTNN generic output grid shape for the given tensor.
1262-
static llvm::SmallVector<llvm::SmallVector<int64_t>>
1263-
computeTTNNGenericGridShapes(GenericOp genericOp,
1264-
ArrayRef<int64_t> targetSquareGridShape) {
1265-
1266-
auto optimalOperandGrids = llvm::SmallVector<llvm::SmallVector<int64_t>>(
1267-
genericOp.getInputsAndOutputs().size());
1268-
1269-
// Determine dim size constraints based on L1 operands. L1 operands are
1270-
// assumed fixed and already legal; DRAM operand streams are aligned to match
1271-
// L1 shapes.
1272-
auto maybeConstrainedDims = genericOp.computeGridDimConstraints(
1273-
[&](ttcore::MetalLayoutAttr baseMetalLayout, bool isOutputOperand) {
1274-
return baseMetalLayout.getMemorySpace() ==
1275-
ttcore::MemorySpace::DeviceL1;
1276-
});
1277-
// this should be guaranteed if GenericOp verification is working.
1278-
TT_assertv(maybeConstrainedDims.has_value(),
1279-
"GenericOp dim constraints are cannot be satisfied.");
1280-
auto constrainedDims = maybeConstrainedDims.value();
1281-
1282-
auto indexingMaps = genericOp.getIndexingMapsValue();
1283-
auto getConstrainedDims = [&](int64_t operandIdx) {
1284-
return indexingMaps[operandIdx].compose(constrainedDims);
1285-
};
1286-
auto allDimsConstrained = [&](int64_t operandIdx) {
1287-
return llvm::all_of(getConstrainedDims(operandIdx),
1288-
[](int64_t dim) { return dim != 0; });
1289-
};
1290-
1291-
// Set all grid shapes according to constraints
1292-
OpBuilder builder(genericOp->getContext());
1293-
for (auto [operandIdx, operand] :
1294-
llvm::enumerate(genericOp.getInputsAndOutputs())) {
1295-
1296-
auto constrainedDims = getConstrainedDims(operandIdx);
1297-
// if all dims are constrained, use the constrained dims.
1298-
if (allDimsConstrained(operandIdx)) {
1299-
optimalOperandGrids[operandIdx] = getConstrainedDims(operandIdx);
1300-
} else {
1301-
// if not all dims are constrained, shard to an optimal grid.
1302-
auto metalTensorType =
1303-
mlir::cast<mlir::RankedTensorType>(operand.getType());
1304-
auto baseMetalLayout =
1305-
mlir::cast<ttcore::MetalLayoutAttr>(metalTensorType.getEncoding());
1306-
auto constrainedDims = getConstrainedDims(operandIdx);
1307-
1308-
// Compute constrained target grid shape as min of targetSquareGridShape
1309-
// and constrainedDims (if constrainedDim > 0), else use
1310-
// targetSquareGridShape.
1311-
llvm::SmallVector<int64_t> constrainedTargetGridShape =
1312-
llvm::to_vector(targetSquareGridShape);
1313-
if (constrainedDims.size() == targetSquareGridShape.size()) {
1314-
for (size_t i = 0; i < targetSquareGridShape.size(); ++i) {
1315-
if (constrainedDims[i] > 0) {
1316-
constrainedTargetGridShape[i] =
1317-
std::min(constrainedDims[i], targetSquareGridShape[i]);
1318-
}
1319-
}
1320-
}
1321-
1322-
llvm::SmallVector<int64_t> physicalShape;
1323-
// If operand is DRAM interleaved operand that is the result of a
1324-
// ttnn->metal cast, we must generate a view of the underlying ttnn tensor
1325-
// _without_ padding, as the underlying tensor also is unpadded.
1326-
bool isNonPaddableTTNNDRAMOperand =
1327-
operand.getDefiningOp<ttir::TTNNMetalLayoutCastOp>() &&
1328-
baseMetalLayout.getMemorySpace() == ttcore::MemorySpace::DeviceDRAM &&
1329-
baseMetalLayout.getMemoryLayout() ==
1330-
ttcore::TensorMemoryLayout::Interleaved;
1331-
if (isNonPaddableTTNNDRAMOperand) {
1332-
llvm::SmallVector<int64_t> tileShape;
1333-
if (auto tileType = mlir::dyn_cast<ttcore::TileType>(
1334-
metalTensorType.getElementType())) {
1335-
tileShape = llvm::to_vector(tileType.getShape());
1336-
} else {
1337-
tileShape = llvm::to_vector(ttcore::TileType::getDefaultShape());
1338-
}
1339-
physicalShape = baseMetalLayout.getPhysicalShape(tileShape);
1340-
} else {
1341-
physicalShape =
1342-
computePhysicalShape(baseMetalLayout, metalTensorType,
1343-
constrainedTargetGridShape, builder);
1344-
}
1345-
1346-
optimalOperandGrids[operandIdx] = computeOptimalGrid(
1347-
metalTensorType, physicalShape, constrainedTargetGridShape);
1348-
}
1349-
}
1350-
1351-
return optimalOperandGrids;
1352-
}
1353-
1354-
// Finds and erases all unit reblocking views inserted by TTIRToD2M,
1355-
// passing each view's input as the new operands.
1356-
static void eraseUnitGridReblockingViews(d2m::GenericOp genericOp) {
1357-
// Use vector here to avoid invalidating iterator with erasures.
1358-
auto operands = llvm::to_vector(genericOp.getInputsAndOutputs());
1359-
for (Value operand : operands) {
1360-
if (auto viewOp = operand.getDefiningOp<d2m::ViewLayoutOp>()) {
1361-
auto originalOperand = viewOp.getInput();
1362-
viewOp.getResult().replaceAllUsesWith(originalOperand);
1363-
viewOp.erase();
1364-
}
1365-
}
1366-
}
1367-
1368-
// TTNN DRAM interleaved tensors are represented as having a 1x1 grid. This
1369-
// leads to the genericOp having a worker grid of 1x1 since it must match the
1370-
// output tensor grid. This is obviously not optimal. We match genericOps that
1371-
// have TTNN DRAM interleaved tensors as operands and:
1372-
// 1. Compute the "optimal" grid for the tensor as if it were a regular Metal
1373-
// sharded tensor.
1374-
// 2. Insert a view layout op to represent the tensor with the "optimal" grid.
1375-
// 3. Update the genericOp to use the view output as an operand.
1376-
//
1377-
// Note the cast op is NOT erased as it represents the canonical layout mapping
1378-
// between TTNN and Metal layouts.
1379-
//
1380-
// For a given TTNN DRAM interleaved tensor, we end up with the following
1381-
// representations:
1382-
// 1. The canonical translation of the TTNN tensor to a Metal tensor, having
1383-
// a metal layout, DRAM memory space, and a 1x1 grid.
1384-
//
1385-
// 2. The "reblocked" version of tensor 1, having a metal layout, DRAM memory
1386-
// space, an inferred grid, and an index map to index into the original
1387-
// tensor.
1388-
//
1389-
// A view layout op is used here so that the Allocator pass retains
1390-
// ownership of stream insertion and buffer count selection.
1391-
static llvm::SmallVector<llvm::SmallVector<int64_t>>
1392-
insertTTNNDRAMViews(d2m::GenericOp genericOp,
1393-
ArrayRef<int64_t> targetSquareGridShape) {
1394-
1395-
eraseUnitGridReblockingViews(genericOp);
1396-
1397-
auto optimalOperandGrids =
1398-
computeTTNNGenericGridShapes(genericOp, targetSquareGridShape);
1399-
1400-
OpBuilder builder(genericOp->getContext());
1401-
for (auto [operandIdx, operand] :
1402-
llvm::enumerate(genericOp.getInputsAndOutputs())) {
1403-
auto metalTensor = mlir::cast<mlir::RankedTensorType>(operand.getType());
1404-
auto baseMetalLayout =
1405-
mlir::cast<ttcore::MetalLayoutAttr>(metalTensor.getEncoding());
1406-
if (baseMetalLayout.getMemorySpace() != ttcore::MemorySpace::DeviceDRAM) {
1407-
continue;
1408-
}
1409-
1410-
// Do not "restream" metal -> ttnn -> metal sequences. This happens when the
1411-
// output of a generic is the input to another generic. The output is
1412-
// already streamed, but the cast back to ttnn silently erases the index
1413-
// map. Instead, we just forward the already streamed metal tensor to the
1414-
// current generic.
1415-
auto castOp = operand.getDefiningOp<ttir::TTNNMetalLayoutCastOp>();
1416-
TT_assertv(
1417-
castOp,
1418-
"If one d2m.generic operand is from TTNN, they must all be from TTNN.");
1419-
auto producerCastOp =
1420-
castOp.getInput().getDefiningOp<ttir::TTNNMetalLayoutCastOp>();
1421-
if (producerCastOp) {
1422-
castOp.getResult().replaceAllUsesExcept(producerCastOp.getInput(),
1423-
producerCastOp);
1424-
continue;
1425-
}
1426-
1427-
// TTNN DRAM interleaved tensors are represented as having a 1x1 grid.
1428-
llvm::SmallVector<int64_t> unitGridShape{1, 1};
1429-
llvm::SmallVector<int64_t> unShardedShapeWithGrid =
1430-
baseMetalLayout.getDeviceShape(unitGridShape,
1431-
ttcore::TileType::getDefaultShape());
1432-
1433-
llvm::SmallVector<int64_t> fakeShardedShape =
1434-
baseMetalLayout.getDeviceShape(optimalOperandGrids[operandIdx],
1435-
ttcore::TileType::getDefaultShape());
1436-
1437-
auto reblockMap = ttmlir::utils::calculateReblockMap(
1438-
unShardedShapeWithGrid, fakeShardedShape, builder.getContext());
1439-
auto viewOutputLayout = ttcore::MetalLayoutAttr::get(
1440-
builder.getContext(), baseMetalLayout.getLogicalShape(),
1441-
baseMetalLayout.getOobVal(), ttcore::MemorySpace::DeviceDRAM,
1442-
ttcore::TensorMemoryLayout::Interleaved,
1443-
baseMetalLayout.getCollapsedIntervals(),
1444-
baseMetalLayout.getDimAlignments());
1445-
1446-
auto viewOutputTensor = mlir::RankedTensorType::get(
1447-
fakeShardedShape, metalTensor.getElementType(), viewOutputLayout);
1448-
1449-
builder.setInsertionPointAfter(castOp);
1450-
auto viewOp =
1451-
d2m::ViewLayoutOp::create(builder, castOp.getLoc(), viewOutputTensor,
1452-
castOp.getResult(), reblockMap);
1453-
castOp.getResult().replaceAllUsesExcept(viewOp.getResult(), viewOp);
1454-
}
1455-
1456-
TT_assertv(llvm::all_of(optimalOperandGrids,
1457-
[](const llvm::SmallVector<int64_t> &grid) {
1458-
return !grid.empty();
1459-
}),
1460-
"Optimal grids must be populated for all operands.");
1461-
return optimalOperandGrids;
1462-
}
1463-
14641249
// Assign optimized grids to all ToLayoutOps feeding into a GenericOp by
14651250
// computing the optimal grid per tensor independently, mirroring the old
14661251
// TTIRToD2M behavior.

lib/Target/TTKernel/TTKernelToCpp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ void dprint(Arg &&arg, ArgV&&... argv) {
264264
auto experimentalPackUntilizeLLKs =
265265
StringRef(experimental_pack_untilize_llks_generated,
266266
experimental_pack_untilize_llks_generated_len);
267-
builder->create<emitc::VerbatimOp>(loc, experimentalPackUntilizeLLKs);
267+
emitc::VerbatimOp::create(*builder, loc, experimentalPackUntilizeLLKs);
268268
}
269269

270270
if (hasCall("experimental::get_noc_multicast_addr")) {

0 commit comments

Comments
 (0)