Skip to content

Commit 413bc25

Browse files
committed
rebase
1 parent 62d5fa1 commit 413bc25

File tree

5 files changed

+77
-78
lines changed

5 files changed

+77
-78
lines changed

lib/Conversion/D2MToTTNN/D2MToTTNN.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,8 @@ class MemrefAllocRewriter : public OpConversionPattern<memref::AllocOp> {
784784
auto memcfg = ttnn::MemoryConfigAttr::get(emptyLayoutAttr,
785785
deviceAttr.getWorkerGrid());
786786

787-
auto emptyOp = rewriter.create<ttnn::EmptyOp>(
788-
op.getLoc(), emptyTensorType, device,
787+
auto emptyOp = ttnn::EmptyOp::create(
788+
rewriter, op.getLoc(), emptyTensorType, device,
789789
ttnn::ShapeAttr::get(ctx, emptyTensorType.getShape()),
790790
ttcore::DataTypeAttr::get(ctx, emptyLayoutAttr.getDataType()),
791791
ttnn::LayoutAttr::get(ctx, emptyLayoutAttr.getLayout()), memcfg);

lib/Dialect/D2M/Transforms/LowerDMAToFullyIndexedForm.cpp

Lines changed: 42 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ namespace mlir::tt::d2m {
2828

2929
static std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
3030
getLoopBounds(OpBuilder &builder, Location loc, ArrayRef<int64_t> shardShape) {
31-
Value zero = builder.create<arith::ConstantOp>(loc, builder.getIndexType(),
32-
builder.getIndexAttr(0));
33-
Value one = builder.create<arith::ConstantOp>(loc, builder.getIndexType(),
34-
builder.getIndexAttr(1));
31+
Value zero = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
32+
builder.getIndexAttr(0));
33+
Value one = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
34+
builder.getIndexAttr(1));
3535
SmallVector<Value> lbs(shardShape.size(), zero);
3636
SmallVector<Value> ubs(llvm::map_range(shardShape, [&](int64_t dim) {
37-
return builder.create<arith::ConstantOp>(loc, builder.getIndexType(),
38-
builder.getIndexAttr(dim));
37+
return arith::ConstantOp::create(builder, loc, builder.getIndexType(),
38+
builder.getIndexAttr(dim));
3939
}));
4040
SmallVector<Value> step(shardShape.size(), one);
4141
return std::make_tuple(lbs, ubs, step);
@@ -93,7 +93,7 @@ static SmallVector<Value> applyMap(Builder &builder, Location loc,
9393
AffineMap map, ValueRange index,
9494
bool isRemote) {
9595
auto affineApply = [&](AffineMap map, ValueRange index) {
96-
return builder.template create<affine::AffineApplyOp>(loc, map, index);
96+
return affine::AffineApplyOp::create(builder, loc, map, index);
9797
};
9898

9999
if (isRemote) {
@@ -217,8 +217,8 @@ static Value generateFullyIndexedDMAOps(
217217
SmallVector<Value> remoteIndices = gridIndices;
218218
SmallVector<Value> localIndices;
219219

220-
Value zero = builder.create<arith::ConstantOp>(loc, builder.getIndexType(),
221-
builder.getIndexAttr(0));
220+
Value zero = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
221+
builder.getIndexAttr(0));
222222
for (size_t i = 0; i < shardShape.size(); ++i) {
223223
remoteIndices.push_back(zero);
224224
localIndices.push_back(zero);
@@ -234,7 +234,7 @@ static Value generateFullyIndexedDMAOps(
234234

235235
// Strided/non-contiguous: generate loops with guarded DMAs.
236236
auto [lbs, ubs, steps] = getLoopBounds(builder, loc, shardShape);
237-
auto nullDmaTx = builder.create<NullTxOp>(loc);
237+
auto nullDmaTx = NullTxOp::create(builder, loc);
238238

239239
scf::LoopNest loopNest = scf::buildLoopNest(
240240
builder, loc, lbs, ubs, steps, ValueRange(nullDmaTx),
@@ -252,50 +252,49 @@ static Value generateFullyIndexedDMAOps(
252252
localIndices, false);
253253

254254
// Create guarded DMA operation based on coalescing factor.
255-
Value cfExpr = loopBuilder.create<arith::ConstantOp>(
256-
innerLoc, loopBuilder.getIndexType(),
255+
Value cfExpr = arith::ConstantOp::create(
256+
loopBuilder, innerLoc, loopBuilder.getIndexType(),
257257
loopBuilder.getIndexAttr(coalescingFactor));
258-
Value zero = loopBuilder.create<arith::ConstantOp>(
259-
innerLoc, loopBuilder.getIndexType(),
258+
Value zero = arith::ConstantOp::create(
259+
loopBuilder, innerLoc, loopBuilder.getIndexType(),
260260
loopBuilder.getIntegerAttr(loopBuilder.getIndexType(), 0));
261261

262262
// Construct guard function: flat_index(iters) % coalescingFactor == 0
263263
auto totalIterCount = zero;
264264
size_t currStride = 1;
265265
for (int i = iters.size() - 1; i >= 0; i--) {
266-
Value currStrideExpr = loopBuilder.create<arith::ConstantOp>(
267-
innerLoc, loopBuilder.getIndexType(),
266+
Value currStrideExpr = arith::ConstantOp::create(
267+
loopBuilder, innerLoc, loopBuilder.getIndexType(),
268268
loopBuilder.getIndexAttr(currStride));
269-
auto scaledCount =
270-
loopBuilder
271-
.create<arith::MulIOp>(innerLoc, currStrideExpr, iters[i])
272-
.getResult();
273-
totalIterCount =
274-
loopBuilder
275-
.create<arith::AddIOp>(innerLoc, scaledCount, totalIterCount)
276-
.getResult();
269+
auto scaledCount = arith::MulIOp::create(loopBuilder, innerLoc,
270+
currStrideExpr, iters[i])
271+
.getResult();
272+
totalIterCount = arith::AddIOp::create(loopBuilder, innerLoc,
273+
scaledCount, totalIterCount)
274+
.getResult();
277275
currStride *= shardShape[i];
278276
}
279-
auto moduloIterCount =
280-
loopBuilder.create<arith::RemSIOp>(innerLoc, totalIterCount, cfExpr)
281-
.getResult();
282-
auto predicate = loopBuilder.create<arith::CmpIOp>(
283-
innerLoc, arith::CmpIPredicate::eq, moduloIterCount, zero);
277+
auto moduloIterCount = arith::RemSIOp::create(loopBuilder, innerLoc,
278+
totalIterCount, cfExpr)
279+
.getResult();
280+
auto predicate = arith::CmpIOp::create(loopBuilder, innerLoc,
281+
arith::CmpIPredicate::eq,
282+
moduloIterCount, zero);
284283

285-
auto nulltx = loopBuilder.create<NullTxOp>(innerLoc);
284+
auto nulltx = NullTxOp::create(loopBuilder, innerLoc);
286285

287286
// Build guarded DMA.
288-
auto ifExpr = loopBuilder.create<scf::IfOp>(
289-
innerLoc, TypeRange(SmallVector<Value>{nulltx}), predicate,
290-
true /*addThenBlock*/, true /*addElseBlock*/);
287+
auto ifExpr = scf::IfOp::create(
288+
loopBuilder, innerLoc, TypeRange(SmallVector<Value>{nulltx}),
289+
predicate, true /*addThenBlock*/, true /*addElseBlock*/);
291290

292291
auto thenBuilder = ifExpr.getThenBodyBuilder();
293292
Value dmaTx = createDMAOp(thenBuilder, innerLoc, remoteIndices,
294293
localIndices, coalescingFactor);
295-
thenBuilder.create<scf::YieldOp>(innerLoc, dmaTx);
294+
scf::YieldOp::create(thenBuilder, innerLoc, dmaTx);
296295

297296
auto elseBuilder = ifExpr.getElseBodyBuilder();
298-
elseBuilder.create<scf::YieldOp>(innerLoc, args[0]);
297+
scf::YieldOp::create(elseBuilder, innerLoc, args[0]);
299298

300299
return SmallVector<Value>{ifExpr.getResult(0)};
301300
});
@@ -356,8 +355,8 @@ class D2MLowerDMAReadToFullyIndexed : public OpRewritePattern<DMAReadOp> {
356355
coalescingFactor, shardVolume,
357356
[&](OpBuilder &b, Location l, SmallVector<Value> &remoteIdx,
358357
SmallVector<Value> &localIdx, size_t cf) {
359-
return b.create<DMAReadOp>(l, remoteMemref, remoteIdx, localMemref,
360-
localIdx, b.getI64IntegerAttr(cf));
358+
return DMAReadOp::create(b, l, remoteMemref, remoteIdx, localMemref,
359+
localIdx, b.getI64IntegerAttr(cf));
361360
});
362361

363362
rewriter.replaceOp(op, newTx);
@@ -398,16 +397,16 @@ class D2MLowerDMAWriteToFullyIndexed : public OpRewritePattern<DMAWriteOp> {
398397
size_t shardVolume = ttmlir::utils::volume(shardShape);
399398

400399
SmallVector<Value> localIndices;
401-
Value zero = rewriter.create<arith::ConstantOp>(
402-
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
400+
Value zero = arith::ConstantOp::create(
401+
rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
403402
for (size_t i = 0; i < shardShape.size(); ++i) {
404403
localIndices.push_back(zero);
405404
}
406405
localIndices =
407406
applyMap(rewriter, loc, localMemoryMap, localIndices, false);
408407

409-
Value newTx = rewriter.create<DMAWriteOp>(
410-
loc, localMemref, localIndices, dstMemref, localIndices,
408+
Value newTx = DMAWriteOp::create(
409+
rewriter, loc, localMemref, localIndices, dstMemref, localIndices,
411410
op.getMcastStartIndex(), op.getMcastShape(), shardVolume);
412411
rewriter.replaceOp(op, newTx);
413412
return success();
@@ -443,8 +442,8 @@ class D2MLowerDMAWriteToFullyIndexed : public OpRewritePattern<DMAWriteOp> {
443442
coalescingFactor, shardVolume,
444443
[&](OpBuilder &b, Location l, SmallVector<Value> &remoteIdx,
445444
SmallVector<Value> &localIdx, size_t cf) {
446-
return b.create<DMAWriteOp>(l, localMemref, localIdx, dstMemref,
447-
remoteIdx, cf);
445+
return DMAWriteOp::create(b, l, localMemref, localIdx, dstMemref,
446+
remoteIdx, cf);
448447
});
449448

450449
rewriter.replaceOp(op, newTx);

lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToDMA.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -135,37 +135,37 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern<RemoteLoadOp> {
135135

136136
// Reserve CB unconditionally before branching - both sender and receiver
137137
// need to reserve to maintain proper circular buffer semantics.
138-
Value localMemref = rewriter.create<ReserveOp>(loc, cb).getResult();
138+
Value localMemref = ReserveOp::create(rewriter, loc, cb).getResult();
139139

140140
SmallVector<Value> gridIndices = remoteLoad.getIndices();
141141

142-
rewriter.create<scf::IfOp>(
143-
loc, isSender,
142+
scf::IfOp::create(
143+
rewriter, loc, isSender,
144144
[&](OpBuilder &builder, Location loc) {
145145
// Sender: shard-level DMA read from remote.
146-
Value dmaTx = builder.create<DMAReadOp>(loc, remoteMemref,
147-
gridIndices, localMemref);
148-
builder.create<DMAWaitOp>(loc, dmaTx);
146+
Value dmaTx = DMAReadOp::create(builder, loc, remoteMemref,
147+
gridIndices, localMemref);
148+
DMAWaitOp::create(builder, loc, dmaTx);
149149

150150
// Wait for all receivers to be ready (mcastVolume - 1, excluding
151151
// sender).
152-
builder.create<SemaphoreWaitOp>(loc, receiversReadySemaphore,
153-
numReceiversVal, zero);
152+
SemaphoreWaitOp::create(builder, loc, receiversReadySemaphore,
153+
numReceiversVal, zero);
154154

155155
// Perform shard-level multicast DMA write: from local CB to local CB
156156
// with multicast parameters. The multicast parameters specify that
157157
// the data should be sent to other cores. We use localMemref (from
158158
// ReserveOp) as both source and destination - this is the Producer
159159
// buffer that was just filled by the DMA read above.
160-
Value mcastTx = builder.create<DMAWriteOp>(
161-
loc, localMemref, localMemref, remoteLoad.getMcastStartIndex(),
162-
remoteLoad.getMcastShape());
160+
Value mcastTx = DMAWriteOp::create(
161+
builder, loc, localMemref, localMemref,
162+
remoteLoad.getMcastStartIndex(), remoteLoad.getMcastShape());
163163
DMAWaitOp::create(builder, loc, mcastTx);
164164

165165
// Signal receivers that sender is finished.
166-
builder.create<SemaphoreSetOp>(loc, senderFinishedSemaphore, one,
167-
remoteLoad.getMcastStartIndex(),
168-
remoteLoad.getMcastShape());
166+
SemaphoreSetOp::create(builder, loc, senderFinishedSemaphore, one,
167+
remoteLoad.getMcastStartIndex(),
168+
remoteLoad.getMcastShape());
169169

170170
scf::YieldOp::create(builder, loc);
171171
},
@@ -185,8 +185,8 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern<RemoteLoadOp> {
185185
senderCoreIndex.push_back(mcastStartIndex[i]);
186186
} else {
187187
// Non-multicast dimension - use current core's position.
188-
Value currentCoreIdx = builder.create<CoreIndexOp>(
189-
loc, static_cast<int64_t>(i), gridMapping);
188+
Value currentCoreIdx = CoreIndexOp::create(
189+
builder, loc, static_cast<int64_t>(i), gridMapping);
190190
senderCoreIndex.push_back(currentCoreIdx);
191191
}
192192
}
@@ -238,15 +238,15 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern<RemoteLoadOp> {
238238
Value remoteMemref = remoteLoad.getMemref();
239239
SmallVector<Value> gridIndices = remoteLoad.getIndices();
240240

241-
Value localMemref = rewriter.create<ReserveOp>(loc, cb).getResult();
242-
Value dmaTx =
243-
rewriter.create<DMAReadOp>(loc, remoteMemref, gridIndices, localMemref);
241+
Value localMemref = ReserveOp::create(rewriter, loc, cb).getResult();
242+
Value dmaTx = DMAReadOp::create(rewriter, loc, remoteMemref, gridIndices,
243+
localMemref);
244244

245245
rewriter.eraseOp(remoteLoad);
246246

247247
// Wait for DMA to complete.
248-
rewriter.create<DMAWaitOp>(loc, dmaTx);
249-
rewriter.create<PushOp>(loc, cb);
248+
DMAWaitOp::create(rewriter, loc, dmaTx);
249+
PushOp::create(rewriter, loc, cb);
250250
return success();
251251
}
252252
};
@@ -283,16 +283,16 @@ class D2MLowerRemoteStoreRewritePattern
283283
SmallVector<Value> gridIndices = remoteStore.getIndices();
284284

285285
// Wait on CB, emit shard-level dma_write, wait, pop
286-
Value localMemref = rewriter.create<WaitOp>(loc, cb).getResult();
287-
Value dmaTx = rewriter.create<DMAWriteOp>(loc, localMemref, remoteMemref,
288-
gridIndices);
286+
Value localMemref = WaitOp::create(rewriter, loc, cb).getResult();
287+
Value dmaTx = DMAWriteOp::create(rewriter, loc, localMemref, remoteMemref,
288+
gridIndices);
289289

290290
rewriter.eraseOp(remoteStore);
291291

292292
// Wait for DMA to complete.
293-
rewriter.create<DMAWaitOp>(loc, dmaTx);
293+
DMAWaitOp::create(rewriter, loc, dmaTx);
294294
// Pop the circular buffer to signal consumption.
295-
rewriter.create<PopOp>(loc, cb);
295+
PopOp::create(rewriter, loc, cb);
296296
return success();
297297
}
298298
};

lib/Dialect/D2M/Transforms/LowerToLayout.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,9 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern<ToLayoutOp> {
847847
}
848848

849849
auto layout = mlir::dyn_cast<ttcore::MetalLayoutAttr>(type.getEncoding());
850-
auto emptyOp = rewriter.create<d2m::EmptyOp>(op.getLoc(), type.getShape(),
851-
type.getElementType(),
852-
layout, targetGridShape);
850+
auto emptyOp =
851+
d2m::EmptyOp::create(rewriter, op.getLoc(), type.getShape(),
852+
type.getElementType(), layout, targetGridShape);
853853
return emptyOp.getResult();
854854
};
855855

@@ -924,8 +924,8 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern<ToLayoutOp> {
924924
// buffers via createEmpty().
925925
auto layout = mlir::dyn_cast<ttcore::MetalLayoutAttr>(
926926
currentInfo.type.getEncoding());
927-
auto maskedEmptyOp = rewriter.create<d2m::EmptyOp>(
928-
op.getLoc(), currentInfo.type.getShape(),
927+
auto maskedEmptyOp = d2m::EmptyOp::create(
928+
rewriter, op.getLoc(), currentInfo.type.getShape(),
929929
currentInfo.type.getElementType(), layout, targetGridShape);
930930
auto maskedEmpty = maskedEmptyOp.getResult();
931931
currentValue =

lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ExplicateOperandBroadcastsRewritePattern.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ LogicalResult ExplicateOperandBroadcastsRewritePattern::matchAndRewrite(
3232
auto broadcastDims = ttmlir::utils::getBroadcastDimensions<int64_t>(
3333
operandShape, resultShape);
3434
auto shapeAttr = ttnn::ShapeAttr::get(rewriter.getContext(), broadcastDims);
35-
auto repeatOp = rewriter.create<ttnn::RepeatOp>(
36-
srcOp->getLoc(), newOutputType, operand, shapeAttr);
35+
auto repeatOp = ttnn::RepeatOp::create(rewriter, srcOp->getLoc(),
36+
newOutputType, operand, shapeAttr);
3737

3838
rewriter.modifyOpInPlace(srcOp, [&]() { srcOp->setOperand(i, repeatOp); });
3939
hasChanged = true;

0 commit comments

Comments
 (0)