Skip to content

Commit 8cc488a

Browse files
committed
rebase
1 parent 4dcf08a commit 8cc488a

File tree

5 files changed

+79
-71
lines changed

5 files changed

+79
-71
lines changed

lib/Conversion/D2MToTTKernel/D2MToTTKernel.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,8 +1174,8 @@ class D2MTilizeUntilizeRewriter : public OpConversionPattern<ConcreteOp> {
11741174

11751175
if constexpr (std::is_same_v<BlockOp,
11761176
ttkernel::ExperimentalTilizeBlockOp>) {
1177-
rewriter.create<ttkernel::TilizeInitOp>(op->getLoc(), src, blockC, dst);
1178-
rewriter.create<BlockOp>(op->getLoc(), src, dst, blockR, blockC);
1177+
ttkernel::TilizeInitOp::create(rewriter, op->getLoc(), src, blockC, dst);
1178+
BlockOp::create(rewriter, op->getLoc(), src, dst, blockR, blockC);
11791179
} else if constexpr (std::is_same_v<
11801180
BlockOp,
11811181
ttkernel::ExperimentalPackUntilizeBlockOp>) {
@@ -1199,11 +1199,12 @@ class D2MTilizeUntilizeRewriter : public OpConversionPattern<ConcreteOp> {
11991199
auto totalColTilesAttr =
12001200
rewriter.getI32IntegerAttr(static_cast<int32_t>(totalColTiles));
12011201

1202-
rewriter.create<ttkernel::PackUntilizeInitOp>(
1203-
op->getLoc(), src, dst, colsPerDstPassAttr, totalColTilesAttr);
1204-
rewriter.create<BlockOp>(op->getLoc(), src, dst, blockR, blockC,
1205-
colsPerDstPassAttr, totalColTilesAttr);
1206-
rewriter.create<ttkernel::PackUntilizeUninitOp>(op->getLoc(), dst);
1202+
ttkernel::PackUntilizeInitOp::create(rewriter, op->getLoc(), src, dst,
1203+
colsPerDstPassAttr,
1204+
totalColTilesAttr);
1205+
BlockOp::create(rewriter, op->getLoc(), src, dst, blockR, blockC,
1206+
colsPerDstPassAttr, totalColTilesAttr);
1207+
ttkernel::PackUntilizeUninitOp::create(rewriter, op->getLoc(), dst);
12071208
} else {
12081209
llvm_unreachable("unsupported tilize/untilize op");
12091210
}

lib/Conversion/D2MToTTNN/D2MToTTNN.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,8 @@ class MemrefAllocRewriter : public OpConversionPattern<memref::AllocOp> {
750750
auto memcfg = ttnn::MemoryConfigAttr::get(emptyLayoutAttr,
751751
deviceAttr.getWorkerGrid());
752752

753-
auto emptyOp = rewriter.create<ttnn::EmptyOp>(
754-
op.getLoc(), emptyTensorType, device,
753+
auto emptyOp = ttnn::EmptyOp::create(
754+
rewriter, op.getLoc(), emptyTensorType, device,
755755
ttnn::ShapeAttr::get(ctx, emptyTensorType.getShape()),
756756
ttcore::DataTypeAttr::get(ctx, emptyLayoutAttr.getDataType()),
757757
ttnn::LayoutAttr::get(ctx, emptyLayoutAttr.getLayout()), memcfg);

lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ class ShardyToTTIRManualComputationOpConversionPattern
228228
// Create a new mesh shard op.
229229
auto outputType = mlir::cast<mlir::RankedTensorType>(
230230
getTypeConverter()->convertType(opResult.getType()));
231-
auto meshShardOp = rewriter.create<mlir::tt::ttir::MeshShardOp>(
232-
loc, outputType, returnOperand.get(),
231+
auto meshShardOp = mlir::tt::ttir::MeshShardOp::create(
232+
rewriter, loc, outputType, returnOperand.get(),
233233
shardyMeshSharding->getShardType(),
234234
shardyMeshSharding->getShardDirection(),
235235
shardyMeshSharding->getShardShape(),

lib/Dialect/D2M/Transforms/GridSelection.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,9 @@ static void insertViewForTTNNDRAMTensor(Value operand,
552552
fakeShardedShape, metalTensor.getElementType(), viewOutputLayout);
553553

554554
builder.setInsertionPointAfter(castOp);
555-
auto viewOp = builder.create<d2m::ViewLayoutOp>(
556-
castOp.getLoc(), viewOutputTensor, castOp.getResult(),
557-
AffineMapAttr::get(reblockMap));
555+
auto viewOp = d2m::ViewLayoutOp::create(builder, castOp.getLoc(),
556+
viewOutputTensor, castOp.getResult(),
557+
AffineMapAttr::get(reblockMap));
558558
castOp.getResult().replaceAllUsesExcept(viewOp.getResult(), viewOp);
559559
}
560560

@@ -580,18 +580,20 @@ static void optimizeTTNNMetalLayoutCastOpGrid(
580580

581581
builder.setInsertionPointAfter(castOp);
582582

583-
auto newViewLayoutOp = builder.create<d2m::ViewLayoutOp>(
584-
castOp.getLoc(), newTensorType, castOp.getResult(), gridRemapping);
583+
auto newViewLayoutOp =
584+
d2m::ViewLayoutOp::create(builder, castOp.getLoc(), newTensorType,
585+
castOp.getResult(), gridRemapping);
585586

586587
// Reblock it back to original shape to preserve IR correctness.
587588
auto viewOutputType = utils::reblockTensor(
588589
newTensorType, outputLayout.getGridShape(outputType));
589590
auto reblockMap = ttmlir::utils::calculateReblockMap(
590591
newTensorType.getShape(), viewOutputType.getShape(),
591592
builder.getContext());
592-
auto revertingView = builder.create<d2m::ViewLayoutOp>(
593-
castOp.getLoc(), viewOutputType, newViewLayoutOp.getResult(), reblockMap,
594-
/*reinterpretLayout=*/false);
593+
auto revertingView =
594+
d2m::ViewLayoutOp::create(builder, castOp.getLoc(), viewOutputType,
595+
newViewLayoutOp.getResult(), reblockMap,
596+
/*reinterpretLayout=*/false);
595597

596598
castOp.getResult().replaceAllUsesExcept(revertingView.getResult(),
597599
newViewLayoutOp);

lib/Target/TTKernel/TTKernelToCpp.cpp

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -103,57 +103,62 @@ class ScopedModuleHelper {
103103
emitc::IncludeOp::create(
104104
*builder, loc, "api/compute/eltwise_unary/sfpu_split_includes.h",
105105
/*isStandard=*/false);
106-
builder->create<emitc::IncludeOp>(loc,
107-
"api/compute/eltwise_unary/recip.h",
108-
/*isStandard=*/false);
109-
builder->create<emitc::IncludeOp>(loc, "api/compute/eltwise_unary/fill.h",
110-
/*isStandard=*/false);
111-
builder->create<emitc::IncludeOp>(loc,
112-
"api/compute/eltwise_unary/negative.h",
113-
/*isStandard=*/false);
114-
builder->create<emitc::IncludeOp>(loc, "api/compute/eltwise_unary/sqrt.h",
115-
/*isStandard=*/false);
116-
builder->create<emitc::IncludeOp>(loc,
117-
"api/compute/eltwise_unary/rounding.h",
118-
/*isStandard=*/false);
119-
builder->create<emitc::IncludeOp>(
120-
loc, "api/compute/eltwise_unary/trigonometry.h",
121-
/*isStandard=*/false);
122-
builder->create<emitc::IncludeOp>(loc, "api/compute/eltwise_unary/gelu.h",
123-
/*isStandard=*/false);
124-
builder->create<emitc::IncludeOp>(loc,
125-
"api/compute/eltwise_unary/erf_erfc.h",
126-
/*isStandard=*/false);
127-
builder->create<emitc::IncludeOp>(
128-
loc, "api/compute/eltwise_unary/logical_not.h",
129-
/*isStandard=*/false);
130-
builder->create<emitc::IncludeOp>(loc, "api/compute/eltwise_unary/comp.h",
131-
/*isStandard=*/false);
132-
builder->create<emitc::IncludeOp>(loc,
133-
"api/compute/eltwise_unary/rsqrt.h",
134-
/*isStandard=*/false);
135-
builder->create<emitc::IncludeOp>(loc,
136-
"api/compute/eltwise_unary/typecast.h",
137-
/*isStandard=*/false);
138-
builder->create<emitc::IncludeOp>(loc,
139-
"api/compute/binary_bitwise_sfpu.h",
140-
/*isStandard=*/false);
141-
builder->create<emitc::IncludeOp>(
142-
loc, "api/compute/eltwise_unary/bitwise_not.h",
143-
/*isStandard=*/false);
144-
builder->create<emitc::IncludeOp>(loc, "api/compute/eltwise_unary/relu.h",
145-
/*isStandard=*/false);
146-
builder->create<emitc::IncludeOp>(
147-
loc, "api/compute/eltwise_unary/binop_with_scalar.h",
148-
/*isStandard=*/false);
149-
builder->create<emitc::IncludeOp>(loc,
150-
"api/compute/eltwise_unary/where.h",
151-
/*isStandard=*/false);
152-
builder->create<emitc::IncludeOp>(loc,
153-
"api/compute/eltwise_unary/clamp.h",
154-
/*isStandard=*/false);
155-
builder->create<emitc::IncludeOp>(loc, "api/compute/pack_untilize.h",
156-
/*isStandard=*/false);
106+
emitc::IncludeOp::create(*builder, loc,
107+
"api/compute/eltwise_unary/recip.h",
108+
/*isStandard=*/false);
109+
emitc::IncludeOp::create(*builder, loc,
110+
"api/compute/eltwise_unary/fill.h",
111+
/*isStandard=*/false);
112+
emitc::IncludeOp::create(*builder, loc,
113+
"api/compute/eltwise_unary/negative.h",
114+
/*isStandard=*/false);
115+
emitc::IncludeOp::create(*builder, loc,
116+
"api/compute/eltwise_unary/sqrt.h",
117+
/*isStandard=*/false);
118+
emitc::IncludeOp::create(*builder, loc,
119+
"api/compute/eltwise_unary/rounding.h",
120+
/*isStandard=*/false);
121+
emitc::IncludeOp::create(*builder, loc,
122+
"api/compute/eltwise_unary/trigonometry.h",
123+
/*isStandard=*/false);
124+
emitc::IncludeOp::create(*builder, loc,
125+
"api/compute/eltwise_unary/gelu.h",
126+
/*isStandard=*/false);
127+
emitc::IncludeOp::create(*builder, loc,
128+
"api/compute/eltwise_unary/erf_erfc.h",
129+
/*isStandard=*/false);
130+
emitc::IncludeOp::create(*builder, loc,
131+
"api/compute/eltwise_unary/logical_not.h",
132+
/*isStandard=*/false);
133+
emitc::IncludeOp::create(*builder, loc,
134+
"api/compute/eltwise_unary/comp.h",
135+
/*isStandard=*/false);
136+
emitc::IncludeOp::create(*builder, loc,
137+
"api/compute/eltwise_unary/rsqrt.h",
138+
/*isStandard=*/false);
139+
emitc::IncludeOp::create(*builder, loc,
140+
"api/compute/eltwise_unary/typecast.h",
141+
/*isStandard=*/false);
142+
emitc::IncludeOp::create(*builder, loc,
143+
"api/compute/binary_bitwise_sfpu.h",
144+
/*isStandard=*/false);
145+
emitc::IncludeOp::create(*builder, loc,
146+
"api/compute/eltwise_unary/bitwise_not.h",
147+
/*isStandard=*/false);
148+
emitc::IncludeOp::create(*builder, loc,
149+
"api/compute/eltwise_unary/relu.h",
150+
/*isStandard=*/false);
151+
emitc::IncludeOp::create(*builder, loc,
152+
"api/compute/eltwise_unary/binop_with_scalar.h",
153+
/*isStandard=*/false);
154+
emitc::IncludeOp::create(*builder, loc,
155+
"api/compute/eltwise_unary/where.h",
156+
/*isStandard=*/false);
157+
emitc::IncludeOp::create(*builder, loc,
158+
"api/compute/eltwise_unary/clamp.h",
159+
/*isStandard=*/false);
160+
emitc::IncludeOp::create(*builder, loc, "api/compute/pack_untilize.h",
161+
/*isStandard=*/false);
157162
// Helper for float-to-uint32 bit reinterpretation (used by scalar tile
158163
// ops).
159164
emitc::VerbatimOp::create(
@@ -261,7 +266,7 @@ void dprint(Arg &&arg, ArgV&&... argv) {
261266
auto experimentalPackUntilizeLLKs =
262267
StringRef(experimental_pack_untilize_llks_generated,
263268
experimental_pack_untilize_llks_generated_len);
264-
builder->create<emitc::VerbatimOp>(loc, experimentalPackUntilizeLLKs);
269+
emitc::VerbatimOp::create(*builder, loc, experimentalPackUntilizeLLKs);
265270
}
266271

267272
if (hasCall("experimental::get_noc_multicast_addr")) {

0 commit comments

Comments
 (0)