diff --git a/.github/get-docker-tag.sh b/.github/get-docker-tag.sh index 3e2ca8b2422..92b42e09b3f 100755 --- a/.github/get-docker-tag.sh +++ b/.github/get-docker-tag.sh @@ -5,6 +5,6 @@ # Calculate hash from the following files. This hash is used to tag the docker images. # Any change in these files will result in a new docker image build -DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci .github/Dockerfile.ird .github/Dockerfile.cibuildwheel env/CMakeLists.txt env/init_venv.sh env/build-requirements.txt env/ttnn-requirements.txt env/patches/shardy.patch env/patches/shardy_mpmd_pybinds.patch test/python/requirements.txt env/install-tt-triage.sh" +DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci .github/Dockerfile.ird .github/Dockerfile.cibuildwheel env/CMakeLists.txt env/init_venv.sh env/build-requirements.txt env/ttnn-requirements.txt env/patches/shardy.patch test/python/requirements.txt env/install-tt-triage.sh" DOCKERFILE_HASH=$(sha256sum $DOCKERFILE_HASH_FILES | sha256sum | cut -d ' ' -f 1) echo dt-$DOCKERFILE_HASH diff --git a/env/CMakeLists.txt b/env/CMakeLists.txt index 5a0fd1537a1..863d83916d0 100644 --- a/env/CMakeLists.txt +++ b/env/CMakeLists.txt @@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.20.0) project(ttmlir-toolchain LANGUAGES CXX C) set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274") -set(LLVM_PROJECT_VERSION "4efe170d858eb54432f520abb4e7f0086236748b") -set(STABLEHLO_VERSION "0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d") -set(SHARDY_VERSION "edfd6730ddfc39da5fbea8b6b202357fdf1cdb90") +set(LLVM_PROJECT_VERSION "1053047a4be7d1fece3adaf5e7597f838058c947") +set(STABLEHLO_VERSION "43550117525e77084ac1f83ba50febc0688f7958") +set(SHARDY_VERSION "4069e470c94d7b08f129ef1607aa2be8f4c06b53") set(LLVM_BUILD_TYPE MinSizeRel CACHE STRING "Build type for LLVM") include(ExternalProject) @@ -52,7 +52,7 @@ if(TTMLIR_BUILD_LLVM) llvm-project # Super hacky way to install the python dependencies before the build # Sync nanobind with tt-metal's pinned version to avoid ODR violations - PATCH_COMMAND bash -c "source ${CMAKE_CURRENT_SOURCE_DIR}/activate && pip install -r mlir/python/requirements.txt && pip install --force-reinstall 'nanobind==2.10.2' && git config user.email \"tt-mlir@tenstorrent.com\" && git config user.name \"tenstorrent\" && git apply --index \"${CMAKE_CURRENT_LIST_DIR}/patches/affine-allow-symbol-vars.patch\" && git commit -m \"tt-mlir related patch\"" + PATCH_COMMAND bash -c "source ${CMAKE_CURRENT_SOURCE_DIR}/activate && pip install -r mlir/python/requirements.txt && pip install --force-reinstall 'nanobind==2.10.2' && git config user.email \"tt-mlir@tenstorrent.com\" && git config user.name \"tenstorrent\" " CMAKE_GENERATOR Ninja CMAKE_ARGS -DCMAKE_BUILD_TYPE=${LLVM_BUILD_TYPE} @@ -97,7 +97,7 @@ ExternalProject_Add(shardy CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" - PATCH_COMMAND git config user.email "tt-mlir@tenstorrent.com" && git config user.name "tenstorrent" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy.patch" && git commit -m "tt-mlir related patch" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy_mpmd_pybinds.patch" && git commit -m "tt-mlir related patch" + PATCH_COMMAND git config user.email "tt-mlir@tenstorrent.com" && git config user.name "tenstorrent" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy.patch" && git commit -m "tt-mlir shardy.patch" ) if(TTMLIR_BUILD_LLVM) diff --git a/env/patches/shardy.patch b/env/patches/shardy.patch index c118388ca24..4d1eb4fd236 100644 --- a/env/patches/shardy.patch +++ b/env/patches/shardy.patch @@ -270,7 +270,7 @@ new file mode 100755 index 0000000..f23330d --- /dev/null +++ b/shardy/dialect/sdy/transforms/export/CMakeLists.txt -@@ -0,0 +1,94 @@ +@@ -0,0 +1,96 @@ +# Shardy MLIR Transform Export Passes and Pipeline + +set(LLVM_TARGET_DEFINITIONS passes.td) @@ -318,6 +318,8 @@ index 0000000..f23330d + sharding_constraint_to_reshard.cc + sink_data_flow_edges.cc + update_non_divisible_input_output_shardings.cc ++ convert_global_to_local.cc ++ remove_sub_axes_in_input_output_shardings.cc + + DEPENDS + SdyExplicitReshardsUtil @@ -635,106 +637,6 @@ index 0000000..dc33501 + SdyDialect + SdyTransformsPropagationShardingProjection +) -diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc -index a73917b..78cc827 100644 ---- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc -+++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc -@@ -346,6 +346,7 @@ void saveShardingOriginsOnModule( - // the case for the target of the edge, because if the source appears multiple - // times, then it's because it effects multiple other operands/results in the - // op. -+[[maybe_unused]] - bool insertSeenValue(Operation* op, const PropagationEdge& edge, - llvm::SmallDenseSet& seenValues) { - EdgeNode target = edge.target; -diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -index ab93067..b181797 100644 ---- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc -@@ -187,12 +187,12 @@ void addGatherScatterFactors( - - // We add factors for all collapsed slice dimensions. - for (int64_t collapsedSliceDim : collapsedSliceDims) { -- // To keep the operand dim sharded for gather, we need an all-reduce on the -- // result. -+ // For gather: use kNeedReplication to all-gather operand before gather, -+ // instead of kReduction which would insert all-reduce after gather. - addUnblockedFactorFn( - collapsedSliceDim, /*indicesDim=*/kNullDim, - /*slicesDim=*/kNullDim, inputType.getDimSize(collapsedSliceDim), -- isScatter ? FactorType::kPassThrough : FactorType::kReduction); -+ isScatter ? FactorType::kPassThrough : FactorType::kNeedReplication); - } - - // Add a factor for the index-vector-dim, if it's present. -@@ -303,6 +303,37 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, - } - return builder.build(); - }) -+ .Case( -+ [conservativePropagation](stablehlo::BatchNormInferenceOp bn) { -+ auto inTy = llvm::cast(bn.getOperand().getType()); -+ auto outTy = llvm::cast(bn.getResult().getType()); -+ -+ OpShardingRuleBuilder builder(bn); -+ -+ const int64_t numOperands = static_cast(bn->getNumOperands()); -+ llvm::SmallVector opDims(numOperands, kNullDim); -+ -+ for (auto [dU, dimSize] : llvm::enumerate(inTy.getShape())) { -+ const int64_t d = static_cast(dU); -+ std::fill(opDims.begin(), opDims.end(), kNullDim); -+ opDims[0] = d; -+ builder.addFactor(opDims, d, dimSize); -+ } -+ -+ const int64_t featAxis = static_cast(bn.getFeatureIndex()); -+ const int64_t C = outTy.getDimSize(featAxis); -+ -+ for (int64_t paramIdx : {1LL, 2LL, 3LL, 4LL}) { -+ std::fill(opDims.begin(), opDims.end(), kNullDim); -+ opDims[paramIdx] = 0; -+ auto factorType = conservativePropagation ? FactorType::kNeedReplication -+ : FactorType::kPassThrough; -+ builder.addFactor(opDims, kNullDim, C, -+ factorType, true); -+ } -+ -+ return builder.build(); -+ }) - .Case( - [](stablehlo::BitcastConvertOp bitcastConvert) { - ArrayRef inShape = -@@ -685,6 +716,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, - /*isBlocked=*/usedByRngBitGenerator) - .build(); - } -+ // Check if the custom call implements the ShardingRuleOpInterface. -+ if (auto shardingRuleOp = -+ llvm::dyn_cast(customCall.getOperation())) { -+ return shardingRuleOp.getShardingRule(); -+ } -+ - // TODO(b/327191011): output unregistered op stats instead. - static llvm::once_flag onceFlag; - emitOpWarningOnce( -@@ -1093,6 +1130,16 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, - return builder.build(); - }) - .Case([](stablehlo::ScatterOp scatter) { -+ // Check if the scatter op implements the ShardingRuleOpInterface. -+ if (auto shardingRuleOp = -+ llvm::dyn_cast(scatter.getOperation())) { -+ // Try to get custom rule - if it returns non-null, use it. -+ if (auto customRule = shardingRuleOp.getShardingRule()) { -+ return customRule; -+ } -+ // If custom rule returns null, fall through to default. -+ } -+ - OpShardingRuleBuilder builder(scatter); - - // Since all inputs and results have compatible shapes, we can look at diff --git a/shardy/integrations/c/CMakeLists.txt b/shardy/integrations/c/CMakeLists.txt new file mode 100644 index 0000000..fdd50c4 @@ -1191,10 +1093,10 @@ index 0000000..30a5cf9 + +#endif diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc -index da451fa..44c0ea2 100644 +index cd7fdc8..1b5aa5b 100644 --- a/shardy/integrations/python/ir/sdy_module.cc +++ b/shardy/integrations/python/ir/sdy_module.cc -@@ -109,7 +109,15 @@ NB_MODULE(_sdy, m) { +@@ -110,7 +110,15 @@ NB_MODULE(_sdy, m) { }) .def_property_readonly("size", [](MlirAttribute self) { return sdyMeshAxisAttrGetSize(self); @@ -1211,7 +1113,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "MeshAttr", sdyAttributeIsAMeshAttr) -@@ -133,7 +141,15 @@ NB_MODULE(_sdy, m) { +@@ -134,7 +142,15 @@ NB_MODULE(_sdy, m) { .def_property_readonly("axes", [](MlirAttribute self) { return propertyVector(self, sdyMeshAttrGetAxesSize, sdyMeshAttrGetAxesElem); @@ -1228,7 +1130,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "SubAxisInfoAttr", sdyAttributeIsASubAxisInfoAttr) -@@ -150,7 +166,15 @@ NB_MODULE(_sdy, m) { +@@ -151,7 +167,15 @@ NB_MODULE(_sdy, m) { [](MlirAttribute self) { return sdySubAxisInfoAttrGetPreSize(self); }) .def_property_readonly("size", [](MlirAttribute self) { return sdySubAxisInfoAttrGetSize(self); @@ -1245,7 +1147,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "AxisRefAttr", sdyAttributeIsAnAxisRefAttr) -@@ -175,7 +199,15 @@ NB_MODULE(_sdy, m) { +@@ -176,7 +200,15 @@ NB_MODULE(_sdy, m) { MlirAttribute subAxisInfo = sdyAxisRefAttrGetSubAxisInfo(self); return subAxisInfo.ptr == nullptr ? std::nullopt : std::optional(subAxisInfo); @@ -1262,7 +1164,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DimensionShardingAttr", sdyAttributeIsADimensionShardingAttr) -@@ -205,7 +237,15 @@ NB_MODULE(_sdy, m) { +@@ -206,7 +238,15 @@ NB_MODULE(_sdy, m) { .def_property_readonly("priority", [](MlirAttribute self) { int64_t priority = sdyDimensionShardingAttrGetPriority(self); return priority == -1 ? std::nullopt : std::optional(priority); @@ -1279,7 +1181,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TensorShardingAttr", sdyAttributeIsATensorShardingAttr) -@@ -251,7 +291,15 @@ NB_MODULE(_sdy, m) { +@@ -252,7 +292,15 @@ NB_MODULE(_sdy, m) { return propertyVector( self, sdyTensorShardingAttrGetUnreducedAxesSize, sdyTensorShardingAttrGetUnreducedAxesElem); @@ -1296,7 +1198,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TensorShardingPerValueAttr", -@@ -270,7 +318,15 @@ NB_MODULE(_sdy, m) { +@@ -271,7 +319,15 @@ NB_MODULE(_sdy, m) { return propertyVector( self, sdyTensorShardingPerValueAttrGetShardingsSize, sdyTensorShardingPerValueAttrGetShardingsElem); @@ -1313,7 +1215,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "DimMappingAttr", sdyAttributeIsADimMappingAttr) -@@ -288,7 +344,15 @@ NB_MODULE(_sdy, m) { +@@ -289,7 +345,15 @@ NB_MODULE(_sdy, m) { return propertyVector(self, sdyDimMappingAttrGetFactorIndicesSize, sdyDimMappingAttrGetFactorIndicesElem); @@ -1330,7 +1232,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "TensorMappingAttr", sdyAttributeIsATensorMappingAttr) -@@ -310,7 +374,15 @@ NB_MODULE(_sdy, m) { +@@ -311,7 +375,15 @@ NB_MODULE(_sdy, m) { }) .def_property_readonly("rank", [](MlirAttribute self) { return sdyTensorMappingAttrGetRank(self); @@ -1347,7 +1249,7 @@ index da451fa..44c0ea2 100644 mlir::python::nanobind_adaptors::mlir_attribute_subclass( m, "OpShardingRuleAttr", sdyAttributeIsAOpShardingRuleAttr) -@@ -394,6 +466,14 @@ NB_MODULE(_sdy, m) { +@@ -395,6 +467,14 @@ NB_MODULE(_sdy, m) { return propertyVector( self, sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize, sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem); @@ -1362,7 +1264,7 @@ index da451fa..44c0ea2 100644 }); mlir::python::nanobind_adaptors::mlir_attribute_subclass( -@@ -417,7 +497,67 @@ NB_MODULE(_sdy, m) { +@@ -418,7 +498,67 @@ NB_MODULE(_sdy, m) { }) .def("__len__", [](MlirAttribute& self) { return sdyManualAxesAttrGetAxesSize(self); diff --git a/env/patches/shardy_mpmd_pybinds.patch b/env/patches/shardy_mpmd_pybinds.patch deleted file mode 100644 index 7f4132af4ad..00000000000 --- a/env/patches/shardy_mpmd_pybinds.patch +++ /dev/null @@ -1,3772 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index ec8d310..2405027 100755 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -53,5 +53,7 @@ add_compile_options(-Wno-deprecated-declarations -Wno-unused-but-set-variable -W - add_subdirectory(shardy/common) - add_subdirectory(shardy/dialect/sdy/ir) - add_subdirectory(shardy/dialect/sdy/transforms) -+add_subdirectory(shardy/dialect/mpmd/ir) -+add_subdirectory(shardy/dialect/mpmd/transforms) - add_subdirectory(shardy/integrations/python/ir) - add_subdirectory(shardy/integrations/c) -\ No newline at end of file -diff --git a/shardy/dialect/mpmd/ir/CMakeLists.txt b/shardy/dialect/mpmd/ir/CMakeLists.txt -new file mode 100644 -index 0000000..e19a236 ---- /dev/null -+++ b/shardy/dialect/mpmd/ir/CMakeLists.txt -@@ -0,0 +1,104 @@ -+# Shardy MLIR MPMD dialect. -+ -+set(LLVM_TARGET_DEFINITIONS dialect.td) -+mlir_tablegen(dialect.h.inc -gen-dialect-decls -dialect=mpmd) -+mlir_tablegen(dialect.cc.inc -gen-dialect-defs -dialect=mpmd) -+add_public_tablegen_target(MpmdDialectIncGen) -+add_dependencies(mlir-headers MpmdDialectIncGen) -+add_mlir_doc(dialect MpmdDialect src/autogen/md/Dialect/ -gen-dialect-doc) -+ -+set(LLVM_TARGET_DEFINITIONS canonicalization.td) -+mlir_tablegen(canonicalization.cc.inc -gen-rewriters) -+add_public_tablegen_target(MpmdCanonicalizationIncGen) -+add_dependencies(mlir-headers MpmdCanonicalizationIncGen) -+ -+set(LLVM_TARGET_DEFINITIONS ops.td) -+mlir_tablegen(ops.h.inc -gen-op-decls) -+mlir_tablegen(ops.cc.inc -gen-op-defs) -+add_public_tablegen_target(MpmdOpsIncGen) -+add_dependencies(mlir-headers MpmdOpsIncGen) -+ -+set(LLVM_TARGET_DEFINITIONS types.td) -+mlir_tablegen(types.h.inc -gen-typedef-decls) -+mlir_tablegen(types.cc.inc -gen-typedef-defs) -+add_public_tablegen_target(MpmdTypesIncGen) -+add_dependencies(mlir-headers MpmdTypesIncGen) -+ -+set(LLVM_TARGET_DEFINITIONS attrs.td) -+mlir_tablegen(attrs.h.inc -gen-attrdef-decls) -+mlir_tablegen(attrs.cc.inc -gen-attrdef-defs) -+add_public_tablegen_target(MpmdAttrsIncGen) -+add_dependencies(mlir-headers MpmdAttrsIncGen) -+ -+set(LLVM_TARGET_DEFINITIONS enums.td) -+mlir_tablegen(enums.h.inc -gen-enum-decls) -+mlir_tablegen(enums.cc.inc -gen-enum-defs) -+add_public_tablegen_target(MpmdEnumsIncGen) -+add_dependencies(mlir-headers MpmdEnumsIncGen) -+ -+add_mlir_dialect_library(MpmdDialect -+ dialect.cc -+ utils.cc -+ -+ DEPENDS -+ MpmdDialectIncGen -+ MpmdOpsIncGen -+ MpmdAttrsIncGen -+ MpmdEnumsIncGen -+ MpmdTypesIncGen -+ MpmdCanonicalizationIncGen -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+ MLIRBytecodeOpInterface -+ MLIRFuncDialect -+ MLIRIR -+ MLIRInferTypeOpInterface -+ MLIRTransformUtils -+ MLIRShapeDialect -+ MLIRSideEffectInterfaces -+ MLIRSupport -+ StablehloAssemblyFormat -+ StablehloBase -+ StablehloOps -+ StablehloTypeInference -+) -+ -+target_include_directories(MpmdDialect INTERFACE -+ $ -+ $ -+) -+ -+add_mlir_dialect_library(MpmdRegister -+ register.cc -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MLIRFuncDialect -+ MLIRFuncAllExtensions -+ MLIRIR -+ StablehloOps -+) -+ -+target_include_directories(MpmdRegister INTERFACE -+ $ -+ $ -+) -+ -+add_mlir_dialect_library(MpmdFragmentExecutionRules -+ fragment_execution_rules.cc -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ LLVMSupport -+ MLIRIR -+ MLIRSupport -+ MLIRParser -+) -+ -+target_include_directories(MpmdFragmentExecutionRules INTERFACE -+ $ -+ $ -+) -+ -diff --git a/shardy/dialect/mpmd/ir/dialect.cc b/shardy/dialect/mpmd/ir/dialect.cc -index 5fc8c90..e211b65 100644 ---- a/shardy/dialect/mpmd/ir/dialect.cc -+++ b/shardy/dialect/mpmd/ir/dialect.cc -@@ -887,9 +887,9 @@ FragmentOp CreateMeshFragmentWithBody( - // Only user defined fragments can be assigned to a stage and any fragment - // created by the compiler is considered to be an inferred fragment. - // Therefore, the created fragment isn't assigned to a stage. -- FragmentOp fragment_op = FragmentOp::create(builder, loc, result_types, -- tensors, origin_attr, mesh_name, -- /*stage_id=*/IntegerAttr()); -+ FragmentOp fragment_op = builder.create( -+ loc, result_types, tensors, origin_attr, mesh_name, -+ /*stage_id=*/IntegerAttr()); - Block& fragment_block = fragment_op.getRegion().emplaceBlock(); - sdy::MeshAttr mesh_attr = GetMeshOrFail(fragment_op, mesh_name); - -@@ -901,8 +901,7 @@ FragmentOp CreateMeshFragmentWithBody( - fragment_block.args_end()); - - OpBuilder block_builder = OpBuilder::atBlockBegin(&fragment_block); -- ReturnOp::create(block_builder, loc, -- body_populator(arguments, block_builder)); -+ block_builder.create(loc, body_populator(arguments, block_builder)); - return fragment_op; - } - } // namespace -@@ -1354,9 +1353,9 @@ ForOp ForOp::create(Location loc, ValueRange tensors, uint32_t iterations, - OpBuilder& builder, ForOpBodyPopulator body_populator, - uint32_t unroll_factor) { - TypeRange result_types = tensors.getTypes(); -- auto op = ForOp::create( -- builder, loc, result_types, tensors, iterations, -- unroll_factor == 1 ? nullptr : builder.getUI32IntegerAttr(unroll_factor)); -+ auto op = builder.create( -+ loc, result_types, tensors, iterations, -+ unroll_factor == 1 ? nullptr : builder.getUI32IntegerAttr(unroll_factor)); - - Block& block = op.getRegion().emplaceBlock(); - for (Value operand : tensors) { -@@ -1369,8 +1368,8 @@ ForOp ForOp::create(Location loc, ValueRange tensors, uint32_t iterations, - ArrayRef args(block.args_begin(), block.args_end()); - - OpBuilder block_builder = OpBuilder::atBlockBegin(&block); -- ReturnOp::create( -- block_builder, loc, -+ block_builder.create( -+ loc, - body_populator(args.drop_back(), /*index=*/args.back(), block_builder)); - return op; - } -diff --git a/shardy/dialect/mpmd/transforms/CMakeLists.txt b/shardy/dialect/mpmd/transforms/CMakeLists.txt -new file mode 100644 -index 0000000..0728af8 ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/CMakeLists.txt -@@ -0,0 +1,31 @@ -+# Shardy MLIR MPMD Transforms Passes -+ -+add_subdirectory(common) -+add_subdirectory(export) -+add_subdirectory(import) -+add_subdirectory(optimize) -+add_subdirectory(sharding_propagation) -+ -+add_mlir_library(MpmdTransformsPasses -+ passes.cc -+ -+ DEPENDS -+ MpmdTransformsCommonPasses -+ MpmdTransformsExportPasses -+ MpmdTransformsImportPasses -+ MpmdTransformsOptimizePasses -+ MpmdTransformsShardingPropagationPasses -+ -+ LINK_LIBS PUBLIC -+ MLIRPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsExportPasses -+ MpmdTransformsImportPasses -+ MpmdTransformsOptimizePasses -+ MpmdTransformsShardingPropagationPasses -+) -+ -+target_include_directories(MpmdTransformsPasses INTERFACE -+ $ -+ $ -+) -diff --git a/shardy/dialect/mpmd/transforms/common/CMakeLists.txt b/shardy/dialect/mpmd/transforms/common/CMakeLists.txt -new file mode 100644 -index 0000000..f717544 ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/common/CMakeLists.txt -@@ -0,0 +1,88 @@ -+# Shardy MLIR MPMD Transforms Common -+ -+set(LLVM_TARGET_DEFINITIONS passes.td) -+mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdCommon) -+add_public_tablegen_target(MpmdTransformsCommonPassesIncGen) -+add_dependencies(mlir-headers MpmdTransformsCommonPassesIncGen) -+ -+add_mlir_library(MpmdTransformsCommonUtils -+ utils.cc -+ -+ DEPENDS -+ MpmdDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ LLVMSupport -+ MLIRFuncDialect -+ MLIRIR -+ MLIRTransformUtils -+ MLIRSupport -+ MLIRPass -+ MLIRTransforms -+) -+ -+add_mlir_library(MpmdTransformsCommonSimplifyRegionOpBase -+ simplify_region_op_base.cc -+ -+ DEPENDS -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsCommonUtils -+ LLVMSupport -+ MLIRIR -+ MLIRSideEffectInterfaces -+ MLIRSupport -+ MLIRTransformUtils -+) -+ -+add_mlir_library(MpmdTransformsCommonDistributedFunctionPass -+ distributed_function_pass.cc -+ -+ DEPENDS -+ MpmdDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MLIRFuncDialect -+ MLIRPass -+) -+ -+add_mlir_library(MpmdTransformsCommonPasses -+ absorb_inferred_fragments.cc -+ call_rewrites.cc -+ copy_constants.cc -+ fragment_dce.cc -+ fragment_dedup.cc -+ merge_fragments.cc -+ merge_transfers.cc -+ remove_transfer_cycles.cc -+ rule_based_merge.cc -+ split_bwd_fragments.cc -+ uniquify_function_inputs_outputs.cc -+ unroll_for_loops.cc -+ scheduler_preprocess.cc -+ -+ DEPENDS -+ MpmdTransformsCommonPassesIncGen -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+ StablehloOps -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonUtils -+ MLIRAnalysis -+ MLIRDataLayoutInterfaces -+ MLIRFuncDialect -+ MLIRIR -+ MLIRTransformUtils -+ MLIRLoopLikeInterface -+ MLIRPass -+ MLIRRewrite -+ MLIRSideEffectInterfaces -+ MLIRSupport -+ MLIRTransforms -+) -diff --git a/shardy/dialect/mpmd/transforms/common/call_rewrites.cc b/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -index 67022ca..d1da0a6 100644 ---- a/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -+++ b/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -@@ -277,9 +277,9 @@ class EraseUnusedCalleeBlockArgumentsPass - // Alas, we cannot directly erase results of an op, so we need to create - // a new call op, and use it to replace the old one. - rewriter.setInsertionPoint(call_op); -- auto new_call_op = -- CallOp::create(rewriter, call_op.getLoc(), result_types, -- call_op->getOperands(), call_op.getCalleeAttr()); -+ auto new_call_op = rewriter.create(call_op.getLoc(), result_types, -+ call_op->getOperands(), -+ call_op.getCalleeAttr()); - new_call_op->setDiscardableAttrs(call_op->getDiscardableAttrDictionary()); - for (auto [new_result, old_result_index] : - llvm::zip_equal(new_call_op.getResults(), kept_results.set_bits())) { -diff --git a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -index 4447b06..e68d4ca 100644 ---- a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -+++ b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -@@ -89,7 +89,7 @@ bool TypeHasOneElement(Type type) { - - // Returns true if `op` is an inter-mesh TransferOp whose global type has only - // one element. --bool IsNonScalarInterMeshTransfer(Operation* op) { -+[[ maybe_unused ]] bool IsNonScalarInterMeshTransfer(Operation* op) { - TransferOp transfer_op = DynCastInterMeshTransfer(op); - return transfer_op && - !TypeHasOneElement(transfer_op.getType().getGlobalTensorType()); -@@ -336,7 +336,7 @@ FailureOr MergeFragmentBasePass::MergeFragmentsRewrite( - producer_op.getMeshNameAttr(), - /*stage_id=*/GetMergedStageIdAttribute(producer_op, mergeable_user)); - -- for (const auto [attr_name, attr] : merged_attributes) { -+ for (const auto &[attr_name, attr] : merged_attributes) { - merged_fragment->setAttr(attr_name, attr); - } - -diff --git a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -index 7aaa133..f76f2e1 100644 ---- a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -+++ b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -@@ -71,8 +71,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, - } - - auto loc = return_op->getLoc(); -- auto fragment_op = FragmentOp::create( -- builder, loc, fragment_return_types, fragment_operands, -+ auto fragment_op = builder.create( -+ loc, fragment_return_types, fragment_operands, - /*user_origin=*/ArrayAttr::get(builder.getContext(), {}), - /*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr()); - Block& fragment_block = fragment_op.getRegion().emplaceBlock(); -@@ -98,7 +98,7 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, - } - } - auto block_builder = OpBuilder::atBlockEnd(&fragment_block); -- ReturnOp::create(block_builder, loc, returned_values); -+ block_builder.create(loc, returned_values); - } - - // Replaces the return values of the function with transfer ops. -diff --git a/shardy/dialect/mpmd/transforms/export/CMakeLists.txt b/shardy/dialect/mpmd/transforms/export/CMakeLists.txt -new file mode 100644 -index 0000000..a42701e ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/export/CMakeLists.txt -@@ -0,0 +1,75 @@ -+# Shardy MLIR MPMD Transform Export Passes and Pipeline -+ -+set(LLVM_TARGET_DEFINITIONS passes.td) -+mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdExport) -+add_public_tablegen_target(MpmdTransformsExportPassesIncGen) -+add_dependencies(mlir-headers MpmdTransformsExportPassesIncGen) -+ -+add_mlir_library(MpmdTransformsExportUtils -+ utils.cc -+ -+ DEPENDS -+ MpmdDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MLIRAnalysis -+ MLIRFuncDialect -+ MLIRIR -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsExportNamingUtils -+ naming_utils.cc -+ -+ DEPENDS -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ MLIRFuncDialect -+ MLIRPass -+ LLVMSupport -+ MLIRIR -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsExportPasses -+ export_pipeline.cc -+ lower_to_fragment_calls.cc -+ mark_aliasing_and_donation.cc -+ mark_fragment_reserved_memory.cc -+ mark_input_output_with_layouts.cc -+ mark_offloaded_input_output.cc -+ reschedule_ops.cc -+ -+ DEPENDS -+ MpmdDialect -+ MpmdTransformsExportPassesIncGen -+ MpmdTransformsExportNamingUtils -+ MpmdTransformsExportUtils -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ SdyDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsExportNamingUtils -+ MpmdTransformsExportUtils -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ SdyDialect -+ LLVMSupport -+ MLIRAnalysis -+ MLIRFuncDialect -+ MLIRIR -+ MLIRPass -+ MLIRSupport -+ MLIRTransformUtils -+ MLIRTransforms -+ StablehloOps -+) -diff --git a/shardy/dialect/mpmd/transforms/import/CMakeLists.txt b/shardy/dialect/mpmd/transforms/import/CMakeLists.txt -new file mode 100644 -index 0000000..c4c4e81 ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/import/CMakeLists.txt -@@ -0,0 +1,125 @@ -+# Shardy MLIR MPMD Transforms Import Passes and Pipeline -+ -+set(LLVM_TARGET_DEFINITIONS passes.td) -+mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdImport) -+add_public_tablegen_target(MpmdTransformsImportPassesIncGen) -+add_dependencies(mlir-headers MpmdTransformsImportPassesIncGen) -+ -+add_mlir_library(MpmdTransformsImportMeshAssignmentMap -+ mesh_assignment_map.cc -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+) -+ -+add_mlir_library(MpmdTransformsImportMeshInferenceOrigins -+ mesh_inference_origins.cc -+ -+ DEPENDS -+ MpmdDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ LLVMSupport -+ MLIRIR -+ MLIRPass -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsImportMeshesWithOrigins -+ meshes_with_origins.cc -+ -+ DEPENDS -+ MpmdDialect -+ MpmdTransformsImportMeshInferenceOrigins -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MpmdTransformsImportMeshInferenceOrigins -+ LLVMSupport -+ MLIRIR -+ MLIRSupport -+) -+ -+ -+add_mlir_library(MpmdTransformsImportMeshInferenceUtils -+ mesh_inference_utils.cc -+ -+ DEPENDS -+ MpmdTransformsImportMeshesWithOrigins -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ SdyDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsImportMeshesWithOrigins -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ SdyDialect -+ LLVMSupport -+ MLIRFuncDialect -+ MLIRIR -+ MLIRPass -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsImportShardingConstraints -+ sharding_constraints.cc -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+) -+ -+add_mlir_library(MpmdTransformsImportPasses -+ copy_topology_from_main.cc -+ enforce_equisharding.cc -+ import_pipeline.cc -+ infer_mesh_assignment.cc -+ infer_mesh_validation.cc -+ insert_nameless_clones_of_negligible_ops.cc -+ introduce_transfers.cc -+ map_input_output_to_mesh.cc -+ map_named_ops_to_mpmd_ops.cc -+ simplify_named_computations.cc -+ validate_named_ops_in_mpmd_func.cc -+ generate_sdy_meshes_from_topology_pass.cc -+ -+ DEPENDS -+ MpmdTransformsImportMeshAssignmentMap -+ MpmdTransformsImportMeshInferenceOrigins -+ MpmdTransformsImportMeshInferenceUtils -+ MpmdTransformsImportMeshesWithOrigins -+ MpmdTransformsImportPassesIncGen -+ MpmdTransformsImportShardingConstraints -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ MpmdTransformsCommonSimplifyRegionOpBase -+ SdyDialect -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsImportMeshAssignmentMap -+ MpmdTransformsImportMeshInferenceOrigins -+ MpmdTransformsImportMeshInferenceUtils -+ MpmdTransformsImportMeshesWithOrigins -+ MpmdTransformsImportShardingConstraints -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ MpmdTransformsCommonSimplifyRegionOpBase -+ SdyDialect -+ LLVMSupport -+ MLIRFuncDialect -+ MLIRIR -+ MLIRPass -+ MLIRRewrite -+ MLIRSideEffectInterfaces -+ MLIRSupport -+ MLIRTransforms -+ MLIRTransformUtils -+ StablehloOps -+ StablehloPasses -+ StablehloOptimizationPasses -+) -\ No newline at end of file -diff --git a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -index 45f6523..9ef9270 100644 ---- a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -+++ b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -@@ -462,8 +462,8 @@ class LowerMpmdReducePattern final : public OpRewritePattern { - if (reduced_val.getType() == user_type) { - transferred_intermediates.push_back(reduced_val); - } else { -- transferred_intermediates.push_back(TransferOp::create( -- rewriter, reduced_val.getLoc(), user_type, reduced_val)); -+ transferred_intermediates.push_back(rewriter.create( -+ reduced_val.getLoc(), user_type, reduced_val)); - } - } - -@@ -1096,8 +1096,7 @@ void AssignInputAndOutputToMesh(FuncOp func, BlockArgument input_arg, - // Assign the output to the mesh. - if (!isa(return_operand.get().getType())) { - rewriter.setInsertionPoint(return_operand.getOwner()); -- return_operand.set(AssignOp::create( -- rewriter, -+ return_operand.set(rewriter.create( - GetResultInfoLoc(func, return_operand.getOperandNumber()) - .value_or(return_operand.get().getLoc()), - return_operand.get(), mesh_name, mesh_attr, kIoConstraintOutputOrigin)); -@@ -1109,8 +1108,8 @@ void AssignInputAndOutputToMesh(FuncOp func, BlockArgument input_arg, - input_arg.setType(MeshTensorType::getFullyReplicated( - input_arg.getContext(), mesh_name, mesh_attr, - cast(input_arg.getType()))); -- auto unassign = UnassignOp::create(rewriter, input_arg.getLoc(), input_arg, -- kIoConstraintInputOrigin); -+ auto unassign = rewriter.create(input_arg.getLoc(), input_arg, -+ kIoConstraintInputOrigin); - rewriter.replaceAllUsesExcept(input_arg, unassign, unassign); - } - } -@@ -1293,8 +1292,7 @@ class InferMeshAssignMeshForFuncLeavesPass - } - mesh_name = first_mesh_name; - } -- return_op_operand.set(AssignOp::create( -- builder, -+ return_op_operand.set(builder.create( - GetResultInfoLoc(func, return_op_operand.getOperandNumber()) - .value_or(return_operand.getLoc()), - return_operand, *mesh_name, GetMeshByName(meshes_by_name, *mesh_name), -@@ -1405,8 +1403,8 @@ class InferMeshAssignMeshForFuncLeavesPass - rewriter.setInsertionPointAfter(op); - sdy::MeshAttr mesh = GetMeshByName(meshes_by_name, mesh_name); - for (Value res : op->getResults()) { -- AssignOp::create(rewriter, op->getLoc(), res, mesh_name, mesh, -- kInferredUnusedOrigin); -+ rewriter.create(op->getLoc(), res, mesh_name, mesh, -+ kInferredUnusedOrigin); - } - - ClearUseSet(op); -@@ -1491,10 +1489,10 @@ class InferMeshAssignMeshForFuncLeavesPass - preferred_meshes.GetPrioritizedMeshName().value_or(first_mesh_name); - } - Value operand_val = operand.get(); -- AssignOp assign = AssignOp::create( -- builder, operand_val.getLoc(), operand_val, *mesh_name, -+ AssignOp assign = builder.create( -+ operand_val.getLoc(), operand_val, *mesh_name, - GetMeshByName(meshes_by_name, *mesh_name), TerminalNodesOrigin(op)); -- operand.set(UnassignOp::create(builder, operand_val.getLoc(), assign)); -+ operand.set(builder.create(operand_val.getLoc(), assign)); - } - } - -@@ -1537,8 +1535,8 @@ void ConvertConcatReduceOp(Operation* op, RewriterBase& rewriter) { - SmallVector reshaped_operands; - reshaped_operands.reserve(concat.getOperands().size()); - for (Value operand : concat.getOperands()) { -- auto reshape = stablehlo::ReshapeOp::create( -- rewriter, operand.getLoc(), reduce->getResultTypes().front(), operand); -+ auto reshape = rewriter.create( -+ operand.getLoc(), reduce->getResultTypes().front(), operand); - if (operand.getDefiningOp()) { - reshape->setDiscardableAttrs( - operand.getDefiningOp()->getDiscardableAttrDictionary()); -@@ -1811,8 +1809,8 @@ void AssignCalleeFuncResultsUsingAnalysis( - // meshes, we copy it such that each result corresponds to a single mesh. - for (auto [i, mesh_name] : llvm::enumerate(mesh_names.getArrayRef())) { - auto assign = -- AssignOp::create(rewriter, return_val.getLoc(), return_val, mesh_name, -- GetMeshByName(meshes_by_name, mesh_name)); -+ rewriter.create(return_val.getLoc(), return_val, mesh_name, -+ GetMeshByName(meshes_by_name, mesh_name)); - if (i == 0) { - new_operands[res_num] = assign; - } else { -@@ -1891,10 +1889,10 @@ void AssignCalleeFuncArgsToAssignUsers( - UnassignOp unassign_op; - if (i == 0) { - arg.setType(mesh_type); -- unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -+ unassign_op = rewriter.create(arg.getLoc(), arg); - } else { -- unassign_op = UnassignOp::create( -- rewriter, arg.getLoc(), body.addArgument(mesh_type, arg.getLoc())); -+ unassign_op = rewriter.create( -+ arg.getLoc(), body.addArgument(mesh_type, arg.getLoc())); - } - - if (auto users_it = assign_users_by_mesh_name.find(mesh_name); -@@ -1928,25 +1926,24 @@ void RewriteAccordingToUpdatedCallee(CallOp call_op, RewriterBase& rewriter) { - continue; - } - SDY_CHECK(isa(call_body.getArgument(arg_num).getType())); -- new_operands[arg_num] = -- AssignOp::create(rewriter, operand.getLoc(), -- call_body.getArgument(arg_num).getType(), operand); -+ new_operands[arg_num] = rewriter.create( -+ operand.getLoc(), call_body.getArgument(arg_num).getType(), operand); - - if (auto copies = - callee.getArgAttrOfType(arg_num, kMpmdCopied)) { - for (int64_t cloned_arg_index : copies.asArrayRef()) { - SDY_CHECK(isa( - call_body.getArgument(cloned_arg_index).getType())); -- new_operands[cloned_arg_index] = AssignOp::create( -- rewriter, operand.getLoc(), -- call_body.getArgument(cloned_arg_index).getType(), operand); -+ new_operands[cloned_arg_index] = rewriter.create( -+ operand.getLoc(), call_body.getArgument(cloned_arg_index).getType(), -+ operand); - } - } - } - - // Create the new call and copy attrs over. -- auto new_call_op = CallOp::create( -- rewriter, call_op.getLoc(), call_body.getTerminator()->getOperandTypes(), -+ auto new_call_op = rewriter.create( -+ call_op.getLoc(), call_body.getTerminator()->getOperandTypes(), - new_operands, call_op.getCalleeAttr()); - new_call_op->setDiscardableAttrs(call_op->getDiscardableAttrDictionary()); - -@@ -1972,9 +1969,8 @@ void RewriteAccordingToUpdatedCallee(CallOp call_op, RewriterBase& rewriter) { - SDY_CHECK(arg_num_it != type_to_arg_num.end()) - << "Argument number for type " << debugString(assign_user.getType()) - << " not found"; -- assign_user.setOperand( -- UnassignOp::create(rewriter, assign_user.getLoc(), -- new_call_op.getResult(arg_num_it->second))); -+ assign_user.setOperand(rewriter.create( -+ assign_user.getLoc(), new_call_op.getResult(arg_num_it->second))); - } - } - } -@@ -2061,7 +2057,7 @@ bool AssignEntrypointFuncArgsToAssignUsers(FuncOp entrypoint_func, - cast(arg.getType()), - memory_kind)); - -- UnassignOp unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -+ UnassignOp unassign_op = rewriter.create(arg.getLoc(), arg); - rewriter.replaceAllUsesExcept(arg, unassign_op, unassign_op); - } - return true; -@@ -2187,12 +2183,12 @@ void AbsorbMeshlessProducer(FragmentOp consumer, Operation* op, - } - rewriter.setInsertionPoint(consumer); - for (Value operand : op_operands_and_free_vars) { -- new_consumer_operands.push_back( -- AssignOp::create(rewriter, operand.getLoc(), -- MeshTensorType::getFullyReplicated( -- operand.getContext(), mesh_name, mesh_attr, -- cast(operand.getType())), -- operand)); -+ new_consumer_operands.push_back(rewriter.create( -+ operand.getLoc(), -+ MeshTensorType::getFullyReplicated( -+ operand.getContext(), mesh_name, mesh_attr, -+ cast(operand.getType())), -+ operand)); - } - consumer->setOperands(new_consumer_operands); - } -@@ -2316,8 +2312,8 @@ void RewriteForOpTerminator( - SDY_CHECK_LE(mesh_names.size(), 1) - << "Multiple mesh names found for return value"; - -- new_operands.push_back(AssignOp::create( -- rewriter, return_val.getLoc(), return_val, mesh_names[0], -+ new_operands.push_back(rewriter.create( -+ return_val.getLoc(), return_val, mesh_names[0], - GetMeshByName(meshes_by_name, mesh_names[0]))); - } - -@@ -2382,7 +2378,7 @@ void RewriteForOpArgsAndTypes( - arg.getContext(), mesh_names[0], - GetMeshByName(meshes_by_name, mesh_names[0]), local_type); - arg.setType(mesh_type); -- UnassignOp unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -+ UnassignOp unassign_op = rewriter.create(arg.getLoc(), arg); - - if (auto users_it = assign_users_by_mesh_name.find(mesh_names[0]); - users_it != assign_users_by_mesh_name.end()) { -@@ -2414,9 +2410,8 @@ void RewriteForOpOperands(ForOp for_op, RewriterBase& rewriter) { - new_operands[arg_num] = operand; - continue; - } -- new_operands[arg_num] = -- AssignOp::create(rewriter, operand.getLoc(), -- for_body.getArgument(arg_num).getType(), operand); -+ new_operands[arg_num] = rewriter.create( -+ operand.getLoc(), for_body.getArgument(arg_num).getType(), operand); - } - - for_op->setOperands(new_operands); -@@ -2430,7 +2425,7 @@ void RewriteForOpResults(ForOp for_op, RewriterBase& rewriter) { - for (Operation* user : res.getUsers()) { - if (auto assign_user = dyn_cast(user)) { - assign_user.setOperand( -- UnassignOp::create(rewriter, assign_user.getLoc(), res)); -+ rewriter.create(assign_user.getLoc(), res)); - } - } - } -diff --git a/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt b/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt -new file mode 100644 -index 0000000..0607049 ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt -@@ -0,0 +1,72 @@ -+# Shardy MLIR MPMD Transforms Optimize -+ -+set(LLVM_TARGET_DEFINITIONS passes.td) -+mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdOptimize) -+add_public_tablegen_target(MpmdTransformsOptimizePassesIncGen) -+add_dependencies(mlir-headers MpmdTransformsOptimizePassesIncGen) -+ -+add_mlir_library(MpmdTransformsOptimizeUtils -+ utils.cc -+ -+ DEPENDS -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ LLVMSupport -+ MLIRIR -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsOptimizePipelineSchedule -+ pipeline_schedule.cc -+ -+ DEPENDS -+ MpmdTransformsOptimizeUtils -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsOptimizeUtils -+ MpmdDialect -+ MpmdTransformsCommonUtils -+ LLVMSupport -+ MLIRIR -+ MLIRSupport -+) -+ -+add_mlir_library(MpmdTransformsOptimizePasses -+ optimize_pipeline.cc -+ remat_fragment.cc -+ scheduler.cc -+ rule_based_schedule.cc -+ -+ DEPENDS -+ MpmdTransformsOptimizePassesIncGen -+ MpmdTransformsOptimizePipelineSchedule -+ MpmdTransformsOptimizeUtils -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdFragmentExecutionRules -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ -+ LINK_LIBS PUBLIC -+ MpmdTransformsOptimizePipelineSchedule -+ MpmdTransformsOptimizeUtils -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdFragmentExecutionRules -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ LLVMSupport -+ MLIRAnalysis -+ MLIRFuncDialect -+ MLIRIR -+ MLIRPass -+ MLIRSupport -+ MLIRTransforms -+ MLIRTransformUtils -+) -diff --git a/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt b/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt -new file mode 100644 -index 0000000..a8a8b05 ---- /dev/null -+++ b/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt -@@ -0,0 +1,41 @@ -+# Shardy MLIR MPMD Transforms Sharding Propagation -+ -+set(LLVM_TARGET_DEFINITIONS passes.td) -+mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdShardingPropagation) -+add_public_tablegen_target(MpmdTransformsShardingPropagationPassesIncGen) -+add_dependencies(mlir-headers MpmdTransformsShardingPropagationPassesIncGen) -+ -+add_mlir_library(MpmdTransformsShardingPropagationPasses -+ convert_sdy_constants.cc -+ convert_sdy_shardings_to_mpmd_types.cc -+ enforce_user_shardings.cc -+ extract_reshards_from_inter_mesh_transfers.cc -+ sharding_propagation_pipeline.cc -+ -+ DEPENDS -+ MpmdTransformsShardingPropagationPassesIncGen -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ SdyDialect -+ SdyExplicitReshardsUtil -+ SdyTransformsPropagationPasses -+ -+ LINK_LIBS PUBLIC -+ MpmdDialect -+ MpmdTransformsCommonDistributedFunctionPass -+ MpmdTransformsCommonPasses -+ MpmdTransformsCommonUtils -+ SdyDialect -+ SdyExplicitReshardsUtil -+ SdyTransformsPropagationPasses -+ LLVMSupport -+ MLIRFuncDialect -+ MLIRIR -+ MLIRPass -+ MLIRRewrite -+ MLIRSupport -+ MLIRTransformUtils -+ StablehloOps -+) -diff --git a/shardy/integrations/c/CMakeLists.txt b/shardy/integrations/c/CMakeLists.txt -index fdd50c4..1c3a624 100644 ---- a/shardy/integrations/c/CMakeLists.txt -+++ b/shardy/integrations/c/CMakeLists.txt -@@ -1,8 +1,39 @@ - add_mlir_public_c_api_library(SdyCAPI - PARTIAL_SOURCES_INTENDED -- attributes.cc -- dialect.cc -- passes.cc -+ attributes_sdy.cc -+ dialect_sdy.cc -+ passes_sdy.cc -+ -+ DEPENDS -+ SdyDialect -+ SdyTransformsPasses -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+ MLIRBytecodeOpInterface -+ MLIRFuncDialect -+ MLIRIR -+ MLIRInferTypeOpInterface -+ MLIRTransformUtils -+ MLIRShapeDialect -+ MLIRSideEffectInterfaces -+ MLIRSupport -+ StablehloAssemblyFormat -+ StablehloOps -+ StablehloTypeInference -+ SdyDialect -+ SdyTransformsPasses -+) -+ -+add_mlir_public_c_api_library(MpmdCAPI -+ PARTIAL_SOURCES_INTENDED -+ passes_mpmd.cc -+ dialect_mpmd.cc -+ attributes_mpmd.cc -+ -+ DEPENDS -+ MpmdDialect -+ MpmdTransformsPasses - - LINK_LIBS PUBLIC - LLVMSupport -@@ -17,4 +48,6 @@ add_mlir_public_c_api_library(SdyCAPI - StablehloAssemblyFormat - StablehloOps - StablehloTypeInference -+ MpmdDialect -+ MpmdTransformsPasses - ) -diff --git a/shardy/integrations/c/attributes_mpmd.cc b/shardy/integrations/c/attributes_mpmd.cc -new file mode 100644 -index 0000000..355107f ---- /dev/null -+++ b/shardy/integrations/c/attributes_mpmd.cc -@@ -0,0 +1,126 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "shardy/integrations/c/attributes_mpmd.h" -+ -+#include -+#include -+ -+#include "mlir-c/IR.h" -+#include "mlir-c/Support.h" -+#include "mlir/CAPI/IR.h" -+#include "mlir/CAPI/Support.h" -+#include "mlir/IR/Attributes.h" -+#include "mlir/Support/LLVM.h" -+#include "shardy/dialect/mpmd/ir/dialect.h" -+ -+namespace { -+ -+namespace mpmd = ::mlir::mpmd; -+ -+template -+AttrTy unwrapAttr(MlirAttribute attr) { -+ return mlir::cast(unwrap(attr)); -+} -+ -+template -+mlir::ArrayRef unwrapAttrs(const MlirAttribute* attrs, -+ intptr_t nAttrs) { -+ return mlir::ArrayRef(reinterpret_cast(attrs), nAttrs); -+} -+ -+} // namespace -+ -+extern "C" { -+ -+//===----------------------------------------------------------------------===// -+// NamedMeshAttr -+//===----------------------------------------------------------------------===// -+ -+bool mpmdAttributeIsANamedMeshAttr(MlirAttribute attr) { -+ return mlir::isa(unwrap(attr)); -+} -+ -+MlirAttribute mpmdNamedMeshAttrGet(MlirContext ctx, MlirStringRef name, MlirAttribute mesh) { -+ return wrap(mpmd::NamedMeshAttr::get(unwrap(ctx), unwrap(name), unwrapAttr(mesh))); -+} -+ -+MlirStringRef mpmdNamedMeshAttrGetName(MlirAttribute attr) { -+ return wrap(unwrapAttr(attr).getName()); -+} -+ -+MlirAttribute mpmdNamedMeshAttrGetMesh(MlirAttribute attr) { -+ mlir::sdy::MeshAttr mesh = unwrapAttr(attr).getMesh(); -+ return wrap(mesh); -+} -+ -+//===----------------------------------------------------------------------===// -+// TopologyAttr -+//===----------------------------------------------------------------------===// -+ -+bool mpmdAttributeIsATopologyAttr(MlirAttribute attr) { -+ return mlir::isa(unwrap(attr)); -+} -+ -+MlirAttribute mpmdTopologyAttrGet(MlirContext ctx, intptr_t nMeshes, const MlirAttribute* meshes) { -+ return wrap(mpmd::TopologyAttr::get( -+ unwrap(ctx), unwrapAttrs(meshes, nMeshes))); -+} -+ -+int64_t mpmdTopologyAttrGetMeshesSize(MlirAttribute attr) { -+ return unwrapAttr(attr).getMeshes().size(); -+} -+ -+MlirAttribute mpmdTopologyAttrGetMeshesElem(MlirAttribute attr, int64_t pos) { -+ return wrap(unwrapAttr(attr).getMeshes()[pos]); -+} -+ -+//===----------------------------------------------------------------------===// -+// UserOriginAttr -+//===----------------------------------------------------------------------===// -+ -+bool mpmdAttributeIsAUserOriginAttr(MlirAttribute attr) { -+ return mlir::isa(unwrap(attr)); -+} -+ -+MlirAttribute mpmdUserOriginAttrGet(MlirContext ctx, MlirAttribute userName, int64_t transposeCount) { -+ return wrap(mpmd::UserOriginAttr::get(unwrap(ctx), unwrapAttr(userName), transposeCount)); -+} -+ -+MlirStringRef mpmdUserOriginAttrGetUserName(MlirAttribute attr) { -+ return wrap(unwrapAttr(attr).getUserName().getValue()); -+} -+ -+int64_t mpmdUserOriginAttrGetTransposeCount(MlirAttribute attr) { -+ return unwrapAttr(attr).getTransposeCount(); -+} -+ -+//===----------------------------------------------------------------------===// -+// OriginAttr -+//===----------------------------------------------------------------------===// -+ -+bool mpmdAttributeIsAOriginAttr(MlirAttribute attr) { -+ return mlir::isa(unwrap(attr)); -+} -+ -+MlirAttribute mpmdOriginAttrGet(MlirContext ctx, MlirStringRef originLabel) { -+ return wrap(mpmd::OriginAttr::get(unwrap(ctx), unwrap(originLabel))); -+} -+ -+MlirStringRef mpmdOriginAttrGetOriginLabel(MlirAttribute attr) { -+ return wrap(unwrapAttr(attr).getOriginLabel()); -+} -+ -+} // extern "C" -diff --git a/shardy/integrations/c/attributes_mpmd.h b/shardy/integrations/c/attributes_mpmd.h -new file mode 100644 -index 0000000..0d638fd ---- /dev/null -+++ b/shardy/integrations/c/attributes_mpmd.h -@@ -0,0 +1,84 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -+#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -+ -+#include -+#include -+ -+#include "mlir-c/IR.h" -+#include "mlir-c/Support.h" -+ -+#ifdef __cplusplus -+extern "C" { -+#endif -+ -+//===----------------------------------------------------------------------===// -+// NamedMeshAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED bool mpmdAttributeIsANamedMeshAttr(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdNamedMeshAttrGet(MlirContext ctx, -+ MlirStringRef name, -+ MlirAttribute mesh); -+ -+MLIR_CAPI_EXPORTED MlirStringRef mpmdNamedMeshAttrGetName(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdNamedMeshAttrGetMesh(MlirAttribute attr); -+ -+//===----------------------------------------------------------------------===// -+// TopologyAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED bool mpmdAttributeIsATopologyAttr(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdTopologyAttrGet(MlirContext ctx, -+ intptr_t nMeshes, -+ const MlirAttribute* meshes); -+ -+MLIR_CAPI_EXPORTED int64_t mpmdTopologyAttrGetMeshesSize(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdTopologyAttrGetMeshesElem(MlirAttribute attr, -+ int64_t pos); -+ -+//===----------------------------------------------------------------------===// -+// UserOriginAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED bool mpmdAttributeIsAUserOriginAttr(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdUserOriginAttrGet(MlirContext ctx, MlirAttribute userName, int64_t transposeCount); -+ -+MLIR_CAPI_EXPORTED MlirStringRef mpmdUserOriginAttrGetUserName(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED int64_t mpmdUserOriginAttrGetTransposeCount(MlirAttribute attr); -+ -+//===----------------------------------------------------------------------===// -+// OriginAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED bool mpmdAttributeIsAOriginAttr(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute mpmdOriginAttrGet(MlirContext ctx, MlirStringRef originLabel); -+ -+MLIR_CAPI_EXPORTED MlirStringRef mpmdOriginAttrGetOriginLabel(MlirAttribute attr); -+ -+#ifdef __cplusplus -+} -+#endif -+ -+#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -diff --git a/shardy/integrations/c/attributes.cc b/shardy/integrations/c/attributes_sdy.cc -similarity index 99% -rename from shardy/integrations/c/attributes.cc -rename to shardy/integrations/c/attributes_sdy.cc -index b683d09..417ed66 100644 ---- a/shardy/integrations/c/attributes.cc -+++ b/shardy/integrations/c/attributes_sdy.cc -@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#include "shardy/integrations/c/attributes.h" -+#include "shardy/integrations/c/attributes_sdy.h" - - #include - #include -diff --git a/shardy/integrations/c/attributes.h b/shardy/integrations/c/attributes_sdy.h -similarity index 98% -rename from shardy/integrations/c/attributes.h -rename to shardy/integrations/c/attributes_sdy.h -index b6e77c9..d2c5c72 100644 ---- a/shardy/integrations/c/attributes.h -+++ b/shardy/integrations/c/attributes_sdy.h -@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ --#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ -+#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ -+#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ - - #include - #include -@@ -276,4 +276,4 @@ MLIR_CAPI_EXPORTED MlirStringRef sdyManualAxesAttrGetAxesElem( - } - #endif - --#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ -+#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ -diff --git a/shardy/integrations/c/dialect_mpmd.cc b/shardy/integrations/c/dialect_mpmd.cc -new file mode 100644 -index 0000000..d311822 ---- /dev/null -+++ b/shardy/integrations/c/dialect_mpmd.cc -@@ -0,0 +1,21 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "shardy/integrations/c/dialect_mpmd.h" // IWYU pragma: keep -+ -+#include "mlir/CAPI/Registration.h" -+#include "shardy/dialect/mpmd/ir/dialect.h" -+ -+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Mpmd, mpmd, mlir::mpmd::MpmdDialect); -diff --git a/shardy/integrations/c/dialect_mpmd.h b/shardy/integrations/c/dialect_mpmd.h -new file mode 100644 -index 0000000..6d699bb ---- /dev/null -+++ b/shardy/integrations/c/dialect_mpmd.h -@@ -0,0 +1,31 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -+#define SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -+ -+#include "mlir-c/IR.h" -+ -+#ifdef __cplusplus -+extern "C" { -+#endif -+ -+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Mpmd, mpmd); -+ -+#ifdef __cplusplus -+} -+#endif -+ -+#endif // SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -diff --git a/shardy/integrations/c/dialect.cc b/shardy/integrations/c/dialect_sdy.cc -similarity index 92% -rename from shardy/integrations/c/dialect.cc -rename to shardy/integrations/c/dialect_sdy.cc -index 1408631..5e3dfe3 100644 ---- a/shardy/integrations/c/dialect.cc -+++ b/shardy/integrations/c/dialect_sdy.cc -@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#include "shardy/integrations/c/dialect.h" // IWYU pragma: keep -+#include "shardy/integrations/c/dialect_sdy.h" // IWYU pragma: keep - - #include "mlir/CAPI/Registration.h" - #include "shardy/dialect/sdy/ir/dialect.h" -diff --git a/shardy/integrations/c/dialect.h b/shardy/integrations/c/dialect_sdy.h -similarity index 100% -rename from shardy/integrations/c/dialect.h -rename to shardy/integrations/c/dialect_sdy.h -diff --git a/shardy/integrations/c/passes_mpmd.cc b/shardy/integrations/c/passes_mpmd.cc -new file mode 100644 -index 0000000..e843a16 ---- /dev/null -+++ b/shardy/integrations/c/passes_mpmd.cc -@@ -0,0 +1,22 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "shardy/integrations/c/passes_mpmd.h" -+ -+#include "shardy/dialect/mpmd/transforms/passes.h" -+ -+void mlirRegisterAllMpmdPassesAndPipelines() { -+ mlir::mpmd::registerAllMpmdPassesAndPipelines(); -+} -diff --git a/shardy/integrations/c/passes_mpmd.h b/shardy/integrations/c/passes_mpmd.h -new file mode 100644 -index 0000000..2125752 ---- /dev/null -+++ b/shardy/integrations/c/passes_mpmd.h -@@ -0,0 +1,33 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -+#define SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -+ -+#include "mlir-c/Support.h" -+ -+#ifdef __cplusplus -+extern "C" { -+#endif -+ -+/// Register all compiler passes and pipelines of Shardy. -+MLIR_CAPI_EXPORTED void mlirRegisterAllMpmdPassesAndPipelines(); -+ -+#ifdef __cplusplus -+} -+#endif -+ -+ -+#endif // SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -diff --git a/shardy/integrations/c/passes.cc b/shardy/integrations/c/passes_sdy.cc -similarity index 94% -rename from shardy/integrations/c/passes.cc -rename to shardy/integrations/c/passes_sdy.cc -index 063c1cf..f10b199 100644 ---- a/shardy/integrations/c/passes.cc -+++ b/shardy/integrations/c/passes_sdy.cc -@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#include "shardy/integrations/c/passes.h" -+#include "shardy/integrations/c/passes_sdy.h" - - #include "shardy/dialect/sdy/transforms/passes.h" - -diff --git a/shardy/integrations/c/passes.h b/shardy/integrations/c/passes_sdy.h -similarity index 86% -rename from shardy/integrations/c/passes.h -rename to shardy/integrations/c/passes_sdy.h -index 6863333..6ffd052 100644 ---- a/shardy/integrations/c/passes.h -+++ b/shardy/integrations/c/passes_sdy.h -@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and - limitations under the License. - ==============================================================================*/ - --#ifndef SHARDY_INTEGRATIONS_C_PASSES_H_ --#define SHARDY_INTEGRATIONS_C_PASSES_H_ -+#ifndef SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ -+#define SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ - - #include "mlir-c/Support.h" - -@@ -30,4 +30,4 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllSdyPassesAndPipelines(); - #endif - - --#endif // SHARDY_INTEGRATIONS_C_PASSES_H_ -+#endif // SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ -diff --git a/shardy/integrations/python/ir/CMakeLists.txt b/shardy/integrations/python/ir/CMakeLists.txt -index cbb4d66..1e8d7bd 100644 ---- a/shardy/integrations/python/ir/CMakeLists.txt -+++ b/shardy/integrations/python/ir/CMakeLists.txt -@@ -28,3 +28,32 @@ declare_mlir_python_extension(SdyPythonExtensions.Main - SdyCAPI - LLVMSupport - ) -+ -+declare_mlir_python_sources(MpmdPythonSources) -+declare_mlir_python_sources(MpmdPythonSources.Dialects -+ ADD_TO_PARENT MpmdPythonSources -+) -+ -+declare_mlir_dialect_python_bindings( -+ ADD_TO_PARENT MpmdPythonSources.Dialects -+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" -+ TD_FILE dialects/mpmd_ops.td -+ GEN_ENUM_BINDINGS ON -+ GEN_ENUM_BINDINGS_TD_FILE dialects/mpmd_enums.td -+ SOURCES dialects/mpmd.py -+ DIALECT_NAME mpmd -+) -+ -+declare_mlir_python_sources(MpmdPythonExtensions) -+declare_mlir_python_extension(MpmdPythonExtensions.Main -+ MODULE_NAME _mpmd -+ ADD_TO_PARENT MpmdPythonExtensions -+ PYTHON_BINDINGS_LIBRARY nanobind -+ SOURCES -+ mpmd_module.cc -+ EMBED_CAPI_LINK_LIBS -+ MpmdCAPI -+ PRIVATE_LINK_LIBS -+ MpmdCAPI -+ LLVMSupport -+) -diff --git a/shardy/integrations/python/ir/__init__.py b/shardy/integrations/python/ir/__init__.py -index 7373840..ce9d4ca 100644 ---- a/shardy/integrations/python/ir/__init__.py -+++ b/shardy/integrations/python/ir/__init__.py -@@ -12,7 +12,7 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - # ============================================================================== --"""Python bindings for the SDY dialect.""" -+"""Python bindings for the SDY and MPMD dialect.""" - - # pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias - from ._sdy import ( -@@ -36,3 +36,27 @@ from ._sdy_ops_gen import ( - ReturnOp as ReturnOp, - ShardingConstraintOp as ShardingConstraintOp, - ) -+ -+# pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias -+from ._mpmd import ( -+ register_dialect as register_dialect, -+ NamedMeshAttr as NamedMeshAttr, -+ TopologyAttr as TopologyAttr, -+) -+ -+from ._mpmd_enums_gen import ReductionType as ReductionType -+ -+from ._mpmd_ops_gen import ( -+ ReturnOp as ReturnOp, -+ NamedComputationOp as NamedComputationOp, -+ NamedTensorOp as NamedTensorOp, -+ FragmentOp as FragmentOp, -+ FragmentCallOp as FragmentCallOp, -+ TransferOp as TransferOp, -+ AssignOp as AssignOp, -+ UnassignOp as UnassignOp, -+ CallOp as CallOp, -+ ForOp as ForOp, -+ BroadcastOp as BroadcastOp, -+ ReduceOp as ReduceOp, -+) -diff --git a/shardy/integrations/python/ir/dialects/mpmd.py b/shardy/integrations/python/ir/dialects/mpmd.py -new file mode 100644 -index 0000000..ab22061 ---- /dev/null -+++ b/shardy/integrations/python/ir/dialects/mpmd.py -@@ -0,0 +1,20 @@ -+# Copyright 2025 The Shardy Authors. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+# ============================================================================== -+"""Python bindings for the MPMD dialect.""" -+ -+# pylint: disable=wildcard-import -+from .._mlir_libs._mpmd import * -+from ._mpmd_enum_gen import * -+from ._mpmd_ops_gen import * -diff --git a/shardy/integrations/python/ir/dialects/mpmd_enums.td b/shardy/integrations/python/ir/dialects/mpmd_enums.td -new file mode 100644 -index 0000000..0dfad9b ---- /dev/null -+++ b/shardy/integrations/python/ir/dialects/mpmd_enums.td -@@ -0,0 +1,21 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+#define SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+ -+include "shardy/dialect/mpmd/ir/enums.td" -+ -+#endif // SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -diff --git a/shardy/integrations/python/ir/dialects/mpmd_ops.td b/shardy/integrations/python/ir/dialects/mpmd_ops.td -new file mode 100644 -index 0000000..86d6dcf ---- /dev/null -+++ b/shardy/integrations/python/ir/dialects/mpmd_ops.td -@@ -0,0 +1,21 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -+#define SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -+ -+include "shardy/dialect/mpmd/ir/ops.td" -+ -+#endif -diff --git a/shardy/integrations/python/ir/mpmd.py b/shardy/integrations/python/ir/mpmd.py -new file mode 100644 -index 0000000..2524fe2 ---- /dev/null -+++ b/shardy/integrations/python/ir/mpmd.py -@@ -0,0 +1,20 @@ -+# Copyright 2025 The Shardy Authors. -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+# ============================================================================== -+"""Python bindings for the MPMD dialect.""" -+ -+# pylint: disable=wildcard-import -+from .._mlir_libs._mpmd import * -+from ._mpmd_enums_gen import * -+from ._mpmd_ops_gen import * -diff --git a/shardy/integrations/python/ir/mpmd_enums.td b/shardy/integrations/python/ir/mpmd_enums.td -new file mode 100644 -index 0000000..0dfad9b ---- /dev/null -+++ b/shardy/integrations/python/ir/mpmd_enums.td -@@ -0,0 +1,21 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+#define SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+ -+include "shardy/dialect/mpmd/ir/enums.td" -+ -+#endif // SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -diff --git a/shardy/integrations/python/ir/mpmd_module.cc b/shardy/integrations/python/ir/mpmd_module.cc -new file mode 100644 -index 0000000..9bcab92 ---- /dev/null -+++ b/shardy/integrations/python/ir/mpmd_module.cc -@@ -0,0 +1,165 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include -+#include -+#include -+#include -+#include -+ -+#include "mlir-c/BuiltinAttributes.h" -+#include "mlir-c/IR.h" -+#include "mlir-c/Support.h" -+#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep -+#include "nanobind/nanobind.h" -+#include "nanobind/stl/optional.h" // IWYU pragma: keep -+#include "nanobind/stl/string.h" // IWYU pragma: keep -+#include "nanobind/stl/variant.h" // IWYU pragma: keep -+#include "nanobind/stl/vector.h" // IWYU pragma: keep -+#include "shardy/integrations/c/attributes_mpmd.h" -+#include "shardy/integrations/c/dialect_mpmd.h" -+ -+namespace mlir { -+namespace mpmd { -+ -+namespace { -+ -+namespace nb = nanobind; -+ -+// Returns a vector containing elements with type T extracted from an attribute -+// using the two provided callbacks. -+template -+std::vector propertyVector( -+ MlirAttribute attr, llvm::function_ref sizeFn, -+ llvm::function_ref getFn) { -+ std::vector result; -+ intptr_t size = sizeFn(attr); -+ result.reserve(size); -+ for (intptr_t i = 0; i < size; ++i) { -+ result.push_back(getFn(attr, i)); -+ } -+ return result; -+} -+ -+nb::str toPyString(MlirStringRef mlirStringRef) { -+ return nb::str(mlirStringRef.data, mlirStringRef.length); -+} -+ -+MlirStringRef toStringRef(const std::string& s) { -+ return mlirStringRefCreate(s.c_str(), s.size()); -+} -+ -+NB_MODULE(_mpmd, m) { -+ m.doc() = "MPMD main Python extension"; -+ -+ // -+ // Dialects. -+ // -+ -+ m.def( -+ "register_dialect", -+ [](MlirContext context, bool load) { -+ MlirDialectHandle dialect = mlirGetDialectHandle__mpmd__(); -+ mlirDialectHandleRegisterDialect(dialect, context); -+ if (load) { -+ mlirDialectHandleLoadDialect(dialect, context); -+ } -+ }, -+ nb::arg("context"), nb::arg("load") = true); -+ -+ // -+ // Attributes. -+ // -+ -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "NamedMeshAttr", mpmdAttributeIsANamedMeshAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, const std::string& name, -+ MlirAttribute meshAttr, MlirContext ctx) { -+ return cls(mpmdNamedMeshAttrGet(ctx, toStringRef(name), meshAttr)); -+ }, -+ nb::arg("cls"), nb::arg("name"), -+ nb::arg("mesh").none() = nb::none(), -+ nb::arg("context").none() = nb::none(), -+ "Creates an NamedMeshAttr with the given name and MeshAttr.") -+ .def_property_readonly("name", -+ [](MlirAttribute self) { -+ return toPyString(mpmdNamedMeshAttrGetName(self)); -+ }) -+ .def_property_readonly("mesh", [](MlirAttribute self) { -+ return mpmdNamedMeshAttrGetMesh(self); -+ }); -+ -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "TopologyAttr", mpmdAttributeIsATopologyAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, const std::vector& meshes, -+ MlirContext ctx) { -+ return cls(mpmdTopologyAttrGet(ctx, meshes.size(), meshes.data())); -+ }, -+ nb::arg("cls"), nb::arg("meshes"), -+ nb::arg("context").none() = nb::none(), -+ "Creates a TopologyAttr with the given meshes.") -+ .def_property_readonly("meshes", -+ [](MlirAttribute self) { -+ return propertyVector( -+ self, mpmdTopologyAttrGetMeshesSize, -+ mpmdTopologyAttrGetMeshesElem); -+ }) -+ .def_property_readonly("size", [](MlirAttribute self) { -+ return mpmdTopologyAttrGetMeshesSize(self); -+ }); -+ -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "UserOriginAttr", mpmdAttributeIsAUserOriginAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, MlirAttribute& userName, int64_t transposeCount, -+ MlirContext ctx) { -+ return cls(mpmdUserOriginAttrGet(ctx, userName, transposeCount)); -+ }, -+ nb::arg("cls"), nb::arg("user_name"), -+ nb::arg("transpose_count") = 0, -+ nb::arg("context").none() = nb::none(), -+ "Creates a UserOriginAttr with the given user name and transpose count.") -+ .def_property_readonly("user_name", -+ [](MlirAttribute self) { -+ return toPyString(mpmdUserOriginAttrGetUserName(self)); -+ }) -+ .def_property_readonly("transpose_count", [](MlirAttribute self) { -+ return mpmdUserOriginAttrGetTransposeCount(self); -+ }); -+ -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "OriginAttr", mpmdAttributeIsAOriginAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, const std::string& originLabel, MlirContext ctx) { -+ return cls(mpmdOriginAttrGet(ctx, toStringRef(originLabel))); -+ }, -+ nb::arg("cls"), nb::arg("origin_label"), -+ nb::arg("context").none() = nb::none(), -+ "Creates an OriginAttr with the given origin label.") -+ .def_property_readonly("origin_label", -+ [](MlirAttribute self) { -+ return toPyString(mpmdOriginAttrGetOriginLabel(self)); -+ }); -+} -+ -+} // namespace -+} // namespace mpmd -+} // namespace mlir -diff --git a/shardy/integrations/python/ir/mpmd_ops.td b/shardy/integrations/python/ir/mpmd_ops.td -new file mode 100644 -index 0000000..86d6dcf ---- /dev/null -+++ b/shardy/integrations/python/ir/mpmd_ops.td -@@ -0,0 +1,21 @@ -+/* Copyright 2025 The Shardy Authors. -+ -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ -+ http://www.apache.org/licenses/LICENSE-2.0 -+ -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -+#define SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -+ -+include "shardy/dialect/mpmd/ir/ops.td" -+ -+#endif -diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc -index da451fa..318e14f 100644 ---- a/shardy/integrations/python/ir/sdy_module.cc -+++ b/shardy/integrations/python/ir/sdy_module.cc -@@ -28,8 +28,8 @@ limitations under the License. - #include "nanobind/stl/string.h" // IWYU pragma: keep - #include "nanobind/stl/variant.h" // IWYU pragma: keep - #include "nanobind/stl/vector.h" // IWYU pragma: keep --#include "shardy/integrations/c/attributes.h" --#include "shardy/integrations/c/dialect.h" -+#include "shardy/integrations/c/attributes_sdy.h" -+#include "shardy/integrations/c/dialect_sdy.h" - - namespace mlir { - namespace sdy { -diff --git a/shardy_mpmd_pybinds.patch b/shardy_mpmd_pybinds.patch -new file mode 100644 -index 0000000..58ce9e1 ---- /dev/null -+++ b/shardy_mpmd_pybinds.patch -@@ -0,0 +1,1908 @@ -+From 248452a6152e710a212611f3be9db01f7d3c927b Mon Sep 17 00:00:00 2001 -+From: tenstorrent -+Date: Sat, 16 Aug 2025 19:37:33 +0000 -+Subject: [PATCH] changes -+ -+--- -+ CMakeLists.txt | 2 + -+ shardy/dialect/mpmd/ir/CMakeLists.txt | 86 +++++++++ -+ shardy/dialect/mpmd/ir/dialect.cc | 19 +- -+ shardy/dialect/mpmd/transforms/CMakeLists.txt | 31 ++++ -+ .../mpmd/transforms/common/CMakeLists.txt | 87 +++++++++ -+ .../mpmd/transforms/common/call_rewrites.cc | 6 +- -+ .../mpmd/transforms/common/merge_fragments.cc | 4 +- -+ .../uniquify_function_inputs_outputs.cc | 6 +- -+ .../mpmd/transforms/export/CMakeLists.txt | 75 ++++++++ -+ .../mpmd/transforms/import/CMakeLists.txt | 124 +++++++++++++ -+ .../import/infer_mesh_assignment.cc | 85 +++++---- -+ .../mpmd/transforms/optimize/CMakeLists.txt | 69 ++++++++ -+ .../sharding_propagation/CMakeLists.txt | 41 +++++ -+ shardy/integrations/c/CMakeLists.txt | 39 ++++- -+ shardy/integrations/c/attributes_mpmd.cc | 126 +++++++++++++ -+ shardy/integrations/c/attributes_mpmd.h | 84 +++++++++ -+ .../c/{attributes.cc => attributes_sdy.cc} | 2 +- -+ .../c/{attributes.h => attributes_sdy.h} | 6 +- -+ shardy/integrations/c/dialect_mpmd.cc | 21 +++ -+ shardy/integrations/c/dialect_mpmd.h | 31 ++++ -+ .../c/{dialect.cc => dialect_sdy.cc} | 2 +- -+ .../c/{dialect.h => dialect_sdy.h} | 0 -+ shardy/integrations/c/passes_mpmd.cc | 22 +++ -+ shardy/integrations/c/passes_mpmd.h | 33 ++++ -+ .../c/{passes.cc => passes_sdy.cc} | 2 +- -+ .../integrations/c/{passes.h => passes_sdy.h} | 6 +- -+ shardy/integrations/python/ir/CMakeLists.txt | 29 +++ -+ shardy/integrations/python/ir/__init__.py | 26 ++- -+ .../integrations/python/ir/dialects/mpmd.py | 20 +++ -+ .../python/ir/dialects/mpmd_enums.td | 21 +++ -+ .../python/ir/dialects/mpmd_ops.td | 21 +++ -+ shardy/integrations/python/ir/mpmd.py | 20 +++ -+ shardy/integrations/python/ir/mpmd_enums.td | 21 +++ -+ shardy/integrations/python/ir/mpmd_module.cc | 165 ++++++++++++++++++ -+ shardy/integrations/python/ir/mpmd_ops.td | 21 +++ -+ shardy/integrations/python/ir/sdy_module.cc | 4 +- -+ 36 files changed, 1279 insertions(+), 78 deletions(-) -+ create mode 100644 shardy/dialect/mpmd/ir/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/common/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/export/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/import/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt -+ create mode 100644 shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt -+ create mode 100644 shardy/integrations/c/attributes_mpmd.cc -+ create mode 100644 shardy/integrations/c/attributes_mpmd.h -+ rename shardy/integrations/c/{attributes.cc => attributes_sdy.cc} (99%) -+ rename shardy/integrations/c/{attributes.h => attributes_sdy.h} (98%) -+ create mode 100644 shardy/integrations/c/dialect_mpmd.cc -+ create mode 100644 shardy/integrations/c/dialect_mpmd.h -+ rename shardy/integrations/c/{dialect.cc => dialect_sdy.cc} (92%) -+ rename shardy/integrations/c/{dialect.h => dialect_sdy.h} (100%) -+ create mode 100644 shardy/integrations/c/passes_mpmd.cc -+ create mode 100644 shardy/integrations/c/passes_mpmd.h -+ rename shardy/integrations/c/{passes.cc => passes_sdy.cc} (94%) -+ rename shardy/integrations/c/{passes.h => passes_sdy.h} (86%) -+ create mode 100644 shardy/integrations/python/ir/dialects/mpmd.py -+ create mode 100644 shardy/integrations/python/ir/dialects/mpmd_enums.td -+ create mode 100644 shardy/integrations/python/ir/dialects/mpmd_ops.td -+ create mode 100644 shardy/integrations/python/ir/mpmd.py -+ create mode 100644 shardy/integrations/python/ir/mpmd_enums.td -+ create mode 100644 shardy/integrations/python/ir/mpmd_module.cc -+ create mode 100644 shardy/integrations/python/ir/mpmd_ops.td -+ -+diff --git a/CMakeLists.txt b/CMakeLists.txt -+index b73a2bd..c5cb085 100755 -+--- a/CMakeLists.txt -++++ b/CMakeLists.txt -+@@ -52,5 +52,7 @@ add_compile_options(-Wno-deprecated-declarations -Wno-unused-but-set-variable -W -+ add_subdirectory(shardy/common) -+ add_subdirectory(shardy/dialect/sdy/ir) -+ add_subdirectory(shardy/dialect/sdy/transforms) -++add_subdirectory(shardy/dialect/mpmd/ir) -++add_subdirectory(shardy/dialect/mpmd/transforms) -+ add_subdirectory(shardy/integrations/python/ir) -+ add_subdirectory(shardy/integrations/c) -+\ No newline at end of file -+diff --git a/shardy/dialect/mpmd/ir/CMakeLists.txt b/shardy/dialect/mpmd/ir/CMakeLists.txt -+new file mode 100644 -+index 0000000..83eccc3 -+--- /dev/null -++++ b/shardy/dialect/mpmd/ir/CMakeLists.txt -+@@ -0,0 +1,86 @@ -++# Shardy MLIR MPMD dialect. -++ -++set(LLVM_TARGET_DEFINITIONS dialect.td) -++mlir_tablegen(dialect.h.inc -gen-dialect-decls -dialect=mpmd) -++mlir_tablegen(dialect.cc.inc -gen-dialect-defs -dialect=mpmd) -++add_public_tablegen_target(MpmdDialectIncGen) -++add_dependencies(mlir-headers MpmdDialectIncGen) -++add_mlir_doc(dialect MpmdDialect src/autogen/md/Dialect/ -gen-dialect-doc) -++ -++set(LLVM_TARGET_DEFINITIONS canonicalization.td) -++mlir_tablegen(canonicalization.cc.inc -gen-rewriters) -++add_public_tablegen_target(MpmdCanonicalizationIncGen) -++add_dependencies(mlir-headers MpmdCanonicalizationIncGen) -++ -++set(LLVM_TARGET_DEFINITIONS ops.td) -++mlir_tablegen(ops.h.inc -gen-op-decls) -++mlir_tablegen(ops.cc.inc -gen-op-defs) -++add_public_tablegen_target(MpmdOpsIncGen) -++add_dependencies(mlir-headers MpmdOpsIncGen) -++ -++set(LLVM_TARGET_DEFINITIONS types.td) -++mlir_tablegen(types.h.inc -gen-typedef-decls) -++mlir_tablegen(types.cc.inc -gen-typedef-defs) -++add_public_tablegen_target(MpmdTypesIncGen) -++add_dependencies(mlir-headers MpmdTypesIncGen) -++ -++set(LLVM_TARGET_DEFINITIONS attrs.td) -++mlir_tablegen(attrs.h.inc -gen-attrdef-decls) -++mlir_tablegen(attrs.cc.inc -gen-attrdef-defs) -++add_public_tablegen_target(MpmdAttrsIncGen) -++add_dependencies(mlir-headers MpmdAttrsIncGen) -++ -++set(LLVM_TARGET_DEFINITIONS enums.td) -++mlir_tablegen(enums.h.inc -gen-enum-decls) -++mlir_tablegen(enums.cc.inc -gen-enum-defs) -++add_public_tablegen_target(MpmdEnumsIncGen) -++add_dependencies(mlir-headers MpmdEnumsIncGen) -++ -++add_mlir_dialect_library(MpmdDialect -++ dialect.cc -++ utils.cc -++ -++ DEPENDS -++ MpmdDialectIncGen -++ MpmdOpsIncGen -++ MpmdAttrsIncGen -++ MpmdEnumsIncGen -++ MpmdTypesIncGen -++ MpmdCanonicalizationIncGen -++ -++ LINK_LIBS PUBLIC -++ LLVMSupport -++ MLIRBytecodeOpInterface -++ MLIRFuncDialect -++ MLIRIR -++ MLIRInferTypeOpInterface -++ MLIRTransformUtils -++ MLIRShapeDialect -++ MLIRSideEffectInterfaces -++ MLIRSupport -++ StablehloAssemblyFormat -++ StablehloBase -++ StablehloOps -++ StablehloTypeInference -++) -++ -++target_include_directories(MpmdDialect INTERFACE -++ $ -++ $ -++) -++ -++add_mlir_dialect_library(MpmdRegister -++ register.cc -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MLIRFuncDialect -++ MLIRFuncAllExtensions -++ MLIRIR -++ StablehloOps -++) -++ -++target_include_directories(MpmdRegister INTERFACE -++ $ -++ $ -++) -+diff --git a/shardy/dialect/mpmd/ir/dialect.cc b/shardy/dialect/mpmd/ir/dialect.cc -+index af013d6..f125e3c 100644 -+--- a/shardy/dialect/mpmd/ir/dialect.cc -++++ b/shardy/dialect/mpmd/ir/dialect.cc -+@@ -878,9 +878,9 @@ FragmentOp CreateMeshFragmentWithBody( -+ // Only user defined fragments can be assigned to a stage and any fragment -+ // created by the compiler is considered to be an inferred fragment. -+ // Therefore, the created fragment isn't assigned to a stage. -+- FragmentOp fragment_op = FragmentOp::create(builder, loc, result_types, -+- tensors, origin_attr, mesh_name, -+- /*stage_id=*/IntegerAttr()); -++ FragmentOp fragment_op = builder.create( -++ loc, result_types, tensors, origin_attr, mesh_name, -++ /*stage_id=*/IntegerAttr()); -+ Block& fragment_block = fragment_op.getRegion().emplaceBlock(); -+ sdy::MeshAttr mesh_attr = GetMeshOrFail(fragment_op, mesh_name); -+ -+@@ -892,8 +892,7 @@ FragmentOp CreateMeshFragmentWithBody( -+ fragment_block.args_end()); -+ -+ OpBuilder block_builder = OpBuilder::atBlockBegin(&fragment_block); -+- ReturnOp::create(block_builder, loc, -+- body_populator(arguments, block_builder)); -++ block_builder.create(loc, body_populator(arguments, block_builder)); -+ return fragment_op; -+ } -+ } // namespace -+@@ -1345,9 +1344,9 @@ ForOp ForOp::create(Location loc, ValueRange tensors, uint32_t iterations, -+ OpBuilder& builder, ForOpBodyPopulator body_populator, -+ uint32_t unroll_factor) { -+ TypeRange result_types = tensors.getTypes(); -+- auto op = ForOp::create( -+- builder, loc, result_types, tensors, iterations, -+- unroll_factor == 1 ? nullptr : builder.getUI32IntegerAttr(unroll_factor)); -++ auto op = builder.create( -++ loc, result_types, tensors, iterations, -++ unroll_factor == 1 ? nullptr : builder.getUI32IntegerAttr(unroll_factor)); -+ -+ Block& block = op.getRegion().emplaceBlock(); -+ for (Value operand : tensors) { -+@@ -1360,8 +1359,8 @@ ForOp ForOp::create(Location loc, ValueRange tensors, uint32_t iterations, -+ ArrayRef args(block.args_begin(), block.args_end()); -+ -+ OpBuilder block_builder = OpBuilder::atBlockBegin(&block); -+- ReturnOp::create( -+- block_builder, loc, -++ block_builder.create( -++ loc, -+ body_populator(args.drop_back(), /*index=*/args.back(), block_builder)); -+ return op; -+ } -+diff --git a/shardy/dialect/mpmd/transforms/CMakeLists.txt b/shardy/dialect/mpmd/transforms/CMakeLists.txt -+new file mode 100644 -+index 0000000..0728af8 -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/CMakeLists.txt -+@@ -0,0 +1,31 @@ -++# Shardy MLIR MPMD Transforms Passes -++ -++add_subdirectory(common) -++add_subdirectory(export) -++add_subdirectory(import) -++add_subdirectory(optimize) -++add_subdirectory(sharding_propagation) -++ -++add_mlir_library(MpmdTransformsPasses -++ passes.cc -++ -++ DEPENDS -++ MpmdTransformsCommonPasses -++ MpmdTransformsExportPasses -++ MpmdTransformsImportPasses -++ MpmdTransformsOptimizePasses -++ MpmdTransformsShardingPropagationPasses -++ -++ LINK_LIBS PUBLIC -++ MLIRPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsExportPasses -++ MpmdTransformsImportPasses -++ MpmdTransformsOptimizePasses -++ MpmdTransformsShardingPropagationPasses -++) -++ -++target_include_directories(MpmdTransformsPasses INTERFACE -++ $ -++ $ -++) -+diff --git a/shardy/dialect/mpmd/transforms/common/CMakeLists.txt b/shardy/dialect/mpmd/transforms/common/CMakeLists.txt -+new file mode 100644 -+index 0000000..e26d2aa -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/common/CMakeLists.txt -+@@ -0,0 +1,87 @@ -++# Shardy MLIR MPMD Transforms Common -++ -++set(LLVM_TARGET_DEFINITIONS passes.td) -++mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdCommon) -++add_public_tablegen_target(MpmdTransformsCommonPassesIncGen) -++add_dependencies(mlir-headers MpmdTransformsCommonPassesIncGen) -++ -++add_mlir_library(MpmdTransformsCommonUtils -++ utils.cc -++ -++ DEPENDS -++ MpmdDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ LLVMSupport -++ MLIRFuncDialect -++ MLIRIR -++ MLIRTransformUtils -++ MLIRSupport -++ MLIRPass -++ MLIRTransforms -++) -++ -++add_mlir_library(MpmdTransformsCommonSimplifyRegionOpBase -++ simplify_region_op_base.cc -++ -++ DEPENDS -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsCommonUtils -++ LLVMSupport -++ MLIRIR -++ MLIRSideEffectInterfaces -++ MLIRSupport -++ MLIRTransformUtils -++) -++ -++add_mlir_library(MpmdTransformsCommonDistributedFunctionPass -++ distributed_function_pass.cc -++ -++ DEPENDS -++ MpmdDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MLIRFuncDialect -++ MLIRPass -++) -++ -++add_mlir_library(MpmdTransformsCommonPasses -++ absorb_inferred_fragments.cc -++ call_rewrites.cc -++ copy_constants.cc -++ fragment_dce.cc -++ fragment_dedup.cc -++ merge_fragments.cc -++ merge_transfers.cc -++ remove_transfer_cycles.cc -++ rule_based_merge.cc -++ split_bwd_fragments.cc -++ uniquify_function_inputs_outputs.cc -++ unroll_for_loops.cc -++ -++ DEPENDS -++ MpmdTransformsCommonPassesIncGen -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ LLVMSupport -++ StablehloOps -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonUtils -++ MLIRAnalysis -++ MLIRDataLayoutInterfaces -++ MLIRFuncDialect -++ MLIRIR -++ MLIRTransformUtils -++ MLIRLoopLikeInterface -++ MLIRPass -++ MLIRRewrite -++ MLIRSideEffectInterfaces -++ MLIRSupport -++ MLIRTransforms -++) -+diff --git a/shardy/dialect/mpmd/transforms/common/call_rewrites.cc b/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -+index 67022ca..d1da0a6 100644 -+--- a/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -++++ b/shardy/dialect/mpmd/transforms/common/call_rewrites.cc -+@@ -277,9 +277,9 @@ class EraseUnusedCalleeBlockArgumentsPass -+ // Alas, we cannot directly erase results of an op, so we need to create -+ // a new call op, and use it to replace the old one. -+ rewriter.setInsertionPoint(call_op); -+- auto new_call_op = -+- CallOp::create(rewriter, call_op.getLoc(), result_types, -+- call_op->getOperands(), call_op.getCalleeAttr()); -++ auto new_call_op = rewriter.create(call_op.getLoc(), result_types, -++ call_op->getOperands(), -++ call_op.getCalleeAttr()); -+ new_call_op->setDiscardableAttrs(call_op->getDiscardableAttrDictionary()); -+ for (auto [new_result, old_result_index] : -+ llvm::zip_equal(new_call_op.getResults(), kept_results.set_bits())) { -+diff --git a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -+index 36ed7bb..098c579 100644 -+--- a/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -++++ b/shardy/dialect/mpmd/transforms/common/merge_fragments.cc -+@@ -87,7 +87,7 @@ bool TypeHasOneElement(Type type) { -+ -+ // Returns true if `op` is an inter-mesh TransferOp whose global type has only -+ // one element. -+-bool IsNonScalarInterMeshTransfer(Operation* op) { -++[[ maybe_unused ]] bool IsNonScalarInterMeshTransfer(Operation* op) { -+ TransferOp transfer_op = DynCastInterMeshTransfer(op); -+ return transfer_op && -+ !TypeHasOneElement(transfer_op.getType().getGlobalTensorType()); -+@@ -334,7 +334,7 @@ FailureOr MergeFragmentBasePass::MergeFragmentsRewrite( -+ producer_op.getMeshNameAttr(), -+ /*stage_id=*/GetMergedStageIdAttribute(producer_op, mergeable_user)); -+ -+- for (const auto [attr_name, attr] : merged_attributes) { -++ for (const auto &[attr_name, attr] : merged_attributes) { -+ merged_fragment->setAttr(attr_name, attr); -+ } -+ -+diff --git a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -+index 7aaa133..f76f2e1 100644 -+--- a/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -++++ b/shardy/dialect/mpmd/transforms/common/uniquify_function_inputs_outputs.cc -+@@ -71,8 +71,8 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, -+ } -+ -+ auto loc = return_op->getLoc(); -+- auto fragment_op = FragmentOp::create( -+- builder, loc, fragment_return_types, fragment_operands, -++ auto fragment_op = builder.create( -++ loc, fragment_return_types, fragment_operands, -+ /*user_origin=*/ArrayAttr::get(builder.getContext(), {}), -+ /*mesh_name=*/mesh_name, /*stage_id=*/IntegerAttr()); -+ Block& fragment_block = fragment_op.getRegion().emplaceBlock(); -+@@ -98,7 +98,7 @@ void CreateReturnFragmentForMesh(StringRef mesh_name, Operation* return_op, -+ } -+ } -+ auto block_builder = OpBuilder::atBlockEnd(&fragment_block); -+- ReturnOp::create(block_builder, loc, returned_values); -++ block_builder.create(loc, returned_values); -+ } -+ -+ // Replaces the return values of the function with transfer ops. -+ -+diff --git a/shardy/dialect/mpmd/transforms/export/CMakeLists.txt b/shardy/dialect/mpmd/transforms/export/CMakeLists.txt -+new file mode 100644 -+index 0000000..a42701e -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/export/CMakeLists.txt -+@@ -0,0 +1,75 @@ -++# Shardy MLIR MPMD Transform Export Passes and Pipeline -++ -++set(LLVM_TARGET_DEFINITIONS passes.td) -++mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdExport) -++add_public_tablegen_target(MpmdTransformsExportPassesIncGen) -++add_dependencies(mlir-headers MpmdTransformsExportPassesIncGen) -++ -++add_mlir_library(MpmdTransformsExportUtils -++ utils.cc -++ -++ DEPENDS -++ MpmdDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MLIRAnalysis -++ MLIRFuncDialect -++ MLIRIR -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsExportNamingUtils -++ naming_utils.cc -++ -++ DEPENDS -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ MLIRFuncDialect -++ MLIRPass -++ LLVMSupport -++ MLIRIR -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsExportPasses -++ export_pipeline.cc -++ lower_to_fragment_calls.cc -++ mark_aliasing_and_donation.cc -++ mark_fragment_reserved_memory.cc -++ mark_input_output_with_layouts.cc -++ mark_offloaded_input_output.cc -++ reschedule_ops.cc -++ -++ DEPENDS -++ MpmdDialect -++ MpmdTransformsExportPassesIncGen -++ MpmdTransformsExportNamingUtils -++ MpmdTransformsExportUtils -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ SdyDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsExportNamingUtils -++ MpmdTransformsExportUtils -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ SdyDialect -++ LLVMSupport -++ MLIRAnalysis -++ MLIRFuncDialect -++ MLIRIR -++ MLIRPass -++ MLIRSupport -++ MLIRTransformUtils -++ MLIRTransforms -++ StablehloOps -++) -+diff --git a/shardy/dialect/mpmd/transforms/import/CMakeLists.txt b/shardy/dialect/mpmd/transforms/import/CMakeLists.txt -+new file mode 100644 -+index 0000000..a45baed -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/import/CMakeLists.txt -+@@ -0,0 +1,124 @@ -++# Shardy MLIR MPMD Transforms Import Passes and Pipeline -++ -++set(LLVM_TARGET_DEFINITIONS passes.td) -++mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdImport) -++add_public_tablegen_target(MpmdTransformsImportPassesIncGen) -++add_dependencies(mlir-headers MpmdTransformsImportPassesIncGen) -++ -++add_mlir_library(MpmdTransformsImportMeshAssignmentMap -++ mesh_assignment_map.cc -++ -++ LINK_LIBS PUBLIC -++ LLVMSupport -++) -++ -++add_mlir_library(MpmdTransformsImportMeshInferenceOrigins -++ mesh_inference_origins.cc -++ -++ DEPENDS -++ MpmdDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ LLVMSupport -++ MLIRIR -++ MLIRPass -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsImportMeshesWithOrigins -++ meshes_with_origins.cc -++ -++ DEPENDS -++ MpmdDialect -++ MpmdTransformsImportMeshInferenceOrigins -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MpmdTransformsImportMeshInferenceOrigins -++ LLVMSupport -++ MLIRIR -++ MLIRSupport -++) -++ -++ -++add_mlir_library(MpmdTransformsImportMeshInferenceUtils -++ mesh_inference_utils.cc -++ -++ DEPENDS -++ MpmdTransformsImportMeshesWithOrigins -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ SdyDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsImportMeshesWithOrigins -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ SdyDialect -++ LLVMSupport -++ MLIRFuncDialect -++ MLIRIR -++ MLIRPass -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsImportShardingConstraints -++ sharding_constraints.cc -++ -++ LINK_LIBS PUBLIC -++ LLVMSupport -++) -++ -++add_mlir_library(MpmdTransformsImportPasses -++ copy_topology_from_main.cc -++ enforce_equisharding.cc -++ import_pipeline.cc -++ infer_mesh_assignment.cc -++ infer_mesh_validation.cc -++ insert_nameless_clones_of_negligible_ops.cc -++ introduce_transfers.cc -++ map_input_output_to_mesh.cc -++ map_named_ops_to_mpmd_ops.cc -++ simplify_named_computations.cc -++ validate_named_ops_in_mpmd_func.cc -++ -++ DEPENDS -++ MpmdTransformsImportMeshAssignmentMap -++ MpmdTransformsImportMeshInferenceOrigins -++ MpmdTransformsImportMeshInferenceUtils -++ MpmdTransformsImportMeshesWithOrigins -++ MpmdTransformsImportPassesIncGen -++ MpmdTransformsImportShardingConstraints -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ MpmdTransformsCommonSimplifyRegionOpBase -++ SdyDialect -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsImportMeshAssignmentMap -++ MpmdTransformsImportMeshInferenceOrigins -++ MpmdTransformsImportMeshInferenceUtils -++ MpmdTransformsImportMeshesWithOrigins -++ MpmdTransformsImportShardingConstraints -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ MpmdTransformsCommonSimplifyRegionOpBase -++ SdyDialect -++ LLVMSupport -++ MLIRFuncDialect -++ MLIRIR -++ MLIRPass -++ MLIRRewrite -++ MLIRSideEffectInterfaces -++ MLIRSupport -++ MLIRTransforms -++ MLIRTransformUtils -++ StablehloOps -++ StablehloPasses -++ StablehloOptimizationPasses -++) -+\ No newline at end of file -+diff --git a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -+index 45f6523..9ef9270 100644 -+--- a/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -++++ b/shardy/dialect/mpmd/transforms/import/infer_mesh_assignment.cc -+@@ -462,8 +462,8 @@ class LowerMpmdReducePattern final : public OpRewritePattern { -+ if (reduced_val.getType() == user_type) { -+ transferred_intermediates.push_back(reduced_val); -+ } else { -+- transferred_intermediates.push_back(TransferOp::create( -+- rewriter, reduced_val.getLoc(), user_type, reduced_val)); -++ transferred_intermediates.push_back(rewriter.create( -++ reduced_val.getLoc(), user_type, reduced_val)); -+ } -+ } -+ -+@@ -1096,8 +1096,7 @@ void AssignInputAndOutputToMesh(FuncOp func, BlockArgument input_arg, -+ // Assign the output to the mesh. -+ if (!isa(return_operand.get().getType())) { -+ rewriter.setInsertionPoint(return_operand.getOwner()); -+- return_operand.set(AssignOp::create( -+- rewriter, -++ return_operand.set(rewriter.create( -+ GetResultInfoLoc(func, return_operand.getOperandNumber()) -+ .value_or(return_operand.get().getLoc()), -+ return_operand.get(), mesh_name, mesh_attr, kIoConstraintOutputOrigin)); -+@@ -1109,8 +1108,8 @@ void AssignInputAndOutputToMesh(FuncOp func, BlockArgument input_arg, -+ input_arg.setType(MeshTensorType::getFullyReplicated( -+ input_arg.getContext(), mesh_name, mesh_attr, -+ cast(input_arg.getType()))); -+- auto unassign = UnassignOp::create(rewriter, input_arg.getLoc(), input_arg, -+- kIoConstraintInputOrigin); -++ auto unassign = rewriter.create(input_arg.getLoc(), input_arg, -++ kIoConstraintInputOrigin); -+ rewriter.replaceAllUsesExcept(input_arg, unassign, unassign); -+ } -+ } -+@@ -1293,8 +1292,7 @@ class InferMeshAssignMeshForFuncLeavesPass -+ } -+ mesh_name = first_mesh_name; -+ } -+- return_op_operand.set(AssignOp::create( -+- builder, -++ return_op_operand.set(builder.create( -+ GetResultInfoLoc(func, return_op_operand.getOperandNumber()) -+ .value_or(return_operand.getLoc()), -+ return_operand, *mesh_name, GetMeshByName(meshes_by_name, *mesh_name), -+@@ -1405,8 +1403,8 @@ class InferMeshAssignMeshForFuncLeavesPass -+ rewriter.setInsertionPointAfter(op); -+ sdy::MeshAttr mesh = GetMeshByName(meshes_by_name, mesh_name); -+ for (Value res : op->getResults()) { -+- AssignOp::create(rewriter, op->getLoc(), res, mesh_name, mesh, -+- kInferredUnusedOrigin); -++ rewriter.create(op->getLoc(), res, mesh_name, mesh, -++ kInferredUnusedOrigin); -+ } -+ -+ ClearUseSet(op); -+@@ -1491,10 +1489,10 @@ class InferMeshAssignMeshForFuncLeavesPass -+ preferred_meshes.GetPrioritizedMeshName().value_or(first_mesh_name); -+ } -+ Value operand_val = operand.get(); -+- AssignOp assign = AssignOp::create( -+- builder, operand_val.getLoc(), operand_val, *mesh_name, -++ AssignOp assign = builder.create( -++ operand_val.getLoc(), operand_val, *mesh_name, -+ GetMeshByName(meshes_by_name, *mesh_name), TerminalNodesOrigin(op)); -+- operand.set(UnassignOp::create(builder, operand_val.getLoc(), assign)); -++ operand.set(builder.create(operand_val.getLoc(), assign)); -+ } -+ } -+ -+@@ -1537,8 +1535,8 @@ void ConvertConcatReduceOp(Operation* op, RewriterBase& rewriter) { -+ SmallVector reshaped_operands; -+ reshaped_operands.reserve(concat.getOperands().size()); -+ for (Value operand : concat.getOperands()) { -+- auto reshape = stablehlo::ReshapeOp::create( -+- rewriter, operand.getLoc(), reduce->getResultTypes().front(), operand); -++ auto reshape = rewriter.create( -++ operand.getLoc(), reduce->getResultTypes().front(), operand); -+ if (operand.getDefiningOp()) { -+ reshape->setDiscardableAttrs( -+ operand.getDefiningOp()->getDiscardableAttrDictionary()); -+@@ -1811,8 +1809,8 @@ void AssignCalleeFuncResultsUsingAnalysis( -+ // meshes, we copy it such that each result corresponds to a single mesh. -+ for (auto [i, mesh_name] : llvm::enumerate(mesh_names.getArrayRef())) { -+ auto assign = -+- AssignOp::create(rewriter, return_val.getLoc(), return_val, mesh_name, -+- GetMeshByName(meshes_by_name, mesh_name)); -++ rewriter.create(return_val.getLoc(), return_val, mesh_name, -++ GetMeshByName(meshes_by_name, mesh_name)); -+ if (i == 0) { -+ new_operands[res_num] = assign; -+ } else { -+@@ -1891,10 +1889,10 @@ void AssignCalleeFuncArgsToAssignUsers( -+ UnassignOp unassign_op; -+ if (i == 0) { -+ arg.setType(mesh_type); -+- unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -++ unassign_op = rewriter.create(arg.getLoc(), arg); -+ } else { -+- unassign_op = UnassignOp::create( -+- rewriter, arg.getLoc(), body.addArgument(mesh_type, arg.getLoc())); -++ unassign_op = rewriter.create( -++ arg.getLoc(), body.addArgument(mesh_type, arg.getLoc())); -+ } -+ -+ if (auto users_it = assign_users_by_mesh_name.find(mesh_name); -+@@ -1928,25 +1926,24 @@ void RewriteAccordingToUpdatedCallee(CallOp call_op, RewriterBase& rewriter) { -+ continue; -+ } -+ SDY_CHECK(isa(call_body.getArgument(arg_num).getType())); -+- new_operands[arg_num] = -+- AssignOp::create(rewriter, operand.getLoc(), -+- call_body.getArgument(arg_num).getType(), operand); -++ new_operands[arg_num] = rewriter.create( -++ operand.getLoc(), call_body.getArgument(arg_num).getType(), operand); -+ -+ if (auto copies = -+ callee.getArgAttrOfType(arg_num, kMpmdCopied)) { -+ for (int64_t cloned_arg_index : copies.asArrayRef()) { -+ SDY_CHECK(isa( -+ call_body.getArgument(cloned_arg_index).getType())); -+- new_operands[cloned_arg_index] = AssignOp::create( -+- rewriter, operand.getLoc(), -+- call_body.getArgument(cloned_arg_index).getType(), operand); -++ new_operands[cloned_arg_index] = rewriter.create( -++ operand.getLoc(), call_body.getArgument(cloned_arg_index).getType(), -++ operand); -+ } -+ } -+ } -+ -+ // Create the new call and copy attrs over. -+- auto new_call_op = CallOp::create( -+- rewriter, call_op.getLoc(), call_body.getTerminator()->getOperandTypes(), -++ auto new_call_op = rewriter.create( -++ call_op.getLoc(), call_body.getTerminator()->getOperandTypes(), -+ new_operands, call_op.getCalleeAttr()); -+ new_call_op->setDiscardableAttrs(call_op->getDiscardableAttrDictionary()); -+ -+@@ -1972,9 +1969,8 @@ void RewriteAccordingToUpdatedCallee(CallOp call_op, RewriterBase& rewriter) { -+ SDY_CHECK(arg_num_it != type_to_arg_num.end()) -+ << "Argument number for type " << debugString(assign_user.getType()) -+ << " not found"; -+- assign_user.setOperand( -+- UnassignOp::create(rewriter, assign_user.getLoc(), -+- new_call_op.getResult(arg_num_it->second))); -++ assign_user.setOperand(rewriter.create( -++ assign_user.getLoc(), new_call_op.getResult(arg_num_it->second))); -+ } -+ } -+ } -+@@ -2061,7 +2057,7 @@ bool AssignEntrypointFuncArgsToAssignUsers(FuncOp entrypoint_func, -+ cast(arg.getType()), -+ memory_kind)); -+ -+- UnassignOp unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -++ UnassignOp unassign_op = rewriter.create(arg.getLoc(), arg); -+ rewriter.replaceAllUsesExcept(arg, unassign_op, unassign_op); -+ } -+ return true; -+@@ -2187,12 +2183,12 @@ void AbsorbMeshlessProducer(FragmentOp consumer, Operation* op, -+ } -+ rewriter.setInsertionPoint(consumer); -+ for (Value operand : op_operands_and_free_vars) { -+- new_consumer_operands.push_back( -+- AssignOp::create(rewriter, operand.getLoc(), -+- MeshTensorType::getFullyReplicated( -+- operand.getContext(), mesh_name, mesh_attr, -+- cast(operand.getType())), -+- operand)); -++ new_consumer_operands.push_back(rewriter.create( -++ operand.getLoc(), -++ MeshTensorType::getFullyReplicated( -++ operand.getContext(), mesh_name, mesh_attr, -++ cast(operand.getType())), -++ operand)); -+ } -+ consumer->setOperands(new_consumer_operands); -+ } -+@@ -2316,8 +2312,8 @@ void RewriteForOpTerminator( -+ SDY_CHECK_LE(mesh_names.size(), 1) -+ << "Multiple mesh names found for return value"; -+ -+- new_operands.push_back(AssignOp::create( -+- rewriter, return_val.getLoc(), return_val, mesh_names[0], -++ new_operands.push_back(rewriter.create( -++ return_val.getLoc(), return_val, mesh_names[0], -+ GetMeshByName(meshes_by_name, mesh_names[0]))); -+ } -+ -+@@ -2382,7 +2378,7 @@ void RewriteForOpArgsAndTypes( -+ arg.getContext(), mesh_names[0], -+ GetMeshByName(meshes_by_name, mesh_names[0]), local_type); -+ arg.setType(mesh_type); -+- UnassignOp unassign_op = UnassignOp::create(rewriter, arg.getLoc(), arg); -++ UnassignOp unassign_op = rewriter.create(arg.getLoc(), arg); -+ -+ if (auto users_it = assign_users_by_mesh_name.find(mesh_names[0]); -+ users_it != assign_users_by_mesh_name.end()) { -+@@ -2414,9 +2410,8 @@ void RewriteForOpOperands(ForOp for_op, RewriterBase& rewriter) { -+ new_operands[arg_num] = operand; -+ continue; -+ } -+- new_operands[arg_num] = -+- AssignOp::create(rewriter, operand.getLoc(), -+- for_body.getArgument(arg_num).getType(), operand); -++ new_operands[arg_num] = rewriter.create( -++ operand.getLoc(), for_body.getArgument(arg_num).getType(), operand); -+ } -+ -+ for_op->setOperands(new_operands); -+@@ -2430,7 +2425,7 @@ void RewriteForOpResults(ForOp for_op, RewriterBase& rewriter) { -+ for (Operation* user : res.getUsers()) { -+ if (auto assign_user = dyn_cast(user)) { -+ assign_user.setOperand( -+- UnassignOp::create(rewriter, assign_user.getLoc(), res)); -++ rewriter.create(assign_user.getLoc(), res)); -+ } -+ } -+ } -+diff --git a/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt b/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt -+new file mode 100644 -+index 0000000..ecf3d7c -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/optimize/CMakeLists.txt -+@@ -0,0 +1,69 @@ -++# Shardy MLIR MPMD Transforms Optimize -++ -++set(LLVM_TARGET_DEFINITIONS passes.td) -++mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdOptimize) -++add_public_tablegen_target(MpmdTransformsOptimizePassesIncGen) -++add_dependencies(mlir-headers MpmdTransformsOptimizePassesIncGen) -++ -++add_mlir_library(MpmdTransformsOptimizeUtils -++ utils.cc -++ -++ DEPENDS -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ LLVMSupport -++ MLIRIR -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsOptimizePipelineSchedule -++ pipeline_schedule.cc -++ -++ DEPENDS -++ MpmdTransformsOptimizeUtils -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsOptimizeUtils -++ MpmdDialect -++ MpmdTransformsCommonUtils -++ LLVMSupport -++ MLIRIR -++ MLIRSupport -++) -++ -++add_mlir_library(MpmdTransformsOptimizePasses -++ optimize_pipeline.cc -++ remat_fragment.cc -++ scheduler.cc -++ -++ DEPENDS -++ MpmdTransformsOptimizePassesIncGen -++ MpmdTransformsOptimizePipelineSchedule -++ MpmdTransformsOptimizeUtils -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ -++ LINK_LIBS PUBLIC -++ MpmdTransformsOptimizePipelineSchedule -++ MpmdTransformsOptimizeUtils -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ LLVMSupport -++ MLIRAnalysis -++ MLIRFuncDialect -++ MLIRIR -++ MLIRPass -++ MLIRSupport -++ MLIRTransforms -++ MLIRTransformUtils -++) -+diff --git a/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt b/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt -+new file mode 100644 -+index 0000000..a8a8b05 -+--- /dev/null -++++ b/shardy/dialect/mpmd/transforms/sharding_propagation/CMakeLists.txt -+@@ -0,0 +1,41 @@ -++# Shardy MLIR MPMD Transforms Sharding Propagation -++ -++set(LLVM_TARGET_DEFINITIONS passes.td) -++mlir_tablegen(passes.h.inc -gen-pass-decls -name=MpmdShardingPropagation) -++add_public_tablegen_target(MpmdTransformsShardingPropagationPassesIncGen) -++add_dependencies(mlir-headers MpmdTransformsShardingPropagationPassesIncGen) -++ -++add_mlir_library(MpmdTransformsShardingPropagationPasses -++ convert_sdy_constants.cc -++ convert_sdy_shardings_to_mpmd_types.cc -++ enforce_user_shardings.cc -++ extract_reshards_from_inter_mesh_transfers.cc -++ sharding_propagation_pipeline.cc -++ -++ DEPENDS -++ MpmdTransformsShardingPropagationPassesIncGen -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ SdyDialect -++ SdyExplicitReshardsUtil -++ SdyTransformsPropagationPasses -++ -++ LINK_LIBS PUBLIC -++ MpmdDialect -++ MpmdTransformsCommonDistributedFunctionPass -++ MpmdTransformsCommonPasses -++ MpmdTransformsCommonUtils -++ SdyDialect -++ SdyExplicitReshardsUtil -++ SdyTransformsPropagationPasses -++ LLVMSupport -++ MLIRFuncDialect -++ MLIRIR -++ MLIRPass -++ MLIRRewrite -++ MLIRSupport -++ MLIRTransformUtils -++ StablehloOps -++) -+diff --git a/shardy/integrations/c/CMakeLists.txt b/shardy/integrations/c/CMakeLists.txt -+index fdd50c4..1c3a624 100644 -+--- a/shardy/integrations/c/CMakeLists.txt -++++ b/shardy/integrations/c/CMakeLists.txt -+@@ -1,8 +1,39 @@ -+ add_mlir_public_c_api_library(SdyCAPI -+ PARTIAL_SOURCES_INTENDED -+- attributes.cc -+- dialect.cc -+- passes.cc -++ attributes_sdy.cc -++ dialect_sdy.cc -++ passes_sdy.cc -++ -++ DEPENDS -++ SdyDialect -++ SdyTransformsPasses -++ -++ LINK_LIBS PUBLIC -++ LLVMSupport -++ MLIRBytecodeOpInterface -++ MLIRFuncDialect -++ MLIRIR -++ MLIRInferTypeOpInterface -++ MLIRTransformUtils -++ MLIRShapeDialect -++ MLIRSideEffectInterfaces -++ MLIRSupport -++ StablehloAssemblyFormat -++ StablehloOps -++ StablehloTypeInference -++ SdyDialect -++ SdyTransformsPasses -++) -++ -++add_mlir_public_c_api_library(MpmdCAPI -++ PARTIAL_SOURCES_INTENDED -++ passes_mpmd.cc -++ dialect_mpmd.cc -++ attributes_mpmd.cc -++ -++ DEPENDS -++ MpmdDialect -++ MpmdTransformsPasses -+ -+ LINK_LIBS PUBLIC -+ LLVMSupport -+@@ -17,4 +48,6 @@ add_mlir_public_c_api_library(SdyCAPI -+ StablehloAssemblyFormat -+ StablehloOps -+ StablehloTypeInference -++ MpmdDialect -++ MpmdTransformsPasses -+ ) -+diff --git a/shardy/integrations/c/attributes_mpmd.cc b/shardy/integrations/c/attributes_mpmd.cc -+new file mode 100644 -+index 0000000..355107f -+--- /dev/null -++++ b/shardy/integrations/c/attributes_mpmd.cc -+@@ -0,0 +1,126 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#include "shardy/integrations/c/attributes_mpmd.h" -++ -++#include -++#include -++ -++#include "mlir-c/IR.h" -++#include "mlir-c/Support.h" -++#include "mlir/CAPI/IR.h" -++#include "mlir/CAPI/Support.h" -++#include "mlir/IR/Attributes.h" -++#include "mlir/Support/LLVM.h" -++#include "shardy/dialect/mpmd/ir/dialect.h" -++ -++namespace { -++ -++namespace mpmd = ::mlir::mpmd; -++ -++template -++AttrTy unwrapAttr(MlirAttribute attr) { -++ return mlir::cast(unwrap(attr)); -++} -++ -++template -++mlir::ArrayRef unwrapAttrs(const MlirAttribute* attrs, -++ intptr_t nAttrs) { -++ return mlir::ArrayRef(reinterpret_cast(attrs), nAttrs); -++} -++ -++} // namespace -++ -++extern "C" { -++ -++//===----------------------------------------------------------------------===// -++// NamedMeshAttr -++//===----------------------------------------------------------------------===// -++ -++bool mpmdAttributeIsANamedMeshAttr(MlirAttribute attr) { -++ return mlir::isa(unwrap(attr)); -++} -++ -++MlirAttribute mpmdNamedMeshAttrGet(MlirContext ctx, MlirStringRef name, MlirAttribute mesh) { -++ return wrap(mpmd::NamedMeshAttr::get(unwrap(ctx), unwrap(name), unwrapAttr(mesh))); -++} -++ -++MlirStringRef mpmdNamedMeshAttrGetName(MlirAttribute attr) { -++ return wrap(unwrapAttr(attr).getName()); -++} -++ -++MlirAttribute mpmdNamedMeshAttrGetMesh(MlirAttribute attr) { -++ mlir::sdy::MeshAttr mesh = unwrapAttr(attr).getMesh(); -++ return wrap(mesh); -++} -++ -++//===----------------------------------------------------------------------===// -++// TopologyAttr -++//===----------------------------------------------------------------------===// -++ -++bool mpmdAttributeIsATopologyAttr(MlirAttribute attr) { -++ return mlir::isa(unwrap(attr)); -++} -++ -++MlirAttribute mpmdTopologyAttrGet(MlirContext ctx, intptr_t nMeshes, const MlirAttribute* meshes) { -++ return wrap(mpmd::TopologyAttr::get( -++ unwrap(ctx), unwrapAttrs(meshes, nMeshes))); -++} -++ -++int64_t mpmdTopologyAttrGetMeshesSize(MlirAttribute attr) { -++ return unwrapAttr(attr).getMeshes().size(); -++} -++ -++MlirAttribute mpmdTopologyAttrGetMeshesElem(MlirAttribute attr, int64_t pos) { -++ return wrap(unwrapAttr(attr).getMeshes()[pos]); -++} -++ -++//===----------------------------------------------------------------------===// -++// UserOriginAttr -++//===----------------------------------------------------------------------===// -++ -++bool mpmdAttributeIsAUserOriginAttr(MlirAttribute attr) { -++ return mlir::isa(unwrap(attr)); -++} -++ -++MlirAttribute mpmdUserOriginAttrGet(MlirContext ctx, MlirAttribute userName, int64_t transposeCount) { -++ return wrap(mpmd::UserOriginAttr::get(unwrap(ctx), unwrapAttr(userName), transposeCount)); -++} -++ -++MlirStringRef mpmdUserOriginAttrGetUserName(MlirAttribute attr) { -++ return wrap(unwrapAttr(attr).getUserName().getValue()); -++} -++ -++int64_t mpmdUserOriginAttrGetTransposeCount(MlirAttribute attr) { -++ return unwrapAttr(attr).getTransposeCount(); -++} -++ -++//===----------------------------------------------------------------------===// -++// OriginAttr -++//===----------------------------------------------------------------------===// -++ -++bool mpmdAttributeIsAOriginAttr(MlirAttribute attr) { -++ return mlir::isa(unwrap(attr)); -++} -++ -++MlirAttribute mpmdOriginAttrGet(MlirContext ctx, MlirStringRef originLabel) { -++ return wrap(mpmd::OriginAttr::get(unwrap(ctx), unwrap(originLabel))); -++} -++ -++MlirStringRef mpmdOriginAttrGetOriginLabel(MlirAttribute attr) { -++ return wrap(unwrapAttr(attr).getOriginLabel()); -++} -++ -++} // extern "C" -+diff --git a/shardy/integrations/c/attributes_mpmd.h b/shardy/integrations/c/attributes_mpmd.h -+new file mode 100644 -+index 0000000..0d638fd -+--- /dev/null -++++ b/shardy/integrations/c/attributes_mpmd.h -+@@ -0,0 +1,84 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -++#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -++ -++#include -++#include -++ -++#include "mlir-c/IR.h" -++#include "mlir-c/Support.h" -++ -++#ifdef __cplusplus -++extern "C" { -++#endif -++ -++//===----------------------------------------------------------------------===// -++// NamedMeshAttr -++//===----------------------------------------------------------------------===// -++ -++MLIR_CAPI_EXPORTED bool mpmdAttributeIsANamedMeshAttr(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdNamedMeshAttrGet(MlirContext ctx, -++ MlirStringRef name, -++ MlirAttribute mesh); -++ -++MLIR_CAPI_EXPORTED MlirStringRef mpmdNamedMeshAttrGetName(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdNamedMeshAttrGetMesh(MlirAttribute attr); -++ -++//===----------------------------------------------------------------------===// -++// TopologyAttr -++//===----------------------------------------------------------------------===// -++ -++MLIR_CAPI_EXPORTED bool mpmdAttributeIsATopologyAttr(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdTopologyAttrGet(MlirContext ctx, -++ intptr_t nMeshes, -++ const MlirAttribute* meshes); -++ -++MLIR_CAPI_EXPORTED int64_t mpmdTopologyAttrGetMeshesSize(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdTopologyAttrGetMeshesElem(MlirAttribute attr, -++ int64_t pos); -++ -++//===----------------------------------------------------------------------===// -++// UserOriginAttr -++//===----------------------------------------------------------------------===// -++ -++MLIR_CAPI_EXPORTED bool mpmdAttributeIsAUserOriginAttr(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdUserOriginAttrGet(MlirContext ctx, MlirAttribute userName, int64_t transposeCount); -++ -++MLIR_CAPI_EXPORTED MlirStringRef mpmdUserOriginAttrGetUserName(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED int64_t mpmdUserOriginAttrGetTransposeCount(MlirAttribute attr); -++ -++//===----------------------------------------------------------------------===// -++// OriginAttr -++//===----------------------------------------------------------------------===// -++ -++MLIR_CAPI_EXPORTED bool mpmdAttributeIsAOriginAttr(MlirAttribute attr); -++ -++MLIR_CAPI_EXPORTED MlirAttribute mpmdOriginAttrGet(MlirContext ctx, MlirStringRef originLabel); -++ -++MLIR_CAPI_EXPORTED MlirStringRef mpmdOriginAttrGetOriginLabel(MlirAttribute attr); -++ -++#ifdef __cplusplus -++} -++#endif -++ -++#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_MPMD_H_ -+diff --git a/shardy/integrations/c/attributes.cc b/shardy/integrations/c/attributes_sdy.cc -+similarity index 99% -+rename from shardy/integrations/c/attributes.cc -+rename to shardy/integrations/c/attributes_sdy.cc -+index b683d09..417ed66 100644 -+--- a/shardy/integrations/c/attributes.cc -++++ b/shardy/integrations/c/attributes_sdy.cc -+@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and -+ limitations under the License. -+ ==============================================================================*/ -+ -+-#include "shardy/integrations/c/attributes.h" -++#include "shardy/integrations/c/attributes_sdy.h" -+ -+ #include -+ #include -+diff --git a/shardy/integrations/c/attributes.h b/shardy/integrations/c/attributes_sdy.h -+similarity index 98% -+rename from shardy/integrations/c/attributes.h -+rename to shardy/integrations/c/attributes_sdy.h -+index b6e77c9..d2c5c72 100644 -+--- a/shardy/integrations/c/attributes.h -++++ b/shardy/integrations/c/attributes_sdy.h -+@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and -+ limitations under the License. -+ ==============================================================================*/ -+ -+-#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ -+-#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ -++#ifndef SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ -++#define SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ -+ -+ #include -+ #include -+@@ -276,4 +276,4 @@ MLIR_CAPI_EXPORTED MlirStringRef sdyManualAxesAttrGetAxesElem( -+ } -+ #endif -+ -+-#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_H_ -++#endif // SHARDY_INTEGRATIONS_C_ATTRIBUTES_SDY_H_ -+diff --git a/shardy/integrations/c/dialect_mpmd.cc b/shardy/integrations/c/dialect_mpmd.cc -+new file mode 100644 -+index 0000000..d311822 -+--- /dev/null -++++ b/shardy/integrations/c/dialect_mpmd.cc -+@@ -0,0 +1,21 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#include "shardy/integrations/c/dialect_mpmd.h" // IWYU pragma: keep -++ -++#include "mlir/CAPI/Registration.h" -++#include "shardy/dialect/mpmd/ir/dialect.h" -++ -++MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Mpmd, mpmd, mlir::mpmd::MpmdDialect); -+diff --git a/shardy/integrations/c/dialect_mpmd.h b/shardy/integrations/c/dialect_mpmd.h -+new file mode 100644 -+index 0000000..6d699bb -+--- /dev/null -++++ b/shardy/integrations/c/dialect_mpmd.h -+@@ -0,0 +1,31 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -++#define SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -++ -++#include "mlir-c/IR.h" -++ -++#ifdef __cplusplus -++extern "C" { -++#endif -++ -++MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Mpmd, mpmd); -++ -++#ifdef __cplusplus -++} -++#endif -++ -++#endif // SHARDY_DIALECT_MPMD_IR_C_DIALECT_H_ -+diff --git a/shardy/integrations/c/dialect.cc b/shardy/integrations/c/dialect_sdy.cc -+similarity index 92% -+rename from shardy/integrations/c/dialect.cc -+rename to shardy/integrations/c/dialect_sdy.cc -+index 1408631..5e3dfe3 100644 -+--- a/shardy/integrations/c/dialect.cc -++++ b/shardy/integrations/c/dialect_sdy.cc -+@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and -+ limitations under the License. -+ ==============================================================================*/ -+ -+-#include "shardy/integrations/c/dialect.h" // IWYU pragma: keep -++#include "shardy/integrations/c/dialect_sdy.h" // IWYU pragma: keep -+ -+ #include "mlir/CAPI/Registration.h" -+ #include "shardy/dialect/sdy/ir/dialect.h" -+diff --git a/shardy/integrations/c/dialect.h b/shardy/integrations/c/dialect_sdy.h -+similarity index 100% -+rename from shardy/integrations/c/dialect.h -+rename to shardy/integrations/c/dialect_sdy.h -+diff --git a/shardy/integrations/c/passes_mpmd.cc b/shardy/integrations/c/passes_mpmd.cc -+new file mode 100644 -+index 0000000..e843a16 -+--- /dev/null -++++ b/shardy/integrations/c/passes_mpmd.cc -+@@ -0,0 +1,22 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#include "shardy/integrations/c/passes_mpmd.h" -++ -++#include "shardy/dialect/mpmd/transforms/passes.h" -++ -++void mlirRegisterAllMpmdPassesAndPipelines() { -++ mlir::mpmd::registerAllMpmdPassesAndPipelines(); -++} -+diff --git a/shardy/integrations/c/passes_mpmd.h b/shardy/integrations/c/passes_mpmd.h -+new file mode 100644 -+index 0000000..2125752 -+--- /dev/null -++++ b/shardy/integrations/c/passes_mpmd.h -+@@ -0,0 +1,33 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -++#define SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -++ -++#include "mlir-c/Support.h" -++ -++#ifdef __cplusplus -++extern "C" { -++#endif -++ -++/// Register all compiler passes and pipelines of Shardy. -++MLIR_CAPI_EXPORTED void mlirRegisterAllMpmdPassesAndPipelines(); -++ -++#ifdef __cplusplus -++} -++#endif -++ -++ -++#endif // SHARDY_INTEGRATIONS_C_PASSES_MPMD_H_ -+diff --git a/shardy/integrations/c/passes.cc b/shardy/integrations/c/passes_sdy.cc -+similarity index 94% -+rename from shardy/integrations/c/passes.cc -+rename to shardy/integrations/c/passes_sdy.cc -+index 063c1cf..f10b199 100644 -+--- a/shardy/integrations/c/passes.cc -++++ b/shardy/integrations/c/passes_sdy.cc -+@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and -+ limitations under the License. -+ ==============================================================================*/ -+ -+-#include "shardy/integrations/c/passes.h" -++#include "shardy/integrations/c/passes_sdy.h" -+ -+ #include "shardy/dialect/sdy/transforms/passes.h" -+ -+diff --git a/shardy/integrations/c/passes.h b/shardy/integrations/c/passes_sdy.h -+similarity index 86% -+rename from shardy/integrations/c/passes.h -+rename to shardy/integrations/c/passes_sdy.h -+index 6863333..6ffd052 100644 -+--- a/shardy/integrations/c/passes.h -++++ b/shardy/integrations/c/passes_sdy.h -+@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and -+ limitations under the License. -+ ==============================================================================*/ -+ -+-#ifndef SHARDY_INTEGRATIONS_C_PASSES_H_ -+-#define SHARDY_INTEGRATIONS_C_PASSES_H_ -++#ifndef SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ -++#define SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ -+ -+ #include "mlir-c/Support.h" -+ -+@@ -30,4 +30,4 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllSdyPassesAndPipelines(); -+ #endif -+ -+ -+-#endif // SHARDY_INTEGRATIONS_C_PASSES_H_ -++#endif // SHARDY_INTEGRATIONS_C_PASSES_SDY_H_ -+diff --git a/shardy/integrations/python/ir/CMakeLists.txt b/shardy/integrations/python/ir/CMakeLists.txt -+index cbb4d66..1e8d7bd 100644 -+--- a/shardy/integrations/python/ir/CMakeLists.txt -++++ b/shardy/integrations/python/ir/CMakeLists.txt -+@@ -28,3 +28,32 @@ declare_mlir_python_extension(SdyPythonExtensions.Main -+ SdyCAPI -+ LLVMSupport -+ ) -++ -++declare_mlir_python_sources(MpmdPythonSources) -++declare_mlir_python_sources(MpmdPythonSources.Dialects -++ ADD_TO_PARENT MpmdPythonSources -++) -++ -++declare_mlir_dialect_python_bindings( -++ ADD_TO_PARENT MpmdPythonSources.Dialects -++ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" -++ TD_FILE dialects/mpmd_ops.td -++ GEN_ENUM_BINDINGS ON -++ GEN_ENUM_BINDINGS_TD_FILE dialects/mpmd_enums.td -++ SOURCES dialects/mpmd.py -++ DIALECT_NAME mpmd -++) -++ -++declare_mlir_python_sources(MpmdPythonExtensions) -++declare_mlir_python_extension(MpmdPythonExtensions.Main -++ MODULE_NAME _mpmd -++ ADD_TO_PARENT MpmdPythonExtensions -++ PYTHON_BINDINGS_LIBRARY nanobind -++ SOURCES -++ mpmd_module.cc -++ EMBED_CAPI_LINK_LIBS -++ MpmdCAPI -++ PRIVATE_LINK_LIBS -++ MpmdCAPI -++ LLVMSupport -++) -+diff --git a/shardy/integrations/python/ir/__init__.py b/shardy/integrations/python/ir/__init__.py -+index 7373840..ce9d4ca 100644 -+--- a/shardy/integrations/python/ir/__init__.py -++++ b/shardy/integrations/python/ir/__init__.py -+@@ -12,7 +12,7 @@ -+ # See the License for the specific language governing permissions and -+ # limitations under the License. -+ # ============================================================================== -+-"""Python bindings for the SDY dialect.""" -++"""Python bindings for the SDY and MPMD dialect.""" -+ -+ # pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias -+ from ._sdy import ( -+@@ -36,3 +36,27 @@ from ._sdy_ops_gen import ( -+ ReturnOp as ReturnOp, -+ ShardingConstraintOp as ShardingConstraintOp, -+ ) -++ -++# pylint: disable=g-multiple-import,g-importing-member,unused-import,useless-import-alias -++from ._mpmd import ( -++ register_dialect as register_dialect, -++ NamedMeshAttr as NamedMeshAttr, -++ TopologyAttr as TopologyAttr, -++) -++ -++from ._mpmd_enums_gen import ReductionType as ReductionType -++ -++from ._mpmd_ops_gen import ( -++ ReturnOp as ReturnOp, -++ NamedComputationOp as NamedComputationOp, -++ NamedTensorOp as NamedTensorOp, -++ FragmentOp as FragmentOp, -++ FragmentCallOp as FragmentCallOp, -++ TransferOp as TransferOp, -++ AssignOp as AssignOp, -++ UnassignOp as UnassignOp, -++ CallOp as CallOp, -++ ForOp as ForOp, -++ BroadcastOp as BroadcastOp, -++ ReduceOp as ReduceOp, -++) -+diff --git a/shardy/integrations/python/ir/dialects/mpmd.py b/shardy/integrations/python/ir/dialects/mpmd.py -+new file mode 100644 -+index 0000000..ab22061 -+--- /dev/null -++++ b/shardy/integrations/python/ir/dialects/mpmd.py -+@@ -0,0 +1,20 @@ -++# Copyright 2025 The Shardy Authors. -++# -++# Licensed under the Apache License, Version 2.0 (the "License"); -++# you may not use this file except in compliance with the License. -++# You may obtain a copy of the License at -++# -++# http://www.apache.org/licenses/LICENSE-2.0 -++# -++# Unless required by applicable law or agreed to in writing, software -++# distributed under the License is distributed on an "AS IS" BASIS, -++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++# See the License for the specific language governing permissions and -++# limitations under the License. -++# ============================================================================== -++"""Python bindings for the MPMD dialect.""" -++ -++# pylint: disable=wildcard-import -++from .._mlir_libs._mpmd import * -++from ._mpmd_enum_gen import * -++from ._mpmd_ops_gen import * -+diff --git a/shardy/integrations/python/ir/dialects/mpmd_enums.td b/shardy/integrations/python/ir/dialects/mpmd_enums.td -+new file mode 100644 -+index 0000000..0dfad9b -+--- /dev/null -++++ b/shardy/integrations/python/ir/dialects/mpmd_enums.td -+@@ -0,0 +1,21 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -++#define SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -++ -++include "shardy/dialect/mpmd/ir/enums.td" -++ -++#endif // SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+diff --git a/shardy/integrations/python/ir/dialects/mpmd_ops.td b/shardy/integrations/python/ir/dialects/mpmd_ops.td -+new file mode 100644 -+index 0000000..86d6dcf -+--- /dev/null -++++ b/shardy/integrations/python/ir/dialects/mpmd_ops.td -+@@ -0,0 +1,21 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -++#define SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -++ -++include "shardy/dialect/mpmd/ir/ops.td" -++ -++#endif -+diff --git a/shardy/integrations/python/ir/mpmd.py b/shardy/integrations/python/ir/mpmd.py -+new file mode 100644 -+index 0000000..2524fe2 -+--- /dev/null -++++ b/shardy/integrations/python/ir/mpmd.py -+@@ -0,0 +1,20 @@ -++# Copyright 2025 The Shardy Authors. -++# -++# Licensed under the Apache License, Version 2.0 (the "License"); -++# you may not use this file except in compliance with the License. -++# You may obtain a copy of the License at -++# -++# http://www.apache.org/licenses/LICENSE-2.0 -++# -++# Unless required by applicable law or agreed to in writing, software -++# distributed under the License is distributed on an "AS IS" BASIS, -++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++# See the License for the specific language governing permissions and -++# limitations under the License. -++# ============================================================================== -++"""Python bindings for the MPMD dialect.""" -++ -++# pylint: disable=wildcard-import -++from .._mlir_libs._mpmd import * -++from ._mpmd_enums_gen import * -++from ._mpmd_ops_gen import * -+diff --git a/shardy/integrations/python/ir/mpmd_enums.td b/shardy/integrations/python/ir/mpmd_enums.td -+new file mode 100644 -+index 0000000..0dfad9b -+--- /dev/null -++++ b/shardy/integrations/python/ir/mpmd_enums.td -+@@ -0,0 +1,21 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -++#define SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -++ -++include "shardy/dialect/mpmd/ir/enums.td" -++ -++#endif // SHARDY_INTEGRATIONS_PYTHON_MPMD_ENUMS -+diff --git a/shardy/integrations/python/ir/mpmd_module.cc b/shardy/integrations/python/ir/mpmd_module.cc -+new file mode 100644 -+index 0000000..9bcab92 -+--- /dev/null -++++ b/shardy/integrations/python/ir/mpmd_module.cc -+@@ -0,0 +1,165 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#include -++#include -++#include -++#include -++#include -++ -++#include "mlir-c/BuiltinAttributes.h" -++#include "mlir-c/IR.h" -++#include "mlir-c/Support.h" -++#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep -++#include "nanobind/nanobind.h" -++#include "nanobind/stl/optional.h" // IWYU pragma: keep -++#include "nanobind/stl/string.h" // IWYU pragma: keep -++#include "nanobind/stl/variant.h" // IWYU pragma: keep -++#include "nanobind/stl/vector.h" // IWYU pragma: keep -++#include "shardy/integrations/c/attributes_mpmd.h" -++#include "shardy/integrations/c/dialect_mpmd.h" -++ -++namespace mlir { -++namespace mpmd { -++ -++namespace { -++ -++namespace nb = nanobind; -++ -++// Returns a vector containing elements with type T extracted from an attribute -++// using the two provided callbacks. -++template -++std::vector propertyVector( -++ MlirAttribute attr, llvm::function_ref sizeFn, -++ llvm::function_ref getFn) { -++ std::vector result; -++ intptr_t size = sizeFn(attr); -++ result.reserve(size); -++ for (intptr_t i = 0; i < size; ++i) { -++ result.push_back(getFn(attr, i)); -++ } -++ return result; -++} -++ -++nb::str toPyString(MlirStringRef mlirStringRef) { -++ return nb::str(mlirStringRef.data, mlirStringRef.length); -++} -++ -++MlirStringRef toStringRef(const std::string& s) { -++ return mlirStringRefCreate(s.c_str(), s.size()); -++} -++ -++NB_MODULE(_mpmd, m) { -++ m.doc() = "MPMD main Python extension"; -++ -++ // -++ // Dialects. -++ // -++ -++ m.def( -++ "register_dialect", -++ [](MlirContext context, bool load) { -++ MlirDialectHandle dialect = mlirGetDialectHandle__mpmd__(); -++ mlirDialectHandleRegisterDialect(dialect, context); -++ if (load) { -++ mlirDialectHandleLoadDialect(dialect, context); -++ } -++ }, -++ nb::arg("context"), nb::arg("load") = true); -++ -++ // -++ // Attributes. -++ // -++ -++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -++ m, "NamedMeshAttr", mpmdAttributeIsANamedMeshAttr) -++ .def_classmethod( -++ "get", -++ [](nb::object cls, const std::string& name, -++ MlirAttribute meshAttr, MlirContext ctx) { -++ return cls(mpmdNamedMeshAttrGet(ctx, toStringRef(name), meshAttr)); -++ }, -++ nb::arg("cls"), nb::arg("name"), -++ nb::arg("mesh").none() = nb::none(), -++ nb::arg("context").none() = nb::none(), -++ "Creates an NamedMeshAttr with the given name and MeshAttr.") -++ .def_property_readonly("name", -++ [](MlirAttribute self) { -++ return toPyString(mpmdNamedMeshAttrGetName(self)); -++ }) -++ .def_property_readonly("mesh", [](MlirAttribute self) { -++ return mpmdNamedMeshAttrGetMesh(self); -++ }); -++ -++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -++ m, "TopologyAttr", mpmdAttributeIsATopologyAttr) -++ .def_classmethod( -++ "get", -++ [](nb::object cls, const std::vector& meshes, -++ MlirContext ctx) { -++ return cls(mpmdTopologyAttrGet(ctx, meshes.size(), meshes.data())); -++ }, -++ nb::arg("cls"), nb::arg("meshes"), -++ nb::arg("context").none() = nb::none(), -++ "Creates a TopologyAttr with the given meshes.") -++ .def_property_readonly("meshes", -++ [](MlirAttribute self) { -++ return propertyVector( -++ self, mpmdTopologyAttrGetMeshesSize, -++ mpmdTopologyAttrGetMeshesElem); -++ }) -++ .def_property_readonly("size", [](MlirAttribute self) { -++ return mpmdTopologyAttrGetMeshesSize(self); -++ }); -++ -++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -++ m, "UserOriginAttr", mpmdAttributeIsAUserOriginAttr) -++ .def_classmethod( -++ "get", -++ [](nb::object cls, MlirAttribute& userName, int64_t transposeCount, -++ MlirContext ctx) { -++ return cls(mpmdUserOriginAttrGet(ctx, userName, transposeCount)); -++ }, -++ nb::arg("cls"), nb::arg("user_name"), -++ nb::arg("transpose_count") = 0, -++ nb::arg("context").none() = nb::none(), -++ "Creates a UserOriginAttr with the given user name and transpose count.") -++ .def_property_readonly("user_name", -++ [](MlirAttribute self) { -++ return toPyString(mpmdUserOriginAttrGetUserName(self)); -++ }) -++ .def_property_readonly("transpose_count", [](MlirAttribute self) { -++ return mpmdUserOriginAttrGetTransposeCount(self); -++ }); -++ -++ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -++ m, "OriginAttr", mpmdAttributeIsAOriginAttr) -++ .def_classmethod( -++ "get", -++ [](nb::object cls, const std::string& originLabel, MlirContext ctx) { -++ return cls(mpmdOriginAttrGet(ctx, toStringRef(originLabel))); -++ }, -++ nb::arg("cls"), nb::arg("origin_label"), -++ nb::arg("context").none() = nb::none(), -++ "Creates an OriginAttr with the given origin label.") -++ .def_property_readonly("origin_label", -++ [](MlirAttribute self) { -++ return toPyString(mpmdOriginAttrGetOriginLabel(self)); -++ }); -++} -++ -++} // namespace -++} // namespace mpmd -++} // namespace mlir -+diff --git a/shardy/integrations/python/ir/mpmd_ops.td b/shardy/integrations/python/ir/mpmd_ops.td -+new file mode 100644 -+index 0000000..86d6dcf -+--- /dev/null -++++ b/shardy/integrations/python/ir/mpmd_ops.td -+@@ -0,0 +1,21 @@ -++/* Copyright 2025 The Shardy Authors. -++ -++Licensed under the Apache License, Version 2.0 (the "License"); -++you may not use this file except in compliance with the License. -++You may obtain a copy of the License at -++ -++ http://www.apache.org/licenses/LICENSE-2.0 -++ -++Unless required by applicable law or agreed to in writing, software -++distributed under the License is distributed on an "AS IS" BASIS, -++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -++See the License for the specific language governing permissions and -++limitations under the License. -++==============================================================================*/ -++ -++#ifndef SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -++#define SHARDY_INTEGRATIONS_PYTHON_MPMD_OPS -++ -++include "shardy/dialect/mpmd/ir/ops.td" -++ -++#endif -+diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc -+index da451fa..318e14f 100644 -+--- a/shardy/integrations/python/ir/sdy_module.cc -++++ b/shardy/integrations/python/ir/sdy_module.cc -+@@ -28,8 +28,8 @@ limitations under the License. -+ #include "nanobind/stl/string.h" // IWYU pragma: keep -+ #include "nanobind/stl/variant.h" // IWYU pragma: keep -+ #include "nanobind/stl/vector.h" // IWYU pragma: keep -+-#include "shardy/integrations/c/attributes.h" -+-#include "shardy/integrations/c/dialect.h" -++#include "shardy/integrations/c/attributes_sdy.h" -++#include "shardy/integrations/c/dialect_sdy.h" -+ -+ namespace mlir { -+ namespace sdy { -+-- -+2.34.1 diff --git a/include/ttmlir/AffineMapUtils.h b/include/ttmlir/AffineMapUtils.h index f4accafccc2..639c69ee4ad 100644 --- a/include/ttmlir/AffineMapUtils.h +++ b/include/ttmlir/AffineMapUtils.h @@ -117,8 +117,8 @@ fullyApplyAffineMap(mlir::OpBuilder &builder, mlir::Location loc, mlir::AffineMap map, mlir::ValueRange inputs) { llvm::SmallVector results; for (unsigned i = 0; i < map.getNumResults(); i++) { - results.push_back(builder.create( - loc, affineMapSelectOneOutput(map, i), inputs)); + results.push_back(mlir::affine::AffineApplyOp::create( + builder, loc, affineMapSelectOneOutput(map, i), inputs)); } return results; } diff --git a/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h b/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h index cb5d2c12605..2ec5b6ce690 100644 --- a/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h +++ b/include/ttmlir/Conversion/TTNNToEmitC/EmitCConversion.h @@ -2082,8 +2082,8 @@ class EmitCTTNNEmitter { // ttnn::distributed::MeshDevice does not support copy/move constructor. // So a reference variable is created to be used as function argument. // ::ttnn::distributed::MeshDevice& deviceRef = *devicePtr; - emitc::ApplyOp meshDeviceOp = rewriter.create( - op.getLoc(), + emitc::ApplyOp meshDeviceOp = emitc::ApplyOp::create( + rewriter, op.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), TypeNameV + "&"), // ::ttnn::distributed::MeshDevice& @@ -2099,8 +2099,8 @@ class EmitCTTNNEmitter { mlir::Value deviceValueFromOperandsList = adaptor.getOperands()[index]; // optional> x = *device_ptr - emitc::ApplyOp meshDeviceOp = rewriter.create( - op.getLoc(), + emitc::ApplyOp meshDeviceOp = emitc::ApplyOp::create( + rewriter, op.getLoc(), rewriter.getType( TypeNameV< ::ttnn::operations::creation::detail::OptionalMeshDevice>), @@ -2135,26 +2135,27 @@ class EmitCTTNNEmitter { ::ttnn::Tensor, std::tuple, std::tuple<::ttnn::Tensor, std::optional<::ttnn::Tensor>>>>; - emitc::ExpressionOp conv2dExpr = rewriter.create( - op.getLoc(), + emitc::ExpressionOp conv2dExpr = emitc::ExpressionOp::create( + rewriter, op.getLoc(), rewriter.getType(TypeNameV<::ttnn::Tensor>), adaptor.getOperands()); mlir::Block &bodyBlock = conv2dExpr.createBody(); rewriter.setInsertionPointToStart(&bodyBlock); - auto conv2dOp = rewriter.create( - op.getLoc(), rewriter.getType(TypeNameV), + auto conv2dOp = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), + rewriter.getType(TypeNameV), opConversionPattern.convertOpName(op), rewriter.getArrayAttr(args), /*template_args=*/nullptr, bodyBlock.getArguments()); - auto getTensorOp = rewriter.create( - op.getLoc(), + auto getTensorOp = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), rewriter.getType(TypeNameV<::ttnn::Tensor>), "::std::get", /*args=*/nullptr, /*template_args=*/ rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}), conv2dOp.getResult(0)); - rewriter.create(op.getLoc(), getTensorOp.getResult(0)); + emitc::YieldOp::create(rewriter, op.getLoc(), getTensorOp.getResult(0)); rewriter.replaceOp(op, conv2dExpr); @@ -2171,15 +2172,16 @@ class EmitCTTNNEmitter { assert(op->getNumResults() == 1 && "Expected single output for MaxPool2dOp."); using ReturnTy = std::vector<::ttnn::Tensor>; - auto maxPool2dOp = rewriter.create( - op.getLoc(), rewriter.getType(TypeNameV), + auto maxPool2dOp = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), + rewriter.getType(TypeNameV), opConversionPattern.convertOpName(op), rewriter.getArrayAttr(args), /*template_args=*/nullptr, operands); // Create index to access first/single element. auto indexType = rewriter.getIndexType(); auto indexOp = - rewriter.create(op.getLoc(), indexType, "0"); + emitc::LiteralOp::create(rewriter, op.getLoc(), indexType, "0"); Value indexVal = indexOp.getResult(); // Create LValue type for the tensor reference. @@ -2187,12 +2189,13 @@ class EmitCTTNNEmitter { rewriter.getContext(), TypeNameV)); // Get reference to the first/single element in the result vector. - auto subscriptOp = rewriter.create( - op.getLoc(), lvalueType, maxPool2dOp.getResult(0), indexVal); + auto subscriptOp = + emitc::SubscriptOp::create(rewriter, op.getLoc(), lvalueType, + maxPool2dOp.getResult(0), indexVal); // Load the actual tensor value from the reference. - auto loadOp = rewriter.create( - op.getLoc(), + auto loadOp = emitc::LoadOp::create( + rewriter, op.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), TypeNameV), subscriptOp.getResult()); @@ -2209,8 +2212,9 @@ class EmitCTTNNEmitter { assert(op.getNumResults() == 2 && "Expected two outputs (values tensor and indices)."); using ReturnTy = std::vector<::ttnn::Tensor>; - auto callOp = rewriter.create( - op.getLoc(), rewriter.getType(TypeNameV), + auto callOp = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), + rewriter.getType(TypeNameV), opConversionPattern.convertOpName(op), rewriter.getArrayAttr(args), /*template_args=*/nullptr, operands); @@ -2218,8 +2222,8 @@ class EmitCTTNNEmitter { for (unsigned i = 0; i < op.getNumResults(); ++i) { // Create index to access i-th element. auto indexType = rewriter.getIndexType(); - auto indexOp = rewriter.create(op.getLoc(), indexType, - std::to_string(i)); + auto indexOp = emitc::LiteralOp::create(rewriter, op.getLoc(), + indexType, std::to_string(i)); Value indexVal = indexOp.getResult(); // Create LValue type for the tensor reference. @@ -2227,12 +2231,12 @@ class EmitCTTNNEmitter { rewriter.getContext(), TypeNameV)); // Get reference to the i-th element in the result vector. - auto subscriptOp = rewriter.create( - op.getLoc(), lvalueType, callOp.getResult(0), indexVal); + auto subscriptOp = emitc::SubscriptOp::create( + rewriter, op.getLoc(), lvalueType, callOp.getResult(0), indexVal); // Load the actual tensor value from the reference. - auto loadOp = rewriter.create( - op.getLoc(), + auto loadOp = emitc::LoadOp::create( + rewriter, op.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), TypeNameV), subscriptOp.getResult()); @@ -2308,11 +2312,12 @@ class EmitCTTNNEmitter { // pointers. mlir::Value dereferenceToRef(mlir::Value ptrValue, const std::string &refTypeName) { - return rewriter - .create( - op.getLoc(), - emitc::OpaqueType::get(rewriter.getContext(), refTypeName), "*", - ptrValue) + return emitc::ApplyOp::create( + rewriter, + + op.getLoc(), + emitc::OpaqueType::get(rewriter.getContext(), refTypeName), "*", + ptrValue) .getResult(); } @@ -2331,22 +2336,23 @@ class EmitCTTNNEmitter { // Create reference variable auto refType = emitc::OpaqueType::get(rewriter.getContext(), refTypeName); std::string verbatimCode = "auto& " + varName + " = *{};"; - rewriter.create( - op.getLoc(), rewriter.getStringAttr(verbatimCode), - llvm::SmallVector{uniquePtrValue}); + emitc::VerbatimOp::create(rewriter, op.getLoc(), + rewriter.getStringAttr(verbatimCode), + llvm::SmallVector{uniquePtrValue}); - return rewriter.create(op.getLoc(), refType, varName) + return emitc::LiteralOp::create(rewriter, op.getLoc(), refType, varName) .getResult(); } private: mlir::Value createVector(ValueRange operands) { - return rewriter - .create( - op.getLoc(), - emitc::OpaqueType::get(rewriter.getContext(), - TypeNameV>), - kCreateVectorFunctionName, nullptr, nullptr, operands) + return emitc::CallOpaqueOp::create( + rewriter, + + op.getLoc(), + emitc::OpaqueType::get(rewriter.getContext(), + TypeNameV>), + kCreateVectorFunctionName, nullptr, nullptr, operands) ->getResult(0); } diff --git a/include/ttmlir/Conversion/TTNNToEmitPy/EmitPyConversion.h b/include/ttmlir/Conversion/TTNNToEmitPy/EmitPyConversion.h index 26fd70abcbb..d6d74e6aef5 100644 --- a/include/ttmlir/Conversion/TTNNToEmitPy/EmitPyConversion.h +++ b/include/ttmlir/Conversion/TTNNToEmitPy/EmitPyConversion.h @@ -2253,12 +2253,13 @@ class EmitPyTTNNEmitter { } mlir::Value createList(ValueRange operands) { - return rewriter - .create( - op.getLoc(), - emitpy::OpaqueType::get(rewriter.getContext(), - TypeNameV>), - kCreateListFunctionName, operands, nullptr, nullptr) + return emitpy::CallOpaqueOp::create( + rewriter, + + op.getLoc(), + emitpy::OpaqueType::get(rewriter.getContext(), + TypeNameV>), + kCreateListFunctionName, operands, nullptr, nullptr) ->getResult(0); } diff --git a/include/ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.td b/include/ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.td index 15d536a0814..f1bc05987e8 100644 --- a/include/ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.td +++ b/include/ttmlir/Dialect/D2M/IR/D2MGenericRegionOps.td @@ -1801,7 +1801,7 @@ class D2M_CBOp traits = []> : D2M_GenericRegionOpemitOpError(); }); assert(succeeded(cbBufferType)); - auto toBuffer = rewriter.create( + auto toBuffer = bufferization::ToBufferOp::create(rewriter, this->getLoc(), *cbBufferType, getCb()); mlir::bufferization::replaceOpWithNewBufferizedOp<$cppClass>( rewriter, *this, toBuffer.getResult()); diff --git a/include/ttmlir/Dialect/TTIR/Transforms/EraseInverseOps/EraseInverseOps.h b/include/ttmlir/Dialect/TTIR/Transforms/EraseInverseOps/EraseInverseOps.h index fbed96423bb..35ccb3dd44d 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/EraseInverseOps/EraseInverseOps.h +++ b/include/ttmlir/Dialect/TTIR/Transforms/EraseInverseOps/EraseInverseOps.h @@ -263,8 +263,8 @@ inline PermuteOp getInverseTM(PermuteOp permuteOp, Value input, SmallVector outputShape = ttmlir::utils::applyPermutation( inputType.getShape(), ArrayRef(inversePermutation)); RankedTensorType resultType = inputType.clone(outputShape); - return rewriter.create(permuteOp->getLoc(), resultType, input, - inversePermutation); + return PermuteOp::create(rewriter, permuteOp->getLoc(), resultType, input, + inversePermutation); } inline ReshapeOp getInverseTM(ReshapeOp reshapeOp, Value input, @@ -277,8 +277,8 @@ inline ReshapeOp getInverseTM(ReshapeOp reshapeOp, Value input, auto outputShape = reshapeOp.getInput().getType().getShape(); RankedTensorType resultType = inputType.clone(outputShape); - return rewriter.create( - reshapeOp->getLoc(), resultType, input, + return ReshapeOp::create( + rewriter, reshapeOp->getLoc(), resultType, input, rewriter.getI32ArrayAttr(SmallVector(outputShape))); } diff --git a/include/ttmlir/Dialect/TTIR/Utils/Utils.h b/include/ttmlir/Dialect/TTIR/Utils/Utils.h index 03cc4c1f41d..7b38ffff389 100644 --- a/include/ttmlir/Dialect/TTIR/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTIR/Utils/Utils.h @@ -67,19 +67,19 @@ struct SplitCaller, // ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, // ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})`. if constexpr (sizeof...(Js) == 0) { - return builder.create(loc, output.getType(), - ttmlir::utils::flatten( - std::get(std::forward_as_tuple( - std::forward(args)...))..., - output)); + return OpTy::create(builder, loc, output.getType(), + ttmlir::utils::flatten( + std::get(std::forward_as_tuple( + std::forward(args)...))..., + output)); } else if constexpr (sizeof...(Js) == 1 && std::is_convertible_v< std::tuple_element_t< sizeof...(Is) + sizeof...(Js) - 1, std::tuple...>>, mlir::ArrayRef>) { - return builder.create( - loc, output.getType(), + return OpTy::create( + builder, loc, output.getType(), ttmlir::utils::flatten( std::get( std::forward_as_tuple(std::forward(args)...))..., @@ -89,8 +89,8 @@ struct SplitCaller, // Otherwise, call the op specific builder that provides positional // `Attribute` arguments. } else { - return builder.create( - loc, output.getType(), + return OpTy::create( + builder, loc, output.getType(), std::get(std::forward_as_tuple(std::forward(args)...))..., output, std::get( @@ -132,17 +132,17 @@ constexpr bool has_dps_trait_v = // createDPSOp(rewriter, loc, outputType, operand1, operand2, ..., // operandN, attribute1, attribute2, ..., attributeM); // is equivalent to: -// auto output = rewriter.create(loc, outputType.getShape(), +// auto output = ttir::EmptyOp::create(rewriter, loc, outputType.getShape(), // outputType.getElementType(), outputType.getEncoding()); -// rewriter.create(loc, outputType, operand1, operand2, ..., operandN, +// OpTy::create(rewriter, loc, outputType, operand1, operand2, ..., operandN, // output, attribute1, attribute2, ..., attributeM); template OpTy createDPSOp(mlir::OpBuilder &builder, mlir::Location loc, mlir::RankedTensorType outputType, ArgsTy &&...args) { static_assert(has_dps_trait_v); - auto output = builder.create( - loc, outputType.getShape(), outputType.getElementType(), + auto output = mlir::tt::ttir::EmptyOp::create( + builder, loc, outputType.getShape(), outputType.getElementType(), outputType.getEncoding()); return detail::splitAndCall(builder, loc, output, @@ -159,9 +159,9 @@ OpTy createDPSOp(mlir::OpBuilder &builder, mlir::Location loc, // is equivalent to: // auto outputType = mlir::RankedTensorType::get(outputShape, outputElementType, // outputEncoding); -// auto output = rewriter.create(loc, outputShape, +// auto output = ttir::EmptyOp::create(rewriter, loc, outputShape, // outputElementType, outputEncoding); -// rewriter.create(loc, outputType, operand1, operand2, ..., operandN, +// OpTy::create(rewriter, loc, outputType, operand1, operand2, ..., operandN, // output, attribute1, attribute2, ..., attributeM); template OpTy createDPSOp(mlir::OpBuilder &builder, mlir::Location loc, @@ -183,7 +183,7 @@ OpTy createDPSOp(mlir::OpBuilder &builder, mlir::Location loc, // replaceOpWithNewDPSOp(rewriter, op, outputType, operand1, operand2, // ..., operandN, attribute1, attribute2, ..., attributeM); // is equivalent to: -// auto output = rewriter.create(loc, outputType.getShape(), +// auto output = ttir::EmptyOp::create(rewriter, loc, outputType.getShape(), // outputType.getElementType(), outputType.getEncoding()); // rewriter.replaceOpWithNewOp(op, outputType, operand1, operand2, ..., // operandN, output, attribute1, attribute2, ..., attributeM); @@ -210,7 +210,7 @@ OpTy replaceOpWithNewDPSOp(mlir::PatternRewriter &rewriter, mlir::Operation *op, // is equivalent to: // auto outputType = mlir::RankedTensorType::get(outputShape, outputElementType, // outputEncoding); -// auto output = rewriter.create(loc, outputShape, +// auto output = ttir::EmptyOp::create(rewriter, loc, outputShape, // outputElementType, outputEncoding); // rewriter.replaceOpWithNewOp(op, outputType, operand1, operand2, ..., // operandN, output, attribute1, attribute2, ..., attributeM); @@ -290,8 +290,8 @@ inline ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, auto shapeAttr = rewriter.getI32ArrayAttr(llvm::SmallVector(targetShape)); - return rewriter.create( - loc, + return ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(targetShape, inputType.getElementType(), inputType.getEncoding()), input, shapeAttr); @@ -407,13 +407,13 @@ inline mlir::Value reshapeAndCastToType(mlir::PatternRewriter &rewriter, auto reshapeOutputType = mlir::RankedTensorType::get( targetType.getShape(), valueType.getElementType(), valueType.getEncoding()); - result = rewriter.create(loc, reshapeOutputType, result, - rewriter.getI32ArrayAttr(targetShape)); + result = ReshapeOp::create(rewriter, loc, reshapeOutputType, result, + rewriter.getI32ArrayAttr(targetShape)); } // If dtype differs, add a typecast if (valueType.getElementType() != targetType.getElementType()) { - result = rewriter.create(loc, targetType, result); + result = TypecastOp::create(rewriter, loc, targetType, result); } return result; diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Fusing/FusionValidator.h b/include/ttmlir/Dialect/TTNN/Transforms/Fusing/FusionValidator.h index 518fe65801b..62b92722f5c 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Fusing/FusionValidator.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Fusing/FusionValidator.h @@ -102,12 +102,12 @@ void FusionValidator::createValidationFunc(ModuleOp module, Location loc, // Create an empty function with a return terminator. auto funcType = builder.getFunctionType({}, {}); - auto func = builder.create(module->getLoc(), - "validation_func", funcType); + auto func = mlir::func::FuncOp::create(builder, module->getLoc(), + "validation_func", funcType); func.addEntryBlock(); auto *block = &func.getBody().front(); builder.setInsertionPointToEnd(block); - builder.create(func->getLoc()); + mlir::func::ReturnOp::create(builder, func->getLoc()); // Capture Value args and create corresponding block arguments. builder.setInsertionPointToStart(block); @@ -140,13 +140,13 @@ void FusionValidator::createValidationFunc(ModuleOp module, Location loc, }; // Create the fused op. - auto op = builder.create(loc, resultTypes, sub(args)...); + auto op = FusedOpType::create(builder, loc, resultTypes, sub(args)...); // Pin results: update return and function type so passes don't DCE the op. auto returnOp = cast(block->getTerminator()); llvm::SmallVector opResults(op->getResults()); OpBuilder retBuilder(returnOp); - retBuilder.create(returnOp.getLoc(), opResults); + mlir::func::ReturnOp::create(retBuilder, returnOp.getLoc(), opResults); returnOp.erase(); llvm::SmallVector outTypes; diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index 7ba189f74b5..fe39e76c2ad 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -11,6 +11,7 @@ #include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h" #include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" +#include "llvm/ADT/StringMap.h" namespace mlir::tt::ttnn { diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h index 14ca928b932..7c9c028a8b3 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -9,6 +9,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Utils/Conv2dConfigParams.h" +#include "llvm/ADT/StringMap.h" #include "llvm/Support/CommandLine.h" #include diff --git a/include/ttmlir/Support/IRHasher.h b/include/ttmlir/Support/IRHasher.h index e57e7c906f5..426e52a748f 100644 --- a/include/ttmlir/Support/IRHasher.h +++ b/include/ttmlir/Support/IRHasher.h @@ -31,7 +31,7 @@ inline std::string hashFuncOp(func::FuncOp func) { // RAII guard to undo the temporary changes made to the function // during hashing. - auto undoTempChanges = llvm::make_scope_exit([&func, originalSymName]() { + auto undoTempChanges = llvm::scope_exit([&func, originalSymName]() { func.setSymName(originalSymName); func.walk([&](func::CallOp callOp) { callOp->removeAttr("callee_hash"); }); }); diff --git a/lib/Conversion/ArithToD2MTileOps/ArithToD2MTileOps.cpp b/lib/Conversion/ArithToD2MTileOps/ArithToD2MTileOps.cpp index ea7464e41fd..2a110c0b400 100644 --- a/lib/Conversion/ArithToD2MTileOps/ArithToD2MTileOps.cpp +++ b/lib/Conversion/ArithToD2MTileOps/ArithToD2MTileOps.cpp @@ -81,8 +81,8 @@ class CmpFTileOpRewriter : public OpConversionPattern { auto tileType = operands[0].getType(); // First, compute (lhs - rhs) - auto subOp = rewriter.create(loc, tileType, operands[0], - operands[1]); + auto subOp = d2m::TileSubOp::create(rewriter, loc, tileType, operands[0], + operands[1]); auto operandTileType = mlir::cast(operands[0].getType()); diff --git a/lib/Conversion/D2MToTTKernel/D2MToTTKernel.cpp b/lib/Conversion/D2MToTTKernel/D2MToTTKernel.cpp index 3f772c52693..65053d4a3c4 100644 --- a/lib/Conversion/D2MToTTKernel/D2MToTTKernel.cpp +++ b/lib/Conversion/D2MToTTKernel/D2MToTTKernel.cpp @@ -35,16 +35,14 @@ namespace mlir::tt::ttkernel { namespace { static Value i32(OpBuilder &rewriter, Location loc, int32_t value) { - return rewriter - .create(loc, rewriter.getI32Type(), - rewriter.getI32IntegerAttr(value)) + return arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(value)) .getResult(); } static Value index(OpBuilder &rewriter, Location loc, int64_t value) { - return rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(value)) + return arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getIndexAttr(value)) .getResult(); } @@ -52,10 +50,12 @@ static std::pair getVirtualCoordsFromLogicalCoords(OpBuilder &rewriter, Location loc, ttcore::ChipDescAttr chipDesc, ValueRange dstCoreIndex) { - Value virtY = rewriter.create( - dstCoreIndex[0].getLoc(), dstCoreIndex[0].getType(), dstCoreIndex[0]); - Value virtX = rewriter.create( - dstCoreIndex[1].getLoc(), dstCoreIndex[1].getType(), dstCoreIndex[1]); + Value virtY = ttkernel::ConvertLogicalYToTranslatedOp::create( + rewriter, dstCoreIndex[0].getLoc(), dstCoreIndex[0].getType(), + dstCoreIndex[0]); + Value virtX = ttkernel::ConvertLogicalXToTranslatedOp::create( + rewriter, dstCoreIndex[1].getLoc(), dstCoreIndex[1].getType(), + dstCoreIndex[1]); return {virtY, virtX}; } @@ -64,16 +64,15 @@ static std::pair getMcastEndCoords(PatternRewriter &rewriter, const Value &nocStartY, const Value &nocStartX, OperandRange mcastShape) { - return {rewriter.create( - nocStartY.getLoc(), - rewriter.create(nocStartY.getLoc(), nocStartY, - mcastShape[0]), - index(rewriter, loc, 1)), - rewriter.create( - nocStartX.getLoc(), - rewriter.create(nocStartX.getLoc(), nocStartX, - mcastShape[1]), - index(rewriter, loc, 1))}; + return { + arith::SubIOp::create(rewriter, nocStartY.getLoc(), + arith::AddIOp::create(rewriter, nocStartY.getLoc(), + nocStartY, mcastShape[0]), + index(rewriter, loc, 1)), + arith::SubIOp::create(rewriter, nocStartX.getLoc(), + arith::AddIOp::create(rewriter, nocStartX.getLoc(), + nocStartX, mcastShape[1]), + index(rewriter, loc, 1))}; } static Value getCB(ConversionPatternRewriter &rewriter, Value cb) { @@ -283,13 +282,13 @@ class MemRefSubviewRewriter : public OpConversionPattern { Value rtIdx = index(rewriter, op.getLoc(), resultTy.getShape()[0]); Value ktIdx = index(rewriter, op.getLoc(), resultTy.getShape()[1]); Value tilesPerBlock = - rewriter.create(op.getLoc(), rtIdx, ktIdx); + arith::MulIOp::create(rewriter, op.getLoc(), rtIdx, ktIdx); // Convert the resolved source row offset to a block-row index. Value rowBlockIdx = - rewriter.create(op.getLoc(), sourceIndices[0], rtIdx); - Value rowBase = - rewriter.create(op.getLoc(), rowBlockIdx, tilesPerBlock); + arith::DivSIOp::create(rewriter, op.getLoc(), sourceIndices[0], rtIdx); + Value rowBase = arith::MulIOp::create(rewriter, op.getLoc(), rowBlockIdx, + tilesPerBlock); rewriter.replaceOpWithNewOp(op, rowBase, sourceIndices[1]); return success(); }; @@ -304,7 +303,7 @@ class AcquireDstRewriter : public OpConversionPattern { LogicalResult matchAndRewrite(d2m::AcquireDstOp op, d2m::AcquireDstOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.create(op.getLoc()); + ttkernel::TileRegsAcquireOp::create(rewriter, op.getLoc()); // Dst is an implicit resource in TTKernel, so we can just erase it. rewriter.eraseOp(op); return success(); @@ -331,10 +330,10 @@ class D2MSetL1AccumulateRewriter if (func) { func.walk([&](func::ReturnOp returnOp) { OpBuilder builder(returnOp); - Value zero = builder.create( - returnOp.getLoc(), builder.getI32Type(), - builder.getI32IntegerAttr(0)); - builder.create(returnOp.getLoc(), zero); + Value zero = arith::ConstantOp::create(builder, returnOp.getLoc(), + builder.getI32Type(), + builder.getI32IntegerAttr(0)); + ttkernel::PackReconfigL1AccOp::create(builder, returnOp.getLoc(), zero); }); } @@ -353,7 +352,7 @@ static Value computeLinearIndex(Location loc, ArrayRef shape, return indices.front(); } - Value linearIdx = rewriter.create(loc, 0); + Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); for (size_t i = 0; i < indices.size(); ++i) { int64_t stride = 1; for (size_t j = i + 1; j < shape.size(); ++j) { @@ -362,10 +361,11 @@ static Value computeLinearIndex(Location loc, ArrayRef shape, Value contribution = indices[i]; if (stride != 1) { - auto strideVal = rewriter.create(loc, stride); - contribution = rewriter.create(loc, indices[i], strideVal); + auto strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride); + contribution = + arith::MulIOp::create(rewriter, loc, indices[i], strideVal); } - linearIdx = rewriter.create(loc, linearIdx, contribution); + linearIdx = arith::AddIOp::create(rewriter, loc, linearIdx, contribution); } return linearIdx; } @@ -410,10 +410,10 @@ class MemrefStoreRewriter : public OpConversionPattern { rewriter.setInsertionPointToStart(rewriter.getInsertionBlock()); setInsertionPointAfterOperands(rewriter, {inCB, outCB}, /*allowHoisting*/ true); - rewriter.create(store.getLoc(), inCB, outCB); + ttkernel::InitSFPUOp::create(rewriter, store.getLoc(), inCB, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); - rewriter.create(store.getLoc(), cb); + ttkernel::CopyTileInitOp::create(rewriter, store.getLoc(), cb); rewriter.replaceOpWithNewOp(store, cb, cbIndex, dstIndex); return success(); @@ -629,8 +629,8 @@ class D2MFPUOpsRewriter : public OpConversionPattern { auto outCB = getOutCB(rewriter, op); setInsertionPointAfterOperands(rewriter, {cbA, cbB, outCB}, /*allowHoisting*/ true); - rewriter.create(op->getLoc(), cbA, cbB, - outCB); + ttkernel::BinaryOpInitCommonOp::create(rewriter, op->getLoc(), cbA, cbB, + outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); } else { static_assert(arity == 3 && !ttmlir::utils::always_false(), @@ -650,17 +650,17 @@ class D2MFPUOpsRewriter : public OpConversionPattern { !hasMatmulInit(func)) { setInsertionPointToFuncStart(rewriter, func); auto transpose = i32(rewriter, op->getLoc(), 0); - rewriter.create(op->getLoc(), cbA, cbB, outCB, - transpose); + ttkernel::MatmulInitOp::create(rewriter, op->getLoc(), cbA, cbB, outCB, + transpose); } auto transpose = i32(rewriter, op->getLoc(), 0); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); - rewriter.create(op->getLoc(), cbA, cbB, - transpose); - rewriter.create(op->getLoc(), cbA, cbB, - adaptor.getA(), adaptor.getB(), - adaptor.getC()); + ttkernel::MatmulInitShortOp::create(rewriter, op->getLoc(), cbA, cbB, + transpose); + ttkernel::MatmulTilesOp::create(rewriter, op->getLoc(), cbA, cbB, + adaptor.getA(), adaptor.getB(), + adaptor.getC()); } else if constexpr (std::is_same_v) { auto insertionPoint = rewriter.getInsertionPoint(); auto cbA = getCB(rewriter, op.getA()); @@ -694,11 +694,12 @@ class D2MFPUOpsRewriter : public OpConversionPattern { auto transpose = i32(rewriter, op->getLoc(), 0); - rewriter.create( - op->getLoc(), cbA, cbB, outCB, transpose, ct_i32, rt_i32, kt_i32); + ttkernel::MatmulBlockInitOp::create(rewriter, op->getLoc(), cbA, cbB, + outCB, transpose, ct_i32, rt_i32, + kt_i32); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); - rewriter.create( - op->getLoc(), cbA, cbB, transpose, ct_i32, rt_i32, kt_i32); + ttkernel::MatmulBlockInitShortOp::create( + rewriter, op->getLoc(), cbA, cbB, transpose, ct_i32, rt_i32, kt_i32); // Get the tile index for each input in the global memref. This is done by // resolving tile (0,0) from the subview, representing a block, into the @@ -718,9 +719,9 @@ class D2MFPUOpsRewriter : public OpConversionPattern { bTileIndex = index(rewriter, op.getLoc(), 0); } - rewriter.create( - op->getLoc(), cbA, cbB, aTileIndex, bTileIndex, destIndex, transpose, - ct_i32, rt_i32, kt_i32, nt_i32); + ttkernel::ExperimentalMatmulBlockOp::create( + rewriter, op->getLoc(), cbA, cbB, aTileIndex, bTileIndex, destIndex, + transpose, ct_i32, rt_i32, kt_i32, nt_i32); } else if constexpr (std::is_same_v || std::is_same_v) { ttkernel::ReduceType reduce_type; @@ -749,15 +750,15 @@ class D2MFPUOpsRewriter : public OpConversionPattern { auto outCB = getOutCB(rewriter, op); setInsertionPointAfterOperands(rewriter, {cbA, cbB, outCB}, /*allowHoisting*/ true); - rewriter.create(op->getLoc(), cbA, - cbB, outCB); + ttkernel::ComputeKernelHWStartupOp::create(rewriter, op->getLoc(), cbA, + cbB, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); - rewriter.create(op->getLoc(), cbA, cbB, outCB, - reduce_type, kernel_reduce_dim); - rewriter.create( - op->getLoc(), cbA, cbB, adaptor.getA(), adaptor.getB(), + ttkernel::ReduceInitOp::create(rewriter, op->getLoc(), cbA, cbB, outCB, + reduce_type, kernel_reduce_dim); + ttkernel::ReduceTileOp::create( + rewriter, op->getLoc(), cbA, cbB, adaptor.getA(), adaptor.getB(), adaptor.getC(), reduce_type, kernel_reduce_dim); - rewriter.create(op->getLoc()); + ttkernel::ReduceUninitOp::create(rewriter, op->getLoc()); } else if constexpr (std::is_same_v) { ttkernel::BcastType bcastType = ttkernel::BcastType::None; switch (op.getBcastType()) { @@ -776,17 +777,17 @@ class D2MFPUOpsRewriter : public OpConversionPattern { } auto cb = getCB(rewriter, op.getInput()); auto dstIdx = getDstIdxFromResult(op.getResult()); - rewriter.create(op->getLoc(), cb, cb, - bcastType); - rewriter.create( - op->getLoc(), cb, adaptor.getInput(), dstIdx, bcastType); + ttkernel::UnaryBcastInitOp::create(rewriter, op->getLoc(), cb, cb, + bcastType); + ttkernel::UnaryBcastTileOp::create(rewriter, op->getLoc(), cb, + adaptor.getInput(), dstIdx, bcastType); } else if constexpr (arity == 2) { auto dstIdx = getDstIdxFromResult(op.getResult()); - rewriter.create(op->getLoc(), getCB(rewriter, op.getLhs()), - getCB(rewriter, op.getRhs())); - rewriter.create(op->getLoc(), getCB(rewriter, op.getLhs()), - getCB(rewriter, op.getRhs()), adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + InitOp::create(rewriter, op->getLoc(), getCB(rewriter, op.getLhs()), + getCB(rewriter, op.getRhs())); + FPUOp::create(rewriter, op->getLoc(), getCB(rewriter, op.getLhs()), + getCB(rewriter, op.getRhs()), adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } else { return llvm::failure(); } @@ -822,7 +823,7 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { rewriter.setInsertionPointToStart(rewriter.getInsertionBlock()); setInsertionPointAfterOperands(rewriter, {inCB, outCB}, /*allowHoisting*/ true); - rewriter.create(op->getLoc(), inCB, outCB); + ttkernel::InitSFPUOp::create(rewriter, op->getLoc(), inCB, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); // For binary ops (arity == 2), check if rhs is a scalar to create the right @@ -834,9 +835,9 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { if (isScalarRhs) { // Use scalar-specific init ops if constexpr (std::is_same_v) { - rewriter.create(op->getLoc()); + ttkernel::PowerTileInitOp::create(rewriter, op->getLoc()); } else { - rewriter.create(op->getLoc()); + ttkernel::BinopWithScalarTileInitOp::create(rewriter, op->getLoc()); } } else if constexpr (hasMapping) { using IntInit = @@ -845,39 +846,39 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { mlir::cast(op.getLhs().getType()); if (llvm::isa(tileType.getElementType())) { if constexpr (needsDtypeArg) { - rewriter.create(op->getLoc(), tileType.getDataType()); + IntInit::create(rewriter, op->getLoc(), tileType.getDataType()); } else { - rewriter.create(op->getLoc()); + IntInit::create(rewriter, op->getLoc()); } } else { - rewriter.create(op->getLoc()); + InitOp::create(rewriter, op->getLoc()); } } else { - rewriter.create(op->getLoc()); + InitOp::create(rewriter, op->getLoc()); } } else if constexpr (std::is_same_v) { const auto inDtype = mlir::cast(op.getInput().getType()).getDataType(); const auto outDtype = mlir::cast(op.getResult().getType()).getDataType(); - rewriter.create(op->getLoc(), inDtype, - outDtype); + ttkernel::TypecastTileInitOp::create(rewriter, op->getLoc(), inDtype, + outDtype); } else { - rewriter.create(op->getLoc()); + InitOp::create(rewriter, op->getLoc()); } if constexpr (std::is_same_v) { const auto dtype = mlir::cast(op.getInput().getType()).getDataType(); - rewriter.create(op->getLoc(), - adaptor.getInput(), dtype); + ttkernel::LogicalNotTileOp::create(rewriter, op->getLoc(), + adaptor.getInput(), dtype); } else if constexpr (std::is_same_v) { const auto inDtype = mlir::cast(op.getInput().getType()).getDataType(); const auto outDtype = mlir::cast(op.getResult().getType()).getDataType(); - rewriter.create( - op->getLoc(), adaptor.getInput(), inDtype, outDtype); + ttkernel::TypecastTileOp::create(rewriter, op->getLoc(), + adaptor.getInput(), inDtype, outDtype); } else if constexpr (std::is_same_v) { auto loc = op->getLoc(); // The hardware clamp API takes i32 params for both int and float clamps. @@ -887,27 +888,27 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { if (mlir::isa(minAttr) && mlir::isa(maxAttr)) { auto intToI32Param = [&](Attribute attr) -> Value { auto intAttr = mlir::cast(attr); - return rewriter.create( - loc, rewriter.getI32Type(), + return arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(intAttr.getValue().getSExtValue())); }; auto minParam = intToI32Param(minAttr); auto maxParam = intToI32Param(maxAttr); - rewriter.create( - loc, adaptor.getInput(), minParam, maxParam); + ttkernel::ClampScalarTileInt32Op::create( + rewriter, loc, adaptor.getInput(), minParam, maxParam); } else { auto floatToI32Param = [&](Attribute attr) -> Value { auto floatAttr = mlir::cast(attr); - auto f32Val = rewriter.create( - loc, + auto f32Val = arith::ConstantOp::create( + rewriter, loc, rewriter.getF32FloatAttr(floatAttr.getValue().convertToDouble())); - return rewriter.create(loc, rewriter.getI32Type(), - f32Val); + return arith::BitcastOp::create(rewriter, loc, rewriter.getI32Type(), + f32Val); }; auto minParam = floatToI32Param(minAttr); auto maxParam = floatToI32Param(maxAttr); - rewriter.create(loc, adaptor.getInput(), - minParam, maxParam); + ttkernel::ClampScalarTileOp::create(rewriter, loc, adaptor.getInput(), + minParam, maxParam); } } else if constexpr (arity == 1 && hasMapping) { @@ -919,12 +920,12 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { mlir::cast(op.getInput().getType()) .getElementType(); if (llvm::isa(elemType)) { - rewriter.create(op->getLoc(), adaptor.getInput()); + IntSFPUOp::create(rewriter, op->getLoc(), adaptor.getInput()); } else { - rewriter.create(op->getLoc(), adaptor.getInput()); + SFPUOp::create(rewriter, op->getLoc(), adaptor.getInput()); } } else if constexpr (arity == 1) { - rewriter.create(op->getLoc(), adaptor.getInput()); + SFPUOp::create(rewriter, op->getLoc(), adaptor.getInput()); } else if constexpr (arity == 2) { // Check if rhs is a scalar (float or integer) at runtime auto rhsType = adaptor.getRhs().getType(); @@ -938,29 +939,29 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { // Create the appropriate unary scalar op based on the D2M op type if constexpr (std::is_same_v) { // Bitcast the scalar value to i32 to pass as parameter - rewriter.create(loc); - auto scalarParam = rewriter.create( - loc, rewriter.getI32Type(), adaptor.getRhs()); - rewriter.create(loc, dstIdx, scalarParam); + ttkernel::BinopWithScalarTileInitOp::create(rewriter, loc); + auto scalarParam = arith::BitcastOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getRhs()); + ttkernel::AddUnaryTileOp::create(rewriter, loc, dstIdx, scalarParam); } else if constexpr (std::is_same_v) { - rewriter.create(loc); - auto scalarParam = rewriter.create( - loc, rewriter.getI32Type(), adaptor.getRhs()); - rewriter.create(loc, dstIdx, scalarParam); + ttkernel::BinopWithScalarTileInitOp::create(rewriter, loc); + auto scalarParam = arith::BitcastOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getRhs()); + ttkernel::SubUnaryTileOp::create(rewriter, loc, dstIdx, scalarParam); } else if constexpr (std::is_same_v) { - rewriter.create(loc); - auto scalarParam = rewriter.create( - loc, rewriter.getI32Type(), adaptor.getRhs()); - rewriter.create(loc, dstIdx, scalarParam); + ttkernel::BinopWithScalarTileInitOp::create(rewriter, loc); + auto scalarParam = arith::BitcastOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getRhs()); + ttkernel::MulUnaryTileOp::create(rewriter, loc, dstIdx, scalarParam); } else if constexpr (std::is_same_v) { - auto scalarParam = rewriter.create( - loc, rewriter.getI32Type(), adaptor.getRhs()); - rewriter.create(loc, dstIdx, scalarParam); + auto scalarParam = arith::BitcastOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getRhs()); + ttkernel::DivUnaryTileOp::create(rewriter, loc, dstIdx, scalarParam); } else if constexpr (std::is_same_v) { // For power, convert float value to integer (not bitcast) - auto scalarParam = rewriter.create( - loc, rewriter.getI32Type(), adaptor.getRhs()); - rewriter.create(loc, dstIdx, scalarParam); + auto scalarParam = arith::FPToSIOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getRhs()); + ttkernel::PowUnaryTileOp::create(rewriter, loc, dstIdx, scalarParam); } // Scalar ops operate in-place on DST slot - replace with the same // dstIdx. @@ -978,8 +979,8 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { std::is_same_v) { const auto dtype = mlir::cast(op.getLhs().getType()).getDataType(); - rewriter.create(op->getLoc(), adaptor.getLhs(), - adaptor.getRhs(), dstIdx, dtype); + SFPUOp::create(rewriter, op->getLoc(), adaptor.getLhs(), + adaptor.getRhs(), dstIdx, dtype); } else if constexpr (hasMapping) { using IntSFPUOp = typename TTKernelOpPair::second_type; @@ -987,20 +988,19 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { mlir::cast(op.getLhs().getType()); if (llvm::isa(tileType.getElementType())) { if constexpr (needsDtypeArg) { - rewriter.create(op->getLoc(), adaptor.getLhs(), - adaptor.getRhs(), dstIdx, - tileType.getDataType()); + IntSFPUOp::create(rewriter, op->getLoc(), adaptor.getLhs(), + adaptor.getRhs(), dstIdx, tileType.getDataType()); } else { - rewriter.create(op->getLoc(), adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + IntSFPUOp::create(rewriter, op->getLoc(), adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } } else { - rewriter.create(op->getLoc(), adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + SFPUOp::create(rewriter, op->getLoc(), adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } } else { - rewriter.create(op->getLoc(), adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + SFPUOp::create(rewriter, op->getLoc(), adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } } else { // Ternary tile operation (arity == 3) @@ -1015,9 +1015,9 @@ class D2MSFPUOpsRewriter : public OpConversionPattern { const auto dtype = mlir::cast(op.getTrueValue().getType()) .getDataType(); - rewriter.create( - op->getLoc(), adaptor.getCondition(), adaptor.getTrueValue(), - adaptor.getFalseValue(), dstIdx, dtype); + ttkernel::WhereTileOp::create( + rewriter, op->getLoc(), adaptor.getCondition(), + adaptor.getTrueValue(), adaptor.getFalseValue(), dstIdx, dtype); } } @@ -1107,23 +1107,23 @@ class D2MFPUBinaryRewriter : public OpConversionPattern { auto insertionPoint = rewriter.getInsertionPoint(); setInsertionPointAfterOperands(rewriter, {cbA, cbB, outCB}, /*allowHoisting*/ true); - rewriter.create(loc, cbA, cbB, outCB); + ttkernel::BinaryOpInitCommonOp::create(rewriter, loc, cbA, cbB, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); auto dstIdx = getDstIdxFromResult(op.getResult()); if constexpr (std::is_same_v) { - rewriter.create(loc, cbA, cbB); - rewriter.create(loc, cbA, cbB, adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + ttkernel::AddTilesInitOp::create(rewriter, loc, cbA, cbB); + ttkernel::AddTilesOp::create(rewriter, loc, cbA, cbB, adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } else if constexpr (std::is_same_v) { - rewriter.create(loc, cbA, cbB); - rewriter.create(loc, cbA, cbB, adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + ttkernel::SubTilesInitOp::create(rewriter, loc, cbA, cbB); + ttkernel::SubTilesOp::create(rewriter, loc, cbA, cbB, adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } else if constexpr (std::is_same_v) { - rewriter.create(loc, cbA, cbB); - rewriter.create(loc, cbA, cbB, adaptor.getLhs(), - adaptor.getRhs(), dstIdx); + ttkernel::MulTilesInitOp::create(rewriter, loc, cbA, cbB); + ttkernel::MulTilesOp::create(rewriter, loc, cbA, cbB, adaptor.getLhs(), + adaptor.getRhs(), dstIdx); } rewriter.eraseOp(op); @@ -1146,25 +1146,25 @@ class D2MFPUBinaryRewriter : public OpConversionPattern { auto insertionPoint = rewriter.getInsertionPoint(); setInsertionPointAfterOperands(rewriter, {cbA, outCB}, /*allowHoisting*/ true); - rewriter.create(loc, cbA, outCB); + ttkernel::InitSFPUOp::create(rewriter, loc, cbA, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); - rewriter.create(loc, cbA); - rewriter.create(loc, cbA, adaptor.getLhs(), dst0); - rewriter.create(loc, cbB); - rewriter.create(loc, cbB, adaptor.getRhs(), dst1); + ttkernel::CopyTileInitOp::create(rewriter, loc, cbA); + ttkernel::CopyTileOp::create(rewriter, loc, cbA, adaptor.getLhs(), dst0); + ttkernel::CopyTileInitOp::create(rewriter, loc, cbB); + ttkernel::CopyTileOp::create(rewriter, loc, cbB, adaptor.getRhs(), dst1); const auto dtype = mlir::cast(op.getLhs().getType()).getDataType(); if constexpr (needsDtypeArg) { - rewriter.create(loc, dtype); + IntInit::create(rewriter, loc, dtype); } else { - rewriter.create(loc); + IntInit::create(rewriter, loc); } if constexpr (needsDtypeArg) { - rewriter.create(loc, dst0, dst1, dst0, dtype); + IntSFPUOp::create(rewriter, loc, dst0, dst1, dst0, dtype); } else { - rewriter.create(loc, dst0, dst1, dst0); + IntSFPUOp::create(rewriter, loc, dst0, dst1, dst0); } rewriter.eraseOp(op); @@ -1200,7 +1200,7 @@ class D2MFPUBinaryRewriter : public OpConversionPattern { auto insertionPoint = rewriter.getInsertionPoint(); setInsertionPointAfterOperands(rewriter, {cb, outCB}, /*allowHoisting*/ true); - rewriter.create(loc, cb, cb, outCB); + ttkernel::BinaryOpInitCommonOp::create(rewriter, loc, cb, cb, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); auto eltwiseType = getEltwiseBinaryType(); @@ -1208,14 +1208,14 @@ class D2MFPUBinaryRewriter : public OpConversionPattern { // binary_dest_reuse is an in-place operation. If the DST // operand comes from a different slot, copy it first to the output slot. if (dstOperandIdx != dstIdx) { - rewriter.create(loc); - rewriter.create(loc, dstOperandIdx, dstIdx); + ttkernel::CopyDestValuesInitOp::create(rewriter, loc); + ttkernel::CopyDestValuesOp::create(rewriter, loc, dstOperandIdx, dstIdx); } - rewriter.create(loc, cb, eltwiseType, - reuseType); - rewriter.create( - loc, cb, cbTileIdx, dstIdx, eltwiseType, reuseType); + ttkernel::BinaryDestReuseTilesInitOp::create(rewriter, loc, cb, eltwiseType, + reuseType); + ttkernel::BinaryDestReuseTilesOp::create(rewriter, loc, cb, cbTileIdx, + dstIdx, eltwiseType, reuseType); rewriter.eraseOp(op); return success(); @@ -1256,13 +1256,13 @@ class D2MTilizeUntilizeRewriter : public OpConversionPattern { auto blockR = i32(rewriter, op->getLoc(), collapsed2DShape[0]); auto blockC = i32(rewriter, op->getLoc(), collapsed2DShape[1]); - rewriter.create(op->getLoc(), src, - nullptr, dst); + ttkernel::ComputeKernelHWStartupOp::create(rewriter, op->getLoc(), src, + nullptr, dst); if constexpr (std::is_same_v) { - rewriter.create(op->getLoc(), src, blockC, dst); - rewriter.create(op->getLoc(), src, dst, blockR, blockC); + ttkernel::TilizeInitOp::create(rewriter, op->getLoc(), src, blockC, dst); + BlockOp::create(rewriter, op->getLoc(), src, dst, blockR, blockC); } else if constexpr (std::is_same_v< BlockOp, ttkernel::ExperimentalPackUntilizeBlockOp>) { @@ -1286,11 +1286,12 @@ class D2MTilizeUntilizeRewriter : public OpConversionPattern { auto totalColTilesAttr = rewriter.getI32IntegerAttr(static_cast(totalColTiles)); - rewriter.create( - op->getLoc(), src, dst, colsPerDstPassAttr, totalColTilesAttr); - rewriter.create(op->getLoc(), src, dst, blockR, blockC, - colsPerDstPassAttr, totalColTilesAttr); - rewriter.create(op->getLoc(), dst); + ttkernel::PackUntilizeInitOp::create(rewriter, op->getLoc(), src, dst, + colsPerDstPassAttr, + totalColTilesAttr); + BlockOp::create(rewriter, op->getLoc(), src, dst, blockR, blockC, + colsPerDstPassAttr, totalColTilesAttr); + ttkernel::PackUntilizeUninitOp::create(rewriter, op->getLoc(), dst); } else { llvm_unreachable("unsupported tilize/untilize op"); } @@ -1313,7 +1314,7 @@ class D2MTileFillRewriter : public OpConversionPattern { Value fillValue = adaptor.getValue(); Location loc = op->getLoc(); - rewriter.create(loc, dstIdx, fillValue); + ttkernel::ExperimentalTileFillOp::create(rewriter, loc, dstIdx, fillValue); // Replace the op with its DST index so users (like TileWhereOp) get the // correct operand value. @@ -1334,11 +1335,11 @@ class D2MWriteRowMaskTileRewriter Location loc = op->getLoc(); Value validRows = adaptor.getValidRows(); if (!validRows.getType().isInteger(32)) { - validRows = rewriter.create( - loc, rewriter.getI32Type(), validRows); + validRows = arith::IndexCastOp::create(rewriter, loc, + rewriter.getI32Type(), validRows); } - rewriter.create( - loc, validRows, adaptor.getOutput()); + ttkernel::ExperimentalWriteRowMaskTileOp::create(rewriter, loc, validRows, + adaptor.getOutput()); rewriter.eraseOp(op); return success(); } @@ -1356,11 +1357,11 @@ class D2MWriteColMaskTileRewriter Location loc = op->getLoc(); Value validCols = adaptor.getValidCols(); if (!validCols.getType().isInteger(32)) { - validCols = rewriter.create( - loc, rewriter.getI32Type(), validCols); + validCols = arith::IndexCastOp::create(rewriter, loc, + rewriter.getI32Type(), validCols); } - rewriter.create( - loc, validCols, adaptor.getOutput()); + ttkernel::ExperimentalWriteColMaskTileOp::create(rewriter, loc, validCols, + adaptor.getOutput()); rewriter.eraseOp(op); return success(); } @@ -1374,8 +1375,8 @@ class D2MExperimentalFillArangeTileRewriter matchAndRewrite(d2m::FillArangeTileOp op, d2m::FillArangeTileOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.create( - op->getLoc(), adaptor.getOutput()); + ttkernel::ExperimentalFillArangeTileOp::create(rewriter, op->getLoc(), + adaptor.getOutput()); rewriter.eraseOp(op); return success(); } @@ -1401,7 +1402,7 @@ class D2MTileTransposeRewriter auto insertionPoint = rewriter.getInsertionPoint(); setInsertionPointAfterOperands(rewriter, {inCB, outCB}, /*allowHoisting*/ true); - rewriter.create(op->getLoc(), inCB, outCB); + ttkernel::TransposeInitOp::create(rewriter, op->getLoc(), inCB, outCB); rewriter.setInsertionPoint(insertionPoint->getBlock(), insertionPoint); // Get the tile index from the input operand. @@ -1410,8 +1411,8 @@ class D2MTileTransposeRewriter // Get the destination index where the result will be stored. Value dstIdx = getDstIdxFromResult(op.getResult()); - rewriter.create(op->getLoc(), inCB, tileIndex, - dstIdx); + ttkernel::TransposeTileOp::create(rewriter, op->getLoc(), inCB, tileIndex, + dstIdx); rewriter.eraseOp(op); return success(); @@ -1493,13 +1494,13 @@ class D2MCBOpRewriter : public OpConversionPattern { op.getCb().getType().template getUnderlyingAs()); auto numPages = i32(rewriter, op->getLoc(), cbNumPages); - rewriter.create(op.getLoc(), adaptor.getCb(), numPages); + TTKernelAcquireOp::create(rewriter, op.getLoc(), adaptor.getCb(), numPages); // Only insert automatic release if there's no explicit push/pop if (!hasExplicitRelease(op)) { Block *block = op->getBlock(); - auto release = rewriter.create( - op.getLoc(), adaptor.getCb(), numPages); + auto release = TTKernelReleaseOp::create(rewriter, op.getLoc(), + adaptor.getCb(), numPages); if (block->mightHaveTerminator()) { rewriter.moveOpBefore(release, block->getTerminator()); } else { @@ -1549,8 +1550,8 @@ static Value castCBTypeAsAddress(OpBuilder &rewriter, Location loc, Value cb) { // 2. It can represent remote data, which we need to lower to a compile time // address (I32 type) // More information on ticket #3172 - return rewriter - .create(loc, rewriter.getI32Type(), cb) + return UnrealizedConversionCastOp::create(rewriter, loc, + rewriter.getI32Type(), cb) ->getResult(0); } @@ -1566,25 +1567,25 @@ static Value buildNocAddress(OpBuilder &rewriter, Location loc, Value cb, auto gridY = index[0]; auto gridX = index[1]; auto offset = index[2]; - auto offsetInt = - rewriter.create(loc, rewriter.getI32Type(), offset); - auto addr = rewriter.create(loc, baseAddr, offsetInt); + auto offsetInt = arith::IndexCastOp::create(rewriter, loc, + rewriter.getI32Type(), offset); + auto addr = arith::AddIOp::create(rewriter, loc, baseAddr, offsetInt); // Translate the src coordinates to virtual coordinates. auto [virtY, virtX] = getVirtualCoordsFromLogicalCoords( rewriter, loc, chipDesc, ValueRange{gridY, gridX}); noc_addr_op = - rewriter.create(loc, virtX, virtY, addr); + ttkernel::GetNocAddrOp::create(rewriter, loc, virtX, virtY, addr); } else { auto bankID = index[1]; - auto bankIDInt = - rewriter.create(loc, rewriter.getI32Type(), bankID); + auto bankIDInt = arith::IndexCastOp::create(rewriter, loc, + rewriter.getI32Type(), bankID); auto offset = index[2]; - auto offsetInt = - rewriter.create(loc, rewriter.getI32Type(), offset); - auto addr = rewriter.create(loc, baseAddr, offsetInt); + auto offsetInt = arith::IndexCastOp::create(rewriter, loc, + rewriter.getI32Type(), offset); + auto addr = arith::AddIOp::create(rewriter, loc, baseAddr, offsetInt); - return rewriter.create(loc, bankIDInt, - addr); + return ttkernel::GetNocAddrFromBankIDOp::create(rewriter, loc, bankIDInt, + addr); } return noc_addr_op; } @@ -1593,10 +1594,10 @@ template static Value buildL1Address(OpBuilder &rewriter, Location loc, Value cb, ValueRange index) { // Use the cb addr as the write address since it is local. - Value baseAddr = rewriter.create(loc, cb); - auto offset = - rewriter.create(loc, rewriter.getI32Type(), index[0]); - return rewriter.create(loc, baseAddr, offset); + Value baseAddr = ReadWritePtrOp::create(rewriter, loc, cb); + auto offset = arith::IndexCastOp::create(rewriter, loc, rewriter.getI32Type(), + index[0]); + return arith::AddIOp::create(rewriter, loc, baseAddr, offset); } class D2MDMAReadRewriter : public OpConversionPattern { @@ -1629,8 +1630,8 @@ class D2MDMAReadRewriter : public OpConversionPattern { rewriter, op.getLoc(), adaptor.getDst(), op.getDstIndices()); auto size = i32(rewriter, op->getLoc(), op.getSizeBytes()); - rewriter.create(op.getLoc(), srcNocAddr, - dstL1Addr, size); + ttkernel::NocAsyncReadOp::create(rewriter, op.getLoc(), srcNocAddr, + dstL1Addr, size); // Add attribute marking whether the DMA wait is for a read or write // operation This will be used when loweing the wait ops because the current @@ -1676,11 +1677,11 @@ class D2MDMAWriteRewriter : public OpConversionPattern { Value srcL1Start; auto srcCBMapping = cbProducerConsumer->get(op.getSrc()); if (srcCBMapping == d2m::ThreadCBOrientation::Producer) { - srcL1Start = rewriter.create(op.getLoc(), - adaptor.getSrc()); + srcL1Start = ttkernel::GetWritePtrOp::create(rewriter, op.getLoc(), + adaptor.getSrc()); } else { - srcL1Start = rewriter.create(op.getLoc(), - adaptor.getSrc()); + srcL1Start = ttkernel::GetReadPtrOp::create(rewriter, op.getLoc(), + adaptor.getSrc()); } auto dstCBMapping = cbProducerConsumer->get(op.getDst()); TT_assertv((dstCBMapping == d2m::ThreadCBOrientation::Producer || @@ -1688,8 +1689,8 @@ class D2MDMAWriteRewriter : public OpConversionPattern { dstCBMapping == d2m::ThreadCBOrientation::Default), "Expected dst cb of a write op to have a producer, " "producer-consumer or default orientation, failing."); - Value dstL1Start = rewriter.create( - op.getLoc(), adaptor.getDst()); + Value dstL1Start = ttkernel::GetWritePtrOp::create(rewriter, op.getLoc(), + adaptor.getDst()); Value transferSize = i32(rewriter, op->getLoc(), op.getSizeBytes()); if (op.isMcast()) { @@ -1704,43 +1705,43 @@ class D2MDMAWriteRewriter : public OpConversionPattern { op.getMcastStartIndex()[1], op.getMcastShape()); auto [virtMcastEndY, virtMcastEndX] = getVirtualCoordsFromLogicalCoords( rewriter, op.getLoc(), chipDesc, {mcastEndY, mcastEndX}); - auto numDestsIdx = rewriter.create( - op.getLoc(), op.getMcastShape()[0], op.getMcastShape()[1]); - auto numDests = rewriter.create( - op.getLoc(), rewriter.getI32Type(), numDestsIdx); - auto numDestsMinusOne = rewriter.create( - op.getLoc(), numDests, - rewriter.create(op.getLoc(), - rewriter.getI32Type(), - rewriter.getI32IntegerAttr(1))); - auto mcastAddr = - rewriter.create( - op.getLoc(), virtX, virtY, virtMcastEndX, virtMcastEndY, - dstL1Start, nullptr); + auto numDestsIdx = + arith::MulIOp::create(rewriter, op.getLoc(), op.getMcastShape()[0], + op.getMcastShape()[1]); + auto numDests = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), numDestsIdx); + auto numDestsMinusOne = arith::SubIOp::create( + rewriter, op.getLoc(), numDests, + arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI32Type(), + rewriter.getI32IntegerAttr(1))); + auto mcastAddr = ttkernel::ExperimentalGetNocMulticastAddrOp::create( + rewriter, op.getLoc(), virtX, virtY, virtMcastEndX, virtMcastEndY, + dstL1Start, nullptr); if (adaptor.getSrc() == adaptor.getDst()) { // If src and dst refer to the same memref, we do not loopback mcast // Dests are one less because the sender core is not included - rewriter.create( - op.getLoc(), srcL1Start, mcastAddr, transferSize, + ttkernel::NocAsyncWriteMulticastOp::create( + rewriter, op.getLoc(), srcL1Start, mcastAddr, transferSize, numDestsMinusOne, rewriter.getBoolAttr(true), nullptr, nullptr); } else { // If src != dst, we loopback mcast - rewriter.create( - op.getLoc(), srcL1Start, mcastAddr, transferSize, numDests, - rewriter.getBoolAttr(true), nullptr, nullptr); + ttkernel::NocAsyncWriteMulticastLoopbackSrcOp::create( + rewriter, op.getLoc(), srcL1Start, mcastAddr, transferSize, + numDests, rewriter.getBoolAttr(true), nullptr, nullptr); } } else { // Local L1 to Local L1 local data movement lowering // Get local coordinates using myY and myX ops - auto myY = rewriter.create(op.getLoc()); - auto myX = rewriter.create(op.getLoc()); + auto myY = ttkernel::MyLogicalYOp::create(rewriter, op.getLoc()); + auto myX = ttkernel::MyLogicalXOp::create(rewriter, op.getLoc()); // Convert local coordinates to virtual coordinates auto [virtY, virtX] = getVirtualCoordsFromLogicalCoords( rewriter, op.getLoc(), chipDesc, ValueRange{myY, myX}); - auto nocAddr = rewriter.create( - op.getLoc(), virtX, virtY, dstL1Start); - rewriter.create(op.getLoc(), srcL1Start, - nocAddr, transferSize); + auto nocAddr = ttkernel::GetNocAddrOp::create(rewriter, op.getLoc(), + virtX, virtY, dstL1Start); + ttkernel::NocAsyncWriteOp::create(rewriter, op.getLoc(), srcL1Start, + nocAddr, transferSize); } } else if (op.isDstRemote()) { auto srcL1Addr = buildL1Address( @@ -1749,8 +1750,8 @@ class D2MDMAWriteRewriter : public OpConversionPattern { buildNocAddress(rewriter, op.getLoc(), adaptor.getDst(), op.getDstIndices(), chipDesc, op.getDstMemorySpace()); auto size = i32(rewriter, op->getLoc(), op.getSizeBytes()); - rewriter.create(op.getLoc(), srcL1Addr, - dstNocAddr, size); + ttkernel::NocAsyncWriteOp::create(rewriter, op.getLoc(), srcL1Addr, + dstNocAddr, size); } // Add attribute marking whether the DMA wait is for a read or write @@ -1786,8 +1787,8 @@ class D2MCoreIndexRewriter : public OpConversionPattern { matchAndRewrite(d2m::CoreIndexOp op, d2m::CoreIndexOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - Value logicalY = rewriter.create(op.getLoc()); - Value logicalX = rewriter.create(op.getLoc()); + Value logicalY = ttkernel::MyLogicalYOp::create(rewriter, op.getLoc()); + Value logicalX = ttkernel::MyLogicalXOp::create(rewriter, op.getLoc()); // If no virtualization mapping, preserve legacy behavior. // Note: phys_to_virt_map is optional on the op. @@ -1815,8 +1816,8 @@ class D2MCoreIndexRewriter : public OpConversionPattern { mlir::AffineMap::get(map.getNumDims(), map.getNumSymbols(), {map.getResult(resultIdx)}, rewriter.getContext()); - Value virtDim = rewriter.create( - op.getLoc(), selectedMap, ValueRange{logicalY, logicalX}); + Value virtDim = mlir::affine::AffineApplyOp::create( + rewriter, op.getLoc(), selectedMap, ValueRange{logicalY, logicalX}); rewriter.replaceOp(op, virtDim); return success(); } @@ -1840,11 +1841,11 @@ class D2MDMAWaitRewriter : public OpConversionPattern { assert(isRead || isWrite || isMcastWrite); if (isRead) { - rewriter.create(op.getLoc()); + ttkernel::NocAsyncReadBarrierOp::create(rewriter, op.getLoc()); } if (isWrite) { - rewriter.create(op.getLoc()); + ttkernel::NocAsyncWriteBarrierOp::create(rewriter, op.getLoc()); } rewriter.eraseOp(op); @@ -2040,11 +2041,11 @@ class D2MKernelFunctionArgsRewriter : public OpConversionPattern { continue; } size_t ctArgIndex = ctArgSpecVector.size(); - auto semaphoreIndex = rewriter.create( - op.getLoc(), rewriter.getI32Type(), + auto semaphoreIndex = GetCompileArgValOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(ctArgIndex)); auto semaphore = - rewriter.create(op.getLoc(), semaphoreIndex); + GetSemaphoreOp::create(rewriter, op.getLoc(), semaphoreIndex); signatureConverter.remapInput(arg.getArgNumber(), semaphore.getResult()); ctArgSpecVector.push_back(rewriter.getAttr( @@ -2091,7 +2092,7 @@ class D2MSemaphoreUpdateRewriter : public OpConversionPattern { // Local semaphore set auto semaphorePtr = - rewriter.create(op.getLoc(), semaphoreAddr); + ttkernel::CastToL1PtrOp::create(rewriter, op.getLoc(), semaphoreAddr); rewriter.replaceOpWithNewOp(op, semaphorePtr, value); @@ -2100,8 +2101,8 @@ class D2MSemaphoreUpdateRewriter : public OpConversionPattern { "d2m.semaphore_set to single remote core is illegal."); auto [virtY, virtX] = getVirtualCoordsFromLogicalCoords( rewriter, op.getLoc(), chipDesc, op.getDstCoreIndex()); - auto nocAddr = rewriter.create( - op.getLoc(), virtX, virtY, semaphoreAddr); + auto nocAddr = ttkernel::GetNocAddrOp::create( + rewriter, op.getLoc(), virtX, virtY, semaphoreAddr); rewriter.replaceOpWithNewOp(op, nocAddr, value, nullptr); } else { @@ -2117,23 +2118,23 @@ class D2MSemaphoreUpdateRewriter : public OpConversionPattern { op.getDstCoreIndex()[1], op.getMcastShape()); auto [virtMcastEndY, virtMcastEndX] = getVirtualCoordsFromLogicalCoords( rewriter, op.getLoc(), chipDesc, {mcastEndY, mcastEndX}); - Value numDestsIdx = rewriter.create( - op.getLoc(), op.getMcastShape()[0], op.getMcastShape()[1]); - Value numDests = rewriter.create( - op.getLoc(), rewriter.getI32Type(), numDestsIdx); - Value numDestsMinusOne = rewriter.create( - op.getLoc(), numDests, - rewriter.create(op.getLoc(), rewriter.getI32Type(), - rewriter.getI32IntegerAttr(1))); - auto mcastAddr = - rewriter.create( - op.getLoc(), virtX, virtY, virtMcastEndX, virtMcastEndY, - semaphoreAddr, nullptr); + Value numDestsIdx = arith::MulIOp::create( + rewriter, op.getLoc(), op.getMcastShape()[0], op.getMcastShape()[1]); + Value numDests = arith::IndexCastOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), numDestsIdx); + Value numDestsMinusOne = arith::SubIOp::create( + rewriter, op.getLoc(), numDests, + arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getI32Type(), + rewriter.getI32IntegerAttr(1))); + auto mcastAddr = ttkernel::ExperimentalGetNocMulticastAddrOp::create( + rewriter, op.getLoc(), virtX, virtY, virtMcastEndX, virtMcastEndY, + semaphoreAddr, nullptr); auto semaphorePtr = - rewriter.create(op.getLoc(), semaphoreAddr); - rewriter.create(op.getLoc(), semaphorePtr, - value); + ttkernel::CastToL1PtrOp::create(rewriter, op.getLoc(), semaphoreAddr); + ttkernel::NocSemaphoreSetOp::create(rewriter, op.getLoc(), semaphorePtr, + value); rewriter.replaceOpWithNewOp( op, semaphoreAddr, mcastAddr, numDestsMinusOne, nullptr, nullptr); } @@ -2156,13 +2157,13 @@ class D2MSemaphoreWaitRewriter Value semaphoreAddr = adaptor.getSemaphore(); auto semaphorePtr = - rewriter.create(op.getLoc(), semaphoreAddr); + ttkernel::CastToL1PtrOp::create(rewriter, op.getLoc(), semaphoreAddr); rewriter.replaceOpWithNewOp(op, semaphorePtr, op.getValue()); if (op.getResetValue()) { - rewriter.create(op.getLoc(), semaphorePtr, - op.getResetValue()); + ttkernel::NocSemaphoreSetOp::create(rewriter, op.getLoc(), semaphorePtr, + op.getResetValue()); } return success(); diff --git a/lib/Conversion/D2MToTTMetal/D2MToTTMetal.cpp b/lib/Conversion/D2MToTTMetal/D2MToTTMetal.cpp index 2e3faf12b68..62930cb0807 100644 --- a/lib/Conversion/D2MToTTMetal/D2MToTTMetal.cpp +++ b/lib/Conversion/D2MToTTMetal/D2MToTTMetal.cpp @@ -346,7 +346,7 @@ class D2MToHostRewriter : public OpConversionPattern { output); // Insert global barrier to ensure the read completes before subsequent // ops use it. - rewriter.create(op->getLoc()); + ttmetal::FinishOp::create(rewriter, op->getLoc()); return success(); } }; diff --git a/lib/Conversion/D2MToTTNN/D2MToTTNN.cpp b/lib/Conversion/D2MToTTNN/D2MToTTNN.cpp index 44d278a542f..be140c278ad 100644 --- a/lib/Conversion/D2MToTTNN/D2MToTTNN.cpp +++ b/lib/Conversion/D2MToTTNN/D2MToTTNN.cpp @@ -756,8 +756,8 @@ class MemrefAllocRewriter : public OpConversionPattern { // Build a temporary typed Value to feed convertMemrefToTTNNTensor. // We use an unrealized_conversion_cast as a placeholder. - auto placeholder = rewriter.create( - op.getLoc(), shardMemrefType, ValueRange{}); + auto placeholder = mlir::UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), shardMemrefType, ValueRange{}); auto convertedTensorType = detail::convertMemrefToTTNNTensor(ctx, placeholder.getResult(0)); rewriter.eraseOp(placeholder); @@ -854,8 +854,8 @@ class MemrefAllocRewriter : public OpConversionPattern { auto memcfg = ttnn::MemoryConfigAttr::get(emptyLayoutAttr, deviceAttr.getWorkerGrid()); - auto emptyOp = rewriter.create( - op.getLoc(), emptyTensorType, device, + auto emptyOp = ttnn::EmptyOp::create( + rewriter, op.getLoc(), emptyTensorType, device, ttnn::ShapeAttr::get(ctx, emptyTensorType.getShape()), ttcore::DataTypeAttr::get(ctx, emptyLayoutAttr.getDataType()), ttnn::LayoutAttr::get(ctx, emptyLayoutAttr.getLayout()), memcfg); diff --git a/lib/Conversion/SFPIToEmitC/SFPIToEmitC.cpp b/lib/Conversion/SFPIToEmitC/SFPIToEmitC.cpp index 065eb13b039..943ef042a2c 100644 --- a/lib/Conversion/SFPIToEmitC/SFPIToEmitC.cpp +++ b/lib/Conversion/SFPIToEmitC/SFPIToEmitC.cpp @@ -58,7 +58,8 @@ class SFPIToEmitCTypeConverter : public TypeConverter { if (inputs.size() != 1) { return nullptr; } - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); @@ -67,7 +68,8 @@ class SFPIToEmitCTypeConverter : public TypeConverter { if (inputs.size() != 1) { return nullptr; } - return builder.create(loc, resultType, inputs) + return UnrealizedConversionCastOp::create(builder, loc, resultType, + inputs) .getResult(0); }); } @@ -110,8 +112,8 @@ class SFPIToEmitCOpConversionPattern : public OpConversionPattern { operands.push_back(*maybeValue); } else if (auto *maybeStr = std::get_if(&valueOrStr); maybeStr) { - auto literalAttribute = rewriter.create( - op.getLoc(), rewriter.getI32Type(), *maybeStr); + auto literalAttribute = emitc::LiteralOp::create( + rewriter, op.getLoc(), rewriter.getI32Type(), *maybeStr); operands.push_back(literalAttribute->getResult(0)); } else { llvm_unreachable("Unsupported builtin operand variant"); diff --git a/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp index ff16f42c18d..2e0d519bbae 100644 --- a/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp @@ -195,8 +195,9 @@ class ShardyToTTIRManualComputationOpConversionPattern // Create a new mesh shard op. auto outputType = mlir::cast( getTypeConverter()->convertType(localArgType)); - auto meshShardOp = rewriter.create( - loc, outputType, globalOperand, shardyMeshSharding->getShardType(), + auto meshShardOp = mlir::tt::ttir::MeshShardOp::create( + rewriter, loc, outputType, globalOperand, + shardyMeshSharding->getShardType(), shardyMeshSharding->getShardDirection(), shardyMeshSharding->getShardShape(), shardyMeshSharding->getShardDims()); @@ -227,8 +228,8 @@ class ShardyToTTIRManualComputationOpConversionPattern // Create a new mesh shard op. auto outputType = mlir::cast( getTypeConverter()->convertType(opResult.getType())); - auto meshShardOp = rewriter.create( - loc, outputType, returnOperand.get(), + auto meshShardOp = mlir::tt::ttir::MeshShardOp::create( + rewriter, loc, outputType, returnOperand.get(), shardyMeshSharding->getShardType(), shardyMeshSharding->getShardDirection(), shardyMeshSharding->getShardShape(), diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOLegalizeCompositePass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOLegalizeCompositePass.cpp index 42fc6eaed49..1e368391e72 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOLegalizeCompositePass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOLegalizeCompositePass.cpp @@ -375,8 +375,8 @@ class ShardyAllSliceToTTIRMeshPartitionConversionPattern auto resultType = mlir::RankedTensorType::get(newShape, currInputType.getElementType(), currInputType.getEncoding()); - currInput = rewriter.create( - srcOp->getLoc(), resultType, currInput, + currInput = ttir::MeshPartitionOp::create( + rewriter, srcOp->getLoc(), resultType, currInput, rewriter.getSI32IntegerAttr(tensorDims[i]), rewriter.getUI32IntegerAttr(clusterAxes[i])); } diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index e97cfc36b53..7f5fafb1911 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -332,8 +332,8 @@ class StableHLOToTTIRReduceOpConversionPattern // We are creating the new ttir op explicitly here, then replace the // original op uses with the new op, and finally, erase the original op // explicitly. - ttir::ArgMaxOp newOp = rewriter.create( - srcOp->getLoc(), outputType, adaptor.getInputs().front(), + ttir::ArgMaxOp newOp = tt::ttir::ArgMaxOp::create( + rewriter, srcOp->getLoc(), outputType, adaptor.getInputs().front(), false /* keep_dim */, dimArg); srcOp->getResults().back().replaceAllUsesWith(newOp->getResults().front()); @@ -820,10 +820,11 @@ class StableHLOToBatchNormTrainingOpConversionPattern // Default momentum for batch norm training FloatAttr momentumAttr = rewriter.getF32FloatAttr(1.0f); - auto runningMean = rewriter.create( - loc, meanType, llvm::to_vector_of(meanType.getShape())); - auto runningVariance = rewriter.create( - loc, varianceType, + auto runningMean = + ttir::ZerosOp::create(rewriter, loc, meanType, + llvm::to_vector_of(meanType.getShape())); + auto runningVariance = ttir::OnesOp::create( + rewriter, loc, varianceType, llvm::to_vector_of(varianceType.getShape())); rewriter.replaceOpWithNewOp( @@ -906,24 +907,24 @@ class StableHLOToBatchNormGradOpConversionPattern scalarType, rewriter.getFloatAttr(operandType.getElementType(), epsilon)); auto epsilonConstant = - rewriter.create(loc, scalarType, epsilonDenseAttr); + ttir::ConstantOp::create(rewriter, loc, scalarType, epsilonDenseAttr); Value epsilonBcast = broadcastFeatureToShape(rewriter, loc, epsilonConstant, operandType, featureIndex); // centered_operand = operand - mean - auto centeredOperand = rewriter.create( - loc, operandType, adaptor.getOperand(), meanBcast); + auto centeredOperand = ttir::SubtractOp::create( + rewriter, loc, operandType, adaptor.getOperand(), meanBcast); // stddev = sqrt(variance + epsilon) - auto variancePlusEpsilon = rewriter.create( - loc, operandType, varianceBcast, epsilonBcast); + auto variancePlusEpsilon = ttir::AddOp::create(rewriter, loc, operandType, + varianceBcast, epsilonBcast); auto stddev = - rewriter.create(loc, operandType, variancePlusEpsilon); + ttir::SqrtOp::create(rewriter, loc, operandType, variancePlusEpsilon); // normalized_operand = centered_operand / stddev - auto normalizedOperand = - rewriter.create(loc, operandType, centeredOperand, stddev); + auto normalizedOperand = ttir::DivOp::create(rewriter, loc, operandType, + centeredOperand, stddev); // elements_per_feature = total_elements / feature_dim_size int64_t totalElements = operandType.getNumElements(); @@ -933,69 +934,70 @@ class StableHLOToBatchNormGradOpConversionPattern auto elementsPerFeatureAttr = DenseElementsAttr::get( scalarType, rewriter.getFloatAttr(operandType.getElementType(), elementsPerFeature)); - auto elementsPerFeatureConst = rewriter.create( - loc, scalarType, elementsPerFeatureAttr); + auto elementsPerFeatureConst = ttir::ConstantOp::create( + rewriter, loc, scalarType, elementsPerFeatureAttr); auto elementsPerFeatureBcast = broadcastFeatureToShape( rewriter, loc, elementsPerFeatureConst, operandType, featureIndex); // i1 = grad_output * elements_per_feature - auto i1 = rewriter.create( - loc, gradOutputType, adaptor.getGradOutput(), elementsPerFeatureBcast); + auto i1 = ttir::MultiplyOp::create(rewriter, loc, gradOutputType, + adaptor.getGradOutput(), + elementsPerFeatureBcast); // i2 = broadcast(sum(grad_output, reduction_dims)) - auto sumGradOutput = rewriter.create( - loc, scaleType, adaptor.getGradOutput(), rewriter.getBoolAttr(false), - reductionDimsAttr); + auto sumGradOutput = + ttir::SumOp::create(rewriter, loc, scaleType, adaptor.getGradOutput(), + rewriter.getBoolAttr(false), reductionDimsAttr); auto i2 = broadcastFeatureToShape(rewriter, loc, sumGradOutput, operandType, featureIndex); // grad_output * centered_operand - auto gradTimesCentered = rewriter.create( - loc, operandType, adaptor.getGradOutput(), centeredOperand); + auto gradTimesCentered = ttir::MultiplyOp::create( + rewriter, loc, operandType, adaptor.getGradOutput(), centeredOperand); // i3 = broadcast(sum(grad_output * centered_operand)) - auto sumGradTimesCentered = rewriter.create( - loc, scaleType, gradTimesCentered, rewriter.getBoolAttr(false), - reductionDimsAttr); + auto sumGradTimesCentered = + ttir::SumOp::create(rewriter, loc, scaleType, gradTimesCentered, + rewriter.getBoolAttr(false), reductionDimsAttr); auto i3 = broadcastFeatureToShape(rewriter, loc, sumGradTimesCentered, operandType, featureIndex); // i4 = i3 * centered_operand - auto i4 = rewriter.create(loc, operandType, i3, - centeredOperand); + auto i4 = ttir::MultiplyOp::create(rewriter, loc, operandType, i3, + centeredOperand); // i5 = i4 / (variance + epsilon) - auto i5 = - rewriter.create(loc, operandType, i4, variancePlusEpsilon); + auto i5 = ttir::DivOp::create(rewriter, loc, operandType, i4, + variancePlusEpsilon); // i6 = i1 - i2 - i5 auto i1MinusI2 = - rewriter.create(loc, operandType, i1, i2); + ttir::SubtractOp::create(rewriter, loc, operandType, i1, i2); auto i6 = - rewriter.create(loc, operandType, i1MinusI2, i5); + ttir::SubtractOp::create(rewriter, loc, operandType, i1MinusI2, i5); // grad_operand = (scale / stddev / elements_per_feature) * i6 auto scaleOverStddev = - rewriter.create(loc, operandType, scaleBcast, stddev); + ttir::DivOp::create(rewriter, loc, operandType, scaleBcast, stddev); - auto scaleOverStddevOverElem = rewriter.create( - loc, operandType, scaleOverStddev, elementsPerFeatureBcast); + auto scaleOverStddevOverElem = ttir::DivOp::create( + rewriter, loc, operandType, scaleOverStddev, elementsPerFeatureBcast); - auto gradOperand = rewriter.create( - loc, gradOperandType, scaleOverStddevOverElem, i6); + auto gradOperand = ttir::MultiplyOp::create(rewriter, loc, gradOperandType, + scaleOverStddevOverElem, i6); // grad_scale = sum(grad_output * normalized_operand) - auto gradTimesNorm = rewriter.create( - loc, operandType, adaptor.getGradOutput(), normalizedOperand); + auto gradTimesNorm = ttir::MultiplyOp::create( + rewriter, loc, operandType, adaptor.getGradOutput(), normalizedOperand); - auto gradScale = rewriter.create( - loc, gradScaleType, gradTimesNorm, rewriter.getBoolAttr(false), - reductionDimsAttr); + auto gradScale = + ttir::SumOp::create(rewriter, loc, gradScaleType, gradTimesNorm, + rewriter.getBoolAttr(false), reductionDimsAttr); // grad_offset = sum(grad_output) - auto gradOffset = rewriter.create( - loc, gradOffsetType, adaptor.getGradOutput(), + auto gradOffset = ttir::SumOp::create( + rewriter, loc, gradOffsetType, adaptor.getGradOutput(), rewriter.getBoolAttr(false), reductionDimsAttr); // Replace the operation with the three results. @@ -1091,8 +1093,8 @@ class StableHLOToBatchNormGradOpConversionPattern auto unsqueezeType = RankedTensorType::get(unsqueezeShape, targetType.getElementType()); - return rewriter.create( - loc, unsqueezeType, input, + return ttir::ReshapeOp::create( + rewriter, loc, unsqueezeType, input, rewriter.getI32ArrayAttr(llvm::to_vector_of(unsqueezeShape))); } }; @@ -1481,8 +1483,8 @@ sliceForBatchGroups(ConversionPatternRewriter &rewriter, Location loc, inputShape.end()); inputSliceShape[groupDimensionIndex] = inputSliceSize; - auto inputSlice = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_inputSlice"), + auto inputSlice = ttir::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_inputSlice"), RankedTensorType::get(inputSliceShape, inputType.getElementType(), inputType.getEncoding()), input, rewriter.getI32ArrayAttr(inputBegins), @@ -1502,8 +1504,8 @@ sliceForBatchGroups(ConversionPatternRewriter &rewriter, Location loc, weightShape.end()); weightSliceShape[kernelOutputFeatureDim] = weightSliceSize; - auto weightSlice = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_weightSlice"), + auto weightSlice = ttir::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_weightSlice"), RankedTensorType::get(weightSliceShape, weightType.getElementType(), weightType.getEncoding()), weight, rewriter.getI32ArrayAttr(weightBegins), @@ -1593,8 +1595,8 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { if (batchGroupCount > 1) { int64_t outputFeatureDim = adaptor.getDimensionNumbers().getOutputFeatureDimension(); - auto concatOp = rewriter.create( - op.getLoc(), outputType, results, outputFeatureDim); + auto concatOp = ttir::ConcatOp::create(rewriter, op.getLoc(), outputType, + results, outputFeatureDim); rewriter.replaceOp(op, concatOp); } else { rewriter.replaceOp(op, results[0]); @@ -1730,8 +1732,8 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { auto shapeAttr = rewriter.getI32ArrayAttr(llvm::SmallVector(targetShape)); - return rewriter.create( - loc, + return ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(targetShape, inputType.getElementType(), inputType.getEncoding()), input, shapeAttr); @@ -1839,8 +1841,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { // Find which dimension is the feature dimension. int64_t featureDim = adaptor.getDimensionNumbers().getOutputFeatureDimension(); - finalResult = rewriter.create(op.getLoc(), outputType, - results, featureDim); + finalResult = ttir::ConcatOp::create(rewriter, op.getLoc(), outputType, + results, featureDim); } else { finalResult = results[0]; } @@ -1936,8 +1938,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { llvm::ArrayRef(inputLayout), llvm::ArrayRef(outputLayout)); auto permutedShape = ttmlir::utils::applyPermutation(inputType.getShape(), permutation); - convInput = rewriter.create( - ttmlir::utils::appendLocationSuffix(op.getLoc(), "_input"), + convInput = ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_input"), RankedTensorType::get(permutedShape, inputType.getElementType(), inputType.getEncoding()), input, permutation); @@ -1964,8 +1966,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { op, isTransposed ? conv2dTransposeKernelLayout : conv2dKernelLayout); auto weightOutputShape = ::ttmlir::utils::applyPermutation( weightType.getShape(), kernelPermutation); - permutedWeight = rewriter.create( - ttmlir::utils::appendLocationSuffix(op.getLoc(), "_weight"), + permutedWeight = ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_weight"), RankedTensorType::get(weightOutputShape, weightType.getElementType(), weightType.getEncoding()), permutedWeight, kernelPermutation); @@ -2045,7 +2047,8 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { ttmlir::utils::applyPermutation( llvm::ArrayRef(extractedGroupsShape), llvm::ArrayRef(permuteOrder)); - auto permutedGroups = rewriter.create( + auto permutedGroups = ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_weight_permuted_groups"), permutedWeightType.cloneWith(permutedGroupsShape, @@ -2063,10 +2066,10 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { } // Use full builder with explicit dimension attributes. - newConv = rewriter.create( - op.getLoc(), outputType, convInput, Value(permutedWeight), Value(), - inputDilationAttr, paddingAttr, outputPaddingAttr, dilationAttr, - rewriter.getI32IntegerAttr(groups), + newConv = ttir::ConvTranspose2dOp::create( + rewriter, op.getLoc(), outputType, convInput, Value(permutedWeight), + Value(), inputDilationAttr, paddingAttr, outputPaddingAttr, + dilationAttr, rewriter.getI32IntegerAttr(groups), rewriter.getI64IntegerAttr(batchDim), rewriter.getI64IntegerAttr(heightDim), rewriter.getI64IntegerAttr(widthDim), @@ -2075,9 +2078,9 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { } else { // Use full builder with explicit dimension attributes. - newConv = rewriter.create( - op.getLoc(), outputType, convInput, Value(permutedWeight), Value(), - strideAttr, paddingAttr, dilationAttr, + newConv = ttir::Conv2dOp::create( + rewriter, op.getLoc(), outputType, convInput, Value(permutedWeight), + Value(), strideAttr, paddingAttr, dilationAttr, rewriter.getI32IntegerAttr(groups), rewriter.getI64IntegerAttr(batchDim), rewriter.getI64IntegerAttr(heightDim), @@ -2213,15 +2216,15 @@ struct ConvolutionToConv3dPattern : public ConvolutionDecompositionPattern { generateConvKernelPermutation(op, conv3dKernelLayout); auto weightOutputShape = ::ttmlir::utils::applyPermutation( weightType.getShape(), kernelPermutation); - permutedWeight = rewriter.create( - ttmlir::utils::appendLocationSuffix(op.getLoc(), "_weight"), + permutedWeight = ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_weight"), RankedTensorType::get(weightOutputShape, weightType.getElementType(), weightType.getEncoding()), permutedWeight, kernelPermutation); - mlir::Value newConv = rewriter.create( - op.getLoc(), outputType, Value(input), Value(permutedWeight), Value(), - strideAttr, paddingAttr, groupsAttr, + mlir::Value newConv = ttir::Conv3dOp::create( + rewriter, op.getLoc(), outputType, Value(input), Value(permutedWeight), + Value(), strideAttr, paddingAttr, groupsAttr, rewriter.getI64IntegerAttr(batchDim), rewriter.getI64IntegerAttr(depthDim), rewriter.getI64IntegerAttr(heightDim), @@ -2383,8 +2386,8 @@ class StableHLOToTTIRConvolutionOpConversionPattern } // Create slice operation to crop the output. - auto sliceOp = rewriter.create( - srcOp.getLoc(), outputType, convOp.getResult(), + auto sliceOp = ttir::SliceStaticOp::create( + rewriter, srcOp.getLoc(), outputType, convOp.getResult(), rewriter.getI32ArrayAttr(sliceBegins), rewriter.getI32ArrayAttr(sliceEnds), rewriter.getI32ArrayAttr(sliceSteps)); @@ -2671,8 +2674,8 @@ class StableHLOToTTIRReduceWindowOpConversionPattern shape4DI64, inputType.getElementType(), inputType.getEncoding()); input = - rewriter.create(srcOp.getLoc(), inputType4D, input, - rewriter.getI32ArrayAttr(shape4D)); + ttir::ReshapeOp::create(rewriter, srcOp.getLoc(), inputType4D, + input, rewriter.getI32ArrayAttr(shape4D)); resultType = RankedTensorType::get( /*shape*/ {1, resultType.getShape()[0], resultType.getShape()[1], 1}, @@ -2682,8 +2685,8 @@ class StableHLOToTTIRReduceWindowOpConversionPattern if (needsPermute) { SmallVector permutedInputShape = ttmlir::utils::applyPermutation(inputType.getShape(), permutation); - input = rewriter.create( - srcOp.getLoc(), + input = ttir::PermuteOp::create( + rewriter, srcOp.getLoc(), RankedTensorType::get(permutedInputShape, inputType.getElementType(), inputType.getEncoding()), @@ -2701,25 +2704,27 @@ class StableHLOToTTIRReduceWindowOpConversionPattern RankedTensorType originalResultType = cast( getTypeConverter()->convertType(srcOp.getResult(i).getType())); if (needsReshape) { - result = rewriter.create( - srcOp.getLoc(), originalResultType, result, + result = ttir::ReshapeOp::create( + rewriter, srcOp.getLoc(), originalResultType, result, rewriter.getI32ArrayAttr( SmallVector(originalResultType.getShape().begin(), originalResultType.getShape().end()))); } if (needsPermute) { - result = rewriter.create( - srcOp.getLoc(), originalResultType, result, inversePermutation); + result = ttir::PermuteOp::create(rewriter, srcOp.getLoc(), + originalResultType, result, + inversePermutation); } return result; }; if (isa(reductionOp) && initVal == NEG_INF) { - result = rewriter - .create( - srcOp.getLoc(), resultType, input, kernelForTTIROps, - strideForTTIROps, dilationForTTIROps, - paddingForTTIROps, ceilMode) + result = ttir::MaxPool2dOp::create(rewriter, + + srcOp.getLoc(), resultType, input, + kernelForTTIROps, strideForTTIROps, + dilationForTTIROps, + paddingForTTIROps, ceilMode) .getResult(); } else if (isa(reductionOp) && initVal == ZERO) { // Special case of sum pooling followed by a convenient div op. @@ -2728,8 +2733,8 @@ class StableHLOToTTIRReduceWindowOpConversionPattern if (divOp && i == 0) { // Average pooling: sum pooling followed by division. // Create AvgPool2dOp directly. - ttir::AvgPool2dOp avgPool2dOp = rewriter.create( - srcOp.getLoc(), resultType, input, kernelForTTIROps, + ttir::AvgPool2dOp avgPool2dOp = ttir::AvgPool2dOp::create( + rewriter, srcOp.getLoc(), resultType, input, kernelForTTIROps, strideForTTIROps, dilationForTTIROps, paddingForTTIROps, ceilMode, countIncludesPad); result = restoreOriginalLayout(avgPool2dOp.getResult()); @@ -2740,18 +2745,18 @@ class StableHLOToTTIRReduceWindowOpConversionPattern } // Sum pooling imitated as average pooling followed by multiplication. - ttir::AvgPool2dOp avgPool2dOp = rewriter.create( - srcOp.getLoc(), resultType, input, kernelForTTIROps, + ttir::AvgPool2dOp avgPool2dOp = ttir::AvgPool2dOp::create( + rewriter, srcOp.getLoc(), resultType, input, kernelForTTIROps, strideForTTIROps, dilationForTTIROps, paddingForTTIROps, ceilMode, countIncludesPad); int32_t kernelSize = kernelForTTIROps[0] * kernelForTTIROps[1]; DenseElementsAttr splatAttr = DenseElementsAttr::get( resultType, rewriter.getFloatAttr(resultType.getElementType(), static_cast(kernelSize))); - ttir::ConstantOp kernelSizeConst = rewriter.create( - srcOp.getLoc(), resultType, splatAttr); - ttir::MultiplyOp mulOp = rewriter.create( - srcOp.getLoc(), resultType, avgPool2dOp.getResult(), + ttir::ConstantOp kernelSizeConst = ttir::ConstantOp::create( + rewriter, srcOp.getLoc(), resultType, splatAttr); + ttir::MultiplyOp mulOp = ttir::MultiplyOp::create( + rewriter, srcOp.getLoc(), resultType, avgPool2dOp.getResult(), kernelSizeConst.getResult()); result = mulOp.getResult(); } else { @@ -3164,8 +3169,9 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern RankedTensorType unsqueezedType = RankedTensorType::get(unsqueezeShape, inputType.getElementType()); - ttir::ReshapeOp reshapeOp = rewriter.create( - srcOp.getLoc(), unsqueezedType, adaptor.getOperand(), reshapeDimAttr); + ttir::ReshapeOp reshapeOp = + ttir::ReshapeOp::create(rewriter, srcOp.getLoc(), unsqueezedType, + adaptor.getOperand(), reshapeDimAttr); ::llvm::ArrayRef inputShape = unsqueezeShape; ::llvm::ArrayRef outputShape = outputType.getShape(); @@ -3366,8 +3372,8 @@ class StableHLOToTTIRLogicalAndBitwiseOpConversionPattern auto logicalOpType = RankedTensorType::get(outputType.getShape(), rewriter.getI1Type(), outputType.getEncoding()); - auto logicalOp = rewriter.create( - srcOp.getLoc(), logicalOpType, + auto logicalOp = LogicalDestOp::create( + rewriter, srcOp.getLoc(), logicalOpType, ValueRange{ adaptor.getOperands()[0].getDefiningOp()->getOperands()[0], adaptor.getOperands()[1].getDefiningOp()->getOperands()[0]}); @@ -3660,8 +3666,8 @@ class StableHLOToTTIRSelectAndScatterOpConversionPattern } } - auto fullTensorOp = rewriter.create( - loc, operandType, rewriter.getF32FloatAttr(fillValue)); + auto fullTensorOp = ttir::FullOp::create( + rewriter, loc, operandType, rewriter.getF32FloatAttr(fillValue)); // Tensor which we scatter into auto fullTensor = fullTensorOp.getResult(); @@ -3773,11 +3779,11 @@ class StableHLOToTTIRSelectAndScatterOpConversionPattern RankedTensorType::get(sourcePermShape, rewriter.getIntegerType(32)); // i32 for indices - Value pooledEmpty = rewriter.create(loc, pooledType); - Value indicesEmpty = rewriter.create(loc, indicesType); + Value pooledEmpty = ttir::EmptyOp::create(rewriter, loc, pooledType); + Value indicesEmpty = ttir::EmptyOp::create(rewriter, loc, indicesType); - auto maxPoolOp = rewriter.create( - loc, TypeRange{pooledType, indicesType}, operand, + auto maxPoolOp = ttir::MaxPool2dWithIndicesOp::create( + rewriter, loc, TypeRange{pooledType, indicesType}, operand, ValueRange{pooledEmpty, indicesEmpty}, kernel, stride, dilations, paddingAttr, ceilMode); @@ -3817,8 +3823,8 @@ class StableHLOToTTIRSelectAndScatterOpConversionPattern } auto reduceTypeAttr = ttcore::ReduceTypeAttr::get(rewriter.getContext(), *scatterReduceType); - auto scatterResult = rewriter.create( - loc, scatterOutputType, + auto scatterResult = ttir::ScatterOp::create( + rewriter, loc, scatterOutputType, reshapedFullTensor, // input tensor reshapedIndices, // index tensor reshapedSource, // source tensor @@ -3867,9 +3873,9 @@ class StableHLOToTTIRSelectAndScatterOpConversionPattern generateReshape(mlir::TypedValue input, RankedTensorType outputType, PatternRewriter &rewriter, StringRef suffix) const { - return rewriter.create( - ttmlir::utils::appendLocationSuffix(input.getLoc(), suffix), outputType, - input, + return ttir::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(input.getLoc(), suffix), + outputType, input, rewriter.getI32ArrayAttr(SmallVector( outputType.getShape().begin(), outputType.getShape().end()))); } @@ -3880,9 +3886,9 @@ class StableHLOToTTIRSelectAndScatterOpConversionPattern ArrayRef permutation, StringRef suffix) const { RankedTensorType permuteType = RankedTensorType::get( permutedShape, inputType.getElementType(), inputType.getEncoding()); - return rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, suffix), permuteType, input, - permutation); + return ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, suffix), permuteType, + input, permutation); } LogicalResult verifySelectBlock(mlir::stablehlo::SelectAndScatterOp srcOp, @@ -3956,8 +3962,9 @@ class StableHLOToTTIRAllReduceOpConversionPattern auto outputType = mlir::cast( getTypeConverter()->convertType(resultOperand.getType())); - auto allReduceOp = rewriter.create( - srcOp.getLoc(), outputType, inputOperand, *reduceType, clusterAxis); + auto allReduceOp = mlir::tt::ttir::AllReduceOp::create( + rewriter, srcOp.getLoc(), outputType, inputOperand, *reduceType, + clusterAxis); allReduceOpResults.push_back(allReduceOp.getResult()); } @@ -4290,8 +4297,8 @@ class StableHLOToTTIRDynamicSliceOpConversionPattern RankedTensorType::get({1}, startIndexElementType); for (Value startIndex : startIndicesRange) { - auto reshapedIndex = rewriter.create( - srcOp.getLoc(), + auto reshapedIndex = ttir::ReshapeOp::create( + rewriter, srcOp.getLoc(), RankedTensorType::get(singleElementTensorType.getShape(), startIndexElementType, singleElementTensorType.getEncoding()), @@ -4302,8 +4309,8 @@ class StableHLOToTTIRDynamicSliceOpConversionPattern auto startIndicesTensorType = RankedTensorType::get( {static_cast(startIndicesValues1D.size())}, startIndexElementType); - auto startIndicesTensor = rewriter.create( - srcOp.getLoc(), + auto startIndicesTensor = mlir::tt::ttir::ConcatOp::create( + rewriter, srcOp.getLoc(), RankedTensorType::get(startIndicesTensorType.getShape(), startIndexElementType, startIndicesTensorType.getEncoding()), @@ -4316,13 +4323,13 @@ class StableHLOToTTIRDynamicSliceOpConversionPattern {static_cast(sliceSizesInt32.size())}, rewriter.getI32Type()); auto sliceSizesAttr = mlir::DenseElementsAttr::get( sliceSizesTensorType, llvm::ArrayRef(sliceSizesInt32)); - auto sliceSizesConstant = rewriter.create( - srcOp.getLoc(), sliceSizesTensorType, sliceSizesAttr); + auto sliceSizesConstant = mlir::tt::ttir::ConstantOp::create( + rewriter, srcOp.getLoc(), sliceSizesTensorType, sliceSizesAttr); // Create an add op that adds the slice sizes to start indices to get end // indices. - auto endIndicesTensor = rewriter.create( - srcOp.getLoc(), + auto endIndicesTensor = mlir::tt::ttir::AddOp::create( + rewriter, srcOp.getLoc(), RankedTensorType::get(startIndicesTensorType.getShape(), startIndexElementType, startIndicesTensorType.getEncoding()), @@ -4513,14 +4520,14 @@ class CacheFillUpdatePattern sliceOutputShape, updatesType.getElementType(), nullptr); // Create slice op. - auto slicedUpdates = rewriter.create( - scatterOp.getLoc(), slicedUpdatesType, updates, + auto slicedUpdates = ttir::SliceStaticOp::create( + rewriter, scatterOp.getLoc(), slicedUpdatesType, updates, rewriter.getI32ArrayAttr(sliceStarts), rewriter.getI32ArrayAttr(sliceEnds), rewriter.getI32ArrayAttr(sliceSteps)); // create fill cache op for this batch. - cache = rewriter.create( - scatterOp.getLoc(), + cache = mlir::tt::ttir::FillCacheOp::create( + rewriter, scatterOp.getLoc(), scatterOp.getResult(0).getType(), // Result type cache, // Cache tensor slicedUpdates, // Updates tensor @@ -4528,11 +4535,12 @@ class CacheFillUpdatePattern ); } } else { - cache = rewriter.create( - scatterOp.getLoc(), scatterOp.getResult(0).getType(), // Result type - cache, // Cache tensor - updates, // Updates tensor - 0 // Batch offset + cache = mlir::tt::ttir::FillCacheOp::create( + rewriter, scatterOp.getLoc(), + scatterOp.getResult(0).getType(), // Result type + cache, // Cache tensor + updates, // Updates tensor + 0 // Batch offset ); } } else { @@ -4549,12 +4557,12 @@ class CacheFillUpdatePattern "Encoding should not be set when this pass is run"); RankedTensorType permutedUpdatesType = RankedTensorType::get( permutedShape, updatesType.getElementType(), nullptr); - updates = rewriter.create( - scatterOp.getLoc(), permutedUpdatesType, updates, + updates = ttir::PermuteOp::create( + rewriter, scatterOp.getLoc(), permutedUpdatesType, updates, rewriter.getDenseI64ArrayAttr({2, 1, 0, 3})); } - cache = rewriter.create( - scatterOp.getLoc(), + cache = mlir::tt::ttir::UpdateCacheOp::create( + rewriter, scatterOp.getLoc(), scatterOp.getResult(0).getType(), // Result type cache, // Cache tensor updates, // Updates tensor @@ -4828,8 +4836,8 @@ class StableHLOToTTIRScatterOpConversionPattern rewriter, srcOp.getLoc(), updateTensor, "_update_flatten"); // Scatter scalars on flattened tensors. - Value scatterResult = rewriter.create( - srcOp.getLoc(), + Value scatterResult = ttir::ScatterOp::create( + rewriter, srcOp.getLoc(), mlir::cast(flattenedInput.getType()), flattenedInput, flattenedIndices, flattenedUpdate, rewriter.getI32IntegerAttr(SCATTER_DIMENSION), @@ -4966,12 +4974,14 @@ class StableHLOToTTIRScatterOpConversionPattern RankedTensorType strideType = RankedTensorType::get( dimIndicesShape, indexElementType, dimIndicesType.getEncoding()); - Value strideTensor = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, - "_stride_" + std::to_string(d)), - strideType, scalarAttr); + Value strideTensor = + ttir::FullOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( + loc, "_stride_" + std::to_string(d)), + strideType, scalarAttr); - dimIndices = rewriter.create( + dimIndices = ttir::MultiplyOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( loc, "_dim_" + std::to_string(d) + "_stride_mul"), strideType, dimIndices, strideTensor); @@ -4984,10 +4994,11 @@ class StableHLOToTTIRScatterOpConversionPattern RankedTensorType addType = RankedTensorType::get( dimIndicesShape, indexElementType, dimIndicesType.getEncoding()); - flatIndices = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_add_dim_" + - std::to_string(d)), - addType, flatIndices, dimIndices); + flatIndices = + ttir::AddOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( + loc, "_add_dim_" + std::to_string(d)), + addType, flatIndices, dimIndices); } } @@ -5081,8 +5092,8 @@ class StableHLOToTTIRScatterOpConversionPattern numScatterPositions, windowSize, originalIndexSize}; RankedTensorType afterRepeatType = RankedTensorType::get(afterRepeatShape, indicesType.getElementType()); - Value repeatedIndices = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_repeat_indices"), + Value repeatedIndices = ttir::RepeatOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_repeat_indices"), afterRepeatType, reshapedForRepeat, rewriter.getDenseI64ArrayAttr({1, windowSize, 1})); @@ -5143,7 +5154,8 @@ class StableHLOToTTIRScatterOpConversionPattern RankedTensorType::get({expandedNumIndices, 1}, indexElementType); auto finalOffsetAttr = DenseIntElementsAttr::get(finalOffsetType, finalOffsetValues); - Value finalOffset = rewriter.create( + Value finalOffset = ttir::ConstantOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_window_offset_" + std::to_string(dim)), finalOffsetType, finalOffsetAttr); @@ -5170,7 +5182,8 @@ class StableHLOToTTIRScatterOpConversionPattern RankedTensorType sliceType = RankedTensorType::get(sliceShape, indicesType.getElementType()); - Value sliced = rewriter.create( + Value sliced = ttir::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( loc, "_slice_orig_idx_" + std::to_string(operandDim)), sliceType, flatRepeatedIndices, rewriter.getI32ArrayAttr(begins), @@ -5238,8 +5251,8 @@ class StableHLOToTTIRScatterOpConversionPattern indexType.getEncoding()); auto repeatDimsAttr = rewriter.getDenseI64ArrayAttr(repeatDims); - indexTensor = rewriter.create( - op.getLoc(), targetIndexType, indexTensor, repeatDimsAttr); + indexTensor = ttir::RepeatOp::create( + rewriter, op.getLoc(), targetIndexType, indexTensor, repeatDimsAttr); } return indexTensor; @@ -5341,8 +5354,8 @@ class StableHLOToTTIRSortOpConversionPattern // Step 3: Emit SortOp. - auto sortOp = rewriter.create( - loc, outputTypes, adaptor.getInputs().front(), + auto sortOp = ttir::SortOp::create( + rewriter, loc, outputTypes, adaptor.getInputs().front(), rewriter.getSI32IntegerAttr(sortDim), rewriter.getBoolAttr(*isDescending), rewriter.getBoolAttr(isStable)); @@ -5420,20 +5433,20 @@ class StableHLOToTTIRSortOpConversionPattern ttmlir::utils::applyPermutation(origShape, perm); auto permType = RankedTensorType::get(permShape, indexElemType); indices2D = - rewriter.create(loc, permType, indices2D, perm); + ttir::PermuteOp::create(rewriter, loc, permType, indices2D, perm); auto indices2DType = RankedTensorType::get({prePost, dSort}, indexElemType); - indices2D = rewriter.create( - loc, indices2DType, indices2D, + indices2D = ttir::ReshapeOp::create( + rewriter, loc, indices2DType, indices2D, rewriter.getI32ArrayAttr( {static_cast(prePost), static_cast(dSort)})); // Flat indices: row i gets offset i*dSort, so flat[i,j] = i*dSort + // idx[i,j]. - auto rowOffsets = rewriter.create( - loc, indices2DType, /*start=*/0, /*end=*/total, + auto rowOffsets = ttir::ArangeOp::create( + rewriter, loc, indices2DType, /*start=*/0, /*end=*/total, /*step=*/dSort, /*arange_dimension=*/0); - Value flatIndices = - rewriter.create(loc, indices2DType, rowOffsets, indices2D); + Value flatIndices = ttir::AddOp::create(rewriter, loc, indices2DType, + rowOffsets, indices2D); // Per value tensor: permute sortDim to last -> flatten -> EmbeddingOp -> // reshape -> permute back. @@ -5450,21 +5463,21 @@ class StableHLOToTTIRSortOpConversionPattern // Permute value tensor so sortDim is last. auto permValType = RankedTensorType::get(permValShape, valType.getElementType()); - val = rewriter.create(loc, permValType, val, perm); + val = ttir::PermuteOp::create(rewriter, loc, permValType, val, perm); // Flatten to [total, 1] - EmbeddingOp requires 2D weights. auto weightType = RankedTensorType::get({total, 1}, valType.getElementType()); - val = rewriter.create( - loc, weightType, val, + val = ttir::ReshapeOp::create( + rewriter, loc, weightType, val, rewriter.getI32ArrayAttr({static_cast(total), 1})); // EmbeddingOp: indices [prePost, dSort] * weights [total, 1] // -> output [prePost, dSort, 1]. auto embOutType = RankedTensorType::get({prePost, dSort, 1}, valType.getElementType()); - val = - rewriter.create(loc, embOutType, flatIndices, val); + val = ttir::EmbeddingOp::create(rewriter, loc, embOutType, flatIndices, + val); // Reshape [prePost, dSort, 1] -> permuted shape [...non-sort dims..., // dSort]. @@ -5472,12 +5485,11 @@ class StableHLOToTTIRSortOpConversionPattern permValShape.end()); auto permValResultType = RankedTensorType::get(permValShape, valType.getElementType()); - val = rewriter.create( - loc, permValResultType, val, - rewriter.getI32ArrayAttr(permValShapeI32)); + val = ttir::ReshapeOp::create(rewriter, loc, permValResultType, val, + rewriter.getI32ArrayAttr(permValShapeI32)); // Permute back to original dimension order. - val = rewriter.create(loc, valType, val, invPerm); + val = ttir::PermuteOp::create(rewriter, loc, valType, val, invPerm); results.push_back(val); } @@ -5583,8 +5595,8 @@ class StableHLOToTTIROpPadOpConversionPattern sum += eachVal; } if (sum != 0) { - auto fullOp = rewriter.create( - srcOp.getLoc(), outputType, rewriter.getF32FloatAttr(value)); + auto fullOp = ttir::FullOp::create(rewriter, srcOp.getLoc(), outputType, + rewriter.getF32FloatAttr(value)); llvm::SmallVector upperbounds; llvm::copy(outputType.getShape(), std::back_inserter(upperbounds)); int64_t index = 0; @@ -5658,8 +5670,8 @@ class StableHLOToTTIROpPadOpConversionPattern {numIndices}, rewriter.getI64Type(), inputType.getEncoding()); auto flatIndicesAttr = DenseIntElementsAttr::get(flatIndicesType, flatIndices1D); - Value flatIndicesTensor = rewriter.create( - srcOp.getLoc(), flatIndicesType, flatIndicesAttr); + Value flatIndicesTensor = ttir::ConstantOp::create( + rewriter, srcOp.getLoc(), flatIndicesType, flatIndicesAttr); // Flatten input and update tensors to 1D. Value flattenedInput = ttir::utils::flattenTensor( @@ -5674,9 +5686,9 @@ class StableHLOToTTIROpPadOpConversionPattern auto reduceTypeAttr = ttcore::ReduceTypeAttr::get( rewriter.getContext(), ttcore::ReduceType::Invalid); - Value scatterResult = rewriter.create( - srcOp.getLoc(), flattenedInputType, flattenedInput, flatIndicesTensor, - flattenedUpdate, dimAttr, reduceTypeAttr); + Value scatterResult = ttir::ScatterOp::create( + rewriter, srcOp.getLoc(), flattenedInputType, flattenedInput, + flatIndicesTensor, flattenedUpdate, dimAttr, reduceTypeAttr); // Reshape result back to original output shape. rewriter.replaceOpWithNewOp( @@ -5886,15 +5898,16 @@ class StableHLOToTTIRRngBitGeneratorOpConversionPattern // Using any other value would make every flatbuffer run deterministic. auto seed = rewriter.getUI32IntegerAttr(0); - auto randOp = rewriter.create( - srcOp.getLoc(), floatOutputType, rewriter.getI32ArrayAttr(size), - mlir::TypeAttr::get(floatElementType), fromFloat, toFloat, seed); + auto randOp = mlir::tt::ttir::RandOp::create( + rewriter, srcOp.getLoc(), floatOutputType, + rewriter.getI32ArrayAttr(size), mlir::TypeAttr::get(floatElementType), + fromFloat, toFloat, seed); // TODO (pglusac): Change to bit cast once we support it or remove if // rand starts supporting uint32. // See https://github.com/tenstorrent/tt-mlir/issues/5078 - auto typecastOp = rewriter.create( - srcOp.getLoc(), outputType, randOp.getResult()); + auto typecastOp = mlir::tt::ttir::TypecastOp::create( + rewriter, srcOp.getLoc(), outputType, randOp.getResult()); // HACK (pglusac): Output state is discarded, initial state is returned as // a result. https://github.com/tenstorrent/tt-mlir/issues/5101 @@ -6415,8 +6428,9 @@ class StableHLOToTTIRPagedScaledDotProductAttentionDecodeOpConversionPattern RankedTensorType outputType = cast( getTypeConverter()->convertType(srcOp.getResult(0).getType())); - ttir::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + ttir::EmptyOp outputTensor = + ttir::EmptyOp::create(rewriter, srcOp.getLoc(), outputType.getShape(), + outputType.getElementType()); rewriter.replaceOpWithNewOp< mlir::tt::ttir::PagedScaledDotProductAttentionDecodeOp>( @@ -7104,8 +7118,8 @@ class StableHLOToTTIRAllToAllDispatchOpConversionPattern RankedTensorType metadataType = cast( getTypeConverter()->convertType(tupleType.getType(1))); - auto newOp = rewriter.create( - srcOp.getLoc(), dispatchedType, metadataType, inputTensor, + auto newOp = ttir::AllToAllDispatchOp::create( + rewriter, srcOp.getLoc(), dispatchedType, metadataType, inputTensor, expertIndices, expertMapping, numDevicesAttr, clusterAxisAttr); // Replace get_tuple_element users with the new op's results @@ -7135,8 +7149,8 @@ class StableHLOToTTIRAllToAllDispatchOpConversionPattern RankedTensorType metadataType = cast( getTypeConverter()->convertType(srcOp.getResult(1).getType())); - auto newOp = rewriter.create( - srcOp.getLoc(), dispatchedType, metadataType, inputTensor, + auto newOp = ttir::AllToAllDispatchOp::create( + rewriter, srcOp.getLoc(), dispatchedType, metadataType, inputTensor, expertIndices, expertMapping, numDevicesAttr, clusterAxisAttr); rewriter.replaceOp(srcOp, {newOp.getDispatched(), newOp.getMetadata()}); @@ -7219,9 +7233,9 @@ class StableHLOToTTIRAllToAllCombineOpConversionPattern RankedTensorType outputType = cast( getTypeConverter()->convertType(srcOp.getResult(0).getType())); - auto combineOp = rewriter.create( - srcOp.getLoc(), outputType, inputTensor, expertMetadata, expertMapping, - rewriter.getI64IntegerAttr(numDevices), + auto combineOp = ttir::AllToAllCombineOp::create( + rewriter, srcOp.getLoc(), outputType, inputTensor, expertMetadata, + expertMapping, rewriter.getI64IntegerAttr(numDevices), rewriter.getI64IntegerAttr(clusterAxis), rewriter.getI64IntegerAttr(numExpertsPerTok)); @@ -7240,8 +7254,8 @@ class StableHLOToTTIRAllToAllCombineOpConversionPattern totalDevices / std::max(numDevices, static_cast(1)); if (nonClusterSize > 1 && numDevices > 1) { uint32_t reduceAxis = (clusterAxis == 0) ? 1 : 0; - auto allReduceOp = rewriter.create( - srcOp.getLoc(), outputType, result, ttcore::ReduceType::Sum, + auto allReduceOp = ttir::AllReduceOp::create( + rewriter, srcOp.getLoc(), outputType, result, ttcore::ReduceType::Sum, reduceAxis); result = allReduceOp.getResult(); } @@ -7313,9 +7327,9 @@ class StableHLOToTTIRMoeExpertTokenRemapOpConversionPattern RankedTensorType reducedType = cast( getTypeConverter()->convertType(tupleType.getType(1))); - auto newOp = rewriter.create( - srcOp.getLoc(), mappingType, reducedType, topkTensor, expertMapping, - expertMetadata, reductionSizeAttr); + auto newOp = ttir::MoeExpertTokenRemapOp::create( + rewriter, srcOp.getLoc(), mappingType, reducedType, topkTensor, + expertMapping, expertMetadata, reductionSizeAttr); for (auto &use : llvm::make_early_inc_range(srcOp.getResult(0).getUses())) { @@ -7343,9 +7357,9 @@ class StableHLOToTTIRMoeExpertTokenRemapOpConversionPattern RankedTensorType reducedType = cast( getTypeConverter()->convertType(srcOp.getResult(1).getType())); - auto newOp = rewriter.create( - srcOp.getLoc(), mappingType, reducedType, topkTensor, expertMapping, - expertMetadata, reductionSizeAttr); + auto newOp = ttir::MoeExpertTokenRemapOp::create( + rewriter, srcOp.getLoc(), mappingType, reducedType, topkTensor, + expertMapping, expertMetadata, reductionSizeAttr); rewriter.replaceOp(srcOp, {newOp.getMapping(), newOp.getReduced()}); return success(); diff --git a/lib/Conversion/TTIRToD2M/TTIRToD2M.cpp b/lib/Conversion/TTIRToD2M/TTIRToD2M.cpp index 54a4196e979..fea00130e68 100644 --- a/lib/Conversion/TTIRToD2M/TTIRToD2M.cpp +++ b/lib/Conversion/TTIRToD2M/TTIRToD2M.cpp @@ -210,8 +210,8 @@ class D2MNamedRewriterCommon { if (isTTNN) { assert(ttnnMode && "Unexpected TTNN tensor as op operand"); auto metalTensorType = getMetalTensorFromTTNNTensor(rewriter, value); - auto metalCastOp = rewriter.create( - value.getLoc(), metalTensorType, value); + auto metalCastOp = ttir::TTNNMetalLayoutCastOp::create( + rewriter, value.getLoc(), metalTensorType, value); // Propagate both VGM maps for height/width sharded TTNN layouts // so that downstream passes (GenericOp::build, getMemoryMap, etc.) @@ -257,9 +257,10 @@ class D2MNamedRewriterCommon { auto unitGridType = RankedTensorType::get( newTensorShape, metalTensorType.getElementType(), metalTensorType.getEncoding()); - auto unitReblockingView = rewriter.create( - value.getLoc(), unitGridType, metalCastOp->getResult(0), reblockMap, - /*reinterpretLayout=*/false); + auto unitReblockingView = + d2m::ViewLayoutOp::create(rewriter, value.getLoc(), unitGridType, + metalCastOp->getResult(0), reblockMap, + /*reinterpretLayout=*/false); return unitReblockingView.getResult(); } // For DRAM operands, we can return the metal cast result directly. @@ -306,8 +307,8 @@ class D2MNamedRewriterCommon { llvm::SmallVector shardedShape = layout.getDeviceShape(simpleGrid, tileShape); - auto emptyOp = rewriter.create(value.getLoc(), shardedShape, - elementType, layout); + auto emptyOp = d2m::EmptyOp::create(rewriter, value.getLoc(), shardedShape, + elementType, layout); // For ND tensors (logicalShape.size() > 2), set placeholder virtual grid // mappings on the EmptyOp. These will be replaced when GridSelection @@ -320,7 +321,7 @@ class D2MNamedRewriterCommon { emptyOp.setVirtualGridForwardMappingAttr(AffineMapAttr::get(forwardMap)); } - return rewriter.create(value.getLoc(), value, emptyOp) + return d2m::ToLayoutOp::create(rewriter, value.getLoc(), value, emptyOp) ->getResult(0); } @@ -367,15 +368,15 @@ class D2MNamedRewriterCommon { Value fromValue, Type toResultType) const { if (isTTNNTensor(toResultType)) { assert(ttnnMode && "Unexpected TTNN tensor as op result"); - return rewriter.create( - fromValue.getLoc(), toResultType, fromValue); + return ttir::TTNNMetalLayoutCastOp::create(rewriter, fromValue.getLoc(), + toResultType, fromValue); } auto output = - rewriter.create(fromValue.getLoc(), toResultType, - /*virtualGridInverseMapping=*/nullptr, - /*virtualGridForwardMapping=*/nullptr); - return rewriter.create(fromValue.getLoc(), fromValue, - output); + d2m::EmptyOp::create(rewriter, fromValue.getLoc(), toResultType, + /*virtualGridInverseMapping=*/nullptr, + /*virtualGridForwardMapping=*/nullptr); + return d2m::ToLayoutOp::create(rewriter, fromValue.getLoc(), fromValue, + output); } static llvm::SmallVector @@ -384,8 +385,9 @@ class D2MNamedRewriterCommon { llvm::SmallVector dpsOutputs; dpsOutputs.reserve(types.size()); for (auto type : types) { - ttir::EmptyOp empty = builder.create( - loc, type.getShape(), type.getElementType(), type.getEncoding()); + ttir::EmptyOp empty = + ttir::EmptyOp::create(builder, loc, type.getShape(), + type.getElementType(), type.getEncoding()); dpsOutputs.push_back(empty); } return dpsOutputs; @@ -492,8 +494,8 @@ class D2MNamedRewriterCommon { } // Create a buffer for the load result - auto bufferOp = builder.create( - loc, shardType.getShape(), shardType.getElementType()); + auto bufferOp = tensor::EmptyOp::create( + builder, loc, shardType.getShape(), shardType.getElementType()); Value buffer = bufferOp.getResult(); Value loadResult; @@ -503,20 +505,18 @@ class D2MNamedRewriterCommon { SmallVector mcastDims; for (int64_t gridDim : mcastGridDims) { mcastDims.push_back( - builder.create(loc, gridDim)); + arith::ConstantIndexOp::create(builder, loc, gridDim)); } // Create remote_load with high-level multicast form loadResult = - builder - .create(loc, shardType, buffer, - genericOperand, indices, mcastDims) + d2m::RemoteLoadOp::create(builder, loc, shardType, buffer, + genericOperand, indices, mcastDims) .getResult(); } else { // Create remote_load without multicast (original behavior) - loadResult = builder - .create(loc, shardType, buffer, - genericOperand, indices) + loadResult = d2m::RemoteLoadOp::create(builder, loc, shardType, buffer, + genericOperand, indices) .getResult(); } @@ -527,8 +527,8 @@ class D2MNamedRewriterCommon { for (size_t i = 0; i < outputs.size(); ++i) { RankedTensorType shardType = getShardType(outputs[i]); - auto emptyOp = builder.create( - loc, shardType.getShape(), shardType.getElementType()); + auto emptyOp = tensor::EmptyOp::create(builder, loc, shardType.getShape(), + shardType.getElementType()); operands.push_back(emptyOp.getResult()); } @@ -724,56 +724,56 @@ class D2MNamedElementwiseRewriter final // Apply broadcast to all operands that need it. for (size_t i = 0; i < numInputs && i < tileBcastTypes.size(); ++i) { if (tileBcastTypes[i] != d2m::TileBcastType::None) { - operands[i] = bbBuilder.create( - loc, resultTypes, operands[i], tileBcastTypes[i]); + operands[i] = d2m::TileBcastOp::create(bbBuilder, loc, resultTypes, + operands[i], tileBcastTypes[i]); } } mlir::Value yield; if constexpr (isComparisonOp) { // For comparison ops, first subtract then compare with zero. - yield = bbBuilder.create(loc, resultTypes, operands); - yield = bbBuilder.create(loc, resultTypes, yield); + yield = d2m::TileSubOp::create(bbBuilder, loc, resultTypes, operands); + yield = TileOp::create(bbBuilder, loc, resultTypes, yield); } else if constexpr (std::is_same_v) { // Decompose into maximum(input, min) then minimum(result, max). - yield = bbBuilder.create( - loc, resultTypes, ValueRange{operands[0], operands[1]}); - yield = bbBuilder.create( - loc, resultTypes, ValueRange{yield, operands[2]}); + yield = d2m::TileMaximumOp::create(bbBuilder, loc, resultTypes, + ValueRange{operands[0], operands[1]}); + yield = d2m::TileMinimumOp::create(bbBuilder, loc, resultTypes, + ValueRange{yield, operands[2]}); } else if constexpr (std::is_same_v) { yield = - bbBuilder.create(loc, resultTypes[0], operands[0], opAttrs); + TileOp::create(bbBuilder, loc, resultTypes[0], operands[0], opAttrs); } else if constexpr (std::is_same_v) { // LogicalAnd: NEZ(a) * NEZ(b) - both must be non-zero. auto nezA = - bbBuilder.create(loc, resultTypes, operands[0]); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[0]); auto nezB = - bbBuilder.create(loc, resultTypes, operands[1]); - yield = bbBuilder.create(loc, resultTypes, - ValueRange{nezA, nezB}); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[1]); + yield = d2m::TileMulOp::create(bbBuilder, loc, resultTypes, + ValueRange{nezA, nezB}); } else if constexpr (std::is_same_v) { // LogicalOr: NEZ(NEZ(a) + NEZ(b)) - at least one must be non-zero. auto nezA = - bbBuilder.create(loc, resultTypes, operands[0]); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[0]); auto nezB = - bbBuilder.create(loc, resultTypes, operands[1]); - auto sum = bbBuilder.create(loc, resultTypes, - ValueRange{nezA, nezB}); - yield = bbBuilder.create(loc, resultTypes, sum); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[1]); + auto sum = d2m::TileAddOp::create(bbBuilder, loc, resultTypes, + ValueRange{nezA, nezB}); + yield = d2m::TileNezOp::create(bbBuilder, loc, resultTypes, sum); } else if constexpr (std::is_same_v) { // LogicalXor: NEZ(NEZ(a) - NEZ(b)) - exactly one must be non-zero. auto nezA = - bbBuilder.create(loc, resultTypes, operands[0]); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[0]); auto nezB = - bbBuilder.create(loc, resultTypes, operands[1]); - auto diff = bbBuilder.create(loc, resultTypes, - ValueRange{nezA, nezB}); - yield = bbBuilder.create(loc, resultTypes, diff); + d2m::TileNezOp::create(bbBuilder, loc, resultTypes, operands[1]); + auto diff = d2m::TileSubOp::create(bbBuilder, loc, resultTypes, + ValueRange{nezA, nezB}); + yield = d2m::TileNezOp::create(bbBuilder, loc, resultTypes, diff); } else { - yield = bbBuilder.create(loc, resultTypes, operands); + yield = TileOp::create(bbBuilder, loc, resultTypes, operands); } - bbBuilder.create(bbLoc, yield); + mlir::linalg::YieldOp::create(bbBuilder, bbLoc, yield); } LogicalResult @@ -816,8 +816,8 @@ class D2MNamedElementwiseRewriter final getIteratorTypesArray(rewriter, physicalRank); // Create 'd2m.generic' accepting 'op's operands. - auto generic = rewriter.create( - loc, inputs, outputs, /*additionalArgs=*/ValueRange(), + auto generic = d2m::GenericOp::create( + rewriter, loc, inputs, outputs, /*additionalArgs=*/ValueRange(), rewriter.getAffineMapArrayAttr(indexingMaps), rewriter.getArrayAttr(iteratorTypes)); @@ -851,8 +851,8 @@ class D2MNamedElementwiseRewriter final opAttrs.push_back(rewriter.getNamedAttr("max", op.getMaxAttr())); } - auto linalgGeneric = rewriter.create( - loc, + auto linalgGeneric = mlir::linalg::GenericOp::create( + rewriter, loc, /* result tensor types */ llvm::to_vector( mlir::ValueRange(blockArgs.take_back(numOutputs)).getTypes()), @@ -875,15 +875,14 @@ class D2MNamedElementwiseRewriter final d2m::utils::buildGridIndices(rewriter, loc, indexingMap); Value genericOperand = generic->getOperand(operandIdx); Value result = linalgGeneric->getResult(outputIdx); - Value storeResult = - rewriter - .create(loc, genericOperand.getType(), - genericOperand, indices, result) - .getResult(); + Value storeResult = d2m::RemoteStoreOp::create( + rewriter, loc, genericOperand.getType(), + genericOperand, indices, result) + .getResult(); storeResults.push_back(storeResult); } - rewriter.create(loc, storeResults); + d2m::YieldOp::create(rewriter, loc, storeResults); } } rewriter.finalizeOpModification(generic); @@ -982,8 +981,8 @@ class D2MNamedReductionRewriter final getIteratorTypesArray(rewriter, op, physicalRank); // Create 'd2m.generic' accepting extended operands. - auto generic = rewriter.create( - loc, inputs, outputs, /*additionalArgs=*/ValueRange(), + auto generic = d2m::GenericOp::create( + rewriter, loc, inputs, outputs, /*additionalArgs=*/ValueRange(), rewriter.getAffineMapArrayAttr(indexingMaps), rewriter.getArrayAttr(iteratorTypes)); @@ -1020,8 +1019,8 @@ class D2MNamedReductionRewriter final dimArgAsReduceDim(op, physicalRank))); } - auto linalgGeneric = rewriter.create( - loc, + auto linalgGeneric = mlir::linalg::GenericOp::create( + rewriter, loc, /* result tensor types */ llvm::to_vector( static_cast(blockArgs.take_back(numOutputs)) @@ -1031,11 +1030,11 @@ class D2MNamedReductionRewriter final linalgIteratorTypes, [&](mlir::OpBuilder &bbBuilder, mlir::Location bbLoc, mlir::ValueRange bbArgs) { - mlir::Value yield = bbBuilder.create( - loc, + mlir::Value yield = TileOp::create( + bbBuilder, loc, /* resultTypes */ bbArgs.take_back(numOutputs).getTypes(), /* operands */ bbArgs, attributes); - bbBuilder.create(bbLoc, yield); + mlir::linalg::YieldOp::create(bbBuilder, bbLoc, yield); }); // Insert remote_store operations for each output before yield @@ -1047,15 +1046,14 @@ class D2MNamedReductionRewriter final d2m::utils::buildGridIndices(rewriter, loc, indexingMap); Value genericOperand = generic->getOperand(operandIdx); Value result = linalgGeneric->getResult(outputIdx); - Value storeResult = - rewriter - .create(loc, genericOperand.getType(), - genericOperand, indices, result) - .getResult(); + Value storeResult = d2m::RemoteStoreOp::create( + rewriter, loc, genericOperand.getType(), + genericOperand, indices, result) + .getResult(); storeResults.push_back(storeResult); } - rewriter.create(loc, storeResults); + d2m::YieldOp::create(rewriter, loc, storeResults); } } rewriter.finalizeOpModification(generic); @@ -1156,9 +1154,9 @@ class D2MNamedReductionRewriter final llvm_unreachable("unexpected input element type"); } - return builder.create( - loc, scalerType, llvm::to_vector_of(scalerType.getShape()), - one); + return d2m::FullOp::create( + builder, loc, scalerType, + llvm::to_vector_of(scalerType.getShape()), one); } static d2m::ReduceDim dimArgAsReduceDim(ConcreteOp op, std::size_t rank) { @@ -1269,8 +1267,8 @@ class D2MMatmulRewriter final getIteratorTypesArray(rewriter, physicalRank); // Create 'd2m.generic' accepting 'op's operands. - auto generic = rewriter.create( - loc, inputs, outputs, /*additionalArgs=*/ValueRange(), + auto generic = d2m::GenericOp::create( + rewriter, loc, inputs, outputs, /*additionalArgs=*/ValueRange(), rewriter.getAffineMapArrayAttr(indexingMaps), rewriter.getArrayAttr(iteratorTypes)); @@ -1291,9 +1289,9 @@ class D2MMatmulRewriter final // Delegate next level of nesting to a "block" op. if constexpr (std::is_same_v) { - rewriter.create(loc, - /* resultTypes */ mlir::TypeRange(), - /* operands */ blockArgs); + TileOp::create(rewriter, loc, + /* resultTypes */ mlir::TypeRange(), + /* operands */ blockArgs); // Insert remote_store operations for each output before yield SmallVector storeResults; @@ -1304,16 +1302,15 @@ class D2MMatmulRewriter final d2m::utils::buildGridIndices(rewriter, loc, indexingMap); Value genericOperand = generic->getOperand(operandIdx); Value result = blockArgs[numInputs + outputIdx]; - Value storeResult = - rewriter - .create(loc, genericOperand.getType(), - genericOperand, indices, result) - .getResult(); + Value storeResult = d2m::RemoteStoreOp::create( + rewriter, loc, genericOperand.getType(), + genericOperand, indices, result) + .getResult(); storeResults.push_back(storeResult); } // In pure tensor semantics, explicitly yield the output shard. - rewriter.create(loc, storeResults); + d2m::YieldOp::create(rewriter, loc, storeResults); } else if constexpr (std::is_same_v) { @@ -1325,8 +1322,8 @@ class D2MMatmulRewriter final SmallVector linalgIteratorTypes = iteratorTypeTTIRToLinalg(rewriter, iteratorTypes); - auto linalgGeneric = rewriter.create( - loc, + auto linalgGeneric = mlir::linalg::GenericOp::create( + rewriter, loc, /* result tensor types */ llvm::to_vector( mlir::ValueRange(blockArgs.take_back(numOutputs)).getTypes()), @@ -1335,12 +1332,12 @@ class D2MMatmulRewriter final linalgIteratorTypes, [&](mlir::OpBuilder &bbBuilder, mlir::Location bbLoc, mlir::ValueRange bbArgs) { - mlir::Value yield = bbBuilder.create( - loc, /* resultTypes */ + mlir::Value yield = TileOp::create( + bbBuilder, loc, /* resultTypes */ bbArgs.take_back(tileOpNumOutputs).getTypes(), /* operands */ bbArgs.take_front(tileOpNumInputs)); - bbBuilder.create(bbLoc, yield); + mlir::linalg::YieldOp::create(bbBuilder, bbLoc, yield); }); // Insert remote_store operations for each output before yield @@ -1352,15 +1349,14 @@ class D2MMatmulRewriter final d2m::utils::buildGridIndices(rewriter, loc, indexingMap); Value genericOperand = generic->getOperand(operandIdx); Value result = linalgGeneric->getResult(outputIdx); - Value storeResult = - rewriter - .create(loc, genericOperand.getType(), - genericOperand, indices, result) - .getResult(); + Value storeResult = d2m::RemoteStoreOp::create( + rewriter, loc, genericOperand.getType(), + genericOperand, indices, result) + .getResult(); storeResults.push_back(storeResult); } - rewriter.create(loc, storeResults); + d2m::YieldOp::create(rewriter, loc, storeResults); } } } @@ -1572,11 +1568,11 @@ class D2MPermuteRewriter final permuted.physicalShape, inputTensorType.getElementType(), resultLayout); // For inner permute, we need a streamLayout to do reblocking. - auto storage = rewriter.create( - loc, permuted.physicalShape, inputTensorType.getElementType(), - resultLayout); - auto stream = rewriter.create( - loc, viewType, inputs[0], permuted.transposeMap, storage); + auto storage = + d2m::EmptyOp::create(rewriter, loc, permuted.physicalShape, + inputTensorType.getElementType(), resultLayout); + auto stream = d2m::StreamLayoutOp::create( + rewriter, loc, viewType, inputs[0], permuted.transposeMap, storage); inputs[0] = stream.getResult(); unsigned logicalRank = deviceRank / 2; // For inner permute, we alse need a GenericOp to transpose each individual @@ -1586,8 +1582,8 @@ class D2MPermuteRewriter final Value inputOperand = inputs[0]; Value outputOperand = outputs[0]; - auto generic = rewriter.create( - loc, inputs, outputs, /*additionalArgs=*/ValueRange(), + auto generic = d2m::GenericOp::create( + rewriter, loc, inputs, outputs, /*additionalArgs=*/ValueRange(), [&, inputOperand, outputOperand](OpBuilder &builder, Location bodyLoc, ValueRange blockArgs) { assert(blockArgs.size() == 2); @@ -1603,25 +1599,24 @@ class D2MPermuteRewriter final SmallVector inputIndices = d2m::utils::buildGridIndices(builder, bodyLoc, inputIndexingMap); Value inputBuffer = blockArgs[0]; - Value input = builder - .create( - bodyLoc, inputShardType, inputBuffer, - inputOperand, inputIndices) - .getResult(); + Value input = + d2m::RemoteLoadOp::create(builder, bodyLoc, inputShardType, + inputBuffer, inputOperand, inputIndices) + .getResult(); // Use the output tensor.empty directly. Value output = blockArgs[1]; - auto linalgGeneric = builder.create( - bodyLoc, output.getType(), input, output, + auto linalgGeneric = mlir::linalg::GenericOp::create( + builder, bodyLoc, output.getType(), input, output, SmallVector{identityMap, identityMap}, linalgIteratorTypes, [&](mlir::OpBuilder &bbBuilder, mlir::Location bbLoc, mlir::ValueRange bbArgs) { - mlir::Value yield = bbBuilder.create( - bbLoc, bbArgs.take_back(1).getTypes(), + mlir::Value yield = d2m::TileTransposeOp::create( + bbBuilder, bbLoc, bbArgs.take_back(1).getTypes(), bbArgs.take_front(1)); - bbBuilder.create(bbLoc, yield); + mlir::linalg::YieldOp::create(bbBuilder, bbLoc, yield); }); // Insert remote_store for output before yield @@ -1629,13 +1624,14 @@ class D2MPermuteRewriter final SmallVector outputIndices = d2m::utils::buildGridIndices(builder, bodyLoc, outputIndexingMap); Value result = linalgGeneric->getResult(0); - Value storeResult = builder - .create( - bodyLoc, outputOperand.getType(), - outputOperand, outputIndices, result) - .getResult(); + Value storeResult = + d2m::RemoteStoreOp::create(builder, + + bodyLoc, outputOperand.getType(), + outputOperand, outputIndices, result) + .getResult(); - builder.create(bodyLoc, storeResult); + d2m::YieldOp::create(builder, bodyLoc, storeResult); }); rewriter.replaceOp(op, unLayoutResult(rewriter, generic->getResult(0), @@ -1718,11 +1714,11 @@ class D2MToLayoutOpRewriter : public D2MNamedRewriterCommon, if (!ttnnMode) { // When ttnnMode is disabled, we can simply convert ttir.to_layout // directly to d2m.to_layout. - Value empty = rewriter.create( - op.getLoc(), outType.getShape(), outType.getElementType(), - outType.getEncoding()); - auto newOp = rewriter.create(op.getLoc(), - adaptor.getInput(), empty); + Value empty = + d2m::EmptyOp::create(rewriter, op.getLoc(), outType.getShape(), + outType.getElementType(), outType.getEncoding()); + auto newOp = d2m::ToLayoutOp::create(rewriter, op.getLoc(), + adaptor.getInput(), empty); rewriter.replaceOp(op, newOp.getResult(0)); return success(); } @@ -1802,8 +1798,8 @@ class D2MToLayoutOpRewriter : public D2MNamedRewriterCommon, if (mlir::isa_and_nonnull(inputType.getEncoding())) { auto inputMetalType = getMetalTensorFromTTNNTensor(rewriter, adaptor.getInput()); - auto inputCast = rewriter.create( - op.getLoc(), inputMetalType, adaptor.getInput()); + auto inputCast = ttir::TTNNMetalLayoutCastOp::create( + rewriter, op.getLoc(), inputMetalType, adaptor.getInput()); propagateVGMToCastOp(rewriter.getContext(), inputCast, inputType.getEncoding()); metalInput = inputCast.getResult(); @@ -1811,20 +1807,20 @@ class D2MToLayoutOpRewriter : public D2MNamedRewriterCommon, auto outputMetalType = getMetalTensorFromTTNNTensor(rewriter, op.getOutput()); // Create d2m.empty for TTNN layout. - Value metalEmpty = rewriter.create( - op.getLoc(), outType.getShape(), outType.getElementType(), - outType.getEncoding()); + Value metalEmpty = + d2m::EmptyOp::create(rewriter, op.getLoc(), outType.getShape(), + outType.getElementType(), outType.getEncoding()); // Cast TTNN empty to Metal layout. - auto metalCast = rewriter.create( - op.getLoc(), outputMetalType, metalEmpty); + auto metalCast = ttir::TTNNMetalLayoutCastOp::create( + rewriter, op.getLoc(), outputMetalType, metalEmpty); propagateVGMToCastOp(rewriter.getContext(), metalCast, outType.getEncoding()); // Create d2m.to_layout with Metal types. auto metalToLayout = - rewriter.create(op.getLoc(), metalInput, metalCast); + d2m::ToLayoutOp::create(rewriter, op.getLoc(), metalInput, metalCast); // Cast back to TTNN. - auto ttnnResult = rewriter.create( - op.getLoc(), outType, metalToLayout.getResult(0)); + auto ttnnResult = ttir::TTNNMetalLayoutCastOp::create( + rewriter, op.getLoc(), outType, metalToLayout.getResult(0)); rewriter.replaceOp(op, ttnnResult.getResult()); return success(); } @@ -1858,9 +1854,9 @@ class D2MEmptyOpRewriter : public OpConversionPattern { } // Create d2m.empty with same shape and element type. - auto d2mEmpty = rewriter.create( - op.getLoc(), tensorType.getShape(), tensorType.getElementType(), - tensorType.getEncoding()); + auto d2mEmpty = d2m::EmptyOp::create( + rewriter, op.getLoc(), tensorType.getShape(), + tensorType.getElementType(), tensorType.getEncoding()); rewriter.replaceOp(op, d2mEmpty.getResult()); return success(); @@ -1934,10 +1930,9 @@ class D2MArangeOpRewriter : public OpConversionPattern, rewriter.getContext(), scratchLogicalShape, ttcore::OOBVal::Undef, ttcore::MemorySpace::DeviceL1, ttcore::TensorMemoryLayout::Sharded); - Value indexTileTensor = - rewriter - .create(loc, scratchShape, tileType, scratchLayout) - .getResult(); + Value indexTileTensor = d2m::EmptyOp::create(rewriter, loc, scratchShape, + tileType, scratchLayout) + .getResult(); AffineMap identityMap = rewriter.getMultiDimIdentityMap(physicalRank); SmallVector zeroExprs(physicalRank, @@ -1951,8 +1946,8 @@ class D2MArangeOpRewriter : public OpConversionPattern, SmallVector iteratorTypes(physicalRank, parallel); SmallVector genericInputs = {indexTileTensor}; - auto generic = rewriter.create( - loc, genericInputs, outputs, /*additionalArgs=*/ValueRange(), + auto generic = d2m::GenericOp::create( + rewriter, loc, genericInputs, outputs, /*additionalArgs=*/ValueRange(), rewriter.getAffineMapArrayAttr(indexingMaps), rewriter.getArrayAttr(iteratorTypes)); @@ -1974,21 +1969,19 @@ class D2MArangeOpRewriter : public OpConversionPattern, // ArangeBlock operation will be decomposed in a later pass. Value arangeResult = - rewriter - .create(loc, indexTileTensor, outputTensor, - numElements, start, step) + d2m::ArangeBlockOp::create(rewriter, loc, indexTileTensor, + outputTensor, numElements, start, step) .getResult(); AffineMap outputIndexingMap = generic.getIndexingMap(1); SmallVector indices = d2m::utils::buildGridIndices(rewriter, loc, outputIndexingMap); Value storeResult = - rewriter - .create(loc, output.getType(), output, - indices, arangeResult) + d2m::RemoteStoreOp::create(rewriter, loc, output.getType(), output, + indices, arangeResult) .getResult(); - rewriter.create(loc, storeResult); + d2m::YieldOp::create(rewriter, loc, storeResult); } rewriter.finalizeOpModification(generic); rewriter.restoreInsertionPoint(insertPoint); @@ -2061,11 +2054,12 @@ class D2MTensorManipulationOpRewriter outTy.getElementType(), newLayout); auto storage = - rewriter.create(op.getLoc(), outputs[0].getType(), - /*virtualGridInverseMapping=*/nullptr, - /*virtualGridForwardMapping=*/nullptr); - auto view = rewriter.create( - op.getLoc(), newOutTy, inputs[0], deviceMap, storage.getResult()); + d2m::EmptyOp::create(rewriter, op.getLoc(), outputs[0].getType(), + /*virtualGridInverseMapping=*/nullptr, + /*virtualGridForwardMapping=*/nullptr); + auto view = + d2m::StreamLayoutOp::create(rewriter, op.getLoc(), newOutTy, inputs[0], + deviceMap, storage.getResult()); rewriter.replaceOp(op, unLayoutResult(rewriter, view->getResult(0), op->getResult(0).getType())); diff --git a/lib/Conversion/TTIRToLinalg/EltwiseBinary.cpp b/lib/Conversion/TTIRToLinalg/EltwiseBinary.cpp index 44d861ff549..f1c4b0ad60b 100644 --- a/lib/Conversion/TTIRToLinalg/EltwiseBinary.cpp +++ b/lib/Conversion/TTIRToLinalg/EltwiseBinary.cpp @@ -65,7 +65,7 @@ class ElementwiseBinaryOpToTosaPattern : public OpConversionPattern { rewriter); auto result = - rewriter.create(loc, resultType, ValueRange{lhs, rhs}); + TosaOpTy::create(rewriter, loc, resultType, ValueRange{lhs, rhs}); rewriter.replaceOp(op, result); return success(); @@ -98,9 +98,9 @@ class DirectComparisonOpToTosaPattern : public OpConversionPattern { auto boolType = RankedTensorType::get(resultType.getShape(), rewriter.getIntegerType(1)); - auto boolResult = rewriter.create(loc, boolType, lhs, rhs); + auto boolResult = TosaOpTy::create(rewriter, loc, boolType, lhs, rhs); - auto result = rewriter.create(loc, resultType, boolResult); + auto result = tosa::CastOp::create(rewriter, loc, resultType, boolResult); rewriter.replaceOp(op, result); return success(); @@ -133,9 +133,9 @@ class SwappedComparisonOpToTosaPattern : public OpConversionPattern { // Swapped operands: rhs, lhs. auto boolType = RankedTensorType::get(resultType.getShape(), rewriter.getIntegerType(1)); - auto boolResult = rewriter.create(loc, boolType, rhs, lhs); + auto boolResult = TosaOpTy::create(rewriter, loc, boolType, rhs, lhs); - auto result = rewriter.create(loc, resultType, boolResult); + auto result = tosa::CastOp::create(rewriter, loc, resultType, boolResult); rewriter.replaceOp(op, result); return success(); @@ -167,12 +167,12 @@ class NegatedComparisonOpToTosaPattern : public OpConversionPattern { auto boolType = RankedTensorType::get(resultType.getShape(), rewriter.getIntegerType(1)); - auto boolResult = rewriter.create(loc, boolType, lhs, rhs); + auto boolResult = TosaOpTy::create(rewriter, loc, boolType, lhs, rhs); auto notResult = - rewriter.create(loc, boolType, boolResult); + tosa::LogicalNotOp::create(rewriter, loc, boolType, boolResult); - auto result = rewriter.create(loc, resultType, notResult); + auto result = tosa::CastOp::create(rewriter, loc, resultType, notResult); rewriter.replaceOp(op, result); return success(); @@ -217,10 +217,11 @@ class LogicalBinaryOpToTosaPattern : public OpConversionPattern { // Apply the logical operation to the boolean tensors. auto logicalResult = - rewriter.create(loc, boolType, boolLhs, boolRhs); + TosaOpTy::create(rewriter, loc, boolType, boolLhs, boolRhs); // Convert boolean result back to original type using cast. - auto result = rewriter.create(loc, resultType, logicalResult); + auto result = + tosa::CastOp::create(rewriter, loc, resultType, logicalResult); rewriter.replaceOp(op, result); return success(); @@ -258,8 +259,8 @@ class ElementwiseBinaryOpToNamedLinalgPattern Value rhs = broadcastToShape(adaptor.getRhs(), resultType.getShape(), loc, rewriter); - auto output = rewriter.create(loc, resultType.getShape(), - resultType.getElementType()); + auto output = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()); rewriter.replaceOpWithNewOp( op, resultType, ValueRange{lhs, rhs}, output.getResult()); return success(); @@ -303,15 +304,16 @@ class ElementwiseBinaryOpToLinalgGenericPatternBase SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - auto emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), resultType.getElementType()); - auto genericOp = rewriter.create( - loc, resultType, ValueRange{lhs, rhs}, ValueRange{emptyTensor}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{lhs, rhs}, + ValueRange{emptyTensor}, SmallVector{indexingMap, indexingMap, indexingMap}, iteratorTypes, [&](OpBuilder &b, Location nestedLoc, ValueRange args) { Value result = buildBody(b, nestedLoc, args, resultType); - b.create(nestedLoc, result); + linalg::YieldOp::create(b, nestedLoc, result); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -336,7 +338,7 @@ class ElementwiseBinaryOpToMathPattern protected: Value buildBody(OpBuilder &b, Location loc, ValueRange args, RankedTensorType /*resultType*/) const override { - return b.create(loc, args[0], args[1]); + return MathOpTy::create(b, loc, args[0], args[1]); } }; } // namespace @@ -360,24 +362,24 @@ class RemainderOpToLinalgGenericPattern if (isa(resultType.getElementType())) { // Python-style float modulo: a - floor(a / b) * b - Value div = b.create(loc, lhs, rhs); - Value floored = b.create(loc, div); - Value prod = b.create(loc, floored, rhs); - return b.create(loc, lhs, prod); + Value div = arith::DivFOp::create(b, loc, lhs, rhs); + Value floored = math::FloorOp::create(b, loc, div); + Value prod = arith::MulFOp::create(b, loc, floored, rhs); + return arith::SubFOp::create(b, loc, lhs, prod); } // Python-style integer modulo: adjust C remainder when signs differ. - Value rem = b.create(loc, lhs, rhs); - Value zero = b.create( - loc, b.getIntegerAttr(resultType.getElementType(), 0)); - Value sum = b.create(loc, rem, rhs); + Value rem = arith::RemSIOp::create(b, loc, lhs, rhs); + Value zero = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(resultType.getElementType(), 0)); + Value sum = arith::AddIOp::create(b, loc, rem, rhs); Value remNeZero = - b.create(loc, arith::CmpIPredicate::ne, rem, zero); - Value xorVal = b.create(loc, rem, rhs); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, rem, zero); + Value xorVal = arith::XOrIOp::create(b, loc, rem, rhs); Value signsDiffer = - b.create(loc, arith::CmpIPredicate::slt, xorVal, zero); - Value needAdjust = b.create(loc, remNeZero, signsDiffer); - return b.create(loc, needAdjust, sum, rem); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, xorVal, zero); + Value needAdjust = arith::AndIOp::create(b, loc, remNeZero, signsDiffer); + return arith::SelectOp::create(b, loc, needAdjust, sum, rem); } }; } // namespace @@ -440,26 +442,26 @@ class GeluBackwardOpToTosaPattern // cdf = 0.5 * (1 + erf(x * invSqrt2)) Value xScaled = - rewriter.create(loc, resultType, x, invSqrt2, shift); - Value erfVal = rewriter.create(loc, resultType, xScaled); + tosa::MulOp::create(rewriter, loc, resultType, x, invSqrt2, shift); + Value erfVal = tosa::ErfOp::create(rewriter, loc, resultType, xScaled); Value onePlusErf = - rewriter.create(loc, resultType, one, erfVal); + tosa::AddOp::create(rewriter, loc, resultType, one, erfVal); Value cdf = - rewriter.create(loc, resultType, half, onePlusErf, shift); + tosa::MulOp::create(rewriter, loc, resultType, half, onePlusErf, shift); // pdf = exp(-x^2/2) / sqrt(2*pi) - Value xSq = rewriter.create(loc, resultType, x, x, shift); + Value xSq = tosa::MulOp::create(rewriter, loc, resultType, x, x, shift); Value negHalfXSq = - rewriter.create(loc, resultType, negHalf, xSq, shift); - Value expVal = rewriter.create(loc, resultType, negHalfXSq); - Value pdf = rewriter.create(loc, resultType, invSqrt2Pi, - expVal, shift); + tosa::MulOp::create(rewriter, loc, resultType, negHalf, xSq, shift); + Value expVal = tosa::ExpOp::create(rewriter, loc, resultType, negHalfXSq); + Value pdf = tosa::MulOp::create(rewriter, loc, resultType, invSqrt2Pi, + expVal, shift); // result = grad * (cdf + x * pdf) Value xTimesPdf = - rewriter.create(loc, resultType, x, pdf, shift); + tosa::MulOp::create(rewriter, loc, resultType, x, pdf, shift); Value cdfPlusXPdf = - rewriter.create(loc, resultType, cdf, xTimesPdf); + tosa::AddOp::create(rewriter, loc, resultType, cdf, xTimesPdf); rewriter.replaceOpWithNewOp(op, resultType, grad, cdfPlusXPdf, shift); return success(); @@ -488,48 +490,49 @@ class GeluBackwardOpToTosaPattern Value one = createTosaConst(rewriter, loc, elemTy, rank, 1.0); // x^2 - Value xSq = rewriter.create(loc, resultType, x, x, shift); + Value xSq = tosa::MulOp::create(rewriter, loc, resultType, x, x, shift); // a * x^2 * x = a * x^3 - Value aXSq = rewriter.create(loc, resultType, a, xSq, shift); - Value aXCub = rewriter.create(loc, resultType, aXSq, x, shift); + Value aXSq = tosa::MulOp::create(rewriter, loc, resultType, a, xSq, shift); + Value aXCub = + tosa::MulOp::create(rewriter, loc, resultType, aXSq, x, shift); // inner = k * (x + a*x^3) - Value xPlusAXCub = rewriter.create(loc, resultType, x, aXCub); + Value xPlusAXCub = tosa::AddOp::create(rewriter, loc, resultType, x, aXCub); Value inner = - rewriter.create(loc, resultType, k, xPlusAXCub, shift); + tosa::MulOp::create(rewriter, loc, resultType, k, xPlusAXCub, shift); // tanh_val = tanh(inner) - Value tanhVal = rewriter.create(loc, resultType, inner); + Value tanhVal = tosa::TanhOp::create(rewriter, loc, resultType, inner); // sech^2 = 1 - tanh^2 Value tanhSq = - rewriter.create(loc, resultType, tanhVal, tanhVal, shift); - Value negTanhSq = rewriter.create(loc, resultType, tanhSq); + tosa::MulOp::create(rewriter, loc, resultType, tanhVal, tanhVal, shift); + Value negTanhSq = tosa::NegateOp::create(rewriter, loc, resultType, tanhSq); Value sechSq = - rewriter.create(loc, resultType, one, negTanhSq); + tosa::AddOp::create(rewriter, loc, resultType, one, negTanhSq); // left = 0.5 * (1 + tanh_val) Value onePlusTanh = - rewriter.create(loc, resultType, one, tanhVal); - Value left = - rewriter.create(loc, resultType, half, onePlusTanh, shift); + tosa::AddOp::create(rewriter, loc, resultType, one, tanhVal); + Value left = tosa::MulOp::create(rewriter, loc, resultType, half, + onePlusTanh, shift); // right = 0.5 * x * sech^2 * k * (1 + 3*a*x^2) Value threeAXSq = - rewriter.create(loc, resultType, threeA, xSq, shift); + tosa::MulOp::create(rewriter, loc, resultType, threeA, xSq, shift); Value onePlus3AXSq = - rewriter.create(loc, resultType, one, threeAXSq); + tosa::AddOp::create(rewriter, loc, resultType, one, threeAXSq); Value sechK = - rewriter.create(loc, resultType, sechSq, k, shift); - Value sechKTerm = rewriter.create(loc, resultType, sechK, - onePlus3AXSq, shift); + tosa::MulOp::create(rewriter, loc, resultType, sechSq, k, shift); + Value sechKTerm = tosa::MulOp::create(rewriter, loc, resultType, sechK, + onePlus3AXSq, shift); Value xTerm = - rewriter.create(loc, resultType, x, sechKTerm, shift); + tosa::MulOp::create(rewriter, loc, resultType, x, sechKTerm, shift); Value right = - rewriter.create(loc, resultType, half, xTerm, shift); + tosa::MulOp::create(rewriter, loc, resultType, half, xTerm, shift); // gelu_bw = grad * (left + right) Value leftPlusRight = - rewriter.create(loc, resultType, left, right); + tosa::AddOp::create(rewriter, loc, resultType, left, right); rewriter.replaceOpWithNewOp(op, resultType, grad, leftPlusRight, shift); return success(); diff --git a/lib/Conversion/TTIRToLinalg/EltwiseUnary.cpp b/lib/Conversion/TTIRToLinalg/EltwiseUnary.cpp index 20da8b3ab4c..71199a69da5 100644 --- a/lib/Conversion/TTIRToLinalg/EltwiseUnary.cpp +++ b/lib/Conversion/TTIRToLinalg/EltwiseUnary.cpp @@ -60,7 +60,7 @@ class ElementwiseUnaryOpToTosaConversionPattern Value input = broadcastToShape(adaptor.getInput(), resultType.getShape(), op.getLoc(), rewriter); - auto result = rewriter.create(op.getLoc(), resultType, input); + auto result = TosaOpTy::create(rewriter, op.getLoc(), resultType, input); rewriter.replaceOp(op, result); return success(); @@ -94,8 +94,8 @@ class ElementwiseUnaryOpToNamedLinalgConversionPattern Value input = broadcastToShape(adaptor.getInput(), resultType.getShape(), loc, rewriter); - auto output = rewriter.create(loc, resultType.getShape(), - resultType.getElementType()); + auto output = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()); rewriter.replaceOpWithNewOp(op, resultType, ValueRange{input}, output.getResult()); return success(); @@ -141,15 +141,15 @@ class ElementwiseUnaryOpToLinalgGenericConversionPatternBase SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - auto emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), resultType.getElementType()); - auto genericOp = rewriter.create( - loc, resultType, ValueRange{input}, ValueRange{emptyTensor}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{input}, ValueRange{emptyTensor}, SmallVector{indexingMap, indexingMap}, iteratorTypes, [&](OpBuilder &b, Location nestedLoc, ValueRange args) { Value result = buildBody(b, nestedLoc, args, resultType); - b.create(nestedLoc, result); + linalg::YieldOp::create(b, nestedLoc, result); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -174,7 +174,7 @@ class ElementwiseUnaryOpToMathConversionPattern protected: Value buildBody(OpBuilder &b, Location loc, ValueRange args, RankedTensorType /*resultType*/) const override { - return b.create(loc, args[0]); + return MathOpTy::create(b, loc, args[0]); } }; } // namespace @@ -210,10 +210,11 @@ class SignOpConversionPattern : public OpConversionPattern { Value one = createTosaConst(rewriter, loc, elemTy, rank, 1.0); Value negOne = createTosaConst(rewriter, loc, elemTy, rank, -1.0); - Value gtZero = rewriter.create(loc, boolType, input, zero); - Value eqZero = rewriter.create(loc, boolType, input, zero); + Value gtZero = + tosa::GreaterOp::create(rewriter, loc, boolType, input, zero); + Value eqZero = tosa::EqualOp::create(rewriter, loc, boolType, input, zero); Value posOrNeg = - rewriter.create(loc, resultType, gtZero, one, negOne); + tosa::SelectOp::create(rewriter, loc, resultType, gtZero, one, negOne); rewriter.replaceOpWithNewOp(op, resultType, eqZero, zero, posOrNeg); return success(); @@ -237,13 +238,13 @@ class IsFiniteOpConversionPattern // isfinite(x) = (x - x) == 0. // If x is NaN, x - x is NaN and OEQ returns false. // If x is Inf, x - x is NaN and OEQ returns false. - Value diff = b.create(loc, elem, elem); + Value diff = arith::SubFOp::create(b, loc, elem, elem); Value zero = - b.create(loc, b.getFloatAttr(elem.getType(), 0.0)); + arith::ConstantOp::create(b, loc, b.getFloatAttr(elem.getType(), 0.0)); Value isFinite = - b.create(loc, arith::CmpFPredicate::OEQ, diff, zero); - return b.create(loc, resultType.getElementType(), - isFinite); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OEQ, diff, zero); + return arith::UIToFPOp::create(b, loc, resultType.getElementType(), + isFinite); } }; } // namespace @@ -281,12 +282,12 @@ class GeluOpConversionPattern : public OpConversionPattern { Value invSqrt2 = createTosaConst(rewriter, loc, elemTy, rank, kInvSqrt2); Value xScaled = - rewriter.create(loc, resultType, input, invSqrt2, shift); - Value erfVal = rewriter.create(loc, resultType, xScaled); + tosa::MulOp::create(rewriter, loc, resultType, input, invSqrt2, shift); + Value erfVal = tosa::ErfOp::create(rewriter, loc, resultType, xScaled); Value onePlusErf = - rewriter.create(loc, resultType, one, erfVal); + tosa::AddOp::create(rewriter, loc, resultType, one, erfVal); Value halfX = - rewriter.create(loc, resultType, half, input, shift); + tosa::MulOp::create(rewriter, loc, resultType, half, input, shift); rewriter.replaceOpWithNewOp(op, resultType, halfX, onePlusErf, shift); return success(); @@ -318,7 +319,7 @@ class SiluOpConversionPattern : public OpConversionPattern { input = broadcastToShape(input, resultType.getShape(), loc, rewriter); Value shift = createTosaMulShift(rewriter, loc); - Value sigm = rewriter.create(loc, resultType, input); + Value sigm = tosa::SigmoidOp::create(rewriter, loc, resultType, input); rewriter.replaceOpWithNewOp(op, resultType, input, sigm, shift); return success(); @@ -359,9 +360,9 @@ class LeakyReluOpConversionPattern Value zero = createTosaConst(rewriter, loc, elemTy, rank, 0.0); Value alphaC = createTosaConst(rewriter, loc, elemTy, rank, alpha); Value scaled = - rewriter.create(loc, resultType, alphaC, input, shift); + tosa::MulOp::create(rewriter, loc, resultType, alphaC, input, shift); Value positive = - rewriter.create(loc, boolType, input, zero); + tosa::GreaterOp::create(rewriter, loc, boolType, input, zero); rewriter.replaceOpWithNewOp(op, resultType, positive, input, scaled); return success(); @@ -399,9 +400,9 @@ class HardsigmoidOpConversionPattern Value three = createTosaConst(rewriter, loc, elemTy, rank, 3.0); Value sixth = createTosaConst(rewriter, loc, elemTy, rank, 1.0 / 6.0); - Value xPlus3 = rewriter.create(loc, resultType, input, three); + Value xPlus3 = tosa::AddOp::create(rewriter, loc, resultType, input, three); Value scaled = - rewriter.create(loc, resultType, xPlus3, sixth, shift); + tosa::MulOp::create(rewriter, loc, resultType, xPlus3, sixth, shift); rewriter.replaceOpWithNewOp( op, resultType, scaled, rewriter.getFloatAttr(elemTy, 0.0), rewriter.getFloatAttr(elemTy, 1.0)); @@ -441,18 +442,18 @@ class MishOpConversionPattern : public OpConversionPattern { Value one = createTosaConst(rewriter, loc, elemTy, rank, 1.0); // softplus(x) = max(x, 0) + log(1 + exp(-|x|)). - Value absX = rewriter.create(loc, resultType, input); - Value negAbsX = rewriter.create(loc, resultType, absX); - Value expNegAbsX = rewriter.create(loc, resultType, negAbsX); + Value absX = tosa::AbsOp::create(rewriter, loc, resultType, input); + Value negAbsX = tosa::NegateOp::create(rewriter, loc, resultType, absX); + Value expNegAbsX = tosa::ExpOp::create(rewriter, loc, resultType, negAbsX); Value onePlusExp = - rewriter.create(loc, resultType, one, expNegAbsX); - Value logPart = rewriter.create(loc, resultType, onePlusExp); + tosa::AddOp::create(rewriter, loc, resultType, one, expNegAbsX); + Value logPart = tosa::LogOp::create(rewriter, loc, resultType, onePlusExp); Value maxXZero = - rewriter.create(loc, resultType, input, zero); + tosa::MaximumOp::create(rewriter, loc, resultType, input, zero); Value softplus = - rewriter.create(loc, resultType, maxXZero, logPart); + tosa::AddOp::create(rewriter, loc, resultType, maxXZero, logPart); - Value tanhSP = rewriter.create(loc, resultType, softplus); + Value tanhSP = tosa::TanhOp::create(rewriter, loc, resultType, softplus); rewriter.replaceOpWithNewOp(op, resultType, input, tanhSP, shift); return success(); @@ -493,10 +494,10 @@ class LogicalNotOpConversionPattern // Apply logical not to the boolean tensor. auto notResult = - rewriter.create(loc, boolType, boolInput); + tosa::LogicalNotOp::create(rewriter, loc, boolType, boolInput); // Convert boolean result back to original type using cast. - auto result = rewriter.create(loc, resultType, notResult); + auto result = tosa::CastOp::create(rewriter, loc, resultType, notResult); rewriter.replaceOp(op, result); return success(); @@ -529,10 +530,11 @@ class ReluOpConversionPattern : public OpConversionPattern { op, "Unsupported element type for ReLU zero constant"); } - auto zeroes = rewriter.create(loc, resultType, zeroAttr); + auto zeroes = + arith::ConstantOp::create(rewriter, loc, resultType, zeroAttr); - auto output = rewriter.create(loc, resultType.getShape(), - resultType.getElementType()); + auto output = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), + resultType.getElementType()); rewriter.replaceOpWithNewOp( op, resultType, ValueRange{input, zeroes.getResult()}, ValueRange{output}); diff --git a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp index 197040c550a..5b444b44342 100644 --- a/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp +++ b/lib/Conversion/TTIRToLinalg/TTIRToLinalg.cpp @@ -62,14 +62,14 @@ convertToBooleanTensorComparison(Value input, Location loc, auto zeroType = RankedTensorType::get(zeroShape, elementType); DenseElementsAttr zeroAttr = DenseElementsAttr::get(zeroType, rewriter.getF32FloatAttr(0.0f)); - auto zeroConst = rewriter.create(loc, zeroType, zeroAttr); + auto zeroConst = tosa::ConstOp::create(rewriter, loc, zeroType, zeroAttr); // For comparison semantics: positive values are true, so we need: (input > // 0). auto boolType = RankedTensorType::get(inputType.getShape(), rewriter.getIntegerType(1)); auto greaterThanZero = - rewriter.create(loc, boolType, input, zeroConst); + tosa::GreaterOp::create(rewriter, loc, boolType, input, zeroConst); return greaterThanZero.getResult(); } @@ -102,9 +102,10 @@ static Value reshapeByPrependingOnes(Value input, int64_t targetRank, auto shapeType = tosa::shapeType::get(rewriter.getContext(), broadcastShape.size()); auto shapeAttr = rewriter.getIndexTensorAttr(broadcastShape); - auto shapeOp = rewriter.create(loc, shapeType, shapeAttr); - return rewriter.create(loc, reshapedType, input, - shapeOp.getResult()); + auto shapeOp = + tosa::ConstShapeOp::create(rewriter, loc, shapeType, shapeAttr); + return tosa::ReshapeOp::create(rewriter, loc, reshapedType, input, + shapeOp.getResult()); } // Get dimensions from the dim_arg attribute; if the attribute is not present or @@ -175,15 +176,16 @@ static Value createReductionOpChain(Value input, RankedTensorType resultType, opResultType = RankedTensorType::get(shape, inputType.getElementType()); // Create the reduction operation - result = rewriter.create(loc, opResultType, result, axisAttr); + result = ReductionOp::create(rewriter, loc, opResultType, result, axisAttr); } if (!keepDim) { ArrayRef newShape = resultType.getShape(); auto shapeType = tosa::shapeType::get(rewriter.getContext(), newShape.size()); auto attr = rewriter.getIndexTensorAttr(newShape); - auto shapeOp = rewriter.create(loc, shapeType, attr); - result = rewriter.create(loc, resultType, result, shapeOp); + auto shapeOp = tosa::ConstShapeOp::create(rewriter, loc, shapeType, attr); + result = + tosa::ReshapeOp::create(rewriter, loc, resultType, result, shapeOp); } return result; } @@ -208,10 +210,10 @@ static Value createTosaReshape(Value input, RankedTensorType targetType, ArrayRef newShape = targetType.getShape(); auto shapeType = tosa::shapeType::get(rewriter.getContext(), newShape.size()); auto shapeAttr = rewriter.getIndexTensorAttr(newShape); - auto shapeOp = rewriter.create(loc, shapeType, shapeAttr); - return rewriter - .create(loc, targetType, input, shapeOp.getResult()) - .getResult(); + auto shapeOp = + tosa::ConstShapeOp::create(rewriter, loc, shapeType, shapeAttr); + return tosa::ReshapeOp::create(rewriter, loc, targetType, input, + shapeOp.getResult()); } // Unflatten input from (1, 1, N*H*W, C) to (N, H, W, C) using metadata from @@ -259,8 +261,8 @@ class WhereOpConversionPattern : public OpConversionPattern { } condition = *conditionOrFailure; - auto result = rewriter.create( - op.getLoc(), resultType, condition, trueValue, falseValue); + auto result = tosa::SelectOp::create(rewriter, op.getLoc(), resultType, + condition, trueValue, falseValue); rewriter.replaceOp(op, result); return success(); @@ -289,10 +291,10 @@ class ReshapeOpConversionPattern : public OpConversionPattern { tosa::shapeType::get(rewriter.getContext(), newShape.size()); auto attr = rewriter.getIndexTensorAttr(newShapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - auto reshapeOp = rewriter.create( - op.getLoc(), resultType, adaptor.getInput(), shapeOp); + auto reshapeOp = tosa::ReshapeOp::create(rewriter, op.getLoc(), resultType, + adaptor.getInput(), shapeOp); rewriter.replaceOp(op, reshapeOp); @@ -335,8 +337,8 @@ class TransposeOpConversionPattern permutation[dim0] = static_cast(dim1); // Create TransposeOp directly with the permutation array - auto result = rewriter.create(op.getLoc(), resultType, - input, permutation); + auto result = tosa::TransposeOp::create(rewriter, op.getLoc(), resultType, + input, permutation); rewriter.replaceOp(op, result); return success(); @@ -371,7 +373,7 @@ class ConcatOpConversionPattern : public OpConversionPattern { // Concatenate all inputs at once using the final result type. Value result = - rewriter.create(op.getLoc(), resultType, inputs, dim); + tosa::ConcatOp::create(rewriter, op.getLoc(), resultType, inputs, dim); rewriter.replaceOp(op, result); return success(); @@ -422,14 +424,14 @@ class BroadcastOpConversionPattern // The broadcast op requires we actually collapse any dimensions with // size 1 we want to broadcast along. if (collapseDimGroups.size() != inputShape.size()) { - broadcastInput = rewriter.create( - loc, input, collapseDimGroups); + broadcastInput = tensor::CollapseShapeOp::create(rewriter, loc, input, + collapseDimGroups); } - auto initTensor = rewriter.create( - loc, targetShape, inputType.getElementType()); - auto broadcastOp = rewriter.create( - loc, broadcastInput, initTensor.getResult(), broadcastDims); + auto initTensor = ttir::EmptyOp::create(rewriter, loc, targetShape, + inputType.getElementType()); + auto broadcastOp = linalg::BroadcastOp::create( + rewriter, loc, broadcastInput, initTensor.getResult(), broadcastDims); rewriter.replaceOp(op, broadcastOp.getResults().front()); return success(); @@ -486,8 +488,8 @@ class MatmulOpConversionPattern : public OpConversionPattern { permutation.push_back(static_cast(lhsShape.size() - 2)); // Create transpose op - lhs = rewriter.create(op.getLoc(), transposedType, - lhs, permutation); + lhs = tosa::TransposeOp::create(rewriter, op.getLoc(), transposedType, + lhs, permutation); lhsType = transposedType; } } @@ -516,8 +518,8 @@ class MatmulOpConversionPattern : public OpConversionPattern { permutation.push_back(static_cast(rhsShape.size() - 2)); // Create transpose op - rhs = rewriter.create(op.getLoc(), transposedType, - rhs, permutation); + rhs = tosa::TransposeOp::create(rewriter, op.getLoc(), transposedType, + rhs, permutation); rhsType = transposedType; } } @@ -544,11 +546,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { lhsType.getDimSize(1)}; auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); // Reshape LHS to 3D - lhs3D = rewriter.create(op.getLoc(), newType, lhs, - shapeOp.getResult()); + lhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, lhs, + shapeOp.getResult()); lhs3DType = newType; } else if (lhsRank > 3) { // For tensors with rank > 3, collapse all but the last two dimensions @@ -569,11 +571,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { lhsType.getShape()[lhsRank - 1]}; auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); // Reshape LHS to 3D - lhs3D = rewriter.create(op.getLoc(), newType, lhs, - shapeOp.getResult()); + lhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, lhs, + shapeOp.getResult()); lhs3DType = newType; } @@ -589,11 +591,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { rhsType.getDimSize(1)}; auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); // Reshape RHS to 3D - rhs3D = rewriter.create(op.getLoc(), newType, rhs, - shapeOp.getResult()); + rhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, rhs, + shapeOp.getResult()); rhs3DType = newType; } else if (rhsRank > 3) { // For tensors with rank > 3, collapse all but the last two dimensions @@ -614,11 +616,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { rhsType.getShape()[rhsRank - 1]}; auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); // Reshape RHS to 3D - rhs3D = rewriter.create(op.getLoc(), newType, rhs, - shapeOp.getResult()); + rhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, rhs, + shapeOp.getResult()); rhs3DType = newType; } @@ -636,11 +638,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto multiplesAttr = rewriter.getIndexTensorAttr(multiples); - auto multiplesOp = rewriter.create( - op.getLoc(), shapeType, multiplesAttr); + auto multiplesOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, multiplesAttr); - lhs3D = rewriter.create(op.getLoc(), newType, lhs3D, - multiplesOp); + lhs3D = tosa::TileOp::create(rewriter, op.getLoc(), newType, lhs3D, + multiplesOp); lhs3DType = cast(lhs3D.getType()); } else if (rhs3DType.getShape()[0] == 1 && lhs3DType.getShape()[0] > 1) { // Use TOSA tile operation for broadcasting @@ -652,11 +654,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto multiplesAttr = rewriter.getIndexTensorAttr(multiples); - auto multiplesOp = rewriter.create( - op.getLoc(), shapeType, multiplesAttr); + auto multiplesOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, multiplesAttr); - rhs3D = rewriter.create(op.getLoc(), newType, rhs3D, - multiplesOp); + rhs3D = tosa::TileOp::create(rewriter, op.getLoc(), newType, rhs3D, + multiplesOp); rhs3DType = cast(rhs3D.getType()); } } @@ -668,8 +670,8 @@ class MatmulOpConversionPattern : public OpConversionPattern { resultType.getElementType()); // Perform matrix multiplication using tosa.matmul - Value matmulResult = rewriter.create( - op.getLoc(), matmulResultType, lhs3D, rhs3D); + Value matmulResult = tosa::MatMulOp::create(rewriter, op.getLoc(), + matmulResultType, lhs3D, rhs3D); // Reshape result back to original rank if needed if (resultType.getRank() != matmulResultType.getRank()) { @@ -682,11 +684,11 @@ class MatmulOpConversionPattern : public OpConversionPattern { } auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); // Reshape result - matmulResult = rewriter.create( - op.getLoc(), resultType, matmulResult, shapeOp.getResult()); + matmulResult = tosa::ReshapeOp::create(rewriter, op.getLoc(), resultType, + matmulResult, shapeOp.getResult()); } rewriter.replaceOp(op, matmulResult); @@ -711,8 +713,8 @@ static Value sliceResultToShape(Value result, RankedTensorType targetType, sizes.push_back(rewriter.getI64IntegerAttr(targetType.getShape()[i])); strides.push_back(rewriter.getI64IntegerAttr(1)); } - return rewriter.create(loc, targetType, result, - offsets, sizes, strides); + return tensor::ExtractSliceOp::create(rewriter, loc, targetType, result, + offsets, sizes, strides); } namespace { @@ -746,8 +748,8 @@ class Conv2dOpConversionPattern : public OpConversionPattern { auto transposedWeightType = RankedTensorType::get(transposedShape, weightType.getElementType()); - auto transposedWeight = rewriter.create( - op.getLoc(), transposedWeightType, weight, permutation); + auto transposedWeight = tosa::TransposeOp::create( + rewriter, op.getLoc(), transposedWeightType, weight, permutation); // Reshape bias from 4D (1,1,1,B) to 1D (B) for TOSA. // If bias is not provided, create a zero bias tensor. @@ -767,13 +769,17 @@ class Conv2dOpConversionPattern : public OpConversionPattern { RankedTensorType::get(reshapedBiasShape, biasType.getElementType()); auto shapeType = tosa::shapeType::get(rewriter.getContext(), 1); auto shapeAttr = rewriter.getIndexTensorAttr(reshapedBiasShape); - auto shapeOp = rewriter.create(op.getLoc(), shapeType, - shapeAttr); - - reshapedBias = rewriter - .create(op.getLoc(), reshapedBiasType, - bias, shapeOp.getResult()) - .getResult(); + auto shapeOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, shapeAttr); + + reshapedBias = tosa::ReshapeOp::create( + rewriter, op.getLoc(), reshapedBiasType, bias, shapeOp.getResult()); + reshapedBias = tosa::ReshapeOp::create( + rewriter, op.getLoc(), reshapedBiasType, bias, shapeOp.getResult()); + reshapedBias = tosa::ReshapeOp::create( + rewriter, op.getLoc(), reshapedBiasType, bias, shapeOp.getResult()); + reshapedBias = tosa::ReshapeOp::create( + rewriter, op.getLoc(), reshapedBiasType, bias, shapeOp.getResult()); } else { int64_t outputChannels = weightShape[0]; auto biasElementType = @@ -793,8 +799,8 @@ class Conv2dOpConversionPattern : public OpConversionPattern { } else { return rewriter.notifyMatchFailure(op, "Unsupported bias element type"); } - reshapedBias = rewriter.create( - op.getLoc(), biasType, cast(zeroAttr)); + reshapedBias = tosa::ConstOp::create(rewriter, op.getLoc(), biasType, + cast(zeroAttr)); } // Expand stride if it contains only one element. auto stridesResult = ttmlir::utils::getPairOfInteger(strides); @@ -909,10 +915,10 @@ class Conv2dOpConversionPattern : public OpConversionPattern { auto actualResultType = RankedTensorType::get(resultShape, resultType.getElementType()); - auto conv2dOp = rewriter.create( - op.getLoc(), actualResultType, input, transposedWeight.getResult(), - reshapedBias, expandedPaddingAttr, expandedStridesAttr, - expandedDilationsAttr, TypeAttr::get(accType)); + auto conv2dOp = tosa::Conv2DOp::create( + rewriter, op.getLoc(), actualResultType, input, + transposedWeight.getResult(), reshapedBias, expandedPaddingAttr, + expandedStridesAttr, expandedDilationsAttr, TypeAttr::get(accType)); Value result = sliceResultToShape(conv2dOp.getResult(), resultType, rewriter, op.getLoc()); @@ -1052,13 +1058,13 @@ class MaxPool2dOpConversionPattern if (auto floatType = dyn_cast(elementType)) { auto negInf = APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); - negInfVal = rewriter.create( - loc, rewriter.getFloatAttr(elementType, negInf)); + negInfVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(elementType, negInf)); } else { auto intType = cast(elementType); auto minVal = APInt::getSignedMinValue(intType.getWidth()); - negInfVal = rewriter.create( - loc, rewriter.getIntegerAttr(elementType, minVal)); + negInfVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(elementType, minVal)); } // Pad input with -inf if needed. @@ -1079,15 +1085,15 @@ class MaxPool2dOpConversionPattern rewriter.getIndexAttr(paddingRight), rewriter.getIndexAttr(0)}; auto paddedType = RankedTensorType::get( {batch, paddedH, paddedW, channels}, elementType); - paddedInput = rewriter.create(loc, paddedType, input, - lowPad, highPad, negInfVal); + paddedInput = tensor::PadOp::create(rewriter, loc, paddedType, input, + lowPad, highPad, negInfVal); } // Create kernel tensor, strides, and dilations. auto kernelType = RankedTensorType::get({kernelH, kernelW}, rewriter.getF32Type()); - Value kernelTensor = rewriter.create( - loc, kernelType.getShape(), kernelType.getElementType()); + Value kernelTensor = tensor::EmptyOp::create( + rewriter, loc, kernelType.getShape(), kernelType.getElementType()); auto stridesAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), @@ -1097,15 +1103,16 @@ class MaxPool2dOpConversionPattern ArrayRef{dilationH, dilationW}); // Init output with -inf and run max pooling. - Value outputInit = rewriter.create( - loc, actualResultType.getShape(), elementType); + Value outputInit = tensor::EmptyOp::create( + rewriter, loc, actualResultType.getShape(), elementType); Value poolOutput = - rewriter.create(loc, negInfVal, outputInit) + linalg::FillOp::create(rewriter, loc, negInfVal, outputInit) .getResult(0); - auto poolOp = rewriter.create( - loc, TypeRange{actualResultType}, ValueRange{paddedInput, kernelTensor}, - ValueRange{poolOutput}, stridesAttr, dilationsAttr); + auto poolOp = linalg::PoolingNhwcMaxOp::create( + rewriter, loc, TypeRange{actualResultType}, + ValueRange{paddedInput, kernelTensor}, ValueRange{poolOutput}, + stridesAttr, dilationsAttr); Value result = poolOp.getResult(0); result = sliceResultToShape(result, resultType, rewriter, loc); @@ -1254,11 +1261,11 @@ class AvgPool2dOpConversionPattern // Create zero constant for padding and fill operations. Value zeroVal; if (isa(elementType)) { - zeroVal = rewriter.create( - loc, rewriter.getFloatAttr(elementType, 0.0)); + zeroVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(elementType, 0.0)); } else { - zeroVal = rewriter.create( - loc, rewriter.getIntegerAttr(elementType, 0)); + zeroVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(elementType, 0)); } // Create padding attributes. @@ -1274,15 +1281,16 @@ class AvgPool2dOpConversionPattern // Pad the input tensor if needed. Value paddedInput = input; if (hasPadding) { - paddedInput = rewriter.create(loc, paddedType, input, - lowPad, highPad, zeroVal); + paddedInput = tensor::PadOp::create(rewriter, loc, paddedType, input, + lowPad, highPad, zeroVal); } // Create the kernel tensor (shape only, values don't matter for pooling). auto linalgKernelType = RankedTensorType::get({kernelH, kernelW}, rewriter.getF32Type()); - Value kernelTensor = rewriter.create( - loc, linalgKernelType.getShape(), linalgKernelType.getElementType()); + Value kernelTensor = + tensor::EmptyOp::create(rewriter, loc, linalgKernelType.getShape(), + linalgKernelType.getElementType()); // Create strides and dilations attributes for linalg.pooling_nhwc_sum. auto linalgStridesAttr = DenseIntElementsAttr::get( @@ -1293,16 +1301,17 @@ class AvgPool2dOpConversionPattern ArrayRef{dilationH, dilationW}); // Create output tensor initialized to zero for sum accumulation. - Value sumOutputInit = rewriter.create( - loc, actualResultType.getShape(), elementType); + Value sumOutputInit = tensor::EmptyOp::create( + rewriter, loc, actualResultType.getShape(), elementType); Value sumOutput = - rewriter.create(loc, zeroVal, sumOutputInit) + linalg::FillOp::create(rewriter, loc, zeroVal, sumOutputInit) .getResult(0); // Perform sum pooling on input. - auto sumPoolOp = rewriter.create( - loc, TypeRange{actualResultType}, ValueRange{paddedInput, kernelTensor}, - ValueRange{sumOutput}, linalgStridesAttr, dilationsAttr); + auto sumPoolOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, TypeRange{actualResultType}, + ValueRange{paddedInput, kernelTensor}, ValueRange{sumOutput}, + linalgStridesAttr, dilationsAttr); Value sumResult = sumPoolOp.getResult(0); // Compute divisor tensor by sum-pooling a binary mask of valid positions. @@ -1314,11 +1323,11 @@ class AvgPool2dOpConversionPattern // valid elements fall into each sliding window. Value oneVal; if (isa(elementType)) { - oneVal = rewriter.create( - loc, rewriter.getFloatAttr(elementType, 1.0)); + oneVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(elementType, 1.0)); } else { - oneVal = rewriter.create( - loc, rewriter.getIntegerAttr(elementType, 1)); + oneVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(elementType, 1)); } SmallVector onesShape; @@ -1343,32 +1352,34 @@ class AvgPool2dOpConversionPattern } Value onesInit = - rewriter.create(loc, onesShape, elementType); + tensor::EmptyOp::create(rewriter, loc, onesShape, elementType); Value onesTensor = - rewriter.create(loc, oneVal, onesInit).getResult(0); + linalg::FillOp::create(rewriter, loc, oneVal, onesInit).getResult(0); // Pad ones tensor with zeros. - Value paddedOnes = rewriter.create( - loc, paddedType, onesTensor, onesLowPad, onesHighPad, zeroVal); + Value paddedOnes = + tensor::PadOp::create(rewriter, loc, paddedType, onesTensor, onesLowPad, + onesHighPad, zeroVal); // Sum-pool the mask to count valid elements per window. - Value countOutputInit = rewriter.create( - loc, actualResultType.getShape(), elementType); + Value countOutputInit = tensor::EmptyOp::create( + rewriter, loc, actualResultType.getShape(), elementType); Value countOutput = - rewriter.create(loc, zeroVal, countOutputInit) + linalg::FillOp::create(rewriter, loc, zeroVal, countOutputInit) .getResult(0); - auto countPoolOp = rewriter.create( - loc, TypeRange{actualResultType}, ValueRange{paddedOnes, kernelTensor}, - ValueRange{countOutput}, linalgStridesAttr, dilationsAttr); + auto countPoolOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, TypeRange{actualResultType}, + ValueRange{paddedOnes, kernelTensor}, ValueRange{countOutput}, + linalgStridesAttr, dilationsAttr); Value divisorTensor = countPoolOp.getResult(0); // Divide sum by divisor to get average. - Value avgOutputInit = rewriter.create( - loc, actualResultType.getShape(), elementType); - auto divOp = rewriter.create( - loc, actualResultType, ValueRange{sumResult, divisorTensor}, - avgOutputInit); + Value avgOutputInit = tensor::EmptyOp::create( + rewriter, loc, actualResultType.getShape(), elementType); + auto divOp = linalg::DivOp::create(rewriter, loc, actualResultType, + ValueRange{sumResult, divisorTensor}, + avgOutputInit); Value result = divOp.getResult(0); result = sliceResultToShape(result, resultType, rewriter, loc); @@ -1422,13 +1433,14 @@ class GlobalAvgPool2dOpConversionPattern resultType.getElementType()); auto heightAxisAttr = rewriter.getI32IntegerAttr(1); - auto heightReduceResult = rewriter.create( - loc, heightReduceType, input, heightAxisAttr); + auto heightReduceResult = tosa::ReduceSumOp::create( + rewriter, loc, heightReduceType, input, heightAxisAttr); // Then, reduce along width dimension (dim 2). auto widthAxisAttr = rewriter.getI32IntegerAttr(2); - auto widthReduceResult = rewriter.create( - loc, resultType, heightReduceResult.getResult(), widthAxisAttr); + auto widthReduceResult = tosa::ReduceSumOp::create( + rewriter, loc, resultType, heightReduceResult.getResult(), + widthAxisAttr); // Divide by the total number of spatial elements (H * W). double spatialCount = static_cast(inputHeight * inputWidth); @@ -1441,17 +1453,18 @@ class GlobalAvgPool2dOpConversionPattern DenseElementsAttr divisorAttr = DenseElementsAttr::get( divisorType, rewriter.getFloatAttr(elementType, 1.0 / spatialCount)); auto divisorConst = - rewriter.create(loc, divisorType, divisorAttr); + tosa::ConstOp::create(rewriter, loc, divisorType, divisorAttr); // Create shift tensor for tosa::MulOp (requires i8 tensor). auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type()); auto shiftAttr = DenseElementsAttr::get(shiftType, rewriter.getI8IntegerAttr(0)); - Value shift = rewriter.create(loc, shiftType, shiftAttr); + Value shift = tosa::ConstOp::create(rewriter, loc, shiftType, shiftAttr); // Multiply by reciprocal to get average. - auto result = rewriter.create( - loc, resultType, widthReduceResult.getResult(), divisorConst, shift); + auto result = + tosa::MulOp::create(rewriter, loc, resultType, + widthReduceResult.getResult(), divisorConst, shift); rewriter.replaceOp(op, result.getResult()); return success(); @@ -1489,8 +1502,8 @@ class GatherOpConversionPattern : public OpConversionPattern { auto indexVectorDim = op.getIndexVectorDim(); // Create initial tensor for result - Value initTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), resultType.getElementType()); // Build indexing maps for the generic op auto resultRank = resultType.getRank(); @@ -1509,15 +1522,16 @@ class GatherOpConversionPattern : public OpConversionPattern { resultRank, utils::IteratorType::parallel); // Create the indexing logic using linalg.generic - auto genericOp = rewriter.create( - loc, resultType, ValueRange{}, ValueRange{initTensor}, indexingMaps, - iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{}, ValueRange{initTensor}, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { // args[0] is the current value in the output tensor // Get the current output indices SmallVector outputIndices; for (int i = 0; i < resultRank; ++i) { - outputIndices.push_back(b.create(loc, i)); + outputIndices.push_back(linalg::IndexOp::create(b, loc, i)); } // Build the input indices for the gather @@ -1525,7 +1539,7 @@ class GatherOpConversionPattern : public OpConversionPattern { // Initialize all indices to zero first to avoid null values for (int64_t i = 0; i < inputType.getRank(); ++i) { - inputIndices[i] = b.create(loc, 0); + inputIndices[i] = arith::ConstantIndexOp::create(b, loc, 0); } // Determine which output dimensions are batch dimensions @@ -1561,7 +1575,7 @@ class GatherOpConversionPattern : public OpConversionPattern { if (d == indexVectorDim) { // This is the index vector dimension fullIndices.push_back( - b.create(loc, i)); + arith::ConstantIndexOp::create(b, loc, i)); } else { // This is a batch dimension if (static_cast(batchIdx) < batchDims.size()) { @@ -1574,24 +1588,25 @@ class GatherOpConversionPattern : public OpConversionPattern { // Extract the index value Value idxValue = - b.create(loc, startIndices, fullIndices); + tensor::ExtractOp::create(b, loc, startIndices, fullIndices); // Convert to index type if needed Value idx; if (idxValue.getType().isF32()) { // First convert f32 to i32 Value i32Val = - b.create(loc, b.getI32Type(), idxValue); + arith::FPToSIOp::create(b, loc, b.getI32Type(), idxValue); // Then convert i32 to index - idx = b.create(loc, b.getIndexType(), i32Val); + idx = + arith::IndexCastOp::create(b, loc, b.getIndexType(), i32Val); } else if (idxValue.getType().isInteger(32)) { // Direct cast from i32 to index - idx = - b.create(loc, b.getIndexType(), idxValue); + idx = arith::IndexCastOp::create(b, loc, b.getIndexType(), + idxValue); } else if (idxValue.getType().isInteger(64)) { // Direct cast from i64 to index - idx = - b.create(loc, b.getIndexType(), idxValue); + idx = arith::IndexCastOp::create(b, loc, b.getIndexType(), + idxValue); } else { // Already index type idx = idxValue; @@ -1628,9 +1643,9 @@ class GatherOpConversionPattern : public OpConversionPattern { // Extract the value from input tensor Value extracted = - b.create(loc, input, inputIndices); + tensor::ExtractOp::create(b, loc, input, inputIndices); - b.create(loc, extracted); + linalg::YieldOp::create(b, loc, extracted); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1662,8 +1677,8 @@ class EmbeddingOpConversionPattern int64_t resultRank = resultType.getRank(); // Create empty output tensor. - Value initTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), resultType.getElementType()); // Input indexing map: project result dims to input dims. // result(d0, ..., d_{N-1}, d_N, ..., d_{N+E-1}) -> input(d0, ..., d_{N-1}) @@ -1685,8 +1700,8 @@ class EmbeddingOpConversionPattern SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); - auto genericOp = rewriter.create( - loc, resultType, ValueRange{input}, ValueRange{initTensor}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{input}, ValueRange{initTensor}, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { // args[0] is the input element (index value) read via the affine map. @@ -1697,10 +1712,11 @@ class EmbeddingOpConversionPattern Value idx; if (idxValue.getType().isF32()) { Value i32Val = - b.create(loc, b.getI32Type(), idxValue); - idx = b.create(loc, b.getIndexType(), i32Val); + arith::FPToSIOp::create(b, loc, b.getI32Type(), idxValue); + idx = arith::IndexCastOp::create(b, loc, b.getIndexType(), i32Val); } else { - idx = b.create(loc, b.getIndexType(), idxValue); + idx = + arith::IndexCastOp::create(b, loc, b.getIndexType(), idxValue); } // Build weight indices: @@ -1709,17 +1725,17 @@ class EmbeddingOpConversionPattern // - Last dim is the last result iteration index. SmallVector weightIndices; for (int64_t i = 0; i < weightRank - 2; ++i) { - weightIndices.push_back(b.create(loc, 0)); + weightIndices.push_back(arith::ConstantIndexOp::create(b, loc, 0)); } weightIndices.push_back(idx); weightIndices.push_back( - b.create(loc, resultRank - 1)); + linalg::IndexOp::create(b, loc, resultRank - 1)); // Extract the value from weight tensor. Value extracted = - b.create(loc, weight, weightIndices); + tensor::ExtractOp::create(b, loc, weight, weightIndices); - b.create(loc, extracted); + linalg::YieldOp::create(b, loc, extracted); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -1916,46 +1932,28 @@ class ArgMaxOpConversionPattern : public OpConversionPattern { auto maxIndicesType = RankedTensorType::get(reducedShape, rewriter.getI32Type()); - // Initialize max values to -inf (float) or INT_MIN (integer), and max - // indices to 0. - Value initMax; - if (isa(elementType)) { - auto negInfAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(cast(elementType).getFloatSemantics(), - /*Negative=*/true)); - initMax = - rewriter.create(loc, elementType, negInfAttr); - } else { - auto intType = cast(elementType); - unsigned bitWidth = intType.getWidth(); - // For unsigned integers (including i1), use 0 as the initial minimum. - // For signed/signless integers, use the signed minimum value. - APInt minValue(bitWidth, /*val=*/0, /*isSigned=*/false); - if (!intType.isUnsignedInteger()) { - minValue = APInt::getSignedMinValue(bitWidth); - } - auto minAttr = rewriter.getIntegerAttr(elementType, minValue); - initMax = rewriter.create(loc, elementType, minAttr); - } - Value zero = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + // Initialize max values to -inf and max indices to 0. + auto negInfAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(cast(elementType).getFloatSemantics(), + /*Negative=*/true)); + Value negInf = + arith::ConstantOp::create(rewriter, loc, elementType, negInfAttr); + Value zero = arith::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(0)); Value maxValuesFilled = - rewriter - .create( - loc, initMax, - rewriter.create(loc, reducedShape, elementType) - .getResult()) + linalg::FillOp::create( + rewriter, loc, negInf, + tensor::EmptyOp::create(rewriter, loc, reducedShape, elementType) + .getResult()) .getResult(0); Value maxIndicesFilled = - rewriter - .create( - loc, zero, - rewriter - .create(loc, reducedShape, - rewriter.getI32Type()) - .getResult()) + linalg::FillOp::create(rewriter, loc, zero, + tensor::EmptyOp::create(rewriter, loc, + reducedShape, + rewriter.getI32Type()) + .getResult()) .getResult(0); // Indexing maps: identity for input, projection for outputs. @@ -1964,9 +1962,9 @@ class ArgMaxOpConversionPattern : public OpConversionPattern { AffineMap outputMap = AffineMap::get(rank, 0, outputExprs, rewriter.getContext()); - auto genericOp = rewriter.create( - loc, TypeRange{maxValuesType, maxIndicesType}, ValueRange{input}, - ValueRange{maxValuesFilled, maxIndicesFilled}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, TypeRange{maxValuesType, maxIndicesType}, + ValueRange{input}, ValueRange{maxValuesFilled, maxIndicesFilled}, SmallVector{inputMap, outputMap, outputMap}, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value currentVal = args[0]; @@ -1976,38 +1974,27 @@ class ArgMaxOpConversionPattern : public OpConversionPattern { // Compute linearized index across reduce dimensions (row-major). Value linearIdx = nullptr; for (int64_t d : reduceDims) { - Value idx = b.create( - loc, b.getI32Type(), b.create(loc, d)); + Value idx = arith::IndexCastOp::create( + b, loc, b.getI32Type(), linalg::IndexOp::create(b, loc, d)); if (!linearIdx) { linearIdx = idx; } else { - Value dimSize = b.create( - loc, b.getI32Type(), + Value dimSize = arith::ConstantOp::create( + b, loc, b.getI32Type(), b.getI32IntegerAttr(inputType.getShape()[d])); - linearIdx = b.create( - loc, b.create(loc, linearIdx, dimSize), idx); + linearIdx = arith::AddIOp::create( + b, loc, arith::MulIOp::create(b, loc, linearIdx, dimSize), + idx); } } - Value isGreater; - if (isa(elementType)) { - isGreater = b.create(loc, arith::CmpFPredicate::OGT, - currentVal, currentMax) - .getResult(); - } else { - auto intType = cast(elementType); - arith::CmpIPredicate pred = intType.isUnsignedInteger() - ? arith::CmpIPredicate::ugt - : arith::CmpIPredicate::sgt; - isGreater = - b.create(loc, pred, currentVal, currentMax) - .getResult(); - } - Value newMax = - b.create(loc, isGreater, currentVal, currentMax); + Value isGreater = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OGT, currentVal, currentMax); + Value newMax = arith::SelectOp::create(b, loc, isGreater, currentVal, + currentMax); Value newIdx = - b.create(loc, isGreater, linearIdx, currentIdx); - b.create(loc, ValueRange{newMax, newIdx}); + arith::SelectOp::create(b, loc, isGreater, linearIdx, currentIdx); + linalg::YieldOp::create(b, loc, ValueRange{newMax, newIdx}); }); Value result = genericOp.getResult(1); @@ -2021,10 +2008,10 @@ class ArgMaxOpConversionPattern : public OpConversionPattern { } auto shapeType = tosa::shapeType::get(rewriter.getContext(), keepDimShape.size()); - auto shapeOp = rewriter.create( - loc, shapeType, rewriter.getIndexTensorAttr(keepDimShape)); + auto shapeOp = tosa::ConstShapeOp::create( + rewriter, loc, shapeType, rewriter.getIndexTensorAttr(keepDimShape)); result = - rewriter.create(loc, resultType, result, shapeOp); + tosa::ReshapeOp::create(rewriter, loc, resultType, result, shapeOp); } rewriter.replaceOp(op, result); @@ -2072,7 +2059,7 @@ class CumSumOpConversionPattern : public OpConversionPattern { "Unsupported element type for cumsum"); } Value output = - rewriter.create(loc, resultType, zeroAttr); + arith::ConstantOp::create(rewriter, loc, resultType, zeroAttr); // Compute the slice type (same shape but with dim size = 1). SmallVector sliceShape(inputType.getShape()); @@ -2082,7 +2069,7 @@ class CumSumOpConversionPattern : public OpConversionPattern { // Create a zero-filled tensor for the running sum accumulator. DenseElementsAttr zeroSliceAttr = createDenseElementsAttr(sliceType, 0.0); Value runningSum = - rewriter.create(loc, sliceType, zeroSliceAttr); + arith::ConstantOp::create(rewriter, loc, sliceType, zeroSliceAttr); // Build the static sizes and strides for slice operations. SmallVector staticSizes; @@ -2104,20 +2091,21 @@ class CumSumOpConversionPattern : public OpConversionPattern { offsets[dim] = rewriter.getIndexAttr(idx); // Extract the current slice from input. - Value inputSlice = rewriter.create( - loc, sliceType, input, offsets, staticSizes, staticStrides); + Value inputSlice = tensor::ExtractSliceOp::create( + rewriter, loc, sliceType, input, offsets, staticSizes, staticStrides); // Add current input slice to running sum. auto emptySlice = - rewriter.create(loc, sliceShape, elementType); - auto addOp = rewriter.create( - loc, sliceType, ValueRange{runningSum, inputSlice}, - emptySlice.getResult()); + tensor::EmptyOp::create(rewriter, loc, sliceShape, elementType); + auto addOp = linalg::AddOp::create(rewriter, loc, sliceType, + ValueRange{runningSum, inputSlice}, + emptySlice.getResult()); runningSum = addOp.getResult(0); // Insert the new sum into the output tensor at the current position. - output = rewriter.create( - loc, runningSum, output, offsets, staticSizes, staticStrides); + output = + tensor::InsertSliceOp::create(rewriter, loc, runningSum, output, + offsets, staticSizes, staticStrides); } rewriter.replaceOp(op, output); @@ -2167,8 +2155,8 @@ class ConcatenateHeadsOpConversionPattern inputShape[1], inputShape[3]}; auto transposedType = RankedTensorType::get(transposedShape, elementType); - auto transposeOp = rewriter.create(loc, transposedType, - input, permutation); + auto transposeOp = tosa::TransposeOp::create(rewriter, loc, transposedType, + input, permutation); // Step 2: Reshape to [batch_size, sequence_size, num_heads * head_size] ArrayRef outputShape = resultType.getShape(); @@ -2176,10 +2164,10 @@ class ConcatenateHeadsOpConversionPattern auto shapeType = tosa::shapeType::get(rewriter.getContext(), outputShape.size()); auto attr = rewriter.getIndexTensorAttr(newShapeValues); - auto shapeOp = rewriter.create(loc, shapeType, attr); + auto shapeOp = tosa::ConstShapeOp::create(rewriter, loc, shapeType, attr); - auto reshapeOp = - rewriter.create(loc, resultType, transposeOp, shapeOp); + auto reshapeOp = tosa::ReshapeOp::create(rewriter, loc, resultType, + transposeOp, shapeOp); rewriter.replaceOp(op, reshapeOp); return success(); @@ -2234,32 +2222,33 @@ class SoftmaxOpConversionPattern : public OpConversionPattern { // Step 1: Compute max along dimension for numerical stability. auto axisAttr = rewriter.getI32IntegerAttr(dim); Value maxVal = - rewriter.create(loc, reducedType, input, axisAttr); + tosa::ReduceMaxOp::create(rewriter, loc, reducedType, input, axisAttr); // Step 2: Subtract max from input (input - max). // tosa::SubOp handles broadcasting automatically. - Value shifted = rewriter.create(loc, inputType, input, maxVal); + Value shifted = + tosa::SubOp::create(rewriter, loc, inputType, input, maxVal); // Step 3: Compute exp(shifted). - Value expVals = rewriter.create(loc, inputType, shifted); + Value expVals = tosa::ExpOp::create(rewriter, loc, inputType, shifted); // Step 4: Compute sum of exp along dimension. - Value sumExp = - rewriter.create(loc, reducedType, expVals, axisAttr); + Value sumExp = tosa::ReduceSumOp::create(rewriter, loc, reducedType, + expVals, axisAttr); // Step 5: Divide exp by sum (exp / sum). // Use reciprocal and multiply with broadcasting. Value reciprocal = - rewriter.create(loc, reducedType, sumExp); + tosa::ReciprocalOp::create(rewriter, loc, reducedType, sumExp); // tosa::MulOp requires a shift tensor (0 for float ops). auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type()); auto shiftAttr = DenseElementsAttr::get(shiftType, rewriter.getI8IntegerAttr(0)); - Value shift = rewriter.create(loc, shiftType, shiftAttr); + Value shift = tosa::ConstOp::create(rewriter, loc, shiftType, shiftAttr); - Value result = rewriter.create(loc, resultType, expVals, - reciprocal, shift); + Value result = tosa::MulOp::create(rewriter, loc, resultType, expVals, + reciprocal, shift); rewriter.replaceOp(op, result); return success(); @@ -2297,8 +2286,9 @@ class PermuteOpConversionPattern : public OpConversionPattern { Value input = adaptor.getInput(); llvm::ArrayRef permutation = op.getPermutation(); - auto output = rewriter.create( - op.getLoc(), resultType.getShape(), resultType.getElementType()); + auto output = + tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), + resultType.getElementType()); rewriter.replaceOpWithNewOp( op, input, output.getResult(), permutation); @@ -2359,8 +2349,8 @@ class SliceStaticOpConversionPattern } // Create the extract_slice operation - Value extractedSlice = rewriter.create( - op.getLoc(), resultType, input, offsets, sizes, strides); + Value extractedSlice = tensor::ExtractSliceOp::create( + rewriter, op.getLoc(), resultType, input, offsets, sizes, strides); rewriter.replaceOp(op, extractedSlice); @@ -2412,11 +2402,11 @@ class PadOpConversionPattern : public OpConversionPattern { Type elementType = inputType.getElementType(); Value padConstant; if (isa(elementType)) { - padConstant = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(elementType, padValue)); + padConstant = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getFloatAttr(elementType, padValue)); } else { - padConstant = rewriter.create( - op.getLoc(), + padConstant = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getIntegerAttr(elementType, static_cast(padValue))); } @@ -2473,8 +2463,8 @@ class ConstantOpConversionPattern op, "Expected DenseElementsAttr or DenseResourceElementsAttr"); } - auto newConstant = rewriter.create( - op.getLoc(), resultType, convertedValue); + auto newConstant = arith::ConstantOp::create(rewriter, op.getLoc(), + resultType, convertedValue); rewriter.replaceOp(op, newConstant.getResult()); return success(); @@ -2507,7 +2497,7 @@ class NamedFillOpConversionPattern : public OpConversionPattern { } auto constOp = - rewriter.create(op.getLoc(), resultType, fillAttr); + arith::ConstantOp::create(rewriter, op.getLoc(), resultType, fillAttr); rewriter.replaceOp(op, constOp.getResult()); return success(); @@ -2579,42 +2569,44 @@ class ArangeOpConversionPattern : public OpConversionPattern { int64_t step = adaptor.getStep(); Type elementType = resultType.getElementType(); - Value initTensor = rewriter.create( - loc, resultType.getShape(), elementType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), elementType); AffineMap outputMap = rewriter.getDimIdentityMap(); SmallVector iteratorTypes = { utils::IteratorType::parallel}; - auto genericOp = rewriter.create( - loc, resultType, ValueRange{}, ValueRange{initTensor}, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{}, ValueRange{initTensor}, SmallVector{outputMap}, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value idx = b.create(loc, 0); + Value idx = linalg::IndexOp::create(b, loc, 0); // Compute: start + idx * step Value result; if (isa(elementType)) { Value idxFloat = - b.create(loc, b.getI64Type(), idx); - Value idxFP = b.create(loc, elementType, idxFloat); - Value startVal = b.create( - loc, b.getFloatAttr(elementType, static_cast(start))); - Value stepVal = b.create( - loc, b.getFloatAttr(elementType, static_cast(step))); - Value scaled = b.create(loc, idxFP, stepVal); - result = b.create(loc, startVal, scaled); + arith::IndexCastOp::create(b, loc, b.getI64Type(), idx); + Value idxFP = + arith::SIToFPOp::create(b, loc, elementType, idxFloat); + Value startVal = arith::ConstantOp::create( + b, loc, + b.getFloatAttr(elementType, static_cast(start))); + Value stepVal = arith::ConstantOp::create( + b, loc, b.getFloatAttr(elementType, static_cast(step))); + Value scaled = arith::MulFOp::create(b, loc, idxFP, stepVal); + result = arith::AddFOp::create(b, loc, startVal, scaled); } else { - Value idxInt = b.create(loc, elementType, idx); - Value startVal = b.create( - loc, b.getIntegerAttr(elementType, start)); - Value stepVal = b.create( - loc, b.getIntegerAttr(elementType, step)); - Value scaled = b.create(loc, idxInt, stepVal); - result = b.create(loc, startVal, scaled); + Value idxInt = arith::IndexCastOp::create(b, loc, elementType, idx); + Value startVal = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(elementType, start)); + Value stepVal = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(elementType, step)); + Value scaled = arith::MulIOp::create(b, loc, idxInt, stepVal); + result = arith::AddIOp::create(b, loc, startVal, scaled); } - b.create(loc, result); + linalg::YieldOp::create(b, loc, result); }); rewriter.replaceOp(op, genericOp.getResult(0)); @@ -2656,11 +2648,11 @@ class MeanOpConversionPattern : public OpConversionPattern { const bool useUnsignedCast = intType.isUnsigned() || intType.getWidth() == 1; if (useUnsignedCast) { - input = rewriter.create(op.getLoc(), floatInputType, - input); + input = arith::UIToFPOp::create(rewriter, op.getLoc(), floatInputType, + input); } else { - input = rewriter.create(op.getLoc(), floatInputType, - input); + input = arith::SIToFPOp::create(rewriter, op.getLoc(), floatInputType, + input); } inputType = cast(input.getType()); } @@ -2689,13 +2681,15 @@ class MeanOpConversionPattern : public OpConversionPattern { } auto divisor = - rewriter.create(op.getLoc(), resultType, divisorAttr); + tosa::ConstOp::create(rewriter, op.getLoc(), resultType, divisorAttr); - auto output = rewriter.create( - op.getLoc(), resultType.getShape(), resultType.getElementType()); + auto output = + tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), + resultType.getElementType()); - auto divOp = rewriter.create( - op.getLoc(), resultType, ValueRange{sum, divisor}, output.getResult()); + auto divOp = + linalg::DivOp::create(rewriter, op.getLoc(), resultType, + ValueRange{sum, divisor}, output.getResult()); rewriter.replaceOp(op, divOp.getResult(0)); return success(); @@ -2777,26 +2771,26 @@ class LayerNormOpConversionPattern op, "Unsupported element type for layer norm"); } Value numElementsConst = - rewriter.create(loc, reducedType, numElementsAttr); - Value reciprocalN = - rewriter.create(loc, reducedType, numElementsConst); + tosa::ConstOp::create(rewriter, loc, reducedType, numElementsAttr); + Value reciprocalN = tosa::ReciprocalOp::create(rewriter, loc, reducedType, + numElementsConst); // mean = sum * (1/N) auto shiftType = RankedTensorType::get({1}, rewriter.getI8Type()); auto shiftAttr = DenseElementsAttr::get(shiftType, rewriter.getI8IntegerAttr(0)); - Value shift = rewriter.create(loc, shiftType, shiftAttr); + Value shift = tosa::ConstOp::create(rewriter, loc, shiftType, shiftAttr); - Value mean = - rewriter.create(loc, reducedType, sum, reciprocalN, shift); + Value mean = tosa::MulOp::create(rewriter, loc, reducedType, sum, + reciprocalN, shift); // Step 2: centered = input - mean (tosa broadcasts automatically). - Value centered = rewriter.create(loc, inputType, input, mean); + Value centered = tosa::SubOp::create(rewriter, loc, inputType, input, mean); // Step 3: Compute variance = mean(centered^2). // First compute centered^2. - Value centeredSquared = - rewriter.create(loc, inputType, centered, centered, shift); + Value centeredSquared = tosa::MulOp::create(rewriter, loc, inputType, + centered, centered, shift); // Sum of squared differences. Value sumSquared = createReductionOpChain( @@ -2804,25 +2798,25 @@ class LayerNormOpConversionPattern rewriter); // variance = sumSquared * (1/N) - Value variance = rewriter.create(loc, reducedType, sumSquared, - reciprocalN, shift); + Value variance = tosa::MulOp::create(rewriter, loc, reducedType, sumSquared, + reciprocalN, shift); // Step 4: Add epsilon for numerical stability. float epsilon = op.getEpsilon().convertToFloat(); DenseElementsAttr epsilonAttr = createDenseElementsAttr(reducedType, static_cast(epsilon)); Value epsilonConst = - rewriter.create(loc, reducedType, epsilonAttr); + tosa::ConstOp::create(rewriter, loc, reducedType, epsilonAttr); Value variancePlusEps = - rewriter.create(loc, reducedType, variance, epsilonConst); + tosa::AddOp::create(rewriter, loc, reducedType, variance, epsilonConst); // Step 5: inv_std = rsqrt(variance + epsilon). Value invStd = - rewriter.create(loc, reducedType, variancePlusEps); + tosa::RsqrtOp::create(rewriter, loc, reducedType, variancePlusEps); // Step 6: normalized = centered * inv_std (tosa broadcasts automatically). Value normalized = - rewriter.create(loc, resultType, centered, invStd, shift); + tosa::MulOp::create(rewriter, loc, resultType, centered, invStd, shift); // Step 7: Apply weight (gamma) if present. // Weight and bias need to be reshaped to match the input rank for TOSA ops. @@ -2832,8 +2826,8 @@ class LayerNormOpConversionPattern if (adaptor.getWeight()) { Value reshapedWeight = reshapeByPrependingOnes( adaptor.getWeight(), rank, numNormDims, elementType, loc, rewriter); - result = rewriter.create(loc, resultType, result, - reshapedWeight, shift); + result = tosa::MulOp::create(rewriter, loc, resultType, result, + reshapedWeight, shift); } // Step 8: Apply bias (beta) if present. @@ -2841,7 +2835,7 @@ class LayerNormOpConversionPattern Value reshapedBias = reshapeByPrependingOnes( adaptor.getBias(), rank, numNormDims, elementType, loc, rewriter); result = - rewriter.create(loc, resultType, result, reshapedBias); + tosa::AddOp::create(rewriter, loc, resultType, result, reshapedBias); } rewriter.replaceOp(op, result); @@ -2882,10 +2876,10 @@ class SqueezeOpConversionPattern : public OpConversionPattern { tosa::shapeType::get(rewriter.getContext(), newShape.size()); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - auto reshapeOp = rewriter.create(op.getLoc(), resultType, - input, shapeOp); + auto reshapeOp = tosa::ReshapeOp::create(rewriter, op.getLoc(), resultType, + input, shapeOp); rewriter.replaceOp(op, reshapeOp); @@ -2918,7 +2912,7 @@ class UnsqueezeOpConversionPattern tosa::shapeType::get(rewriter.getContext(), newShape.size()); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); rewriter.replaceOpWithNewOp(op, resultType, input, shapeOp); @@ -3042,8 +3036,8 @@ class LinearOpConversionPattern : public OpConversionPattern { permutation.push_back(static_cast(lhsShape.size() - 1)); permutation.push_back(static_cast(lhsShape.size() - 2)); - lhs = rewriter.create(op.getLoc(), transposedType, - lhs, permutation); + lhs = tosa::TransposeOp::create(rewriter, op.getLoc(), transposedType, + lhs, permutation); lhsType = transposedType; } } @@ -3068,8 +3062,8 @@ class LinearOpConversionPattern : public OpConversionPattern { permutation.push_back(static_cast(rhsShape.size() - 1)); permutation.push_back(static_cast(rhsShape.size() - 2)); - rhs = rewriter.create(op.getLoc(), transposedType, - rhs, permutation); + rhs = tosa::TransposeOp::create(rewriter, op.getLoc(), transposedType, + rhs, permutation); rhsType = transposedType; } } @@ -3092,10 +3086,10 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - lhs3D = rewriter.create(op.getLoc(), newType, lhs, - shapeOp.getResult()); + lhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, lhs, + shapeOp.getResult()); lhs3DType = newType; } else if (lhsRank > 3) { // Check for dynamic dimensions in batch dimensions. @@ -3119,10 +3113,10 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - lhs3D = rewriter.create(op.getLoc(), newType, lhs, - shapeOp.getResult()); + lhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, lhs, + shapeOp.getResult()); lhs3DType = newType; } @@ -3135,10 +3129,10 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - rhs3D = rewriter.create(op.getLoc(), newType, rhs, - shapeOp.getResult()); + rhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, rhs, + shapeOp.getResult()); rhs3DType = newType; } else if (rhsRank > 3) { // Check for dynamic dimensions in batch dimensions. @@ -3162,10 +3156,10 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto attr = rewriter.getIndexTensorAttr(newShape); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - rhs3D = rewriter.create(op.getLoc(), newType, rhs, - shapeOp.getResult()); + rhs3D = tosa::ReshapeOp::create(rewriter, op.getLoc(), newType, rhs, + shapeOp.getResult()); rhs3DType = newType; } @@ -3180,11 +3174,11 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto multiplesAttr = rewriter.getIndexTensorAttr(multiples); - auto multiplesOp = rewriter.create( - op.getLoc(), shapeType, multiplesAttr); + auto multiplesOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, multiplesAttr); - lhs3D = rewriter.create(op.getLoc(), newType, lhs3D, - multiplesOp); + lhs3D = tosa::TileOp::create(rewriter, op.getLoc(), newType, lhs3D, + multiplesOp); lhs3DType = cast(lhs3D.getType()); } else if (rhs3DType.getShape()[0] == 1 && lhs3DType.getShape()[0] > 1) { SmallVector multiples = {lhs3DType.getShape()[0], 1, 1}; @@ -3195,11 +3189,11 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), 3); auto multiplesAttr = rewriter.getIndexTensorAttr(multiples); - auto multiplesOp = rewriter.create( - op.getLoc(), shapeType, multiplesAttr); + auto multiplesOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, multiplesAttr); - rhs3D = rewriter.create(op.getLoc(), newType, rhs3D, - multiplesOp); + rhs3D = tosa::TileOp::create(rewriter, op.getLoc(), newType, rhs3D, + multiplesOp); rhs3DType = cast(rhs3D.getType()); } } @@ -3210,8 +3204,8 @@ class LinearOpConversionPattern : public OpConversionPattern { rhs3DType.getShape()[2]}, resultType.getElementType()); - Value matmulResult = rewriter.create( - op.getLoc(), matmulResultType, lhs3D, rhs3D); + Value matmulResult = tosa::MatMulOp::create(rewriter, op.getLoc(), + matmulResultType, lhs3D, rhs3D); // Reshape result back to original rank if needed if (resultType.getRank() != matmulResultType.getRank()) { @@ -3223,10 +3217,10 @@ class LinearOpConversionPattern : public OpConversionPattern { } auto attr = rewriter.getIndexTensorAttr(shapeValues); auto shapeOp = - rewriter.create(op.getLoc(), shapeType, attr); + tosa::ConstShapeOp::create(rewriter, op.getLoc(), shapeType, attr); - matmulResult = rewriter.create( - op.getLoc(), resultType, matmulResult, shapeOp.getResult()); + matmulResult = tosa::ReshapeOp::create(rewriter, op.getLoc(), resultType, + matmulResult, shapeOp.getResult()); } // If bias is provided, add it to the result @@ -3249,14 +3243,14 @@ class LinearOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), newBiasShape.size()); auto shapeAttr = rewriter.getIndexTensorAttr(newBiasShape); - auto shapeOp = rewriter.create( - op.getLoc(), shapeType, shapeAttr); - bias = rewriter.create(op.getLoc(), reshapedBiasType, - bias, shapeOp.getResult()); + auto shapeOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, shapeAttr); + bias = tosa::ReshapeOp::create(rewriter, op.getLoc(), reshapedBiasType, + bias, shapeOp.getResult()); } - matmulResult = rewriter.create(op.getLoc(), resultType, - matmulResult, bias); + matmulResult = tosa::AddOp::create(rewriter, op.getLoc(), resultType, + matmulResult, bias); } rewriter.replaceOp(op, matmulResult); @@ -3287,8 +3281,8 @@ class RepeatOpConversionPattern : public OpConversionPattern { auto shapeType = tosa::shapeType::get(rewriter.getContext(), multiples.size()); auto multiplesAttr = rewriter.getIndexTensorAttr(multiples); - auto multiplesOp = rewriter.create( - op.getLoc(), shapeType, multiplesAttr); + auto multiplesOp = tosa::ConstShapeOp::create(rewriter, op.getLoc(), + shapeType, multiplesAttr); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getInput(), multiplesOp); diff --git a/lib/Conversion/TTIRToLinalg/Utils.cpp b/lib/Conversion/TTIRToLinalg/Utils.cpp index 4de9cfdf1c7..dde23e2980c 100644 --- a/lib/Conversion/TTIRToLinalg/Utils.cpp +++ b/lib/Conversion/TTIRToLinalg/Utils.cpp @@ -85,8 +85,8 @@ Value broadcastToShape(Value input, ArrayRef targetShape, Location loc, return input; } - auto initTensor = rewriter.create(loc, targetShape, - inputType.getElementType()); + auto initTensor = ttir::EmptyOp::create(rewriter, loc, targetShape, + inputType.getElementType()); // When all dims need broadcasting (e.g. [1,1] -> [64,128]), extract the // scalar element and use linalg.fill instead of linalg.broadcast, since @@ -94,10 +94,10 @@ Value broadcastToShape(Value input, ArrayRef targetShape, Location loc, // tensor) and tensor.collapse_shape with empty reassociation is invalid. if (broadcastDims.size() == targetShape.size()) { SmallVector zeroIndices( - inputShape.size(), rewriter.create(loc, 0)); - Value scalar = rewriter.create(loc, input, zeroIndices); + inputShape.size(), arith::ConstantIndexOp::create(rewriter, loc, 0)); + Value scalar = tensor::ExtractOp::create(rewriter, loc, input, zeroIndices); auto fillOp = - rewriter.create(loc, scalar, initTensor.getResult()); + linalg::FillOp::create(rewriter, loc, scalar, initTensor.getResult()); return fillOp.getResult(0); } @@ -107,12 +107,12 @@ Value broadcastToShape(Value input, ArrayRef targetShape, Location loc, SmallVector, 2> collapseDimGroups = getCollapseDims(inputShape, targetShape); if (collapseDimGroups.size() != inputShape.size()) { - broadcastInput = - rewriter.create(loc, input, collapseDimGroups); + broadcastInput = tensor::CollapseShapeOp::create(rewriter, loc, input, + collapseDimGroups); } - auto broadcastOp = rewriter.create( - loc, broadcastInput, initTensor.getResult(), broadcastDims); + auto broadcastOp = linalg::BroadcastOp::create( + rewriter, loc, broadcastInput, initTensor.getResult(), broadcastDims); return broadcastOp.getResults().front(); } @@ -137,17 +137,17 @@ Value convertToBooleanTensor(Value input, Location loc, SmallVector zeroShape(inputType.getRank(), 1); auto zeroType = RankedTensorType::get(zeroShape, elementType); auto zeroAttr = createDenseElementsAttr(zeroType, 0.0); - auto zeroConst = rewriter.create(loc, zeroType, zeroAttr); + auto zeroConst = tosa::ConstOp::create(rewriter, loc, zeroType, zeroAttr); // For logical operations, non-zero means true. // So we need: (input != 0) which we get by computing !(input == 0). auto boolType = RankedTensorType::get(inputType.getShape(), rewriter.getIntegerType(1)); auto equalZero = - rewriter.create(loc, boolType, input, zeroConst); + tosa::EqualOp::create(rewriter, loc, boolType, input, zeroConst); // Then use LogicalNotOp to invert it, giving us (input != 0). auto notEqualZero = - rewriter.create(loc, boolType, equalZero); + tosa::LogicalNotOp::create(rewriter, loc, boolType, equalZero); return notEqualZero; } @@ -170,13 +170,13 @@ Value createTosaConst(ConversionPatternRewriter &rewriter, Location loc, SmallVector shape(rank, 1); auto type = RankedTensorType::get(shape, elementType); auto attr = createDenseElementsAttr(type, value); - return rewriter.create(loc, type, attr); + return tosa::ConstOp::create(rewriter, loc, type, attr); } Value createTosaMulShift(ConversionPatternRewriter &rewriter, Location loc) { auto type = RankedTensorType::get({1}, rewriter.getI8Type()); auto attr = DenseElementsAttr::get(type, rewriter.getI8IntegerAttr(0)); - return rewriter.create(loc, type, attr); + return tosa::ConstOp::create(rewriter, loc, type, attr); } } // namespace mlir::tt::ttir_to_linalg diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index c16a8f6a4ae..e7291fec023 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -67,8 +67,8 @@ struct IndexToSliceConversionPattern } } - auto newOp = rewriter.create( - op.getLoc(), getTypeConverter()->convertType(op.getType()), + auto newOp = ttir::SliceStaticOp::create( + rewriter, op.getLoc(), getTypeConverter()->convertType(op.getType()), adaptor.getInput(), rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends), rewriter.getArrayAttr(steps)); @@ -114,8 +114,8 @@ struct ReverseOpConversionPattern // Step 1: Permute reversing dims to front. auto permutedShape = ttmlir::utils::applyPermutation(shape, permutation); - current = rewriter.create( - loc, + current = ttir::PermuteOp::create( + rewriter, loc, RankedTensorType::get(permutedShape, inputType.getElementType(), inputType.getEncoding()), current, permutation); @@ -134,7 +134,7 @@ struct ReverseOpConversionPattern auto flatType = RankedTensorType::get(flatShape, inputType.getElementType(), inputType.getEncoding()); current = - rewriter.create(loc, flatType, current, shapeAttr); + ttir::ReshapeOp::create(rewriter, loc, flatType, current, shapeAttr); // Step 3: Create reversed linear indices [N-1, N-2, ..., 0]. SmallVector indices(nReversing); @@ -144,25 +144,25 @@ struct ReverseOpConversionPattern auto idxType = RankedTensorType::get( {nReversing}, rewriter.getIntegerType(32, /*isSigned=*/true)); auto idxAttr = DenseIntElementsAttr::get(idxType, indices); - Value idxConst = rewriter.create(loc, idxType, idxAttr); + Value idxConst = ttir::ConstantOp::create(rewriter, loc, idxType, idxAttr); // Step 4: EmbeddingOp to reorder rows. current = - rewriter.create(loc, flatType, idxConst, current); + ttir::EmbeddingOp::create(rewriter, loc, flatType, idxConst, current); // Step 5: Reshape back to permuted shape. auto permShapeAttr = rewriter.getI32ArrayAttr( SmallVector(permutedShape.begin(), permutedShape.end())); auto permType = RankedTensorType::get( permutedShape, inputType.getElementType(), inputType.getEncoding()); - current = - rewriter.create(loc, permType, current, permShapeAttr); + current = ttir::ReshapeOp::create(rewriter, loc, permType, current, + permShapeAttr); // Step 6: Inverse permute back to original shape. SmallVector invPerm = ttmlir::utils::inversePermutation(permutation); - current = rewriter.create( - loc, + current = ttir::PermuteOp::create( + rewriter, loc, RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding()), current, invPerm); @@ -408,8 +408,8 @@ struct GatherToEmbeddingConversionPattern auto embeddingOutputType = mlir::RankedTensorType::get( newOutputShape, input.getType().getElementType(), input.getType().getEncoding()); - ttir::EmbeddingOp embeddingOp = rewriter.create( - op.getLoc(), embeddingOutputType, startIndices, input); + ttir::EmbeddingOp embeddingOp = ttir::EmbeddingOp::create( + rewriter, op.getLoc(), embeddingOutputType, startIndices, input); rewriter.replaceOp(op, reshapeAndPermuteOutput(rewriter, embeddingOp, startIndexMap[0], op)); @@ -436,8 +436,8 @@ struct GatherToEmbeddingConversionPattern })); auto permutedInputShape = ttmlir::utils::applyPermutation(inputType.getShape(), inputPermutation); - return rewriter.create( - loc, + return ttir::PermuteOp::create( + rewriter, loc, RankedTensorType::get(permutedInputShape, inputType.getElementType(), inputType.getEncoding()), input, inputPermutation); @@ -489,20 +489,22 @@ struct GatherToEmbeddingConversionPattern auto permutedStartIndicesShape = ttmlir::utils::applyPermutation( startIndicesType.getShape(), startIndicesPermutation); auto startIndicesPermuted = - rewriter - .create( - ttmlir::utils::appendLocationSuffix(op.getLoc(), - "_permuteStartIndices"), - RankedTensorType::get(permutedStartIndicesShape, - startIndicesType.getElementType(), - startIndicesType.getEncoding()), - startIndices, startIndicesPermutation) + ttir::PermuteOp::create( + rewriter, + + ttmlir::utils::appendLocationSuffix(op.getLoc(), + "_permuteStartIndices"), + RankedTensorType::get(permutedStartIndicesShape, + startIndicesType.getElementType(), + startIndicesType.getEncoding()), + startIndices, startIndicesPermutation) .getResult(); // Typecast op because matmul needs float operands. auto typecastResultType = startIndicesPermuted.getType().clone( mlir::Float32Type::get(op.getContext())); - ttir::TypecastOp typecastOp = rewriter.create( + ttir::TypecastOp typecastOp = ttir::TypecastOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op->getLoc(), "_typecast"), typecastResultType, startIndicesPermuted); @@ -518,7 +520,8 @@ struct GatherToEmbeddingConversionPattern mlir::Float32Type::get(op.getContext())); auto denseAttr = mlir::DenseElementsAttr::get(tensorType, llvm::ArrayRef(strides)); - ttir::ConstantOp constantOp = rewriter.create( + ttir::ConstantOp constantOp = ttir::ConstantOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op->getLoc(), "_constant"), tensorType, denseAttr); @@ -528,8 +531,8 @@ struct GatherToEmbeddingConversionPattern auto matmulResultType = mlir::RankedTensorType::get( matmulResultShape, Float32Type::get(op.getContext())); - return rewriter.create(op.getLoc(), matmulResultType, - typecastOp.getResult(), constantOp); + return ttir::MatmulOp::create(rewriter, op.getLoc(), matmulResultType, + typecastOp.getResult(), constantOp); } // If startIndicesShape[indexVectorDim] > 1, but we are actually slicing only @@ -558,8 +561,8 @@ struct GatherToEmbeddingConversionPattern llvm::SmallVector resultShape(startIndicesShape); resultShape[indexVectorDim] = 1; - return rewriter.create( - loc, + return ttir::SliceStaticOp::create( + rewriter, loc, RankedTensorType::get(resultShape, startIndicesType.getElementType(), startIndicesType.getEncoding()), startIndices, rewriter.getI32ArrayAttr(begins), @@ -610,21 +613,23 @@ struct GatherToEmbeddingConversionPattern startIndicesType.getEncoding()); auto offsetAttr = mlir::DenseElementsAttr::get(expandedType, llvm::ArrayRef(matrixData)); - auto offsetConstant = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_offsetConstant"), + auto offsetConstant = ttir::ConstantOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_offsetConstant"), expandedType, offsetAttr); // Create broadcast dimensions - all dimensions map directly except the // expanded one. llvm::SmallVector broadcastDimensions = {sliceSize, 1}; // Broadcast the original startIndices to the expanded shape. - auto broadcastedStartIndices = rewriter.create( + auto broadcastedStartIndices = ttir::BroadcastOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_broadcastStartIndices"), expandedType, startIndices, rewriter.getDenseI64ArrayAttr(broadcastDimensions)); // Add the broadcasted tensors to get the final expanded indices. - return rewriter.create( + return ttir::AddOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_expandedStartIndices"), expandedType, broadcastedStartIndices, offsetConstant); } @@ -687,7 +692,8 @@ struct GatherToEmbeddingConversionPattern ttmlir::utils::appendLocationSuffix(op->getLoc(), "_reshapeOutput"), output, permutedOutputShape); - return rewriter.create( + return ttir::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op->getLoc(), "_permuteOutput"), RankedTensorType::get(expectedOutputShape, expectedOutputType.getElementType(), @@ -702,9 +708,10 @@ struct GatherToEmbeddingConversionPattern auto shapeAttr = rewriter.getI32ArrayAttr(llvm::SmallVector(targetShape)); - return rewriter.create( - loc, inputType.cloneWith(targetShape, inputType.getElementType()), - input, shapeAttr); + return ttir::ReshapeOp::create( + rewriter, loc, + inputType.cloneWith(targetShape, inputType.getElementType()), input, + shapeAttr); } }; } // namespace @@ -805,8 +812,8 @@ struct GatherToSliceRepeatConcatConversionPattern slicesToConcat.append(createSlices(ends, indexedDim, sliceSize, inputShape, inputType, rewriter, op)); - Value result = rewriter.create( - op.getLoc(), op.getType(), slicesToConcat, + Value result = ttir::ConcatOp::create( + rewriter, op.getLoc(), op.getType(), slicesToConcat, rewriter.getSI32IntegerAttr(static_cast(indexedDim))); rewriter.replaceOp(op, result); @@ -846,8 +853,8 @@ struct GatherToSliceRepeatConcatConversionPattern auto sliceType = RankedTensorType::get( sliceShape, inputType.getElementType(), inputType.getEncoding()); - Value slice = rewriter.create( - op.getLoc(), sliceType, op.getInput(), + Value slice = ttir::SliceStaticOp::create( + rewriter, op.getLoc(), sliceType, op.getInput(), rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(endsArr), rewriter.getI32ArrayAttr(step)); @@ -858,9 +865,8 @@ struct GatherToSliceRepeatConcatConversionPattern SmallVector repeatDims(inputShape.size(), 1); repeatDims[indexedDim] = numberOfRepeats; - slice = rewriter.create( - op.getLoc(), repeatType, slice, - rewriter.getDenseI64ArrayAttr(repeatDims)); + slice = ttir::RepeatOp::create(rewriter, op.getLoc(), repeatType, slice, + rewriter.getDenseI64ArrayAttr(repeatDims)); slices.push_back(slice); } @@ -964,8 +970,9 @@ struct DotGeneralToMatmulConversionPattern computeProductOfDims(rhsType.getShape(), rhsResultDims)); // Perform matmul operation. - auto matmulOp = rewriter.create( - op.getLoc(), RankedTensorType::get(matmulDestinationShape, elementType), + auto matmulOp = ttir::MatmulOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(matmulDestinationShape, elementType), lhsMatmulInput, rhsMatmulInput); // Reshape the result by unrolling the prod(lhsResultDims) to original @@ -1047,8 +1054,8 @@ struct DotGeneralToMatmulConversionPattern SmallVector destinationShape = ttmlir::utils::applyPermutation(inputType.getShape(), permutation); - auto permuteOp = rewriter.create( - loc, + auto permuteOp = ttir::PermuteOp::create( + rewriter, loc, RankedTensorType::get(destinationShape, inputType.getElementType(), inputType.getEncoding()), input, permutation); @@ -1087,8 +1094,8 @@ struct DotGeneralToMatmulConversionPattern llvm::SmallVector finalShapeI32(finalShape.begin(), finalShape.end()); - auto reshapeOp = rewriter.create( - loc, + auto reshapeOp = ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(finalShape, type.getElementType(), type.getEncoding()), input, rewriter.getI32ArrayAttr(finalShapeI32)); @@ -1143,9 +1150,9 @@ class PoolingToFullOp : public OpConversionPattern { FloatAttr::get(Float32Type::get(rewriter.getContext()), std::get(fillValue))); - rewriter.replaceOp( - op, rewriter.create(op.getLoc(), op.getResult().getType(), - fillValueAttr)); + rewriter.replaceOp(op, ttir::FullOp::create(rewriter, op.getLoc(), + op.getResult().getType(), + fillValueAttr)); return success(); } @@ -1238,8 +1245,8 @@ struct SelectToSliceConversionPattern begins[dim] = newBegin; ends[dim] = newEnd; - auto newOp = rewriter.create( - op.getLoc(), + auto newOp = ttir::SliceStaticOp::create( + rewriter, op.getLoc(), RankedTensorType::get(resultShape, inputType.getElementType(), inputType.getEncoding()), adaptor.getInput(), rewriter.getI32ArrayAttr(begins), @@ -1249,8 +1256,8 @@ struct SelectToSliceConversionPattern assert(!slices.empty()); if (slices.size() > 1) { - auto concatOp = - rewriter.create(op.getLoc(), outputType, slices, dim); + auto concatOp = ttir::ConcatOp::create(rewriter, op.getLoc(), outputType, + slices, dim); rewriter.replaceOp(op, concatOp); } else { rewriter.replaceOp(op, slices[0]); @@ -1301,12 +1308,12 @@ struct ArangeForceLastDimensionPattern RankedTensorType arangeOutputType = RankedTensorType::get( requiredShape, outputType.getElementType(), outputType.getEncoding()); - Value output = - rewriter - .create( // perform arange on the last dimension to - // match how ttnn behaves - op.getLoc(), arangeOutputType, start, end, step, 0) - .getResult(); + Value output = ttir::ArangeOp::create(rewriter, + // perform arange on the last + // dimension to match how ttnn behaves + op.getLoc(), arangeOutputType, start, + end, step, 0) + .getResult(); std::vector outputShape = arangeOutputType.getShape().vec(); @@ -1320,7 +1327,8 @@ struct ArangeForceLastDimensionPattern : reshapeShape.push_back(1); } - output = rewriter.create( + output = ttir::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_reshapeOutput"), RankedTensorType::get(reshapeShape, outputType.getElementType(), outputType.getEncoding()), @@ -1350,7 +1358,8 @@ struct ArangeForceLastDimensionPattern ttmlir::utils::getBroadcastDimensions(inputShape, outputShape); - output = rewriter.create( + output = ttir::BroadcastOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_broadcastOutput"), broadcastType, output, broadcastShape); @@ -1425,9 +1434,9 @@ struct ReductionOrPattern : public OpConversionPattern { RankedTensorType reduceOutputType = mlir::cast( getTypeConverter()->convertType(op.getResult().getType())); - mlir::Value sumOp = rewriter.create( - op.getLoc(), reduceOutputType, adaptor.getInput(), op.getKeepDim(), - op.getDimArgAttr()); + mlir::Value sumOp = ttir::SumOp::create( + rewriter, op.getLoc(), reduceOutputType, adaptor.getInput(), + op.getKeepDim(), op.getDimArgAttr()); // Create zero constant. auto elementType = reduceOutputType.getElementType(); @@ -1444,13 +1453,14 @@ struct ReductionOrPattern : public OpConversionPattern { ElementsAttr zeroConstantAttr = DenseElementsAttr::get(reduceOutputType, zerAttr); - mlir::Value zeroConstant = rewriter.create( + mlir::Value zeroConstant = ttir::ConstantOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_zeroConstant"), reduceOutputType, zeroConstantAttr); // Compare sum != 0. - mlir::Value cmpOp = rewriter.create( - op.getLoc(), reduceOutputType, sumOp, zeroConstant); + mlir::Value cmpOp = ttir::NotEqualOp::create( + rewriter, op.getLoc(), reduceOutputType, sumOp, zeroConstant); // Typecast boolean result to float type. rewriter.replaceOpWithNewOp(op, reduceOutputType, cmpOp); @@ -1486,8 +1496,9 @@ normalizeToNCHW(mlir::Value input, uint64_t featureIndex, llvm::SmallVector permutedShape = ttmlir::utils::applyPermutation( llvm::ArrayRef(currentShape), llvm::ArrayRef(permutation)); - newInput = rewriter.create( - loc, RankedTensorType::get(permutedShape, inputType.getElementType()), + newInput = mlir::tt::ttir::PermuteOp::create( + rewriter, loc, + RankedTensorType::get(permutedShape, inputType.getElementType()), newInput, rewriter.getDenseI64ArrayAttr(permutation)); currentShape = permutedShape; } @@ -1502,8 +1513,8 @@ normalizeToNCHW(mlir::Value input, uint64_t featureIndex, currentShape[3] * currentShape[4]}; llvm::SmallVector reshapedShapeI32(reshapedShape.begin(), reshapedShape.end()); - newInput = rewriter.create( - loc, + newInput = mlir::tt::ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(reshapedShape, inputType.getElementType(), inputType.getEncoding()), newInput, rewriter.getI32ArrayAttr(reshapedShapeI32)); @@ -1526,8 +1537,8 @@ normalizeToNCHW(mlir::Value input, uint64_t featureIndex, } llvm::SmallVector reshapedShapeI32(reshapedShape.begin(), reshapedShape.end()); - newInput = rewriter.create( - loc, + newInput = mlir::tt::ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(reshapedShape, inputType.getElementType(), inputType.getEncoding()), newInput, rewriter.getI32ArrayAttr(reshapedShapeI32)); @@ -1563,8 +1574,8 @@ static mlir::Value denormalizeFromNCHW(mlir::Value output, llvm::SmallVector shapeAfterPermuteI32(shapeAfterPermute.begin(), shapeAfterPermute.end()); - result = rewriter.create( - loc, + result = mlir::tt::ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(shapeAfterPermute, outputType.getElementType(), outputType.getEncoding()), result, rewriter.getI32ArrayAttr(shapeAfterPermuteI32)); @@ -1578,8 +1589,8 @@ static mlir::Value denormalizeFromNCHW(mlir::Value output, std::iota(permutation.begin(), permutation.end(), 0); std::swap(permutation[1], permutation[originalFeatureIndex]); - result = rewriter.create( - loc, + result = mlir::tt::ttir::PermuteOp::create( + rewriter, loc, RankedTensorType::get(originalShape, outputType.getElementType(), outputType.getEncoding()), result, rewriter.getDenseI64ArrayAttr(permutation)); @@ -1614,8 +1625,8 @@ static mlir::Value getBatchNorm4DTensor(PatternRewriter &rewriter, Location loc, llvm::SmallVector shape32(newShape.begin(), newShape.end()); auto shapeAttr = rewriter.getI32ArrayAttr(shape32); - return rewriter.create( - loc, + return ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(newShape, inputType.getElementType(), inputType.getEncoding()), batchNormInput, shapeAttr); @@ -1637,8 +1648,8 @@ static mlir::Value reshapeBatchNorm4DTo1D(PatternRewriter &rewriter, llvm::SmallVector shape1D = { static_cast(target1DType.getDimSize(0))}; - return rewriter.create( - loc, + return ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(target1DType.getShape(), target1DType.getElementType(), target1DType.getEncoding()), @@ -1721,10 +1732,9 @@ struct BatchNormInferencePattern IntegerAttr dimensionAttr = mlir::IntegerAttr::get(integerType, 1); // Create the BatchNorm op with normalized input and 4D weight tensors - auto batchNormInferenceOp = - rewriter.create( - loc, normalizedOutputType, normalizedInput, scale4D, offset4D, - mean4D, variance4D, adaptor.getEpsilonAttr(), dimensionAttr); + auto batchNormInferenceOp = mlir::tt::ttir::BatchNormInferenceOp::create( + rewriter, loc, normalizedOutputType, normalizedInput, scale4D, offset4D, + mean4D, variance4D, adaptor.getEpsilonAttr(), dimensionAttr); // Denormalize output back to original layout mlir::Value result = @@ -1814,8 +1824,9 @@ struct BatchNormTrainingPattern // Create new BatchNormTrainingOp with normalized input and all 4D weight // tensors - auto batchNormTrainingOp = rewriter.create( - loc, TypeRange{normalizedOutputType, mean4DType, variance4DType}, + auto batchNormTrainingOp = ttir::BatchNormTrainingOp::create( + rewriter, loc, + TypeRange{normalizedOutputType, mean4DType, variance4DType}, normalizedInput, scale4D, offset4D, mean4D, variance4D, adaptor.getEpsilonAttr(), dimensionAttr, adaptor.getMomentumAttr()); @@ -1859,7 +1870,7 @@ getScaleAndZeroPoint(mlir::quant::QuantizedType elementType, mlir::DenseFPElementsAttr scaleDenseAttr = mlir::DenseFPElementsAttr::get(scaleType, scaleValue); ttir::ConstantOp scaleConstant = - rewriter.create(loc, scaleType, scaleDenseAttr); + ttir::ConstantOp::create(rewriter, loc, scaleType, scaleDenseAttr); // Create ttir::ConstantOp for zero point. int32_t zeroPoint = static_cast(quantPerTensorType.getZeroPoint()); @@ -1867,8 +1878,8 @@ getScaleAndZeroPoint(mlir::quant::QuantizedType elementType, {1}, IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed)); mlir::DenseIntElementsAttr zeroPointDenseAttr = mlir::DenseIntElementsAttr::get(zeroPointType, zeroPoint); - ttir::ConstantOp zeroPointConstant = rewriter.create( - loc, zeroPointType, zeroPointDenseAttr); + ttir::ConstantOp zeroPointConstant = ttir::ConstantOp::create( + rewriter, loc, zeroPointType, zeroPointDenseAttr); return {scaleConstant.getResult(), zeroPointConstant.getResult()}; } @@ -1884,7 +1895,7 @@ getScaleAndZeroPoint(mlir::quant::QuantizedType elementType, mlir::DenseFPElementsAttr scaleDenseAttr = mlir::DenseFPElementsAttr::get(scaleType, scales); ttir::ConstantOp scaleConstant = - rewriter.create(loc, scaleType, scaleDenseAttr); + ttir::ConstantOp::create(rewriter, loc, scaleType, scaleDenseAttr); // Create ttir::ConstantOp for zero point. SmallVector zeroPoints( @@ -1894,8 +1905,8 @@ getScaleAndZeroPoint(mlir::quant::QuantizedType elementType, IntegerType::get(rewriter.getContext(), 32, IntegerType::Signed)); mlir::DenseIntElementsAttr zeroPointDenseAttr = mlir::DenseIntElementsAttr::get(zeroPointType, zeroPoints); - ttir::ConstantOp zeroPointConstant = rewriter.create( - loc, zeroPointType, zeroPointDenseAttr); + ttir::ConstantOp zeroPointConstant = ttir::ConstantOp::create( + rewriter, loc, zeroPointType, zeroPointDenseAttr); return {scaleConstant.getResult(), zeroPointConstant.getResult()}; } @@ -2079,8 +2090,9 @@ struct ReductionProdPattern : public OpConversionPattern { } RankedTensorType outputType = RankedTensorType::get(shape, elementType); - runningProdOp = rewriter.create( - op.getLoc(), outputType, runningProdOp, op.getKeepDimAttr(), dimArg); + runningProdOp = + ttir::ProdOp::create(rewriter, op.getLoc(), outputType, runningProdOp, + op.getKeepDimAttr(), dimArg); } rewriter.replaceOp(op, runningProdOp); @@ -2129,8 +2141,8 @@ struct ConvChannelLastDecompositionPattern auto permutedInputType = RankedTensorType::get(permutedInputShape, inputType.getElementType(), inputType.getEncoding()); - auto permutedInput = rewriter.create( - op.getLoc(), permutedInputType, adaptor.getInput(), toNhwc); + auto permutedInput = ttir::PermuteOp::create( + rewriter, op.getLoc(), permutedInputType, adaptor.getInput(), toNhwc); // Compute output shape in NHWC format. auto permutedOutputShape = @@ -2147,22 +2159,23 @@ struct ConvChannelLastDecompositionPattern ttmlir::utils::applyPermutation(biasType.getShape(), toNhwc); auto permutedBiasType = RankedTensorType::get( permutedBiasShape, biasType.getElementType(), biasType.getEncoding()); - permutedBias = rewriter.create( - op.getLoc(), permutedBiasType, adaptor.getBias(), toNhwc); + permutedBias = ttir::PermuteOp::create( + rewriter, op.getLoc(), permutedBiasType, adaptor.getBias(), toNhwc); } ConvOpType newConv; if constexpr (std::is_same_v) { - newConv = rewriter.create( - op.getLoc(), permutedOutputType, permutedInput, adaptor.getWeight(), - permutedBias, adaptor.getStride(), adaptor.getPadding(), - adaptor.getOutputPadding(), adaptor.getDilation(), op.getGroups(), + newConv = ttir::ConvTranspose2dOp::create( + rewriter, op.getLoc(), permutedOutputType, permutedInput, + adaptor.getWeight(), permutedBias, adaptor.getStride(), + adaptor.getPadding(), adaptor.getOutputPadding(), + adaptor.getDilation(), op.getGroups(), adaptor.getFlattenedCompatInfo()); } else if constexpr (std::is_same_v) { - newConv = rewriter.create( - op.getLoc(), permutedOutputType, permutedInput, adaptor.getWeight(), - permutedBias, adaptor.getStride(), adaptor.getPadding(), - adaptor.getDilation(), op.getGroups(), + newConv = ttir::Conv2dOp::create( + rewriter, op.getLoc(), permutedOutputType, permutedInput, + adaptor.getWeight(), permutedBias, adaptor.getStride(), + adaptor.getPadding(), adaptor.getDilation(), op.getGroups(), adaptor.getFlattenedCompatInfo()); } else { static_assert(ttmlir::utils::always_false(), @@ -2170,8 +2183,8 @@ struct ConvChannelLastDecompositionPattern } // Permute output from NHWC back to original layout. - auto outputPermute = rewriter.create( - op.getLoc(), outputType, newConv.getResult(), fromNhwc); + auto outputPermute = ttir::PermuteOp::create( + rewriter, op.getLoc(), outputType, newConv.getResult(), fromNhwc); rewriter.replaceOp(op, outputPermute.getResult()); return success(); @@ -2251,8 +2264,8 @@ struct ArgMaxPattern : public OpConversionPattern { for (int32_t dimIdx : reduceDims) { permutation.push_back(dimIdx); } - auto permuteOp = rewriter.create( - op.getLoc(), permuteOpResultType, adaptor.getInput(), + auto permuteOp = ttir::PermuteOp::create( + rewriter, op.getLoc(), permuteOpResultType, adaptor.getInput(), rewriter.getDenseI64ArrayAttr(permutation)); // Step 2. Reshape the permuted tensor to make all reduction dimensions into @@ -2264,8 +2277,8 @@ struct ArgMaxPattern : public OpConversionPattern { tempArgMaxShape.push_back(reshapedDimSize); // RankedTensor need int64_t shape - auto reshapeOp = rewriter.create( - op.getLoc(), + auto reshapeOp = ttir::ReshapeOp::create( + rewriter, op.getLoc(), RankedTensorType::get(tempArgMaxShape, inputType.getElementType()), permuteOp.getResult(), rewriter.getI32ArrayAttr(llvm::to_vector_of(tempArgMaxShape))); @@ -2278,8 +2291,8 @@ struct ArgMaxPattern : public OpConversionPattern { } else { tempArgMaxShape.pop_back(); // else remove the reduced dim } - auto argMaxOp = rewriter.create( - op.getLoc(), + auto argMaxOp = ttir::ArgMaxOp::create( + rewriter, op.getLoc(), RankedTensorType::get(tempArgMaxShape, outputType.getElementType()), reshapeOp.getResult(), rewriter.getBoolAttr(keepDim), rewriter.getI32ArrayAttr(newDim)); @@ -2295,8 +2308,8 @@ struct ArgMaxPattern : public OpConversionPattern { finalArgMaxShape.push_back(1); } } - auto result = rewriter.create( - op.getLoc(), + auto result = ttir::ReshapeOp::create( + rewriter, op.getLoc(), RankedTensorType::get(finalArgMaxShape, outputType.getElementType()), argMaxOp.getResult(), rewriter.getI32ArrayAttr( @@ -2495,8 +2508,8 @@ struct SplitQueryKeyValueAndSplitHeadsDecompositionPattern stepsAttr.push_back(rewriter.getI32IntegerAttr(1)); } - return rewriter.create( - loc, resultType, input, rewriter.getArrayAttr(beginsAttr), + return ttir::SliceStaticOp::create( + rewriter, loc, resultType, input, rewriter.getArrayAttr(beginsAttr), rewriter.getArrayAttr(endsAttr), rewriter.getArrayAttr(stepsAttr)); } @@ -2505,16 +2518,16 @@ struct SplitQueryKeyValueAndSplitHeadsDecompositionPattern Value input, RankedTensorType resultType) const { auto newShape = resultType.getShape(); llvm::SmallVector newShapeI32(newShape.begin(), newShape.end()); - return rewriter.create( - loc, resultType, input, rewriter.getI32ArrayAttr(newShapeI32)); + return ttir::ReshapeOp::create(rewriter, loc, resultType, input, + rewriter.getI32ArrayAttr(newShapeI32)); } // Helper to create a permute operation using ttir.permute. Value createPermute(ConversionPatternRewriter &rewriter, Location loc, Value input, RankedTensorType resultType, ArrayRef permutation) const { - return rewriter.create(loc, resultType, input, - permutation); + return ttir::PermuteOp::create(rewriter, loc, resultType, input, + permutation); } }; } // namespace @@ -2570,8 +2583,8 @@ struct NegativePadOpDecompositionPattern sliceResultType = RankedTensorType::get(slicedShape, inputType.getElementType()); - ttir::SliceStaticOp sliceOp = rewriter.create( - op.getLoc(), sliceResultType, input, + ttir::SliceStaticOp sliceOp = ttir::SliceStaticOp::create( + rewriter, op.getLoc(), sliceResultType, input, rewriter.getI32ArrayAttr(sliceBegins), rewriter.getI32ArrayAttr(sliceEnds), rewriter.getI32ArrayAttr(sliceSteps)); @@ -2582,8 +2595,8 @@ struct NegativePadOpDecompositionPattern SmallVector posPadding = llvm::to_vector( llvm::map_range(padding, [](int32_t p) { return std::max(p, 0); })); - ttir::PadOp padOp = rewriter.create( - op.getLoc(), op.getType(), input, + ttir::PadOp padOp = ttir::PadOp::create( + rewriter, op.getLoc(), op.getType(), input, rewriter.getDenseI32ArrayAttr(posPadding), adaptor.getValue()); input = padOp.getResult(); } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 3b9e2f2a956..3f329b94f3b 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -432,8 +432,8 @@ class EmbeddingBackwardOpConversionPattern auto paddedIndicesType = ttnn::utils::RankedTensorTypeFactory::create( indicesType, paddedIndicesShape); - inputIndices = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_pad_indices"), + inputIndices = ttnn::PadOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_pad_indices"), paddedIndicesType, inputIndices, rewriter.getDenseI32ArrayAttr(indicesPadding), rewriter.getF32FloatAttr(0.0), rewriter.getBoolAttr(true), nullptr); @@ -448,8 +448,8 @@ class EmbeddingBackwardOpConversionPattern auto paddedGradType = ttnn::utils::RankedTensorTypeFactory::create( gradTensor, paddedGradShape); - reshapedGrad = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_pad_gradient"), + reshapedGrad = ttnn::PadOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_pad_gradient"), paddedGradType, adaptor.getInGradient(), rewriter.getDenseI32ArrayAttr(gradPadding), rewriter.getF32FloatAttr(0.0), rewriter.getBoolAttr(true), nullptr); @@ -719,9 +719,9 @@ class UpdateCacheOpConversionPattern op, "UpdateCacheOp cache argument must have exactly one user"); } - rewriter.create( - op.getLoc(), adaptor.getCache(), adaptor.getInput(), - adaptor.getUpdateIndex(), adaptor.getBatchOffset()); + ttnn::UpdateCacheOp::create(rewriter, op.getLoc(), adaptor.getCache(), + adaptor.getInput(), adaptor.getUpdateIndex(), + adaptor.getBatchOffset()); rewriter.replaceOp(op, adaptor.getCache()); return success(); @@ -745,8 +745,8 @@ class PagedUpdateCacheOpConversionPattern op, "PagedUpdateCacheOp cache argument must have exactly one user"); } - rewriter.create( - op.getLoc(), adaptor.getCache(), adaptor.getInput(), + ttnn::PagedUpdateCacheOp::create( + rewriter, op.getLoc(), adaptor.getCache(), adaptor.getInput(), adaptor.getUpdateIndex(), adaptor.getShareCache(), adaptor.getPageTable()); @@ -772,9 +772,9 @@ class PagedFillCacheOpConversionPattern op, "PagedFillCacheOp cache argument must have exactly one user"); } - rewriter.create( - op.getLoc(), adaptor.getCache(), adaptor.getInput(), - adaptor.getPageTable(), adaptor.getBatchIdxTensor()); + ttnn::PagedFillCacheOp::create(rewriter, op.getLoc(), adaptor.getCache(), + adaptor.getInput(), adaptor.getPageTable(), + adaptor.getBatchIdxTensor()); rewriter.replaceOp(op, adaptor.getCache()); return success(); @@ -810,9 +810,8 @@ class FillCacheOpConversionPattern op, "FillCacheOp must have exactly one user"); } - rewriter.create(op.getLoc(), adaptor.getCache(), - adaptor.getInput(), - adaptor.getBatchOffset()); + ttnn::FillCacheOp::create(rewriter, op.getLoc(), adaptor.getCache(), + adaptor.getInput(), adaptor.getBatchOffset()); rewriter.replaceOp(op, adaptor.getCache()); return success(); @@ -1202,10 +1201,11 @@ class BatchNormTrainingOpConversionPattern auto resultType = this->getTypeConverter()->convertType(op.getResult().getType()); - auto batchNormTrainingOp = rewriter.create( - op.getLoc(), resultType, adaptor.getOperand(), adaptor.getRunningMean(), - adaptor.getRunningVariance(), adaptor.getEpsilon(), - adaptor.getMomentum(), adaptor.getScale(), adaptor.getOffset(), + auto batchNormTrainingOp = ttnn::BatchNormTrainingOp::create( + rewriter, op.getLoc(), resultType, adaptor.getOperand(), + adaptor.getRunningMean(), adaptor.getRunningVariance(), + adaptor.getEpsilon(), adaptor.getMomentum(), adaptor.getScale(), + adaptor.getOffset(), /*memoryConfig*/ nullptr); // TTIR expects the running mean and variance to be returned as separate @@ -1734,9 +1734,9 @@ class Conv3dOpConversionPattern : public OpConversionPattern { outputType, permutedOutputShape); } - auto convOp = rewriter.create( - op.getLoc(), outputType, input, reshapedWeight, reshapedBias, device, - inChannelsAttr, outChannelsAttr, batchSizeAttr, inputDepthAttr, + auto convOp = ttnn::Conv3dOp::create( + rewriter, op.getLoc(), outputType, input, reshapedWeight, reshapedBias, + device, inChannelsAttr, outChannelsAttr, batchSizeAttr, inputDepthAttr, inputHeightAttr, inputWidthAttr, kernelSizeAttr, *strideAttr, *paddingAttr, paddingModeAttr, groupsAttr, outputDtypeAttr, nullptr, nullptr); @@ -1834,9 +1834,10 @@ class Conv3dOpConversionPattern : public OpConversionPattern { auto curTy = mlir::cast(result.getType()); auto blockedTy = ttnn::utils::RankedTensorTypeFactory::create(curTy, blockedShape); - result = rewriter.create( - loc, blockedTy, result, rewriter.getI32ArrayAttr(blockedShapeI32), - /*memory_config=*/nullptr); + result = + ttnn::ReshapeOp::create(rewriter, loc, blockedTy, result, + rewriter.getI32ArrayAttr(blockedShapeI32), + /*memory_config=*/nullptr); // Permute 6D: (K_D, K_H, K_W, num_blocks, C_in_block, O) // → (num_blocks, K_D, K_H, K_W, C_in_block, O) @@ -1856,9 +1857,9 @@ class Conv3dOpConversionPattern : public OpConversionPattern { RankedTensorType outputType = ttnn::utils::RankedTensorTypeFactory::create(resultTy, finalShape); - return rewriter.create( - loc, outputType, result, rewriter.getI32ArrayAttr(finalShapeI32), - /*memory_config=*/nullptr); + return ttnn::ReshapeOp::create(rewriter, loc, outputType, result, + rewriter.getI32ArrayAttr(finalShapeI32), + /*memory_config=*/nullptr); } // Transforms bias tensor to 2D: (1, 1, 1, 1, O) → (1, O) @@ -1876,9 +1877,9 @@ class Conv3dOpConversionPattern : public OpConversionPattern { RankedTensorType outputType = ttnn::utils::RankedTensorTypeFactory::create(biasTy, newShape); - return rewriter.create( - loc, outputType, bias, rewriter.getI32ArrayAttr(newShapeI32), - /*memory_config=*/nullptr); + return ttnn::ReshapeOp::create(rewriter, loc, outputType, bias, + rewriter.getI32ArrayAttr(newShapeI32), + /*memory_config=*/nullptr); } }; } // namespace @@ -2506,10 +2507,11 @@ class CollectivePermuteOpConversionPattern ConversionPatternRewriter &rewriter) const { FloatAttr zeroAttr = FloatAttr::get(Float32Type::get(rewriter.getContext()), 0.0f); - return rewriter - .create( - op.getLoc(), this->getTypeConverter()->convertType(op.getType()), - inputTensor, zeroAttr, zeroAttr, memoryConfigAttr) + return ttnn::ClampScalarOp::create( + rewriter, + + op.getLoc(), this->getTypeConverter()->convertType(op.getType()), + inputTensor, zeroAttr, zeroAttr, memoryConfigAttr) .getResult(); } @@ -2546,11 +2548,11 @@ class CollectivePermuteOpConversionPattern // Create a cloned tensor to skip P2P ops for self-mapped // source_target_pairs. mlir::Value resultTensor = - rewriter - .create( - op.getLoc(), - this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), memoryConfigAttr, dTypeAttr) + ttnn::AssignOp::create( + rewriter, + + op.getLoc(), this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), memoryConfigAttr, dTypeAttr) .getResult(); auto meshDevice = ttcore::lookupDevice(op); @@ -2576,11 +2578,11 @@ class CollectivePermuteOpConversionPattern DenseI64ArrayAttr receiveCoord = rewriter.getDenseI64ArrayAttr( ttmlir::utils::linearIdToCoord(targetDevice, meshShape)); resultTensor = - rewriter - .create( - op.getLoc(), - this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), sendCoord, receiveCoord, resultTensor) + ttnn::PointToPointOp::create( + rewriter, + + op.getLoc(), this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), sendCoord, receiveCoord, resultTensor) .getResult(); } @@ -2599,13 +2601,13 @@ class CollectivePermuteOpConversionPattern ttmlir::utils::linearIdToCoord(sourceDevice, meshShape)); DenseI64ArrayAttr receiveCoord = rewriter.getDenseI64ArrayAttr( ttmlir::utils::linearIdToCoord(idx, meshShape)); - resultTensor = - rewriter - .create( - op.getLoc(), - this->getTypeConverter()->convertType(op.getType()), - zerosTensor, sendCoord, receiveCoord, resultTensor) - .getResult(); + resultTensor = ttnn::PointToPointOp::create( + rewriter, + + op.getLoc(), + this->getTypeConverter()->convertType(op.getType()), + zerosTensor, sendCoord, receiveCoord, resultTensor) + .getResult(); } } @@ -2825,9 +2827,9 @@ class CollectiveBroadcastOpConversionPattern ttcore::DataTypeAttr dTypeAttr = ttcore::DataTypeAttr::get(op.getContext(), layoutAttr.getDataType()); - Value finalValue = rewriter.create( - op.getLoc(), inputType, adaptor.getInput(), memoryConfigAttr, - dTypeAttr); + Value finalValue = + ttnn::AssignOp::create(rewriter, op.getLoc(), inputType, + adaptor.getInput(), memoryConfigAttr, dTypeAttr); auto replicaGroups = ttmlir::utils::denseElementsAttrTo2D( adaptor.getReplicaGroups()); @@ -2839,8 +2841,8 @@ class CollectiveBroadcastOpConversionPattern for (size_t idx = 1; idx < group.size(); idx++) { // Skip the first device in the group because the buffer is already // cloned - finalValue = rewriter.create( - op.getLoc(), inputType, adaptor.getInput(), sourceCoord, + finalValue = ttnn::PointToPointOp::create( + rewriter, op.getLoc(), inputType, adaptor.getInput(), sourceCoord, rewriter.getDenseI64ArrayAttr( ttmlir::utils::linearIdToCoord(group[idx], meshShape)), finalValue); @@ -2940,8 +2942,8 @@ class SplitQueryKeyValueAndSplitHeadsOpConversionPattern this->getTypeConverter()->convertType(op.getValue().getType()); // Create the TTNN op with 3 results - auto ttnnOp = rewriter.create( - op.getLoc(), TypeRange{queryType, keyType, valueType}, + auto ttnnOp = ttnn::SplitQueryKeyValueAndSplitHeadsOp::create( + rewriter, op.getLoc(), TypeRange{queryType, keyType, valueType}, adaptor.getInputTensor(), adaptor.getKvInputTensor(), adaptor.getNumHeadsAttr(), adaptor.getNumKvHeadsAttr(), adaptor.getTransposeKeyAttr(), @@ -3049,7 +3051,8 @@ class ScaledDotProductAttentionOpConversionPattern maskType.getShape(), broadcastShape); auto shapeAttr = ttnn::ShapeAttr::get(rewriter.getContext(), broadcastDims); - return rewriter.create(loc, broadcastType, mask, shapeAttr); + return ttnn::RepeatOp::create(rewriter, loc, broadcastType, mask, + shapeAttr); } // Lower to SDPA decode op with necessary permutations. @@ -3070,9 +3073,9 @@ class ScaledDotProductAttentionOpConversionPattern Value attentionMask = broadcastMaskForDecode( adaptor.getAttentionMask(), numHeads, rewriter, op.getLoc()); - auto decodeOp = rewriter.create( - op.getLoc(), permutedQuery.getType(), permutedQuery, adaptor.getKey(), - adaptor.getValue(), op.getIsCausal(), attentionMask, + auto decodeOp = ttnn::ScaledDotProductAttentionDecodeOp::create( + rewriter, op.getLoc(), permutedQuery.getType(), permutedQuery, + adaptor.getKey(), adaptor.getValue(), op.getIsCausal(), attentionMask, /*cur_pos_tensor=*/Value(), /*attention_sink=*/Value(), adaptor.getScaleAttr(), /*memory_config=*/nullptr, /*program_config=*/nullptr); @@ -3152,9 +3155,10 @@ class AllToAllOpConversionPattern ends[splitDim] = (sliceIdx + 1) * splitSize; // Create a slice for this range - ttnn::SliceStaticOp sliceOp = rewriter.create( - loc, sliceOutputType, op.getInput(), rewriter.getI32ArrayAttr(begins), - rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); + ttnn::SliceStaticOp sliceOp = ttnn::SliceStaticOp::create( + rewriter, loc, sliceOutputType, op.getInput(), + rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), + rewriter.getI32ArrayAttr(steps)); sliceOpResults.push_back(sliceOp.getResult()); } // Step 2: Reorganize sliced data using PointToPoint communication. @@ -3175,9 +3179,9 @@ class AllToAllOpConversionPattern llvm::SmallVector reorgBuffers(splitCount); for (int32_t i = 0; i < splitCount; i++) { - reorgBuffers[i] = rewriter.create( - loc, sliceOpResults[i].getType(), sliceOpResults[i], memoryConfigAttr, - dTypeAttr); + reorgBuffers[i] = ttnn::AssignOp::create( + rewriter, loc, sliceOpResults[i].getType(), sliceOpResults[i], + memoryConfigAttr, dTypeAttr); } auto meshShape = ttcore::lookupDevice(op).getMeshShape(); @@ -3196,8 +3200,8 @@ class AllToAllOpConversionPattern } auto receiverCoord = rewriter.getDenseI64ArrayAttr( ttmlir::utils::linearIdToCoord(group[receiverIdx], meshShape)); - reorgBuffers[senderIdx] = rewriter.create( - loc, sliceOpResults[senderIdx].getType(), + reorgBuffers[senderIdx] = ttnn::PointToPointOp::create( + rewriter, loc, sliceOpResults[senderIdx].getType(), sliceOpResults[receiverIdx], senderCoord, receiverCoord, reorgBuffers[senderIdx]); } diff --git a/lib/Conversion/TTIRToTTNN/Utils.cpp b/lib/Conversion/TTIRToTTNN/Utils.cpp index ecd9be58872..cf002ee2d3a 100644 --- a/lib/Conversion/TTIRToTTNN/Utils.cpp +++ b/lib/Conversion/TTIRToTTNN/Utils.cpp @@ -26,9 +26,9 @@ ttnn::ReshapeOp generateReshape(mlir::TypedValue input, ttnn::utils::RankedTensorTypeFactory::create(inputType, newShape); llvm::SmallVector newShapeI32(newShape.begin(), newShape.end()); - return rewriter.create(newLoc, outputType, input, - rewriter.getI32ArrayAttr(newShapeI32), - /* memory_config */ nullptr); + return ttnn::ReshapeOp::create(rewriter, newLoc, outputType, input, + rewriter.getI32ArrayAttr(newShapeI32), + /* memory_config */ nullptr); } ttnn::ReshapeOp @@ -53,9 +53,10 @@ ttnn::PermuteOp generatePermute(mlir::TypedValue input, RankedTensorType outputType = ttnn::utils::RankedTensorTypeFactory::create(inputType, outputShape); - return rewriter.create( - newLoc, outputType, input, rewriter.getDenseI64ArrayAttr(permutation), - /* memory_config */ nullptr, /* pad_value */ mlir::FloatAttr()); + return ttnn::PermuteOp::create(rewriter, newLoc, outputType, input, + rewriter.getDenseI64ArrayAttr(permutation), + /* memory_config */ nullptr, + /* pad_value */ mlir::FloatAttr()); } ttnn::PadOp generatePad(mlir::TypedValue input, @@ -75,10 +76,11 @@ ttnn::PadOp generatePad(mlir::TypedValue input, RankedTensorType outputType = ttnn::utils::RankedTensorTypeFactory::create(inputType, outputShape); - return rewriter.create( - newLoc, outputType, input, rewriter.getDenseI32ArrayAttr(padding), - rewriter.getF32FloatAttr(0.0f), rewriter.getBoolAttr(true), - /*memory_config=*/nullptr); + return ttnn::PadOp::create(rewriter, newLoc, outputType, input, + rewriter.getDenseI32ArrayAttr(padding), + rewriter.getF32FloatAttr(0.0f), + rewriter.getBoolAttr(true), + /*memory_config=*/nullptr); } } // namespace ttir_to_ttnn::utils } // namespace tt diff --git a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp index 66a4d25195c..af927779e02 100644 --- a/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp +++ b/lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp @@ -205,8 +205,8 @@ class TTKernelStoreToL1OpToEmitCOpRewriter matchAndRewrite(ttkernel::StoreToL1Op op, ttkernel::StoreToL1Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - auto subscriptOp = rewriter.create( - op->getLoc(), + auto subscriptOp = emitc::SubscriptOp::create( + rewriter, op->getLoc(), emitc::LValueType::get( op.getContext(), mlir::cast(adaptor.getL1Ptr().getType()) @@ -216,8 +216,8 @@ class TTKernelStoreToL1OpToEmitCOpRewriter // Cast rhs to volatile tt_l1_ptr uint32_t to match the pointed type. // This is because assignment requires the types to match. This compiles // in metal, but it looks ugly. - auto casted = rewriter.create( - op->getLoc(), + auto casted = emitc::CastOp::create( + rewriter, op->getLoc(), emitc::OpaqueType::get(op.getContext(), "volatile tt_l1_ptr uint32_t"), adaptor.getValue()); rewriter.replaceOpWithNewOp(op, subscriptOp, casted); @@ -469,10 +469,12 @@ class TTKernelToEmitCDPrintRewriter StringRef fmt = op.getFmt(); auto stringlit = [&](StringRef str) { - return rewriter - .create( - op.getLoc(), rewriter.getType("const char[]"), - (Twine("\"") + str + "\"").str()) + return emitc::LiteralOp::create( + rewriter, + + op.getLoc(), + rewriter.getType("const char[]"), + (Twine("\"") + str + "\"").str()) .getResult(); }; @@ -493,12 +495,13 @@ class TTKernelToEmitCDPrintRewriter ttkernel::ThreadTypeAttr::name) .getValue() == ttkernel::ThreadType::Compute) { auto cbPrinter = - rewriter - .create( - op.getLoc(), - rewriter.getType("ttmlir::CBPrinter"), - "ttmlir::CBPrinter", nullptr, nullptr, - ValueRange{*operandsIter++}) + emitc::CallOpaqueOp::create( + rewriter, + + op.getLoc(), + rewriter.getType("ttmlir::CBPrinter"), + "ttmlir::CBPrinter", nullptr, nullptr, + ValueRange{*operandsIter++}) .getResult(0); vargs.push_back(cbPrinter); } else { @@ -573,9 +576,9 @@ class TTKernelInvokeSFPIOpRewriter ttkernel::InvokeSFPIOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { assert(op.getRegion().hasOneBlock()); - rewriter.create(op->getLoc(), - "experimental::invoke_sfpi([=]() {"); - auto endScope = rewriter.create(op->getLoc(), "});"); + emitc::VerbatimOp::create(rewriter, op->getLoc(), + "experimental::invoke_sfpi([=]() {"); + auto endScope = emitc::VerbatimOp::create(rewriter, op->getLoc(), "});"); rewriter.inlineBlockBefore(&op.getRegion().front(), endScope); rewriter.eraseOp(op); return success(); @@ -623,34 +626,35 @@ class TTKernelGetInterleavedAddrGenFastOpRewriter mlir::Type lvalueType = emitc::LValueType::get(opaqueStructType); // Declare the struct variable and then assign to its members - auto varOp = rewriter.create( - op->getLoc(), lvalueType, + auto varOp = emitc::VariableOp::create( + rewriter, op->getLoc(), lvalueType, emitc::OpaqueAttr::get(op.getContext(), "")); // Create an lvalue for all struct field accesses - auto lvalueBankBaseAddr = rewriter.create( - op->getLoc(), + auto lvalueBankBaseAddr = emitc::MemberOp::create( + rewriter, op->getLoc(), emitc::LValueType::get(adaptor.getBankBaseAddress().getType()), "bank_base_address", varOp); - auto lvaluePageSize = rewriter.create( - op->getLoc(), emitc::LValueType::get(adaptor.getPageSize().getType()), - "page_size", varOp); - auto lvalueDataFormat = rewriter.create( - op->getLoc(), + auto lvaluePageSize = emitc::MemberOp::create( + rewriter, op->getLoc(), + emitc::LValueType::get(adaptor.getPageSize().getType()), "page_size", + varOp); + auto lvalueDataFormat = emitc::MemberOp::create( + rewriter, op->getLoc(), emitc::LValueType::get(adaptor.getDataFormat().getType()), "data_format", varOp); // Assign corresponding values to the struct members - rewriter.create(op->getLoc(), lvalueBankBaseAddr, - adaptor.getBankBaseAddress()); - rewriter.create(op->getLoc(), lvaluePageSize, - adaptor.getPageSize()); - rewriter.create(op->getLoc(), lvalueDataFormat, - adaptor.getDataFormat()); + emitc::AssignOp::create(rewriter, op->getLoc(), lvalueBankBaseAddr, + adaptor.getBankBaseAddress()); + emitc::AssignOp::create(rewriter, op->getLoc(), lvaluePageSize, + adaptor.getPageSize()); + emitc::AssignOp::create(rewriter, op->getLoc(), lvalueDataFormat, + adaptor.getDataFormat()); // Load the value from the lvalue variable - auto loadOp = - rewriter.create(op->getLoc(), opaqueStructType, varOp); + auto loadOp = emitc::LoadOp::create(rewriter, op->getLoc(), + opaqueStructType, varOp); // Replace the original operation with the loaded value so it can be used. rewriter.replaceOp(op, loadOp.getResult()); @@ -718,14 +722,14 @@ class TTKernelTensorAccessorArgsOpRewriter // crtaArg>(); std::string code = "auto " + varName + " = TensorAccessorArgs<" + ctaArg + ", " + crtaArg + ">();"; - rewriter.create(op.getLoc(), code); + emitc::VerbatimOp::create(rewriter, op.getLoc(), code); // Create literal to reference the variable (pattern from // TTKernelClassMethodRewriter). auto resultType = this->getTypeConverter()->convertType(op->getResultTypes()[0]); auto literalOp = - rewriter.create(op.getLoc(), resultType, varName); + emitc::LiteralOp::create(rewriter, op.getLoc(), resultType, varName); rewriter.replaceOp(op, literalOp.getResult()); return success(); @@ -756,13 +760,13 @@ class TTKernelCreateFabricConnectionManagerOpRewriter mlir::Type lvalueType = emitc::LValueType::get(opaqueStructType); // Declare the struct variable - auto varOp = rewriter.create( - op->getLoc(), lvalueType, + auto varOp = emitc::VariableOp::create( + rewriter, op->getLoc(), lvalueType, emitc::OpaqueAttr::get(op.getContext(), "")); // Load the value from the lvalue variable - auto loadOp = - rewriter.create(op->getLoc(), opaqueStructType, varOp); + auto loadOp = emitc::LoadOp::create(rewriter, op->getLoc(), + opaqueStructType, varOp); // Replace the original operation with the loaded value so it can be used. rewriter.replaceOp(op, loadOp.getResult()); @@ -851,12 +855,12 @@ class TTKernelClassMethodRewriter : public OpConversionPattern { } callStr += ");"; - rewriter.create( - op->getLoc(), rewriter.getStringAttr(callStr), operands); + emitc::VerbatimOp::create(rewriter, op->getLoc(), + rewriter.getStringAttr(callStr), operands); // create a literal referencing the temp variable to be used later. auto literalOp = - rewriter.create(op->getLoc(), resultTypes, varName); + emitc::LiteralOp::create(rewriter, op->getLoc(), resultTypes, varName); rewriter.replaceOp(op, literalOp.getResult()); @@ -960,9 +964,9 @@ class TTKernelScalarUnaryTileOpRewriter : public OpConversionPattern { // Note that apparently "{{" produces "{" but "}" is not escaped in EmitC. std::string code = "{{ volatile int32_t __s = {}; " + getOpName(op).str() + "({}, __s); }"; - rewriter.create(op->getLoc(), - rewriter.getStringAttr(code), - ValueRange{scalarParam, dstIndex}); + emitc::VerbatimOp::create(rewriter, op->getLoc(), + rewriter.getStringAttr(code), + ValueRange{scalarParam, dstIndex}); rewriter.eraseOp(op); return success(); @@ -980,8 +984,8 @@ class TTKernelToEmitCPackReconfigL1AccToEmitCRewriter matchAndRewrite(ttkernel::PackReconfigL1AccOp op, ttkernel::PackReconfigL1AccOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.create( - op->getLoc(), + emitc::VerbatimOp::create( + rewriter, op->getLoc(), rewriter.getStringAttr("PACK((llk_pack_reconfig_l1_acc({})));"), ValueRange{adaptor.getL1AccEn()}); rewriter.eraseOp(op); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index ac62d139da4..00897e82a1d 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -973,8 +973,8 @@ class MaxPool2dWithIndicesOpConversionPattern using ReturnTy = std::vector<::ttnn::Tensor>; - auto maxPool2dWithIndicesOp = rewriter.create( - srcOp.getLoc(), + auto maxPool2dWithIndicesOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), rewriter.getType(ttnn_to_emitc::TypeNameV), convertOpName(srcOp), rewriter.getArrayAttr(args), /*template_args=*/nullptr, adaptor.getOperands()); @@ -983,8 +983,8 @@ class MaxPool2dWithIndicesOpConversionPattern for (unsigned i = 0; i < srcOp.getNumResults(); ++i) { // Create index to access i-th element. auto indexType = rewriter.getIndexType(); - auto indexOp = rewriter.create( - srcOp.getLoc(), indexType, std::to_string(i)); + auto indexOp = emitc::LiteralOp::create(rewriter, srcOp.getLoc(), + indexType, std::to_string(i)); Value indexVal = indexOp.getResult(); // Create LValue type for the tensor reference. @@ -993,13 +993,13 @@ class MaxPool2dWithIndicesOpConversionPattern ttnn_to_emitc::TypeNameV)); // Get reference to the i-th element in the result vector. - auto subscriptOp = rewriter.create( - srcOp.getLoc(), lvalueType, maxPool2dWithIndicesOp.getResult(0), - indexVal); + auto subscriptOp = emitc::SubscriptOp::create( + rewriter, srcOp.getLoc(), lvalueType, + maxPool2dWithIndicesOp.getResult(0), indexVal); // Load the actual tensor value from the reference. - auto loadOp = rewriter.create( - srcOp.getLoc(), + auto loadOp = emitc::LoadOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get( rewriter.getContext(), ttnn_to_emitc::TypeNameV), @@ -2402,8 +2402,8 @@ class GetTupleElementOpConversionPattern // SubscriptOp requires a Value object as index, which is created by // invoking the emitc::LiteralOp. // - Value indexAsVal = rewriter.create( - getTupleElementOp->getLoc(), rewriter.getIndexType(), + Value indexAsVal = emitc::LiteralOp::create( + rewriter, getTupleElementOp->getLoc(), rewriter.getIndexType(), std::to_string(adaptor.getIndex())); // SubscriptOp also returns an emitc::LValueType, so we wrap the @@ -2412,9 +2412,9 @@ class GetTupleElementOpConversionPattern emitc::LValueType lvalueReturnType = emitc::LValueType::get(emitc::OpaqueType::get( rewriter.getContext(), ttnn_to_emitc::TypeNameV<::ttnn::Tensor>)); - Value subscript = rewriter.create( - getTupleElementOp->getLoc(), lvalueReturnType, adaptor.getOperand(), - indexAsVal); + Value subscript = emitc::SubscriptOp::create( + rewriter, getTupleElementOp->getLoc(), lvalueReturnType, + adaptor.getOperand(), indexAsVal); // As SubscriptOp returns an LValueType, we need to convert it to an // OpaqueType - this is done by invoking the emitc::LoadOp. @@ -2495,8 +2495,9 @@ class LoadCachedOpConversionPattern rewriter.setInsertionPoint(funcOp); // Create the global variable using EmitC's GlobalOp - rewriter.create( - srcOp.getLoc(), StringAttr::get(rewriter.getContext(), globalVarName), + emitc::GlobalOp::create( + rewriter, srcOp.getLoc(), + StringAttr::get(rewriter.getContext(), globalVarName), TypeAttr::get(tupleType), /*initialValue=*/nullptr, /*extern_specifier=*/UnitAttr(), @@ -2519,40 +2520,41 @@ class LoadCachedOpConversionPattern ":std::vector<::ttnn::Tensor>)>"); auto addressAttr = emitc::OpaqueAttr::get(rewriter.getContext(), "&" + callee.str()); - auto funcPtrValue = rewriter.create( - srcOp.getLoc(), funcPtrType, addressAttr); + auto funcPtrValue = emitc::ConstantOp::create(rewriter, srcOp.getLoc(), + funcPtrType, addressAttr); - auto tupleOp = rewriter.create( - srcOp.getLoc(), tupleType, + auto tupleOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), tupleType, mlir::tt::ttnn_to_emitc::kCreateVectorFunctionName, nullptr, nullptr, adaptor.getInputs()); Value tupleValue = tupleOp.getResult(0); // Get a reference to the global variable using GetGlobalOp - auto globalVar = rewriter.create( - srcOp.getLoc(), emitc::LValueType::get(tupleType), globalSym); + auto globalVar = emitc::GetGlobalOp::create( + rewriter, srcOp.getLoc(), emitc::LValueType::get(tupleType), globalSym); // Create a pointer type for the output parameter auto ptrType = emitc::PointerType::get(rewriter.getContext(), tupleType); // Get the address of the global variable - auto addressOfOp = rewriter.create(srcOp.getLoc(), ptrType, - "&", globalVar); + auto addressOfOp = emitc::ApplyOp::create(rewriter, srcOp.getLoc(), ptrType, + "&", globalVar); // Call the wrapper function with the pointer if (isZeroArgWrapper) { - rewriter.create( - srcOp.getLoc(), TypeRange{}, "ttnn::constEvalFuncWrapperZeroArg", - ValueRange{funcPtrValue, addressOfOp}, ArrayAttr{}); + emitc::CallOpaqueOp::create(rewriter, srcOp.getLoc(), TypeRange{}, + "ttnn::constEvalFuncWrapperZeroArg", + ValueRange{funcPtrValue, addressOfOp}, + ArrayAttr{}); } else { - rewriter.create( - srcOp.getLoc(), TypeRange{}, "ttnn::constEvalFuncWrapper", + emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), TypeRange{}, "ttnn::constEvalFuncWrapper", ValueRange{funcPtrValue, tupleValue, addressOfOp}, ArrayAttr{}); } // Load the value from the global variable auto resultVar = - rewriter.create(srcOp.getLoc(), tupleType, globalVar); + emitc::LoadOp::create(rewriter, srcOp.getLoc(), tupleType, globalVar); // Unpack the tuple result - extract each element from the tuple SmallVector results; @@ -2560,8 +2562,8 @@ class LoadCachedOpConversionPattern for (unsigned i = 0; i < srcOp.getNumResults(); ++i) { // Create index value auto indexType = rewriter.getIndexType(); - auto indexOp = rewriter.create( - srcOp.getLoc(), indexType, std::to_string(i)); + auto indexOp = emitc::LiteralOp::create(rewriter, srcOp.getLoc(), + indexType, std::to_string(i)); Value indexVal = indexOp.getResult(); // Create LValue type for the tensor reference @@ -2570,12 +2572,13 @@ class LoadCachedOpConversionPattern // Get reference to the i-th element in the static cache result // Use the variable that references our global result - auto subscriptOp = rewriter.create( - srcOp.getLoc(), lvalueType, resultVar.getResult(), indexVal); + auto subscriptOp = + emitc::SubscriptOp::create(rewriter, srcOp.getLoc(), lvalueType, + resultVar.getResult(), indexVal); // Load the actual tensor value from the reference - auto loadOp = rewriter.create( - srcOp.getLoc(), + auto loadOp = emitc::LoadOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttnn::Tensor"), subscriptOp.getResult()); results.push_back(loadOp.getResult()); @@ -2657,8 +2660,8 @@ class MeshShardOpConversionPattern mlir::tt::ttnn::MeshShardOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create( - srcOp.getLoc(), + emitc::VerbatimOp::create( + rewriter, srcOp.getLoc(), "assert(0 && \"Mesh shard operation is " "not supported in emitc yet.\"); // ::ttnn::mesh_shard"); ttnn_to_emitc::EmitCTTNNEmitter emitter( @@ -2736,8 +2739,8 @@ class DistributeTensorOpConversionPattern rewriter.getContext(), kDistributedNs + "MeshMapperConfig"); auto mapperConfigOpaqueAttr = rewriter.getAttr(configCode); - auto mapperConfigOp = rewriter.create( - srcOp.getLoc(), mapperConfigType, mapperConfigOpaqueAttr); + auto mapperConfigOp = emitc::ConstantOp::create( + rewriter, srcOp.getLoc(), mapperConfigType, mapperConfigOpaqueAttr); auto meshDeviceRef = emitter.dereferenceToRef( adaptor.getMeshDevice(), kDistributedNs + "MeshDevice&"); @@ -2745,8 +2748,8 @@ class DistributeTensorOpConversionPattern auto meshMapperType = emitc::OpaqueType::get( rewriter.getContext(), "::std::unique_ptr<" + kDistributedNs + "TensorToMesh>"); - auto createMapperOp = rewriter.create( - srcOp.getLoc(), meshMapperType, + auto createMapperOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), meshMapperType, rewriter.getStringAttr(kDistributedNs + "create_mesh_mapper"), nullptr, nullptr, llvm::SmallVector{meshDeviceRef, @@ -2818,8 +2821,8 @@ class AggregateTensorOpConversionPattern rewriter.getContext(), kDistributedNs + "MeshComposerConfig"); auto composerConfigOpaqueAttr = rewriter.getAttr(configCode); - auto composerConfigOp = rewriter.create( - srcOp.getLoc(), composerConfigType, composerConfigOpaqueAttr); + auto composerConfigOp = emitc::ConstantOp::create( + rewriter, srcOp.getLoc(), composerConfigType, composerConfigOpaqueAttr); auto meshDeviceRef = emitter.dereferenceToRef( adaptor.getMeshDevice(), kDistributedNs + "MeshDevice&"); @@ -2827,8 +2830,8 @@ class AggregateTensorOpConversionPattern auto meshComposerType = emitc::OpaqueType::get( rewriter.getContext(), "::std::unique_ptr<" + kDistributedNs + "MeshToTensor>"); - auto createComposerOp = rewriter.create( - srcOp.getLoc(), meshComposerType, + auto createComposerOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), meshComposerType, rewriter.getStringAttr(kDistributedNs + "create_mesh_composer"), nullptr, nullptr, llvm::SmallVector{meshDeviceRef, @@ -3024,15 +3027,15 @@ class SliceStaticOpConversionPattern // Create SmallVector variable for begins auto beginsAttr = emitter.emit<::ttsl::SmallVector>(srcOp.getBegins()); - auto beginsVar = rewriter.create( - srcOp.getLoc(), + auto beginsVar = emitc::ConstantOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::SmallVector"), beginsAttr); // Create span from SmallVector variable using CallOpaqueOp - auto beginsSpanVar = rewriter.create( - srcOp.getLoc(), + auto beginsSpanVar = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::Span"), "ttsl::make_const_span", mlir::ArrayAttr{}, nullptr, @@ -3040,15 +3043,15 @@ class SliceStaticOpConversionPattern // Create SmallVector variable for ends auto endsAttr = emitter.emit<::ttsl::SmallVector>(srcOp.getEnds()); - auto endsVar = rewriter.create( - srcOp.getLoc(), + auto endsVar = emitc::ConstantOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::SmallVector"), endsAttr); // Create span from SmallVector variable using CallOpaqueOp - auto endsSpanVar = rewriter.create( - srcOp.getLoc(), + auto endsSpanVar = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::Span"), "ttsl::make_const_span", mlir::ArrayAttr{}, nullptr, @@ -3056,15 +3059,15 @@ class SliceStaticOpConversionPattern // Create SmallVector variable for step auto stepAttr = emitter.emit<::ttsl::SmallVector>(srcOp.getStep()); - auto stepVar = rewriter.create( - srcOp.getLoc(), + auto stepVar = emitc::ConstantOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::SmallVector"), stepAttr); // Create span from SmallVector variable using CallOpaqueOp - auto stepSpanVar = rewriter.create( - srcOp.getLoc(), + auto stepSpanVar = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), "::ttsl::Span"), "ttsl::make_const_span", mlir::ArrayAttr{}, nullptr, @@ -3261,8 +3264,8 @@ class BatchNormTrainingOpConversionPattern auto resultType = this->getTypeConverter()->convertType(srcOp.getResult().getType()); - auto callOp = rewriter.create( - srcOp.getLoc(), resultType, "ttnn::batch_norm", + auto callOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), resultType, "ttnn::batch_norm", rewriter.getArrayAttr(args), /*template_args=*/nullptr, adaptor.getOperands()); @@ -3550,8 +3553,8 @@ class PointToPointOpConversionPattern mlir::tt::ttnn::PointToPointOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.create( - srcOp.getLoc(), + emitc::VerbatimOp::create( + rewriter, srcOp.getLoc(), "assert(0 && \"PointToPoint operation is " "not supported in emitc yet.\"); // ::ttnn::PointToPoint"); ttnn_to_emitc::EmitCTTNNEmitter emitter( @@ -3600,21 +3603,23 @@ class AllToAllDispatchOpConversionPattern static constexpr llvm::StringLiteral kReturnTypeName = "::std::array<::ttnn::Tensor, 2>"; static constexpr llvm::StringLiteral kElemTypeName = "::ttnn::Tensor"; - auto callOp = rewriter.create( - srcOp.getLoc(), rewriter.getType(kReturnTypeName), + auto callOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), + rewriter.getType(kReturnTypeName), this->convertOpName(srcOp), rewriter.getArrayAttr(args), /*template_args=*/nullptr, adaptor.getOperands()); SmallVector results; for (unsigned i = 0; i < srcOp.getNumResults(); ++i) { - auto indexOp = rewriter.create( - srcOp.getLoc(), rewriter.getIndexType(), std::to_string(i)); + auto indexOp = emitc::LiteralOp::create( + rewriter, srcOp.getLoc(), rewriter.getIndexType(), std::to_string(i)); auto lvalueType = emitc::LValueType::get( emitc::OpaqueType::get(rewriter.getContext(), kElemTypeName)); - auto subscriptOp = rewriter.create( - srcOp.getLoc(), lvalueType, callOp.getResult(0), indexOp.getResult()); - auto loadOp = rewriter.create( - srcOp.getLoc(), + auto subscriptOp = + emitc::SubscriptOp::create(rewriter, srcOp.getLoc(), lvalueType, + callOp.getResult(0), indexOp.getResult()); + auto loadOp = emitc::LoadOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get(rewriter.getContext(), kElemTypeName), subscriptOp.getResult()); results.push_back(loadOp.getResult()); @@ -3685,23 +3690,24 @@ class MoeExpertTokenRemapOpConversionPattern // Multi-result: returns std::vector with 2 elements. using ReturnTy = std::vector<::ttnn::Tensor>; - auto callOp = rewriter.create( - srcOp.getLoc(), + auto callOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), rewriter.getType(ttnn_to_emitc::TypeNameV), this->convertOpName(srcOp), rewriter.getArrayAttr(args), /*template_args=*/nullptr, adaptor.getOperands()); SmallVector results; for (unsigned i = 0; i < srcOp.getNumResults(); ++i) { - auto indexOp = rewriter.create( - srcOp.getLoc(), rewriter.getIndexType(), std::to_string(i)); + auto indexOp = emitc::LiteralOp::create( + rewriter, srcOp.getLoc(), rewriter.getIndexType(), std::to_string(i)); auto lvalueType = emitc::LValueType::get(emitc::OpaqueType::get( rewriter.getContext(), ttnn_to_emitc::TypeNameV)); - auto subscriptOp = rewriter.create( - srcOp.getLoc(), lvalueType, callOp.getResult(0), indexOp.getResult()); - auto loadOp = rewriter.create( - srcOp.getLoc(), + auto subscriptOp = + emitc::SubscriptOp::create(rewriter, srcOp.getLoc(), lvalueType, + callOp.getResult(0), indexOp.getResult()); + auto loadOp = emitc::LoadOp::create( + rewriter, srcOp.getLoc(), emitc::OpaqueType::get( rewriter.getContext(), ttnn_to_emitc::TypeNameV), @@ -3964,8 +3970,8 @@ class NLPCreateQKVHeadsDecodeOpConversionPattern operands.push_back(adaptor.getBatchOffset()); } - auto nlpCreateQKVHeadsDecodeOp = rewriter.create( - srcOp.getLoc(), + auto nlpCreateQKVHeadsDecodeOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), rewriter.getType( ttnn_to_emitc::TypeNameV), convertOpName(srcOp), rewriter.getArrayAttr(args), @@ -3973,8 +3979,8 @@ class NLPCreateQKVHeadsDecodeOpConversionPattern llvm::SmallVector results; for (std::size_t i = 0; i < srcOp.getNumResults(); ++i) { - auto tupleGetResult = rewriter.create( - srcOp.getLoc(), + auto tupleGetResult = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), rewriter.getType( ttnn_to_emitc::TypeNameV<::ttnn::Tensor>), "::std::get", /*args=*/nullptr, @@ -4030,18 +4036,17 @@ class SplitQueryKeyValueAndSplitHeadsOpConversionPattern using OpReturnType = std::tuple<::ttnn::Tensor, ::ttnn::Tensor, ::ttnn::Tensor>; - auto splitQueryKeyValueAndSplitHeadsOp = - rewriter.create( - srcOp.getLoc(), - rewriter.getType( - ttnn_to_emitc::TypeNameV), - convertOpName(srcOp), rewriter.getArrayAttr(args), - /*template_args=*/nullptr, adaptor.getOperands()); + auto splitQueryKeyValueAndSplitHeadsOp = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), + rewriter.getType( + ttnn_to_emitc::TypeNameV), + convertOpName(srcOp), rewriter.getArrayAttr(args), + /*template_args=*/nullptr, adaptor.getOperands()); llvm::SmallVector results; for (std::size_t i = 0; i < srcOp.getNumResults(); ++i) { - auto tupleGetResult = rewriter.create( - srcOp.getLoc(), + auto tupleGetResult = emitc::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), rewriter.getType( ttnn_to_emitc::TypeNameV<::ttnn::Tensor>), "::std::get", /*args=*/nullptr, @@ -4223,13 +4228,13 @@ class CaptureOrExecuteTraceOpConversionPattern rewriter.setInsertionPoint(&*insertionPointOp); // Define global variable to track if this is the first call. - auto isFirstCall = rewriter.create( - loc, /*sym_name=*/"is_first_call", rewriter.getI1Type(), + auto isFirstCall = emitc::GlobalOp::create( + rewriter, loc, /*sym_name=*/"is_first_call", rewriter.getI1Type(), rewriter.getBoolAttr(true)); // Define global variable to track the trace id. - auto traceId = rewriter.create( - loc, /*sym_name=*/"trace_id", + auto traceId = emitc::GlobalOp::create( + rewriter, loc, /*sym_name=*/"trace_id", getTypeConverter()->convertType( mlir::tt::ttnn::utils::getTraceIdType(srcOp.getContext())), /*initial_value=*/nullptr); @@ -4237,8 +4242,8 @@ class CaptureOrExecuteTraceOpConversionPattern // Define global variables to track the inputs. llvm::SmallVector traceInputVariable; for (auto [i, input] : llvm::enumerate(adaptor.getInputs())) { - traceInputVariable.emplace_back(rewriter.create( - loc, /*sym_name=*/"input_" + std::to_string(i), + traceInputVariable.emplace_back(emitc::GlobalOp::create( + rewriter, loc, /*sym_name=*/"input_" + std::to_string(i), getTypeConverter()->convertType(input.getType()), /*initial_value=*/nullptr)); } @@ -4246,8 +4251,8 @@ class CaptureOrExecuteTraceOpConversionPattern // Define global variables to track the outputs. llvm::SmallVector traceOutputVariable; for (auto [i, output] : llvm::enumerate(srcOp.getResults())) { - traceOutputVariable.emplace_back(rewriter.create( - loc, /*sym_name=*/"output_" + std::to_string(i), + traceOutputVariable.emplace_back(emitc::GlobalOp::create( + rewriter, loc, /*sym_name=*/"output_" + std::to_string(i), getTypeConverter()->convertType(output.getType()), /*initial_value=*/nullptr)); } @@ -4266,7 +4271,7 @@ class CaptureOrExecuteTraceOpConversionPattern auto funcType = rewriter.getFunctionType(inputTypes, resultTypes); rewriter.setInsertionPoint(srcOp->getParentOfType()); - auto funcOp = rewriter.create(loc, funcName, funcType); + auto funcOp = emitc::FuncOp::create(rewriter, loc, funcName, funcType); auto *block = funcOp.addEntryBlock(); @@ -4275,15 +4280,16 @@ class CaptureOrExecuteTraceOpConversionPattern // Define local variables that represent the return values. llvm::SmallVector returnVariable = llvm::map_to_vector(resultTypes, [&](auto resultType) { - return rewriter.create( - loc, emitc::LValueType::get(ttnnTensorType), + return emitc::VariableOp::create( + rewriter, loc, emitc::LValueType::get(ttnnTensorType), emitc::OpaqueAttr::get(rewriter.getContext(), "::ttnn::Tensor()")); }); // Create if statement with then/else blocks - auto ifOp = rewriter.create( - loc, loadGlobalVariable(rewriter, srcOp.getLoc(), isFirstCall), + auto ifOp = emitc::IfOp::create( + rewriter, loc, + loadGlobalVariable(rewriter, srcOp.getLoc(), isFirstCall), /*add_then_block=*/true, /*add_else_block=*/true); @@ -4293,25 +4299,25 @@ class CaptureOrExecuteTraceOpConversionPattern rewriter.setInsertionPointToStart(&thenBlock); // is_first_call = false; - auto falseC = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); - rewriter.create( - loc, getGlobalVariable(rewriter, loc, isFirstCall), falseC); + auto falseC = mlir::arith::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + emitc::AssignOp::create( + rewriter, loc, getGlobalVariable(rewriter, loc, isFirstCall), falseC); // v = capture_callee(args...) // First block argument is the device, so we drop it. auto autoTy = emitc::OpaqueType::get(ctx, "auto"); - auto captureTuple = rewriter.create( - loc, autoTy, srcOp.getCaptureCallee(), nullptr, nullptr, + auto captureTuple = emitc::CallOpaqueOp::create( + rewriter, loc, autoTy, srcOp.getCaptureCallee(), nullptr, nullptr, block->getArguments().drop_front()); // trace_id = std::get<0>(v); - auto getTraceId = rewriter.create( - loc, traceId.getType(), "::std::get<0>", nullptr, nullptr, + auto getTraceId = emitc::CallOpaqueOp::create( + rewriter, loc, traceId.getType(), "::std::get<0>", nullptr, nullptr, captureTuple.getResult(0)); - rewriter.create( - loc, getGlobalVariable(rewriter, loc, traceId), - getTraceId.getResult(0)); + emitc::AssignOp::create(rewriter, loc, + getGlobalVariable(rewriter, loc, traceId), + getTraceId.getResult(0)); // local_output_0 = std::get(v); // ... @@ -4320,11 +4326,11 @@ class CaptureOrExecuteTraceOpConversionPattern for (size_t i = 0; i < returnVariable.size(); ++i) { std::string getName = "::std::get<" + std::to_string(outputBaseIndex + i) + ">"; - auto getResult = rewriter.create( - loc, ttnnTensorType, getName, nullptr, nullptr, + auto getResult = emitc::CallOpaqueOp::create( + rewriter, loc, ttnnTensorType, getName, nullptr, nullptr, captureTuple.getResult(0)); - rewriter.create(loc, returnVariable[i].getResult(), - getResult.getResult(0)); + emitc::AssignOp::create(rewriter, loc, returnVariable[i].getResult(), + getResult.getResult(0)); } // input_i = std::get(v) @@ -4332,11 +4338,12 @@ class CaptureOrExecuteTraceOpConversionPattern for (size_t i = 0; i < traceInputVariable.size(); ++i) { std::string getName = "::std::get<" + std::to_string(inputBaseIndex + i) + ">"; - auto getResult = rewriter.create( - loc, traceInputVariable[i].getType(), getName, nullptr, nullptr, - ValueRange{captureTuple.getResult(0)}); - rewriter.create( - loc, getGlobalVariable(rewriter, loc, traceInputVariable[i]), + auto getResult = emitc::CallOpaqueOp::create( + rewriter, loc, traceInputVariable[i].getType(), getName, nullptr, + nullptr, ValueRange{captureTuple.getResult(0)}); + emitc::AssignOp::create( + rewriter, loc, + getGlobalVariable(rewriter, loc, traceInputVariable[i]), getResult.getResult(0)); } @@ -4345,15 +4352,16 @@ class CaptureOrExecuteTraceOpConversionPattern for (size_t i = 0; i < traceOutputVariable.size(); ++i) { std::string getName = "::std::get<" + std::to_string(traceOutputBaseIndex + i) + ">"; - auto getResult = rewriter.create( - loc, traceOutputVariable[i].getType(), getName, nullptr, nullptr, - captureTuple.getResult(0)); - rewriter.create( - loc, getGlobalVariable(rewriter, loc, traceOutputVariable[i]), + auto getResult = emitc::CallOpaqueOp::create( + rewriter, loc, traceOutputVariable[i].getType(), getName, nullptr, + nullptr, captureTuple.getResult(0)); + emitc::AssignOp::create( + rewriter, loc, + getGlobalVariable(rewriter, loc, traceOutputVariable[i]), getResult.getResult(0)); } - rewriter.create(loc); + emitc::YieldOp::create(rewriter, loc); } // ELSE: execute path @@ -4362,17 +4370,17 @@ class CaptureOrExecuteTraceOpConversionPattern rewriter.setInsertionPointToStart(&elseBlock); // execute_callee(trace_id); - rewriter.create( - loc, TypeRange{}, srcOp.getExecuteCallee(), nullptr, nullptr, - loadGlobalVariable(rewriter, loc, traceId)); + emitc::CallOpaqueOp::create(rewriter, loc, TypeRange{}, + srcOp.getExecuteCallee(), nullptr, nullptr, + loadGlobalVariable(rewriter, loc, traceId)); // Load the result to the return variable. assert(returnVariable.size() == 1 && "expected one return variable"); - rewriter.create( - loc, returnVariable[0], + emitc::AssignOp::create( + rewriter, loc, returnVariable[0], loadGlobalVariable(rewriter, loc, traceOutputVariable[0])); - rewriter.create(loc); + emitc::YieldOp::create(rewriter, loc); } // After the if-then-else, load the result and return it from the function. @@ -4380,8 +4388,8 @@ class CaptureOrExecuteTraceOpConversionPattern assert(returnVariable.size() == 1 && "expected one return variable"); auto result = - rewriter.create(loc, ttnnTensorType, returnVariable[0]); - rewriter.create(loc, result.getResult()); + emitc::LoadOp::create(rewriter, loc, ttnnTensorType, returnVariable[0]); + emitc::ReturnOp::create(rewriter, loc, result.getResult()); // Replace the original operation with a call to our new function. auto resultType = @@ -4399,8 +4407,8 @@ class CaptureOrExecuteTraceOpConversionPattern mlir::Value getGlobalVariable(mlir::PatternRewriter &rewriter, mlir::Location loc, emitc::GlobalOp globalOp) const { - return rewriter.create( - loc, emitc::LValueType::get(globalOp.getType()), + return emitc::GetGlobalOp::create( + rewriter, loc, emitc::LValueType::get(globalOp.getType()), globalOp.getSymNameAttr()); } @@ -4408,7 +4416,8 @@ class CaptureOrExecuteTraceOpConversionPattern mlir::Location loc, emitc::GlobalOp globalOp) const { auto getGlobalOp = getGlobalVariable(rewriter, loc, globalOp); - return rewriter.create(loc, globalOp.getType(), getGlobalOp); + return emitc::LoadOp::create(rewriter, loc, globalOp.getType(), + getGlobalOp); } }; } // namespace @@ -4463,12 +4472,13 @@ class UpdateCacheOpConversionPattern // The `update_index` is modeled as a tensor in the IR, but the // `ttnn::update_cache` expects a `uint32_t` scalar. mlir::Value updateIndex = - rewriter - .create( - srcOp.getLoc(), rewriter.getI32Type(), - ttnn_to_emitc::kGetScalarFromTensorFunctionName, - /*args=*/nullptr, - /*template_args=*/nullptr, adaptor.getUpdateIndex()) + emitc::CallOpaqueOp::create( + rewriter, + + srcOp.getLoc(), rewriter.getI32Type(), + ttnn_to_emitc::kGetScalarFromTensorFunctionName, + /*args=*/nullptr, + /*template_args=*/nullptr, adaptor.getUpdateIndex()) .getResult(0); llvm::SmallVector args{ diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp index 805dd556f55..c5253340215 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp @@ -91,8 +91,8 @@ struct ConvertTTNNToEmitCPass // Include headers // - builder.create(module.getLoc(), "ttnn-precompiled.hpp", - /*isStandard=*/false); + emitc::IncludeOp::create(builder, module.getLoc(), "ttnn-precompiled.hpp", + /*isStandard=*/false); } // TTNN -> EmitC diff --git a/lib/Conversion/TTNNToEmitPy/EmitPyConstEvalCaching.cpp b/lib/Conversion/TTNNToEmitPy/EmitPyConstEvalCaching.cpp index 442e39ca0f6..d0083262505 100644 --- a/lib/Conversion/TTNNToEmitPy/EmitPyConstEvalCaching.cpp +++ b/lib/Conversion/TTNNToEmitPy/EmitPyConstEvalCaching.cpp @@ -145,9 +145,9 @@ class EmitPyConstEvalCaching // Create if-guard and move the caching ops into the if body. builder.setInsertionPoint(opsToGuardChain.front()); - auto ifOp = builder.create(funcOp.getLoc(), - builder.getStringAttr("not {}"), - ValueRange{cacheDict}); + auto ifOp = emitpy::IfOp::create(builder, funcOp.getLoc(), + builder.getStringAttr("not {}"), + ValueRange{cacheDict}); auto *ifBody = builder.createBlock(&ifOp.getThenRegion()); for (Operation *op : opsToGuardChain) { diff --git a/lib/Conversion/TTNNToEmitPy/EmitPyLinkModules.cpp b/lib/Conversion/TTNNToEmitPy/EmitPyLinkModules.cpp index 8029b638056..adedf3fb313 100644 --- a/lib/Conversion/TTNNToEmitPy/EmitPyLinkModules.cpp +++ b/lib/Conversion/TTNNToEmitPy/EmitPyLinkModules.cpp @@ -60,8 +60,8 @@ class EmitPyLinkModulesPass } builder.setInsertionPointToStart(&mainFile.getBodyRegion().front()); - builder.create( - mainFile.getLoc(), builder.getStringAttr(kConstevalFileName), + emitpy::ImportOp::create( + builder, mainFile.getLoc(), builder.getStringAttr(kConstevalFileName), /*module_alias=*/nullptr, /*members_to_import=*/builder.getArrayAttr(memberNames), /*member_aliases=*/builder.getArrayAttr(emptyAliases), diff --git a/lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp b/lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp index f6bf04f5391..6d030b98664 100644 --- a/lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp +++ b/lib/Conversion/TTNNToEmitPy/TTNNToEmitPy.cpp @@ -2512,13 +2512,14 @@ class UpdateCacheOpConversionPattern // The `update_index` is modeled as a tensor in the IR, but the // `ttnn.update_cache` expects a `int` scalar. - auto updateIndex = rewriter - .create( - srcOp.getLoc(), rewriter.getI32Type(), - ttnn_to_emitpy::kGetScalarFromTensorFunctionName, - adaptor.getUpdateIndex(), - /*args=*/nullptr, - /*keyword_args=*/nullptr) + auto updateIndex = emitpy::CallOpaqueOp::create( + rewriter, + + srcOp.getLoc(), rewriter.getI32Type(), + ttnn_to_emitpy::kGetScalarFromTensorFunctionName, + adaptor.getUpdateIndex(), + /*args=*/nullptr, + /*keyword_args=*/nullptr) .getResult(0); llvm::SmallVector args{ @@ -2590,8 +2591,8 @@ class GetTupleElementOpConversionPattern // Create an expression op to inline the subscript operation auto loc = getTupleElementOp->getLoc(); - auto exprOp = rewriter.create( - loc, resultType, ValueRange{adaptor.getOperand()}); + auto exprOp = emitpy::ExpressionOp::create( + rewriter, loc, resultType, ValueRange{adaptor.getOperand()}); // Setup the expression body { @@ -2605,15 +2606,16 @@ class GetTupleElementOpConversionPattern rewriter.setInsertionPointToStart(bodyBlock); // Create literal for the index - Value indexAsVal = rewriter.create( - loc, rewriter.getIndexType(), std::to_string(adaptor.getIndex())); + Value indexAsVal = + emitpy::LiteralOp::create(rewriter, loc, rewriter.getIndexType(), + std::to_string(adaptor.getIndex())); // Create subscript operation inside the expression - Value subscriptResult = rewriter.create( - loc, resultType, tupleArg, indexAsVal); + Value subscriptResult = emitpy::SubscriptOp::create( + rewriter, loc, resultType, tupleArg, indexAsVal); // Yield the result - rewriter.create(loc, subscriptResult); + emitpy::YieldOp::create(rewriter, loc, subscriptResult); } // Replace the original op with the expression op @@ -2675,9 +2677,9 @@ class LoadCachedOpConversionPattern // Pack inputs into a list if present. llvm::SmallVector callOperands; if (!adaptor.getInputs().empty()) { - auto inputList = rewriter.create( - loc, tensorListType, ttnn_to_emitpy::kCreateListFunctionName, - adaptor.getInputs()); + auto inputList = emitpy::CallOpaqueOp::create( + rewriter, loc, tensorListType, + ttnn_to_emitpy::kCreateListFunctionName, adaptor.getInputs()); callOperands.push_back(inputList.getResult(0)); } @@ -2695,18 +2697,18 @@ class LoadCachedOpConversionPattern // Call the const-eval function. Add discardable attribute to easily // identify that the result is a const-eval in the caching pass afterwards. - auto callOp = rewriter.create( - loc, tensorListType, calleeName.str(), callOperands); + auto callOp = emitpy::CallOpaqueOp::create(rewriter, loc, tensorListType, + calleeName.str(), callOperands); callOp->setDiscardableAttr(ttnn_to_emitpy::kConstEvaledAttr, rewriter.getUnitAttr()); // Subscript individual results from the returned tensor list. llvm::SmallVector results; for (unsigned i = 0; i < loadCachedOp.getNumResults(); ++i) { - auto index = rewriter.create( - loc, rewriter.getIndexType(), std::to_string(i)); - auto sub = rewriter.create( - loc, tensorType, callOp.getResult(0), index.getResult()); + auto index = emitpy::LiteralOp::create( + rewriter, loc, rewriter.getIndexType(), std::to_string(i)); + auto sub = emitpy::SubscriptOp::create( + rewriter, loc, tensorType, callOp.getResult(0), index.getResult()); results.push_back(sub.getResult()); } rewriter.replaceOp(loadCachedOp, results); @@ -2781,16 +2783,14 @@ static Value emitDictKey(ConversionPatternRewriter &rewriter, Location loc, Attribute keyAttr) { auto *ctx = rewriter.getContext(); if (auto strAttr = dyn_cast(keyAttr)) { - return rewriter - .create( - loc, emitpy::StringType::get(ctx), - emitpy::OpaqueAttr::get(ctx, "\"" + strAttr.str() + "\"")) + return emitpy::ConstantOp::create( + rewriter, loc, emitpy::StringType::get(ctx), + emitpy::OpaqueAttr::get(ctx, "\"" + strAttr.str() + "\"")) .getResult(); } auto intAttr = cast(keyAttr); - return rewriter - .create(loc, rewriter.getIndexType(), - std::to_string(intAttr.getInt())) + return emitpy::LiteralOp::create(rewriter, loc, rewriter.getIndexType(), + std::to_string(intAttr.getInt())) .getResult(); } @@ -2827,16 +2827,16 @@ class TTCoreSetKeyValueOpConversionPattern // Pack values to set into a list. auto tensorListType = emitpy::OpaqueType::get(ctx, "[ttnn.Tensor]"); - auto tensorListOp = rewriter.create( - loc, tensorListType, ttnn_to_emitpy::kCreateListFunctionName, + auto tensorListOp = emitpy::CallOpaqueOp::create( + rewriter, loc, tensorListType, ttnn_to_emitpy::kCreateListFunctionName, adaptor.getValues()); auto value = tensorListOp.getResult(0); SmallVector exprOperands = {adaptor.getDict(), key, value}; SmallVector exprOperandTypes = {adaptor.getDict().getType(), key.getType(), value.getType()}; - auto exprOp = rewriter.create( - loc, dummyExpressionResultType, exprOperands); + auto exprOp = emitpy::ExpressionOp::create( + rewriter, loc, dummyExpressionResultType, exprOperands); Block *expressionBodyBlock = rewriter.createBlock(&exprOp.getBody()); for (Type type : exprOperandTypes) { expressionBodyBlock->addArgument(type, loc); @@ -2847,14 +2847,15 @@ class TTCoreSetKeyValueOpConversionPattern auto keyArg = expressionBodyBlock->getArgument(1); auto valArg = expressionBodyBlock->getArgument(2); - auto subOp = rewriter.create(loc, valArg.getType(), - dictArg, keyArg); - rewriter.create(loc, subOp.getResult(), valArg); + auto subOp = emitpy::SubscriptOp::create(rewriter, loc, valArg.getType(), + dictArg, keyArg); + emitpy::AssignOp::create(rewriter, loc, subOp.getResult(), valArg); - auto dummyExpressionResultValue = rewriter.create( - loc, dummyExpressionResultType, emitpy::OpaqueAttr::get(ctx, "None")); - rewriter.create(loc, - dummyExpressionResultValue.getResult()); + auto dummyExpressionResultValue = + emitpy::ConstantOp::create(rewriter, loc, dummyExpressionResultType, + emitpy::OpaqueAttr::get(ctx, "None")); + emitpy::YieldOp::create(rewriter, loc, + dummyExpressionResultValue.getResult()); rewriter.eraseOp(setKVOp); return success(); @@ -2887,15 +2888,14 @@ class TTCoreGetKeyValueOpConversionPattern } llvm::SmallVector results; - auto value = rewriter - .create(loc, tensorListType, - adaptor.getDict(), key) + auto value = emitpy::SubscriptOp::create(rewriter, loc, tensorListType, + adaptor.getDict(), key) .getResult(); for (unsigned i = 0; i < getKVOp.getNumResults(); ++i) { - auto index = rewriter.create( - loc, rewriter.getIndexType(), std::to_string(i)); - auto sub = rewriter.create(loc, convertedTypes[i], - value, index.getResult()); + auto index = emitpy::LiteralOp::create( + rewriter, loc, rewriter.getIndexType(), std::to_string(i)); + auto sub = emitpy::SubscriptOp::create(rewriter, loc, convertedTypes[i], + value, index.getResult()); results.push_back(sub.getResult()); } rewriter.replaceOp(getKVOp, results); @@ -3329,8 +3329,8 @@ class DistributeTensorOpConversionPattern configKeywordArgs.push_back(rewriter.getStringAttr("")); } - auto mapperConfigOp = rewriter.create( - srcOp.getLoc(), + auto mapperConfigOp = emitpy::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), emitpy::OpaqueType::get(rewriter.getContext(), "ttnn.MeshMapperConfig"), "ttnn.MeshMapperConfig", llvm::SmallVector{}, rewriter.getArrayAttr(configArgs), @@ -3338,8 +3338,8 @@ class DistributeTensorOpConversionPattern auto meshMapperType = emitpy::OpaqueType::get(rewriter.getContext(), "ttnn.TensorToMesh"); - auto createMapperOp = rewriter.create( - srcOp.getLoc(), meshMapperType, "ttnn.create_mesh_mapper", + auto createMapperOp = emitpy::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), meshMapperType, "ttnn.create_mesh_mapper", llvm::SmallVector{adaptor.getMeshDevice(), mapperConfigOp.getResult(0)}); @@ -3396,8 +3396,8 @@ class AggregateTensorOpConversionPattern configKeywordArgs.push_back(rewriter.getStringAttr("")); } - auto composerConfigOp = rewriter.create( - srcOp.getLoc(), + auto composerConfigOp = emitpy::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), emitpy::OpaqueType::get(rewriter.getContext(), "ttnn.MeshComposerConfig"), "ttnn.MeshComposerConfig", llvm::SmallVector{}, @@ -3406,8 +3406,8 @@ class AggregateTensorOpConversionPattern auto meshComposerType = emitpy::OpaqueType::get(rewriter.getContext(), "ttnn.MeshToTensor"); - auto createComposerOp = rewriter.create( - srcOp.getLoc(), meshComposerType, "ttnn.create_mesh_composer", + auto createComposerOp = emitpy::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), meshComposerType, "ttnn.create_mesh_composer", llvm::SmallVector{adaptor.getMeshDevice(), composerConfigOp.getResult(0)}); @@ -3723,8 +3723,8 @@ class DistributedRMSNormOpConversionPattern auto opaqueType = emitpy::OpaqueType::get(rewriter.getContext(), "ttnn.Tensor"); - auto globalSemaphoreOp = rewriter.create( - srcOp.getLoc(), opaqueType, "utils.create_global_semaphore", + auto globalSemaphoreOp = emitpy::CallOpaqueOp::create( + rewriter, srcOp.getLoc(), opaqueType, "utils.create_global_semaphore", llvm::SmallVector{adaptor.getInput()}); llvm::SmallVector args{ diff --git a/lib/Conversion/TTNNToEmitPy/TTNNToEmitPyPass.cpp b/lib/Conversion/TTNNToEmitPy/TTNNToEmitPyPass.cpp index 7278504a164..01bff890bff 100644 --- a/lib/Conversion/TTNNToEmitPy/TTNNToEmitPyPass.cpp +++ b/lib/Conversion/TTNNToEmitPy/TTNNToEmitPyPass.cpp @@ -62,9 +62,9 @@ void enableTorchConversion(func::FuncOp funcOp) { for (BlockArgument arg : funcOp.getArguments()) { // Create ttnn.to_torch call. // - auto toTorchOp = builder.create( - funcOp.getLoc(), arg.getType(), "ttnn.to_torch", ValueRange{arg}, - nullptr, nullptr); + auto toTorchOp = emitpy::CallOpaqueOp::create( + builder, funcOp.getLoc(), arg.getType(), "ttnn.to_torch", + ValueRange{arg}, nullptr, nullptr); // Replace all uses of the original argument with the to_torch result, // except for the to_torch op itself. @@ -81,8 +81,8 @@ void enableTorchConversion(func::FuncOp funcOp) { for (Value returnValue : returnOp.getOperands()) { // Create ttnn.from_torch call. // - auto fromTorchOp = builder.create( - returnOp.getLoc(), returnValue.getType(), "ttnn.from_torch", + auto fromTorchOp = emitpy::CallOpaqueOp::create( + builder, returnOp.getLoc(), returnValue.getType(), "ttnn.from_torch", ValueRange{returnValue}, nullptr, nullptr); newReturnOperands.push_back(fromTorchOp.getResult(0)); @@ -124,10 +124,10 @@ struct ConvertTTNNToEmitPyPass // Include headers // - builder.create(module->getLoc(), "ttnn", nullptr, nullptr, - nullptr, nullptr); - builder.create(module->getLoc(), "utils", nullptr, - nullptr, nullptr, nullptr); + emitpy::ImportOp::create(builder, module->getLoc(), "ttnn", nullptr, + nullptr, nullptr, nullptr); + emitpy::ImportOp::create(builder, module->getLoc(), "utils", nullptr, + nullptr, nullptr, nullptr); // If we are in the module-export path (i.e., `target-module=true`), // const-eval functions must also take `device` as an explicit argument so diff --git a/lib/Dialect/D2M/IR/D2MGenericRegionOps.cpp b/lib/Dialect/D2M/IR/D2MGenericRegionOps.cpp index bbbd6ec2df5..60b8bcd7c1d 100644 --- a/lib/Dialect/D2M/IR/D2MGenericRegionOps.cpp +++ b/lib/Dialect/D2M/IR/D2MGenericRegionOps.cpp @@ -32,8 +32,8 @@ bufferizeCBOp(OpTy op, mlir::RewriterBase &rewriter, mlir::cast(op.getCbType()) .getBufferType(options, [&]() { return op.emitOpError(); }); assert(succeeded(cbBufferType)); - auto toBuffer = rewriter.create( - op.getLoc(), *cbBufferType, op.getCb()); + auto toBuffer = bufferization::ToBufferOp::create(rewriter, op.getLoc(), + *cbBufferType, op.getCb()); mlir::bufferization::replaceOpWithNewBufferizedOp(rewriter, op, toBuffer.getResult()); return mlir::success(); @@ -974,9 +974,9 @@ mlir::LogicalResult ArangeBlockOp::bufferize( } // Create new op with memref operands. - auto newOp = rewriter.create( - getLoc(), *maybeIndexTileBuffer, *maybeOutputBuffer, getNumElements(), - getStart(), getStep()); + auto newOp = ArangeBlockOp::create(rewriter, getLoc(), *maybeIndexTileBuffer, + *maybeOutputBuffer, getNumElements(), + getStart(), getStep()); // Replace uses and erase (DPS pattern - result aliases output buffer). mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(), @@ -1081,13 +1081,13 @@ mlir::LogicalResult RemoteLoadOp::bufferize( RemoteLoadOp newOp; if (isHighLevelMcast()) { // High-level mcast form: use mcastDims builder - newOp = rewriter.create(getLoc(), resultBufferType, - *localBufferBuffer, *memrefBuffer, - getIndices(), getMcastDims()); + newOp = RemoteLoadOp::create(rewriter, getLoc(), resultBufferType, + *localBufferBuffer, *memrefBuffer, + getIndices(), getMcastDims()); } else { // Low-level mcast form or no mcast: use mcastStartIndex/mcastShape builder - newOp = rewriter.create( - getLoc(), resultBufferType, *localBufferBuffer, *memrefBuffer, + newOp = RemoteLoadOp::create( + rewriter, getLoc(), resultBufferType, *localBufferBuffer, *memrefBuffer, getIndices(), getMcastStartIndex(), getMcastShape()); } @@ -1095,8 +1095,8 @@ mlir::LogicalResult RemoteLoadOp::bufferize( // ops. This ensures that operations like linalg.generic still see tensors // until they are bufferized. When they call getBuffer() during bufferization, // they'll get the underlying memref (*localBufferBuffer). - auto toTensor = rewriter.create( - getLoc(), result.getType(), *localBufferBuffer); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, getLoc(), result.getType(), *localBufferBuffer); rewriter.replaceAllUsesWith(result, toTensor.getResult()); rewriter.eraseOp(*this); @@ -1405,8 +1405,8 @@ mlir::LogicalResult TileTilizeBlockOp::bufferize( out = *maybe; } - rewriter.create(getLoc(), out.getType(), in, - out); + mlir::tt::d2m::TileTilizeBlockOp::create(rewriter, getLoc(), out.getType(), + in, out); // DPS-style op: replace uses of result with the output buffer, not the new // op's result. This ensures downstream ops correctly use the original buffer // allocation. @@ -1514,8 +1514,8 @@ mlir::LogicalResult TileUntilizeBlockOp::bufferize( out = *maybe; } - rewriter.create(getLoc(), out.getType(), - in, out); + mlir::tt::d2m::TileUntilizeBlockOp::create(rewriter, getLoc(), out.getType(), + in, out); // DPS-style op: replace uses of result with the output buffer, not the new // op's result. This ensures downstream ops correctly use the original buffer // allocation. @@ -1696,9 +1696,9 @@ BlockMaskOp::bufferize(mlir::RewriterBase &rewriter, colMaskCb = *maybe; } - rewriter.create( - getLoc(), out.getType(), in, out, rowMaskCb, colMaskCb, getLogicalRows(), - getLogicalCols(), getFillValue()); + mlir::tt::d2m::BlockMaskOp::create(rewriter, getLoc(), out.getType(), in, out, + rowMaskCb, colMaskCb, getLogicalRows(), + getLogicalCols(), getFillValue()); rewriter.replaceAllUsesWith(getResult(), out); rewriter.eraseOp(*this); return mlir::success(); diff --git a/lib/Dialect/D2M/IR/D2MOps.cpp b/lib/Dialect/D2M/IR/D2MOps.cpp index a9c47d04337..fcbdda19bfb 100644 --- a/lib/Dialect/D2M/IR/D2MOps.cpp +++ b/lib/Dialect/D2M/IR/D2MOps.cpp @@ -153,7 +153,7 @@ mlir::LogicalResult d2m::EmptyOp::bufferize( ::llvm::SmallVector invocationStack; auto bufferType = mlir::cast( *getBufferType(getResult(), options, state, invocationStack)); - auto allocOp = rewriter.create(getLoc(), bufferType); + auto allocOp = memref::AllocOp::create(rewriter, getLoc(), bufferType); // Propagate virtualGridInverseMapping (inverse) and virtualGridForwardMapping // (forward) as discardable attributes on memref::AllocOp (we don't own @@ -666,7 +666,8 @@ ToLayoutOp::bufferize(mlir::RewriterBase &rewriter, // ToLayoutOp is now only for device-to-device transfers. Host transfers // use ToDeviceOp and ToHostOp instead. - rewriter.create(getLoc(), TypeRange(), *maybeInput, *maybeOutput); + ToLayoutOp::create(rewriter, getLoc(), TypeRange(), *maybeInput, + *maybeOutput); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, *this, *maybeOutput); @@ -823,13 +824,13 @@ ToDeviceOp::bufferize(mlir::RewriterBase &rewriter, if (mlir::cast(alignedHostMemref.getLayout()) .isPadded()) { auto alignedHostTensor = - rewriter.create(getLoc(), alignedHostMemref); - rewriter.create(getLoc(), *maybeInput, alignedHostTensor); + memref::AllocOp::create(rewriter, getLoc(), alignedHostMemref); + memref::CopyOp::create(rewriter, getLoc(), *maybeInput, alignedHostTensor); maybeInput = alignedHostTensor.getResult(); } - rewriter.create(getLoc(), TypeRange(), *maybeInput, *maybeOutput, - getLayout()); + ToDeviceOp::create(rewriter, getLoc(), TypeRange(), *maybeInput, *maybeOutput, + getLayout()); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, *this, *maybeOutput); @@ -953,15 +954,15 @@ ToHostOp::bufferize(mlir::RewriterBase &rewriter, mlir::dyn_cast(alignedHostMemref.getLayout()); if (hostLayout && hostLayout.isPadded()) { auto alignedHostTensor = - rewriter.create(getLoc(), alignedHostMemref); + memref::AllocOp::create(rewriter, getLoc(), alignedHostMemref); - rewriter.create(getLoc(), TypeRange(), *maybeInput, - alignedHostTensor, getLayout()); + ToHostOp::create(rewriter, getLoc(), TypeRange(), *maybeInput, + alignedHostTensor, getLayout()); - rewriter.create(getLoc(), alignedHostTensor, *maybeOutput); + memref::CopyOp::create(rewriter, getLoc(), alignedHostTensor, *maybeOutput); } else { - rewriter.create(getLoc(), TypeRange(), *maybeInput, *maybeOutput, - getLayout()); + ToHostOp::create(rewriter, getLoc(), TypeRange(), *maybeInput, *maybeOutput, + getLayout()); } mlir::bufferization::replaceOpWithBufferizedValues(rewriter, *this, @@ -1027,9 +1028,10 @@ mlir::LogicalResult d2m::StreamLayoutOp::bufferize( } ::llvm::SmallVector invocationStack; - Value result = rewriter.create( - getLoc(), *getBufferType(getResult(), options, state, invocationStack), - *maybeInput, getRemapping(), *maybeStorage); + Value result = d2m::StreamLayoutOp::create( + rewriter, getLoc(), + *getBufferType(getResult(), options, state, invocationStack), *maybeInput, + getRemapping(), *maybeStorage); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, *this, result); return success(); } @@ -1242,9 +1244,9 @@ mlir::LogicalResult d2m::ViewLayoutOp::bufferize( } auto outMemrefType = mlir::cast(*outMemrefTypeOr); - auto newOp = rewriter.create(getLoc(), outMemrefType, - *maybeInput, getRemapping(), - getReinterpretLayout()); + auto newOp = + d2m::ViewLayoutOp::create(rewriter, getLoc(), outMemrefType, *maybeInput, + getRemapping(), getReinterpretLayout()); mlir::bufferization::replaceOpWithBufferizedValues(rewriter, *this, newOp.getResult()); @@ -1516,8 +1518,8 @@ void d2m::GenericOp::build( assert(layout && "Expected MetalLayoutAttr or ViewLayoutAttr with StreamLayoutOp"); auto shardShape = layout.getShardShape(tensorType); - auto emptyOp = builder.create( - state.location, shardShape, tensorType.getElementType()); + auto emptyOp = mlir::tensor::EmptyOp::create( + builder, state.location, shardShape, tensorType.getElementType()); operandAllocs.push_back(emptyOp.getResult()); } @@ -2441,10 +2443,11 @@ mlir::LogicalResult d2m::GenericOp::bufferize( } bufferOutputs.push_back(*maybeValue); } - auto bufferGeneric = rewriter.create( - getLoc(), ValueRange(), bufferInputs, bufferOutputs, getAdditionalArgs(), - getGrid(), getBlockFactors(), getIndexingMaps(), getIteratorTypes(), - getThreads(), getScratchInputsAttr(), getNumRegions()); + auto bufferGeneric = d2m::GenericOp::create( + rewriter, getLoc(), ValueRange(), bufferInputs, bufferOutputs, + getAdditionalArgs(), getGrid(), getBlockFactors(), getIndexingMaps(), + getIteratorTypes(), getThreads(), getScratchInputsAttr(), + getNumRegions()); for (mlir::Region ®ion : bufferGeneric.getRegions()) { region.takeBody(getRegion(region.getRegionNumber())); } diff --git a/lib/Dialect/D2M/Transforms/AddScratchInputs.cpp b/lib/Dialect/D2M/Transforms/AddScratchInputs.cpp index 398e15b8e72..196ae3e72ee 100644 --- a/lib/Dialect/D2M/Transforms/AddScratchInputs.cpp +++ b/lib/Dialect/D2M/Transforms/AddScratchInputs.cpp @@ -142,7 +142,7 @@ static LogicalResult addScratchToGeneric(GenericOp genericOp) { // Create memref.alloc for the scratch buffer before the generic. OpBuilder builder(genericOp); auto scratchAlloc = - builder.create(genericOp.getLoc(), scratchMemRefType); + memref::AllocOp::create(builder, genericOp.getLoc(), scratchMemRefType); // Build new inputs with scratch added at the end. unsigned numOldInputs = genericOp.getInputs().size(); @@ -175,8 +175,8 @@ static LogicalResult addScratchToGeneric(GenericOp genericOp) { auto scratchInputsAttr = builder.getDenseI64ArrayAttr({scratchInputIndex}); // Create the new GenericOp with empty regions. - auto newGenericOp = builder.create( - genericOp.getLoc(), genericOp.getResultTypes(), newInputs, + auto newGenericOp = GenericOp::create( + builder, genericOp.getLoc(), genericOp.getResultTypes(), newInputs, genericOp.getOutputs(), genericOp.getAdditionalArgs(), genericOp.getGrid(), genericOp.getBlockFactors(), builder.getArrayAttr(newIndexingMaps), genericOp.getIteratorTypes(), @@ -232,9 +232,9 @@ static LogicalResult addScratchToGeneric(GenericOp genericOp) { } // Insert the scratch tensor.empty with the shard shape. - builder.create( - genericOp.getLoc(), scratchShardMemRefType.getShape(), - scratchShardMemRefType.getElementType()); + mlir::tensor::EmptyOp::create(builder, genericOp.getLoc(), + scratchShardMemRefType.getShape(), + scratchShardMemRefType.getElementType()); // Clone remaining ops (output tensor.empties and all other ops). for (unsigned i = clonedUpTo; i < oldOps.size(); ++i) { diff --git a/lib/Dialect/D2M/Transforms/Allocate.cpp b/lib/Dialect/D2M/Transforms/Allocate.cpp index ea7c08e67d8..37fd366d888 100644 --- a/lib/Dialect/D2M/Transforms/Allocate.cpp +++ b/lib/Dialect/D2M/Transforms/Allocate.cpp @@ -1555,8 +1555,8 @@ class D2MAllocate final : public impl::D2MAllocateBase { numStreamBuffers, operandGrid); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(oldAllocOp); - auto newAllocOp = rewriter.create( - oldTensor.getLoc(), + auto newAllocOp = memref::AllocOp::create( + rewriter, oldTensor.getLoc(), MemRefType::get(shardShape, operandType.getElementType(), cbLayout, oldMemRefType.getMemorySpace())); // Transfer address and alignment from the planner. @@ -1588,7 +1588,7 @@ class D2MAllocate final : public impl::D2MAllocateBase { rewriter.setInsertionPoint(op); auto bufferAllocOp = - rewriter.create(op.getLoc(), bufferType); + memref::AllocOp::create(rewriter, op.getLoc(), bufferType); if (req) { assignAddressAndAlignment(rewriter, bufferAllocOp, req->offset, info); @@ -1604,9 +1604,10 @@ class D2MAllocate final : public impl::D2MAllocateBase { getStreamType(bufferType.getShape(), reblockingMap, oldOperandType.getElementType(), remappedMemspace); - auto streamOp = rewriter.create( - op.getLoc(), /* result */ streamType, /* input */ operand.get(), - AffineMapAttr::get(reblockingMap), /* storage */ bufferAllocOp); + auto streamOp = d2m::StreamLayoutOp::create( + rewriter, op.getLoc(), /* result */ streamType, + /* input */ operand.get(), AffineMapAttr::get(reblockingMap), + /* storage */ bufferAllocOp); rewriter.startOpModification(op); { @@ -1657,8 +1658,8 @@ class D2MAllocate final : public impl::D2MAllocateBase { streamType.getContext(), shardShape, ttcore::getElementSizeBytes(streamType.getElementType()), numStreamBuffers, operandGrid); - auto newAllocOp = rewriter.create( - oldTensor.getLoc(), + auto newAllocOp = memref::AllocOp::create( + rewriter, oldTensor.getLoc(), MemRefType::get(shardShape, streamType.getElementType(), cbLayout, oldMemRefType.getMemorySpace())); // Transfer address and alignment from the old alloc (assigned @@ -1671,8 +1672,9 @@ class D2MAllocate final : public impl::D2MAllocateBase { } newValue = newAllocOp.getResult(); } else { - auto newEmptyOp = rewriter.create( - oldTensor.getLoc(), shardShape, streamType.getElementType()); + auto newEmptyOp = mlir::tensor::EmptyOp::create( + rewriter, oldTensor.getLoc(), shardShape, + streamType.getElementType()); newValue = newEmptyOp.getResult(); } rewriter.replaceAllUsesWith(oldTensor, newValue); @@ -1694,8 +1696,8 @@ class D2MAllocate final : public impl::D2MAllocateBase { OpBuilder::InsertionGuard guard(rewriter); { rewriter.setInsertionPointAfter(lastOp); - rewriter.create(lastOp->getLoc(), - allocOp.getResult()); + memref::DeallocOp::create(rewriter, lastOp->getLoc(), + allocOp.getResult()); } } } diff --git a/lib/Dialect/D2M/Transforms/ConvertLocalLoadStoreOpsToAliasedCBs.cpp b/lib/Dialect/D2M/Transforms/ConvertLocalLoadStoreOpsToAliasedCBs.cpp index b1454a71256..ef412a0a722 100644 --- a/lib/Dialect/D2M/Transforms/ConvertLocalLoadStoreOpsToAliasedCBs.cpp +++ b/lib/Dialect/D2M/Transforms/ConvertLocalLoadStoreOpsToAliasedCBs.cpp @@ -262,9 +262,9 @@ class D2MConvertLocalLoadStoreOpsToAliasedCBs rewriter.setInsertionPoint(allocOp); // Create reserve, push, and wait operations - rewriter.create(loc, assocCb); - rewriter.create(loc, assocCb); - auto waitOp = rewriter.create(loc, assocCb); + ReserveOp::create(rewriter, loc, assocCb); + PushOp::create(rewriter, loc, assocCb); + auto waitOp = WaitOp::create(rewriter, loc, assocCb); // Replace all uses of the alloc result and remote_load result with the // wait result @@ -281,7 +281,7 @@ class D2MConvertLocalLoadStoreOpsToAliasedCBs // No uses found, insert pop immediately after wait rewriter.setInsertionPointAfter(waitOp); } - rewriter.create(loc, assocCb); + PopOp::create(rewriter, loc, assocCb); // Erase the original remote_load operation rewriter.eraseOp(remoteLoad); @@ -330,15 +330,15 @@ class D2MConvertLocalLoadStoreOpsToAliasedCBs // Replace memref.alloc with reserve rewriter.setInsertionPoint(allocOp); - auto reserveOp = rewriter.create(loc, assocCb); + auto reserveOp = ReserveOp::create(rewriter, loc, assocCb); rewriter.replaceAllUsesWith(allocOp.getResult(), reserveOp.getResult()); rewriter.eraseOp(allocOp); // At remote_store location, insert: push, wait, pop rewriter.setInsertionPoint(remoteStore); - rewriter.create(loc, assocCb); - rewriter.create(loc, assocCb); - rewriter.create(loc, assocCb); + PushOp::create(rewriter, loc, assocCb); + WaitOp::create(rewriter, loc, assocCb); + PopOp::create(rewriter, loc, assocCb); // Erase the original remote_store operation rewriter.eraseOp(remoteStore); @@ -352,8 +352,8 @@ class D2MConvertLocalLoadStoreOpsToAliasedCBs for (GetScratchFromCBOp getScratchOp : scratchOpsToConvert) { rewriter.setInsertionPoint(getScratchOp); - auto reserveOp = rewriter.create(getScratchOp.getLoc(), - getScratchOp.getCb()); + auto reserveOp = ReserveOp::create(rewriter, getScratchOp.getLoc(), + getScratchOp.getCb()); rewriter.replaceAllUsesWith(getScratchOp.getResult(), reserveOp.getResult()); rewriter.eraseOp(getScratchOp); diff --git a/lib/Dialect/D2M/Transforms/DecomposeArange.cpp b/lib/Dialect/D2M/Transforms/DecomposeArange.cpp index 5eac3ec6943..de36858b852 100644 --- a/lib/Dialect/D2M/Transforms/DecomposeArange.cpp +++ b/lib/Dialect/D2M/Transforms/DecomposeArange.cpp @@ -54,38 +54,38 @@ struct DecomposeArangeBlockPattern : OpRewritePattern { // Total tiles across all cores. int64_t totalTileCols = numTileCols * gridShape[gridShape.size() - 1]; - Value zeroIdx = rewriter.create(loc, 0); - Value oneIdx = rewriter.create(loc, 1); + Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1); Value numTileRowsVal = - rewriter.create(loc, numTileRows); + arith::ConstantIndexOp::create(rewriter, loc, numTileRows); Value numTileColsVal = - rewriter.create(loc, numTileCols); + arith::ConstantIndexOp::create(rewriter, loc, numTileCols); // === STEP 1: Write the scratch tile === TT_assert(indexTileMemref); - rewriter.create(loc, indexTileMemref); + FillArangeTileOp::create(rewriter, loc, indexTileMemref); // === STEP 2: Scalar constants for arange start and step === - Value startF = rewriter.create( - loc, elemType, + Value startF = arith::ConstantOp::create( + rewriter, loc, elemType, rewriter.getFloatAttr(elemType, static_cast(start))); - Value stepF = rewriter.create( - loc, elemType, + Value stepF = arith::ConstantOp::create( + rewriter, loc, elemType, rewriter.getFloatAttr(elemType, static_cast(step))); // === STEP 3: Create nested loops over tiles === // Get this core's coordinates. - Value coreY = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getI64IntegerAttr(0), nullptr); - Value coreX = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getI64IntegerAttr(1), nullptr); + Value coreY = CoreIndexOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getI64IntegerAttr(0), nullptr); + Value coreX = CoreIndexOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getI64IntegerAttr(1), nullptr); auto outerLoop = - rewriter.create(loc, zeroIdx, numTileRowsVal, oneIdx); + scf::ForOp::create(rewriter, loc, zeroIdx, numTileRowsVal, oneIdx); rewriter.setInsertionPointToStart(outerLoop.getBody()); auto innerLoop = - rewriter.create(loc, zeroIdx, numTileColsVal, oneIdx); + scf::ForOp::create(rewriter, loc, zeroIdx, numTileColsVal, oneIdx); // Mark the INNER loop as the compute root, since that's where // the actual compute operations are emitted. This ensures DST // syncs are placed inside the inner loop body, not the outer. @@ -100,62 +100,63 @@ struct DecomposeArangeBlockPattern : OpRewritePattern { // === STEP 4: Load scratch tile === Value localIndexTile = - rewriter - .create(loc, indexTileMemref, - ValueRange{zeroIdx, zeroIdx}) + memref::LoadOp::create(rewriter, loc, indexTileMemref, + ValueRange{zeroIdx, zeroIdx}) .getResult(); // === STEP 5: Compute tile offset as scalar === Value shardTileRowsIdx = - rewriter.create(loc, numTileRows); + arith::ConstantIndexOp::create(rewriter, loc, numTileRows); Value shardTileColsIdx = - rewriter.create(loc, numTileCols); + arith::ConstantIndexOp::create(rewriter, loc, numTileCols); Value totalTileColsIdx = - rewriter.create(loc, totalTileCols); - Value const32Idx = rewriter.create(loc, 32); + arith::ConstantIndexOp::create(rewriter, loc, totalTileCols); + Value const32Idx = arith::ConstantIndexOp::create(rewriter, loc, 32); // globalTileRow = coreY * shardTileRows + localTileRow - Value globalTileRow = rewriter.create( - loc, rewriter.create(loc, coreY, shardTileRowsIdx), + Value globalTileRow = arith::AddIOp::create( + rewriter, loc, + arith::MulIOp::create(rewriter, loc, coreY, shardTileRowsIdx), tileRowIdx); // globalTileCol = coreX * shardTileCols + localTileCol - Value globalTileCol = rewriter.create( - loc, rewriter.create(loc, coreX, shardTileColsIdx), + Value globalTileCol = arith::AddIOp::create( + rewriter, loc, + arith::MulIOp::create(rewriter, loc, coreX, shardTileColsIdx), tileColIdx); // Row contribution: globalTileRow * totalTileCols * 32 * 32 - Value rowContrib = rewriter.create( - loc, - rewriter.create( - loc, - rewriter.create(loc, globalTileRow, - totalTileColsIdx), - const32Idx), + Value rowContrib = arith::MulIOp::create( + rewriter, loc, + arith::MulIOp::create(rewriter, loc, + arith::MulIOp::create(rewriter, loc, + globalTileRow, + totalTileColsIdx), + const32Idx), const32Idx); // Column contribution: globalTileCol * 32 Value colContrib = - rewriter.create(loc, globalTileCol, const32Idx); + arith::MulIOp::create(rewriter, loc, globalTileCol, const32Idx); // Total offset (index type) Value tileOffsetIdx = - rewriter.create(loc, rowContrib, colContrib); - Value tileOffsetI64 = rewriter.create( - loc, rewriter.getI64Type(), tileOffsetIdx); + arith::AddIOp::create(rewriter, loc, rowContrib, colContrib); + Value tileOffsetI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), tileOffsetIdx); Value tileOffsetF = - rewriter.create(loc, elemType, tileOffsetI64); + arith::SIToFPOp::create(rewriter, loc, elemType, tileOffsetI64); // === STEP 6: Tile arithmetic with scalar RHS === Value globalIndexTile = - rewriter.create(loc, tileType, localIndexTile, tileOffsetF) + TileAddOp::create(rewriter, loc, tileType, localIndexTile, tileOffsetF) .getResult(); Value scaledTile = - rewriter.create(loc, tileType, globalIndexTile, stepF) + TileMulOp::create(rewriter, loc, tileType, globalIndexTile, stepF) .getResult(); Value resultTile = - rewriter.create(loc, tileType, scaledTile, startF) + TileAddOp::create(rewriter, loc, tileType, scaledTile, startF) .getResult(); // === STEP 7: Store result tile to output === - rewriter.create(loc, resultTile, output, - ValueRange{tileRowIdx, tileColIdx}); + memref::StoreOp::create(rewriter, loc, resultTile, output, + ValueRange{tileRowIdx, tileColIdx}); rewriter.setInsertionPointAfter(outerLoop); diff --git a/lib/Dialect/D2M/Transforms/DecomposeMasking.cpp b/lib/Dialect/D2M/Transforms/DecomposeMasking.cpp index 5170600ad07..33d45227692 100644 --- a/lib/Dialect/D2M/Transforms/DecomposeMasking.cpp +++ b/lib/Dialect/D2M/Transforms/DecomposeMasking.cpp @@ -72,35 +72,35 @@ struct DecomposeBlockMaskPattern : OpRewritePattern { Value coreIdx, int64_t shardSize) { Value shardSizeVal = - rewriter.create(loc, shardSize); + arith::ConstantIndexOp::create(rewriter, loc, shardSize); Value globalCoreStart = - rewriter.create(loc, coreIdx, shardSizeVal); + arith::MulIOp::create(rewriter, loc, coreIdx, shardSizeVal); Value globalRegionStartVal = - rewriter.create(loc, globalRegionStart); + arith::ConstantIndexOp::create(rewriter, loc, globalRegionStart); Value globalRegionEndVal = - rewriter.create(loc, globalRegionEnd); + arith::ConstantIndexOp::create(rewriter, loc, globalRegionEnd); // We define localStart = max(globalRegionStart - globalCoreStart, 0); in // turn this can be rewritten as localStart = globalRegionStart - // min(globalRegionStart, globalCoreStart). - Value clampedStart = rewriter.create( - loc, globalRegionStartVal, globalCoreStart); - Value localStart = - rewriter.create(loc, globalRegionStartVal, clampedStart); + Value clampedStart = arith::MinUIOp::create( + rewriter, loc, globalRegionStartVal, globalCoreStart); + Value localStart = arith::SubIOp::create( + rewriter, loc, globalRegionStartVal, clampedStart); // Similarly, we define localEnd = min(globalRegionEnd - globalCoreStart, // shardSize). However, to avoid underflow on unsigned, we re-express it as // clampedEnd = max(min(globalRegionEnd, globalCoreEnd), globalCoreStart), // and localEnd = clampedEnd - globalCoreStart, which is equivalent. Value globalCoreEnd = - rewriter.create(loc, globalCoreStart, shardSizeVal); - Value clampedEnd = - rewriter.create(loc, globalRegionEndVal, globalCoreEnd); + arith::AddIOp::create(rewriter, loc, globalCoreStart, shardSizeVal); + Value clampedEnd = arith::MinUIOp::create(rewriter, loc, globalRegionEndVal, + globalCoreEnd); clampedEnd = - rewriter.create(loc, clampedEnd, globalCoreStart); + arith::MaxUIOp::create(rewriter, loc, clampedEnd, globalCoreStart); Value localEnd = - rewriter.create(loc, clampedEnd, globalCoreStart); + arith::SubIOp::create(rewriter, loc, clampedEnd, globalCoreStart); return {localStart, localEnd}; } @@ -186,91 +186,93 @@ struct DecomposeBlockMaskPattern : OpRewritePattern { int64_t totalTileRows = shardTileRows * gridShape[gridShape.size() - 2]; int64_t totalTileCols = shardTileCols * gridShape[gridShape.size() - 1]; - Value zeroIdx = rewriter.create(loc, 0); - Value oneIdx = rewriter.create(loc, 1); + Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value oneIdx = arith::ConstantIndexOp::create(rewriter, loc, 1); double fillValueDouble = getFillValueAsDouble(fillOOBVal); - Value fillScalar = rewriter.create( - loc, elemType, rewriter.getFloatAttr(elemType, fillValueDouble)); + Value fillScalar = arith::ConstantOp::create( + rewriter, loc, elemType, + rewriter.getFloatAttr(elemType, fillValueDouble)); // Get this core's coordinates. - Value coreY = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getI64IntegerAttr(0), nullptr); - Value coreX = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getI64IntegerAttr(1), nullptr); + Value coreY = CoreIndexOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getI64IntegerAttr(0), nullptr); + Value coreX = CoreIndexOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getI64IntegerAttr(1), nullptr); // Write the mask tiles. Value validRowsVal = - rewriter.create(loc, validRowsInLastTile); + arith::ConstantIndexOp::create(rewriter, loc, validRowsInLastTile); Value validColsVal = - rewriter.create(loc, validColsInLastTile); + arith::ConstantIndexOp::create(rewriter, loc, validColsInLastTile); TT_assert(rowMaskCB); - rewriter.create(loc, validRowsVal, rowMaskCB); + WriteRowMaskTileOp::create(rewriter, loc, validRowsVal, rowMaskCB); TT_assert(colMaskCB); - rewriter.create(loc, validColsVal, colMaskCB); + WriteColMaskTileOp::create(rewriter, loc, validColsVal, colMaskCB); // === Tile operation helpers === auto createFillTile = [&]() { - return rewriter.create(loc, tileType, fillScalar).getResult(); + return TileFillOp::create(rewriter, loc, tileType, fillScalar) + .getResult(); }; auto emitPassthrough = [&](Value localRowIdx, Value localColIdx) { - auto inputTile = rewriter.create( - loc, input, ValueRange{localRowIdx, localColIdx}); - rewriter.create(loc, inputTile.getResult(), output, - ValueRange{localRowIdx, localColIdx}); + auto inputTile = memref::LoadOp::create( + rewriter, loc, input, ValueRange{localRowIdx, localColIdx}); + memref::StoreOp::create(rewriter, loc, inputTile.getResult(), output, + ValueRange{localRowIdx, localColIdx}); }; auto emitRowMasked = [&](Value localRowIdx, Value localColIdx) { - auto inputTile = rewriter.create( - loc, input, ValueRange{localRowIdx, localColIdx}); + auto inputTile = memref::LoadOp::create( + rewriter, loc, input, ValueRange{localRowIdx, localColIdx}); auto fillTile = createFillTile(); - auto rowMaskTile = rewriter.create( - loc, rowMaskCB, ValueRange{zeroIdx, zeroIdx}); + auto rowMaskTile = memref::LoadOp::create(rewriter, loc, rowMaskCB, + ValueRange{zeroIdx, zeroIdx}); auto result = - rewriter.create(loc, tileType, rowMaskTile.getResult(), - inputTile.getResult(), fillTile); - rewriter.create(loc, result.getResult(), output, - ValueRange{localRowIdx, localColIdx}); + TileWhereOp::create(rewriter, loc, tileType, rowMaskTile.getResult(), + inputTile.getResult(), fillTile); + memref::StoreOp::create(rewriter, loc, result.getResult(), output, + ValueRange{localRowIdx, localColIdx}); }; auto emitColMasked = [&](Value localRowIdx, Value localColIdx) { - auto inputTile = rewriter.create( - loc, input, ValueRange{localRowIdx, localColIdx}); + auto inputTile = memref::LoadOp::create( + rewriter, loc, input, ValueRange{localRowIdx, localColIdx}); auto fillTile = createFillTile(); - auto colMaskTile = rewriter.create( - loc, colMaskCB, ValueRange{zeroIdx, zeroIdx}); + auto colMaskTile = memref::LoadOp::create(rewriter, loc, colMaskCB, + ValueRange{zeroIdx, zeroIdx}); auto result = - rewriter.create(loc, tileType, colMaskTile.getResult(), - inputTile.getResult(), fillTile); - rewriter.create(loc, result.getResult(), output, - ValueRange{localRowIdx, localColIdx}); + TileWhereOp::create(rewriter, loc, tileType, colMaskTile.getResult(), + inputTile.getResult(), fillTile); + memref::StoreOp::create(rewriter, loc, result.getResult(), output, + ValueRange{localRowIdx, localColIdx}); }; auto emitCornerMasked = [&](Value localRowIdx, Value localColIdx) { - auto inputTile = rewriter.create( - loc, input, ValueRange{localRowIdx, localColIdx}); + auto inputTile = memref::LoadOp::create( + rewriter, loc, input, ValueRange{localRowIdx, localColIdx}); auto fillTile1 = createFillTile(); - auto rowMaskTile = rewriter.create( - loc, rowMaskCB, ValueRange{zeroIdx, zeroIdx}); + auto rowMaskTile = memref::LoadOp::create(rewriter, loc, rowMaskCB, + ValueRange{zeroIdx, zeroIdx}); auto rowMaskedResult = - rewriter.create(loc, tileType, rowMaskTile.getResult(), - inputTile.getResult(), fillTile1); + TileWhereOp::create(rewriter, loc, tileType, rowMaskTile.getResult(), + inputTile.getResult(), fillTile1); auto fillTile2 = createFillTile(); - auto colMaskTile = rewriter.create( - loc, colMaskCB, ValueRange{zeroIdx, zeroIdx}); + auto colMaskTile = memref::LoadOp::create(rewriter, loc, colMaskCB, + ValueRange{zeroIdx, zeroIdx}); auto finalResult = - rewriter.create(loc, tileType, colMaskTile.getResult(), - rowMaskedResult.getResult(), fillTile2); - rewriter.create(loc, finalResult.getResult(), output, - ValueRange{localRowIdx, localColIdx}); + TileWhereOp::create(rewriter, loc, tileType, colMaskTile.getResult(), + rowMaskedResult.getResult(), fillTile2); + memref::StoreOp::create(rewriter, loc, finalResult.getResult(), output, + ValueRange{localRowIdx, localColIdx}); }; auto emitFill = [&](Value localRowIdx, Value localColIdx) { auto fillTile = createFillTile(); - rewriter.create(loc, fillTile, output, - ValueRange{localRowIdx, localColIdx}); + memref::StoreOp::create(rewriter, loc, fillTile, output, + ValueRange{localRowIdx, localColIdx}); }; // Helper to create a nested loop over local coordinates. @@ -278,11 +280,11 @@ struct DecomposeBlockMaskPattern : OpRewritePattern { Value colEnd, std::function emitBody) { auto outerLoop = - rewriter.create(loc, rowStart, rowEnd, oneIdx); + scf::ForOp::create(rewriter, loc, rowStart, rowEnd, oneIdx); rewriter.setInsertionPointToStart(outerLoop.getBody()); auto innerLoop = - rewriter.create(loc, colStart, colEnd, oneIdx); + scf::ForOp::create(rewriter, loc, colStart, colEnd, oneIdx); // Mark the INNER loop as the compute root, since that's where // the actual compute operations are emitted. This ensures DST // syncs are placed inside the inner loop body, not the outer. diff --git a/lib/Dialect/D2M/Transforms/ElementwiseFusion.cpp b/lib/Dialect/D2M/Transforms/ElementwiseFusion.cpp index 2dec7a73b5d..a586030d05e 100644 --- a/lib/Dialect/D2M/Transforms/ElementwiseFusion.cpp +++ b/lib/Dialect/D2M/Transforms/ElementwiseFusion.cpp @@ -273,8 +273,8 @@ static GenericOp createFusedGeneric(OpOperand *fusedOperand, GenericOp producer, ///////////////////////////////////////////////////////////////////////////// auto fusedResultTypes = TypeRange(fusedOutputs); - auto fusedOp = rewriter.create( - consumer.getLoc(), fusedResultTypes, fusedInputs, fusedOutputs, + auto fusedOp = GenericOp::create( + rewriter, consumer.getLoc(), fusedResultTypes, fusedInputs, fusedOutputs, mergedAdditionalArgs, consumer.getGrid(), consumer.getBlockFactors(), rewriter.getAffineMapArrayAttr(fusedMaps), consumer.getIteratorTypes(), consumer.getThreads(), consumer.getScratchInputsAttr(), /*regions=*/1); @@ -326,8 +326,9 @@ static GenericOp createFusedGeneric(OpOperand *fusedOperand, GenericOp producer, SmallVector fusedTensorEmpties; for (Type emptyType : fusedEmptyTypes) { auto shapedType = mlir::cast(emptyType); - auto emptyOp = rewriter.create( - fusedOp.getLoc(), shapedType.getShape(), shapedType.getElementType()); + auto emptyOp = mlir::tensor::EmptyOp::create(rewriter, fusedOp.getLoc(), + shapedType.getShape(), + shapedType.getElementType()); fusedTensorEmpties.push_back(emptyOp.getResult()); } @@ -540,7 +541,7 @@ static GenericOp createFusedGeneric(OpOperand *fusedOperand, GenericOp producer, fusedYields.push_back(irMap.lookupOrDefault(y)); } rewriter.setInsertionPointToEnd(&fusedBlock); - rewriter.create(fusedOp.getLoc(), fusedYields); + YieldOp::create(rewriter, fusedOp.getLoc(), fusedYields); return fusedOp; } diff --git a/lib/Dialect/D2M/Transforms/GenerateOuterLoops.cpp b/lib/Dialect/D2M/Transforms/GenerateOuterLoops.cpp index 9e3ddc56c67..ed6350d150b 100644 --- a/lib/Dialect/D2M/Transforms/GenerateOuterLoops.cpp +++ b/lib/Dialect/D2M/Transforms/GenerateOuterLoops.cpp @@ -27,7 +27,7 @@ class D2MGenerateOuterLoopsRewriter : public OpRewritePattern { SmallVector ubs; for (unsigned i = 0; i < numDims; ++i) { ubs.push_back( - rewriter.create(loc, static_cast(i))); + GetBlockFactorOp::create(rewriter, loc, static_cast(i))); } // Upper bound map: ()[s0] -> (s0). @@ -52,7 +52,7 @@ class D2MGenerateOuterLoopsRewriter : public OpRewritePattern { rewriter.eraseOp(innerBody->getTerminator()); rewriter.mergeBlocks(regionBlock, innerBody, loopedBlock->getArguments()); rewriter.setInsertionPointToEnd(innerBody); - rewriter.create(loc); + affine::AffineYieldOp::create(rewriter, loc); return loops; } @@ -71,10 +71,10 @@ class D2MGenerateOuterLoopsRewriter : public OpRewritePattern { for (BlockIndexOp blockIndex : blockIndices) { rewriter.setInsertionPoint(blockIndex); int64_t dim = blockIndex.getDim(); - Value offset = rewriter.create(loc, dim); - Value iterIndex = rewriter.create(loc, dim); - Value index = rewriter.create( - loc, addMap, ValueRange{iterIndex, offset}); + Value offset = BlockOffsetOp::create(rewriter, loc, dim); + Value iterIndex = IterIndexOp::create(rewriter, loc, dim); + Value index = affine::AffineApplyOp::create( + rewriter, loc, addMap, ValueRange{iterIndex, offset}); rewriter.replaceOp(blockIndex, index); } } @@ -133,9 +133,10 @@ class D2MGenerateOuterLoopsRewriter : public OpRewritePattern { // Create a new GenericOp with the same structure // After generating loops, preserve all attributes including block_factors // (needed by LowerLoadStoreOpsToDMA for stream index computation). - auto loopedGeneric = rewriter.create( - generic->getLoc(), generic.getResultTypes(), generic.getInputs(), - generic.getOutputs(), generic.getAdditionalArgs(), generic.getGrid(), + auto loopedGeneric = GenericOp::create( + rewriter, generic->getLoc(), generic.getResultTypes(), + generic.getInputs(), generic.getOutputs(), generic.getAdditionalArgs(), + generic.getGrid(), /* block_factors */ generic.getBlockFactors(), /* indexing_maps */ generic.getIndexingMaps(), /* iterator_types */ generic.getIteratorTypes(), generic.getThreads(), diff --git a/lib/Dialect/D2M/Transforms/GenericLinearizeMemref.cpp b/lib/Dialect/D2M/Transforms/GenericLinearizeMemref.cpp index 99a4d179c48..74d927e11ed 100644 --- a/lib/Dialect/D2M/Transforms/GenericLinearizeMemref.cpp +++ b/lib/Dialect/D2M/Transforms/GenericLinearizeMemref.cpp @@ -70,8 +70,8 @@ struct D2MLinearizeMemrefAccessRewriter final collapsedDims) && "linearizeAffineMap assumes that the shape is collapsible aka " "has contiguous memory layout"); - linearizedArg = rewriter.create(op.getLoc(), val, - collapsedDims); + linearizedArg = memref::CollapseShapeOp::create(rewriter, op.getLoc(), + val, collapsedDims); collapseOps->insert({val, linearizedArg}); } @@ -80,8 +80,8 @@ struct D2MLinearizeMemrefAccessRewriter final rewriter.setInsertionPoint(op); - Value linearIndex = - rewriter.create(op.getLoc(), linearMap, indices); + Value linearIndex = affine::AffineApplyOp::create(rewriter, op.getLoc(), + linearMap, indices); // Create new load/store with linearized access if constexpr (std::is_same_v) { diff --git a/lib/Dialect/D2M/Transforms/GenericRegionsToFuncs.cpp b/lib/Dialect/D2M/Transforms/GenericRegionsToFuncs.cpp index 6c3ee78809a..4c4d7806413 100644 --- a/lib/Dialect/D2M/Transforms/GenericRegionsToFuncs.cpp +++ b/lib/Dialect/D2M/Transforms/GenericRegionsToFuncs.cpp @@ -38,7 +38,7 @@ static void rewriteOperand(OpBuilder &builder, DMAOpInterface dma, applyViews(dmaOperand.get().getDefiningOp()); } Operation *globalOperand = - builder.create(dma.getLoc(), memref, operandIndex); + GetGlobalOperandOp::create(builder, dma.getLoc(), memref, operandIndex); dmaOperand.set(globalOperand->getResult(0)); } @@ -69,8 +69,9 @@ static void rewriteAdditionalArgOperands(OpBuilder &builder, generic->isAncestor(use.getOwner())) { builder.setInsertionPoint(use.getOwner()); //  And insert a get_global_operand op where the generic operand is being used. - auto globalOperand = builder.create( - use.getOwner()->getLoc(), operand.getType(), capturedOperandIndex); + auto globalOperand = + GetGlobalOperandOp::create(builder, use.getOwner()->getLoc(), + operand.getType(), capturedOperandIndex); use.set(globalOperand.getResult()); } } @@ -112,8 +113,8 @@ class D2MGenericRegionsToFuncs Location loc = region.getNumArguments() > 0 ? region.getArgument(0).getLoc() : generic.getLoc(); - auto func = builder.create( - loc, symbolName, + auto func = func::FuncOp::create( + builder, loc, symbolName, FunctionType::get(builder.getContext(), region.getArgumentTypes(), {})); func.setPrivate(); @@ -122,14 +123,15 @@ class D2MGenericRegionsToFuncs ttmlir::utils::setFunctionType(func, ttmlir::utils::FunctionType::Kernel); builder.setInsertionPointToEnd(&func.getBody().front()); - builder.create(generic.getLoc()); + func::ReturnOp::create(builder, generic.getLoc()); threads.push_back(threadAttrWithSym); } builder.setInsertionPoint(generic); - auto symbolicGeneric = builder.create( - generic->getLoc(), generic.getResultTypes(), generic.getInputs(), - generic.getOutputs(), generic.getAdditionalArgs(), generic.getGrid(), + auto symbolicGeneric = GenericOp::create( + builder, generic->getLoc(), generic.getResultTypes(), + generic.getInputs(), generic.getOutputs(), + generic.getAdditionalArgs(), generic.getGrid(), generic.getBlockFactors(), generic.getIndexingMaps(), generic.getIteratorTypes(), builder.getArrayAttr(threads), generic.getScratchInputsAttr(), diff --git a/lib/Dialect/D2M/Transforms/GridSelection.cpp b/lib/Dialect/D2M/Transforms/GridSelection.cpp index a5672ba6615..1ffb57f7cad 100644 --- a/lib/Dialect/D2M/Transforms/GridSelection.cpp +++ b/lib/Dialect/D2M/Transforms/GridSelection.cpp @@ -409,13 +409,13 @@ static void optimizeToLayoutGrid(d2m::ToLayoutOp toLayoutOp, virtualGridForwardMapping = AffineMapAttr::get(forwardMap); } - auto newEmptyOp = builder.create( - emptyOp.getLoc(), newTensorType, virtualGridInverseMapping, + auto newEmptyOp = d2m::EmptyOp::create( + builder, emptyOp.getLoc(), newTensorType, virtualGridInverseMapping, virtualGridForwardMapping); builder.setInsertionPoint(toLayoutOp); - auto newToLayoutOp = builder.create( - toLayoutOp.getLoc(), toLayoutOp.getInput(), newEmptyOp); + auto newToLayoutOp = d2m::ToLayoutOp::create( + builder, toLayoutOp.getLoc(), toLayoutOp.getInput(), newEmptyOp); // Reblock it back to original shape to preserve IR correctness. // The view chain that applyViews composes through depends on this @@ -426,8 +426,8 @@ static void optimizeToLayoutGrid(d2m::ToLayoutOp toLayoutOp, auto reblockMap = ttmlir::utils::calculateReblockMap( newTensorType.getShape(), viewOutputType.getShape(), builder.getContext()); - auto view = builder.create( - toLayoutOp.getLoc(), viewOutputType, newToLayoutOp.getResult(0), + auto view = d2m::ViewLayoutOp::create( + builder, toLayoutOp.getLoc(), viewOutputType, newToLayoutOp.getResult(0), reblockMap, /*reinterpretLayout=*/false); // We expect the ToLayout to be used in one of two ways: @@ -542,9 +542,9 @@ static void insertViewForTTNNDRAMTensor(Value operand, fakeShardedShape, metalTensor.getElementType(), viewOutputLayout); builder.setInsertionPointAfter(castOp); - auto viewOp = builder.create( - castOp.getLoc(), viewOutputTensor, castOp.getResult(), - AffineMapAttr::get(reblockMap)); + auto viewOp = d2m::ViewLayoutOp::create(builder, castOp.getLoc(), + viewOutputTensor, castOp.getResult(), + AffineMapAttr::get(reblockMap)); castOp.getResult().replaceAllUsesExcept(viewOp.getResult(), viewOp); } @@ -570,8 +570,9 @@ static void optimizeTTNNMetalLayoutCastOpGrid( builder.setInsertionPointAfter(castOp); - auto newViewLayoutOp = builder.create( - castOp.getLoc(), newTensorType, castOp.getResult(), gridRemapping); + auto newViewLayoutOp = + d2m::ViewLayoutOp::create(builder, castOp.getLoc(), newTensorType, + castOp.getResult(), gridRemapping); // Reblock it back to original shape to preserve IR correctness. auto viewOutputType = utils::reblockTensor( @@ -579,9 +580,10 @@ static void optimizeTTNNMetalLayoutCastOpGrid( auto reblockMap = ttmlir::utils::calculateReblockMap( newTensorType.getShape(), viewOutputType.getShape(), builder.getContext()); - auto revertingView = builder.create( - castOp.getLoc(), viewOutputType, newViewLayoutOp.getResult(), reblockMap, - /*reinterpretLayout=*/false); + auto revertingView = + d2m::ViewLayoutOp::create(builder, castOp.getLoc(), viewOutputType, + newViewLayoutOp.getResult(), reblockMap, + /*reinterpretLayout=*/false); castOp.getResult().replaceAllUsesExcept(revertingView.getResult(), newViewLayoutOp); @@ -938,8 +940,8 @@ updateStreamLayoutOps(ArrayRef streamLayoutsToUpdate, } } - auto newStorageEmpty = builder.create( - storageEmpty.getLoc(), + auto newStorageEmpty = d2m::EmptyOp::create( + builder, storageEmpty.getLoc(), RankedTensorType::get(newStorageShape, elementType, newStorageLayout), virtualGridInverseMapping, virtualGridForwardMapping); @@ -961,9 +963,10 @@ updateStreamLayoutOps(ArrayRef streamLayoutsToUpdate, newStorageShape, outputStreamType.getElementType(), newOutputLayout); builder.setInsertionPoint(streamLayout); - auto newStreamLayout = builder.create( - streamLayout.getLoc(), newStreamOutputType, streamLayout.getInput(), - AffineMapAttr::get(newOutputMap), newStorageEmpty); + auto newStreamLayout = d2m::StreamLayoutOp::create( + builder, streamLayout.getLoc(), newStreamOutputType, + streamLayout.getInput(), AffineMapAttr::get(newOutputMap), + newStorageEmpty); // We expect the StreamLayout to be used only by the GenericOp we're // optimizing. Check that all uses are either the GenericOp itself or @@ -1024,8 +1027,8 @@ static void updateEmptyOps(ArrayRef emptyOpsToUpdate, } } - auto newEmptyOp = builder.create( - emptyOp.getLoc(), newTensorType, virtualGridInverseMapping, + auto newEmptyOp = d2m::EmptyOp::create( + builder, emptyOp.getLoc(), newTensorType, virtualGridInverseMapping, virtualGridForwardMapping); emptyOp.getResult().replaceAllUsesWith(newEmptyOp.getResult()); emptyOp.erase(); @@ -1083,8 +1086,8 @@ recreateGenericOp(d2m::GenericOp genericOp, auto viewTensorType = utils::reblockTensor(tensorType, optimalGrid); auto reblockMap = ttmlir::utils::calculateReblockMap( tensorType.getShape(), viewTensorType.getShape(), builder.getContext()); - auto view = builder.create( - genericOp.getLoc(), viewTensorType, operand.get(), reblockMap, + auto view = d2m::ViewLayoutOp::create( + builder, genericOp.getLoc(), viewTensorType, operand.get(), reblockMap, /*reinterpretLayout=*/false); newOperands.push_back(view.getResult()); } @@ -1100,8 +1103,8 @@ recreateGenericOp(d2m::GenericOp genericOp, Region &oldRegion = genericOp.getRegion(0); auto newAdditionalArgs = genericOp.getAdditionalArgs(); - auto newGenericOp = builder.create( - genericOp.getLoc(), newInputs, newOutputs, newAdditionalArgs, + auto newGenericOp = d2m::GenericOp::create( + builder, genericOp.getLoc(), newInputs, newOutputs, newAdditionalArgs, genericOp.getIndexingMaps(), genericOp.getIteratorTypes(), [&](OpBuilder &b, Location loc, ValueRange blockArgs) { IRMapping mapping; diff --git a/lib/Dialect/D2M/Transforms/HoistCBAllocs.cpp b/lib/Dialect/D2M/Transforms/HoistCBAllocs.cpp index 9265039e091..8101ffde4d2 100644 --- a/lib/Dialect/D2M/Transforms/HoistCBAllocs.cpp +++ b/lib/Dialect/D2M/Transforms/HoistCBAllocs.cpp @@ -59,7 +59,7 @@ class D2MHoistCBAllocs : public impl::D2MHoistCBAllocsBase { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(genericOp); auto externalAlloc = - rewriter.create(genericOp.getLoc(), allocType); + memref::AllocOp::create(rewriter, genericOp.getLoc(), allocType); // Transfer address and alignment. if (auto addressAttr = allocOp->getAttrOfType("address")) { diff --git a/lib/Dialect/D2M/Transforms/InsertDstRegisterAccess.cpp b/lib/Dialect/D2M/Transforms/InsertDstRegisterAccess.cpp index 9caf0d28dee..0792de781dd 100644 --- a/lib/Dialect/D2M/Transforms/InsertDstRegisterAccess.cpp +++ b/lib/Dialect/D2M/Transforms/InsertDstRegisterAccess.cpp @@ -179,8 +179,8 @@ static bool hasTileMatmul(linalg::GenericOp linalgGenericOp) { // Falls back to constant 1 when loop metadata is unavailable. static Value getSecondIterationValue(PatternRewriter &rewriter, Location loc, Value loopIV) { - auto one = rewriter.create( - loc, rewriter.getIndexType(), + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); auto ivBlockArg = mlir::dyn_cast(loopIV); @@ -199,31 +199,31 @@ static Value getSecondIterationValue(PatternRewriter &rewriter, Location loc, } if (auto scfFor = mlir::dyn_cast(ownerOp)) { - return rewriter.create(loc, scfFor.getLowerBound(), - scfFor.getStep()); + return arith::AddIOp::create(rewriter, loc, scfFor.getLowerBound(), + scfFor.getStep()); } if (auto affineFor = mlir::dyn_cast(ownerOp)) { Value lb = nullptr; if (affineFor.hasConstantLowerBound()) { - lb = rewriter.create( - loc, rewriter.getIndexType(), + lb = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), affineFor.getConstantLowerBound())); } else { AffineMap lowerBoundMap = affineFor.getLowerBoundMap(); if (lowerBoundMap.getNumResults() == 1) { - lb = rewriter.create( - loc, lowerBoundMap, affineFor.getLowerBoundOperands()); + lb = affine::AffineApplyOp::create(rewriter, loc, lowerBoundMap, + affineFor.getLowerBoundOperands()); } } if (lb) { - Value step = rewriter.create( - loc, rewriter.getIndexType(), + Value step = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), affineFor.getStepAsInt())); - return rewriter.create(loc, lb, step); + return arith::AddIOp::create(rewriter, loc, lb, step); } } @@ -553,13 +553,13 @@ struct D2MInsertDstRegisterAccessRewriter final AcquireDstOp acquireDst, Value loopIV) { rewriter.setInsertionPointAfter(acquireDst); Value secondIterationValue = getSecondIterationValue(rewriter, loc, loopIV); - Value cond = rewriter.create(loc, arith::CmpIPredicate::eq, - loopIV, secondIterationValue); - auto ifOp = rewriter.create(loc, cond); + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + loopIV, secondIterationValue); + auto ifOp = scf::IfOp::create(rewriter, loc, cond); rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value enableFlag = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); - rewriter.create(loc, enableFlag); + Value enableFlag = arith::ConstantOp::create( + rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + SetL1AccumulateOp::create(rewriter, loc, enableFlag); } static bool @@ -750,7 +750,7 @@ struct D2MInsertDstRegisterAccessRewriter final rewriter.getAttr( ttcore::MemorySpace::RegisterDst)); - return rewriter.create(loc, dstType); + return AcquireDstOp::create(rewriter, loc, dstType); } // Walk all compute ops in the region and collect: @@ -981,8 +981,8 @@ struct D2MInsertDstRegisterAccessRewriter final for (Operation *loopOp : llvm::reverse(linalgLoops.value())) { rewriter.eraseOp(loopOp); } - rewriter.create(gOp.getLoc(), inputAMemref, - inputBMemref, outputCMemref); + d2m::TileMatmulBlockOp::create(rewriter, gOp.getLoc(), inputAMemref, + inputBMemref, outputCMemref); } return true; @@ -1011,8 +1011,8 @@ struct D2MInsertDstRegisterAccessRewriter final auto loc = record.loadStore.getLoc(); Value cb = record.loadStore.getMemref(); - auto cbLoad = rewriter.create( - loc, cb, l1AccessMap, l1AccessIndices); + auto cbLoad = affine::AffineLoadOp::create( + rewriter, loc, cb, l1AccessMap, l1AccessIndices); Value valueToStore = cbLoad.getResult(); if (record.bcast.has_value()) { @@ -1023,8 +1023,8 @@ struct D2MInsertDstRegisterAccessRewriter final valueToStore = clonedBcast->getResult(0); } - rewriter.create( - loc, valueToStore, dst, dstAccessMap, dstAccessIndices); + affine::AffineStoreOp::create(rewriter, loc, valueToStore, dst, + dstAccessMap, dstAccessIndices); }; // Replace the original load with one from the DST. @@ -1032,8 +1032,9 @@ struct D2MInsertDstRegisterAccessRewriter final [&](PatternRewriter &rewriter, LoadStoreRecord record, AffineMap dstAccessMap, ValueRange dstAccessIndices) { - auto dstLoad = rewriter.create( - record.loadStore.getLoc(), dst, dstAccessMap, dstAccessIndices); + auto dstLoad = affine::AffineLoadOp::create( + rewriter, record.loadStore.getLoc(), dst, dstAccessMap, + dstAccessIndices); if (record.bcast.has_value()) { // Keep the original load in case another bcastOp uses it. record.bcast->getResult().replaceAllUsesWith(dstLoad.getResult()); @@ -1058,22 +1059,23 @@ struct D2MInsertDstRegisterAccessRewriter final AffineMap dstAccessMap, ValueRange dstAccessIndices) { auto loc = record.loadStore.getLoc(); Value cb = record.loadStore.getMemref(); - auto dstLoad = rewriter.create( - loc, dst, dstAccessMap, dstAccessIndices); + auto dstLoad = affine::AffineLoadOp::create( + rewriter, loc, dst, dstAccessMap, dstAccessIndices); Value valueToStore = dstLoad.getResult(); // Insert DST reinterpret cast if destination CB type differs from // DST type. auto cbType = mlir::cast(cb.getType()); if (valueToStore.getType() != cbType.getElementType()) { - valueToStore = rewriter - .create( - loc, cbType.getElementType(), valueToStore) + valueToStore = d2m::DstReinterpretCastOp::create( + rewriter, + + loc, cbType.getElementType(), valueToStore) .getResult(); } - rewriter.create( - loc, valueToStore, cb, l1AccessMap, l1AccessIndices); + affine::AffineStoreOp::create(rewriter, loc, valueToStore, cb, + l1AccessMap, l1AccessIndices); }; // Replace the original store with one to the DST. @@ -1085,10 +1087,11 @@ struct D2MInsertDstRegisterAccessRewriter final // Insert DST reinterpret cast if value type differs from DST type. auto dstType = mlir::cast(dst.getType()); if (valueToStore.getType() != dstType.getElementType()) { - valueToStore = rewriter - .create( - record.loadStore.getLoc(), - dstType.getElementType(), valueToStore) + valueToStore = d2m::DstReinterpretCastOp::create( + rewriter, + + record.loadStore.getLoc(), + dstType.getElementType(), valueToStore) .getResult(); } rewriter.replaceOpWithNewOp( @@ -1117,11 +1120,9 @@ struct D2MInsertDstRegisterAccessRewriter final // Initial condition: // - Bcast: load-if-all, start with true and disable when false shows up. // - Accum: skip-unless-any, start with false and enable when true shows up. - Value guard = - rewriter - .create(loc, rewriter.getI1Type(), - rewriter.getBoolAttr(isBcastGuard)) - .getResult(); + Value guard = arith::ConstantOp::create(rewriter, loc, rewriter.getI1Type(), + rewriter.getBoolAttr(isBcastGuard)) + .getResult(); // Check: // - Bcast: IS 1st iter? @@ -1129,24 +1130,24 @@ struct D2MInsertDstRegisterAccessRewriter final const auto cmpPredicate = isBcastGuard ? arith::CmpIPredicate::eq : arith::CmpIPredicate::ne; - auto zero = rewriter.create( - loc, rewriter.getIndexType(), + auto zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); for (Value guardIV : guardIVs) { Value cmp = - rewriter.create(loc, cmpPredicate, guardIV, zero); + arith::CmpIOp::create(rewriter, loc, cmpPredicate, guardIV, zero); // Aggregation: if (isBcastGuard) { // - Bcast: load if ALL(&&) bcast dims ARE at the 1st iter. - guard = rewriter.create(loc, guard, cmp).getResult(); + guard = arith::AndIOp::create(rewriter, loc, guard, cmp).getResult(); } else { // - Accum: reload if ANY(||) reduce dims is NOT at the 1st iter. - guard = rewriter.create(loc, guard, cmp).getResult(); + guard = arith::OrIOp::create(rewriter, loc, guard, cmp).getResult(); } } - return rewriter.create(loc, guard); + return scf::IfOp::create(rewriter, loc, guard); } template @@ -1353,25 +1354,25 @@ struct D2MInsertDstRegisterAccessRewriter final bool needsTypeCast = (originalType != dstType.getElementType()); if (needsTypeCast) { - auto cast = rewriter.create( - loc, dstType.getElementType(), valueToStore); + auto cast = d2m::DstReinterpretCastOp::create( + rewriter, loc, dstType.getElementType(), valueToStore); valueToStore = cast.getResult(); castOp = cast.getOperation(); } - auto storeOp = rewriter.create( - loc, valueToStore, dst, storeMap, storeIndices); + auto storeOp = affine::AffineStoreOp::create(rewriter, loc, valueToStore, + dst, storeMap, storeIndices); - auto loadedResult = rewriter.create( - loc, dst, storeMap, storeIndices); + auto loadedResult = affine::AffineLoadOp::create(rewriter, loc, dst, + storeMap, storeIndices); // If we cast for storage, we need to cast back to the original type // after loading, since downstream ops expect the original type. Value replacementValue = loadedResult.getResult(); Operation *castBackOp = nullptr; if (needsTypeCast) { - auto castBack = rewriter.create( - loc, originalType, replacementValue); + auto castBack = d2m::DstReinterpretCastOp::create( + rewriter, loc, originalType, replacementValue); replacementValue = castBack.getResult(); castBackOp = castBack.getOperation(); } @@ -1736,10 +1737,10 @@ struct D2MInsertDstRegisterAccessRewriter final [&](PatternRewriter &rewriter, Location loc, Value cb, AffineMap l1AccessMap, ValueRange l1AccessIndices, AffineMap dstAccessMap, ValueRange dstAccessIndices) { - auto l1Load = rewriter.create( - loc, cb, l1AccessMap, l1AccessIndices); - rewriter.create( - loc, l1Load.getResult(), dst, dstAccessMap, dstAccessIndices); + auto l1Load = affine::AffineLoadOp::create( + rewriter, loc, cb, l1AccessMap, l1AccessIndices); + affine::AffineStoreOp::create(rewriter, loc, l1Load.getResult(), + dst, dstAccessMap, dstAccessIndices); }, // Replacement of the original load with one from dst. [&](PatternRewriter &rewriter, affine::AffineLoadOp op, @@ -1761,16 +1762,17 @@ struct D2MInsertDstRegisterAccessRewriter final // When L1 accumulation is enabled, skip CB->DST copy but still // rewrite the original load to use DST. if (!enableL1Acc) { - auto cbLoad = rewriter.create( - loadOp.getLoc(), loadOp.getMemRef(), loadOp.getIndices()); - rewriter.create(loadOp.getLoc(), - cbLoad.getResult(), dst, - dstAccessMap, ValueRange{}); + auto cbLoad = + memref::LoadOp::create(rewriter, loadOp.getLoc(), + loadOp.getMemRef(), loadOp.getIndices()); + affine::AffineStoreOp::create(rewriter, loadOp.getLoc(), + cbLoad.getResult(), dst, dstAccessMap, + ValueRange{}); } // Replace original load with DST load. - auto dstLoad = rewriter.create( - loadOp.getLoc(), dst, dstAccessMap, ValueRange{}); + auto dstLoad = affine::AffineLoadOp::create( + rewriter, loadOp.getLoc(), dst, dstAccessMap, ValueRange{}); rewriter.replaceOp(loadOp, dstLoad.getResult()); } @@ -1783,22 +1785,23 @@ struct D2MInsertDstRegisterAccessRewriter final [&](PatternRewriter &rewriter, Location loc, Value cb, AffineMap l1AccessMap, ValueRange l1AccessIndices, AffineMap dstAccessMap, ValueRange dstAccessIndices) { - auto dstLoad = rewriter.create( - loc, dst, dstAccessMap, dstAccessIndices); + auto dstLoad = affine::AffineLoadOp::create( + rewriter, loc, dst, dstAccessMap, dstAccessIndices); Value valueToStore = dstLoad.getResult(); // Insert dst reinterpret cast if destination CB type differs // from dst type auto cbType = mlir::cast(cb.getType()); if (valueToStore.getType() != cbType.getElementType()) { - valueToStore = rewriter - .create( - loc, cbType.getElementType(), valueToStore) + valueToStore = d2m::DstReinterpretCastOp::create( + rewriter, + + loc, cbType.getElementType(), valueToStore) .getResult(); } - rewriter.create( - loc, dstLoad.getResult(), cb, l1AccessMap, l1AccessIndices); + affine::AffineStoreOp::create(rewriter, loc, dstLoad.getResult(), + cb, l1AccessMap, l1AccessIndices); }, // Replacement of the original store with one from dst. [&](PatternRewriter &rewriter, affine::AffineStoreOp op, @@ -1809,9 +1812,10 @@ struct D2MInsertDstRegisterAccessRewriter final auto dstType = mlir::cast(dst.getType()); if (valueToStore.getType() != dstType.getElementType()) { valueToStore = - rewriter - .create( - op.getLoc(), dstType.getElementType(), valueToStore) + d2m::DstReinterpretCastOp::create( + rewriter, + + op.getLoc(), dstType.getElementType(), valueToStore) .getResult(); } @@ -1832,27 +1836,28 @@ struct D2MInsertDstRegisterAccessRewriter final auto dstType = mlir::cast(dst.getType()); if (valueToStore.getType() != dstType.getElementType()) { valueToStore = - rewriter - .create( - storeOp.getLoc(), dstType.getElementType(), valueToStore) + d2m::DstReinterpretCastOp::create( + rewriter, + + storeOp.getLoc(), dstType.getElementType(), valueToStore) .getResult(); } // Store to DST. - rewriter.create(storeOp.getLoc(), valueToStore, - dst, dstAccessMap, ValueRange{}); + affine::AffineStoreOp::create(rewriter, storeOp.getLoc(), valueToStore, + dst, dstAccessMap, ValueRange{}); // Load from DST and store to CB. - auto dstLoad = rewriter.create( - storeOp.getLoc(), dst, dstAccessMap, ValueRange{}); + auto dstLoad = affine::AffineLoadOp::create( + rewriter, storeOp.getLoc(), dst, dstAccessMap, ValueRange{}); Value packValue = dstLoad.getResult(); auto cbType = mlir::cast(storeOp.getMemRef().getType()); if (packValue.getType() != cbType.getElementType()) { - packValue = - rewriter - .create( - storeOp.getLoc(), cbType.getElementType(), packValue) - .getResult(); + packValue = d2m::DstReinterpretCastOp::create( + rewriter, + + storeOp.getLoc(), cbType.getElementType(), packValue) + .getResult(); } // Replace original store with CB store. diff --git a/lib/Dialect/D2M/Transforms/InsertStreams.cpp b/lib/Dialect/D2M/Transforms/InsertStreams.cpp index 4b2d02ab190..b83ab51e45c 100644 --- a/lib/Dialect/D2M/Transforms/InsertStreams.cpp +++ b/lib/Dialect/D2M/Transforms/InsertStreams.cpp @@ -70,9 +70,10 @@ class D2MInsertStreamsRewriter final : public OpRewritePattern { MemRefType::get(memref.getShape(), memref.getElementType(), storageAttr, rewriter.getAttr( ttcore::MemorySpace::DeviceL1)); - auto storage = rewriter.create(op.getLoc(), storageMemref); - auto streamLayout = rewriter.create( - op.getLoc(), streamMemref, operand.get(), + auto storage = + memref::AllocOp::create(rewriter, op.getLoc(), storageMemref); + auto streamLayout = d2m::StreamLayoutOp::create( + rewriter, op.getLoc(), streamMemref, operand.get(), AffineMapAttr::get(rewriter.getMultiDimIdentityMap(memref.getRank())), storage); rewriter.modifyOpInPlace( diff --git a/lib/Dialect/D2M/Transforms/LowerDMAToFullyIndexedForm.cpp b/lib/Dialect/D2M/Transforms/LowerDMAToFullyIndexedForm.cpp index d412b0508dc..9cc23ccb8ec 100644 --- a/lib/Dialect/D2M/Transforms/LowerDMAToFullyIndexedForm.cpp +++ b/lib/Dialect/D2M/Transforms/LowerDMAToFullyIndexedForm.cpp @@ -28,14 +28,14 @@ namespace mlir::tt::d2m { static std::tuple, SmallVector, SmallVector> getLoopBounds(OpBuilder &builder, Location loc, ArrayRef shardShape) { - Value zero = builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(0)); - Value one = builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(1)); + Value zero = arith::ConstantOp::create(builder, loc, builder.getIndexType(), + builder.getIndexAttr(0)); + Value one = arith::ConstantOp::create(builder, loc, builder.getIndexType(), + builder.getIndexAttr(1)); SmallVector lbs(shardShape.size(), zero); SmallVector ubs(llvm::map_range(shardShape, [&](int64_t dim) { - return builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(dim)); + return arith::ConstantOp::create(builder, loc, builder.getIndexType(), + builder.getIndexAttr(dim)); })); SmallVector step(shardShape.size(), one); return std::make_tuple(lbs, ubs, step); @@ -93,7 +93,7 @@ static SmallVector applyMap(Builder &builder, Location loc, AffineMap map, ValueRange index, bool isRemote) { auto affineApply = [&](AffineMap map, ValueRange index) { - return builder.template create(loc, map, index); + return affine::AffineApplyOp::create(builder, loc, map, index); }; if (isRemote) { @@ -217,8 +217,8 @@ static Value generateFullyIndexedDMAOps( SmallVector remoteIndices = gridIndices; SmallVector localIndices; - Value zero = builder.create(loc, builder.getIndexType(), - builder.getIndexAttr(0)); + Value zero = arith::ConstantOp::create(builder, loc, builder.getIndexType(), + builder.getIndexAttr(0)); for (size_t i = 0; i < shardShape.size(); ++i) { remoteIndices.push_back(zero); localIndices.push_back(zero); @@ -234,7 +234,7 @@ static Value generateFullyIndexedDMAOps( // Strided/non-contiguous: generate loops with guarded DMAs. auto [lbs, ubs, steps] = getLoopBounds(builder, loc, shardShape); - auto nullDmaTx = builder.create(loc); + auto nullDmaTx = NullTxOp::create(builder, loc); scf::LoopNest loopNest = scf::buildLoopNest( builder, loc, lbs, ubs, steps, ValueRange(nullDmaTx), @@ -252,50 +252,49 @@ static Value generateFullyIndexedDMAOps( localIndices, false); // Create guarded DMA operation based on coalescing factor. - Value cfExpr = loopBuilder.create( - innerLoc, loopBuilder.getIndexType(), + Value cfExpr = arith::ConstantOp::create( + loopBuilder, innerLoc, loopBuilder.getIndexType(), loopBuilder.getIndexAttr(coalescingFactor)); - Value zero = loopBuilder.create( - innerLoc, loopBuilder.getIndexType(), + Value zero = arith::ConstantOp::create( + loopBuilder, innerLoc, loopBuilder.getIndexType(), loopBuilder.getIntegerAttr(loopBuilder.getIndexType(), 0)); // Construct guard function: flat_index(iters) % coalescingFactor == 0 auto totalIterCount = zero; size_t currStride = 1; for (int i = iters.size() - 1; i >= 0; i--) { - Value currStrideExpr = loopBuilder.create( - innerLoc, loopBuilder.getIndexType(), + Value currStrideExpr = arith::ConstantOp::create( + loopBuilder, innerLoc, loopBuilder.getIndexType(), loopBuilder.getIndexAttr(currStride)); - auto scaledCount = - loopBuilder - .create(innerLoc, currStrideExpr, iters[i]) - .getResult(); - totalIterCount = - loopBuilder - .create(innerLoc, scaledCount, totalIterCount) - .getResult(); + auto scaledCount = arith::MulIOp::create(loopBuilder, innerLoc, + currStrideExpr, iters[i]) + .getResult(); + totalIterCount = arith::AddIOp::create(loopBuilder, innerLoc, + scaledCount, totalIterCount) + .getResult(); currStride *= shardShape[i]; } - auto moduloIterCount = - loopBuilder.create(innerLoc, totalIterCount, cfExpr) - .getResult(); - auto predicate = loopBuilder.create( - innerLoc, arith::CmpIPredicate::eq, moduloIterCount, zero); + auto moduloIterCount = arith::RemSIOp::create(loopBuilder, innerLoc, + totalIterCount, cfExpr) + .getResult(); + auto predicate = arith::CmpIOp::create(loopBuilder, innerLoc, + arith::CmpIPredicate::eq, + moduloIterCount, zero); - auto nulltx = loopBuilder.create(innerLoc); + auto nulltx = NullTxOp::create(loopBuilder, innerLoc); // Build guarded DMA. - auto ifExpr = loopBuilder.create( - innerLoc, TypeRange(SmallVector{nulltx}), predicate, - true /*addThenBlock*/, true /*addElseBlock*/); + auto ifExpr = scf::IfOp::create( + loopBuilder, innerLoc, TypeRange(SmallVector{nulltx}), + predicate, true /*addThenBlock*/, true /*addElseBlock*/); auto thenBuilder = ifExpr.getThenBodyBuilder(); Value dmaTx = createDMAOp(thenBuilder, innerLoc, remoteIndices, localIndices, coalescingFactor); - thenBuilder.create(innerLoc, dmaTx); + scf::YieldOp::create(thenBuilder, innerLoc, dmaTx); auto elseBuilder = ifExpr.getElseBodyBuilder(); - elseBuilder.create(innerLoc, args[0]); + scf::YieldOp::create(elseBuilder, innerLoc, args[0]); return SmallVector{ifExpr.getResult(0)}; }); @@ -356,8 +355,8 @@ class D2MLowerDMAReadToFullyIndexed : public OpRewritePattern { coalescingFactor, shardVolume, [&](OpBuilder &b, Location l, SmallVector &remoteIdx, SmallVector &localIdx, size_t cf) { - return b.create(l, remoteMemref, remoteIdx, localMemref, - localIdx, b.getI64IntegerAttr(cf)); + return DMAReadOp::create(b, l, remoteMemref, remoteIdx, localMemref, + localIdx, b.getI64IntegerAttr(cf)); }); rewriter.replaceOp(op, newTx); @@ -398,16 +397,16 @@ class D2MLowerDMAWriteToFullyIndexed : public OpRewritePattern { size_t shardVolume = ttmlir::utils::volume(shardShape); SmallVector localIndices; - Value zero = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); for (size_t i = 0; i < shardShape.size(); ++i) { localIndices.push_back(zero); } localIndices = applyMap(rewriter, loc, localMemoryMap, localIndices, false); - Value newTx = rewriter.create( - loc, localMemref, localIndices, dstMemref, localIndices, + Value newTx = DMAWriteOp::create( + rewriter, loc, localMemref, localIndices, dstMemref, localIndices, op.getMcastStartIndex(), op.getMcastShape(), shardVolume); rewriter.replaceOp(op, newTx); return success(); @@ -443,8 +442,8 @@ class D2MLowerDMAWriteToFullyIndexed : public OpRewritePattern { coalescingFactor, shardVolume, [&](OpBuilder &b, Location l, SmallVector &remoteIdx, SmallVector &localIdx, size_t cf) { - return b.create(l, localMemref, localIdx, dstMemref, - remoteIdx, cf); + return DMAWriteOp::create(b, l, localMemref, localIdx, dstMemref, + remoteIdx, cf); }); rewriter.replaceOp(op, newTx); diff --git a/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToDMA.cpp b/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToDMA.cpp index 11fe65e1e98..70149deaf4a 100644 --- a/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToDMA.cpp +++ b/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToDMA.cpp @@ -92,10 +92,10 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { mcastVolume *= dimSize; } - Value zero = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); - Value one = rewriter.create(loc, rewriter.getIndexType(), - rewriter.getIndexAttr(1)); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + Value one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(1)); // Get pre-allocated semaphores for synchronization. // These must have been set by D2MPreallocateMcastSemaphores pass. @@ -105,8 +105,9 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { // Number of receivers is mcastVolume - 1 (excluding sender itself). // The sender waits for this many semaphore increments before multicasting. - Value numReceiversVal = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(mcastVolume - 1)); + Value numReceiversVal = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getIndexAttr(mcastVolume - 1)); // Determine if this core is the sender. // The sender is at position mcastStartIndex[i] for each multicast @@ -117,13 +118,13 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { ValueRange mcastStartIndex = remoteLoad.getMcastStartIndex(); for (size_t i = 0; i < isMcastDim.size(); ++i) { if (isMcastDim[i]) { - Value coreIdx = rewriter.create( - loc, static_cast(i), gridMapping); - Value condition = rewriter.create( - loc, rewriter.getI1Type(), arith::CmpIPredicate::eq, coreIdx, - mcastStartIndex[i]); + Value coreIdx = CoreIndexOp::create( + rewriter, loc, static_cast(i), gridMapping); + Value condition = arith::CmpIOp::create( + rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::eq, + coreIdx, mcastStartIndex[i]); if (isSender) { - isSender = rewriter.create(loc, isSender, condition) + isSender = arith::AndIOp::create(rewriter, loc, isSender, condition) .getResult(); } else { isSender = condition; @@ -134,45 +135,45 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { // Reserve CB unconditionally before branching - both sender and receiver // need to reserve to maintain proper circular buffer semantics. - Value localMemref = rewriter.create(loc, cb).getResult(); + Value localMemref = ReserveOp::create(rewriter, loc, cb).getResult(); SmallVector gridIndices = remoteLoad.getIndices(); - rewriter.create( - loc, isSender, + scf::IfOp::create( + rewriter, loc, isSender, [&](OpBuilder &builder, Location loc) { // Sender: shard-level DMA read from remote. - Value dmaTx = builder.create(loc, remoteMemref, - gridIndices, localMemref); - builder.create(loc, dmaTx); + Value dmaTx = DMAReadOp::create(builder, loc, remoteMemref, + gridIndices, localMemref); + DMAWaitOp::create(builder, loc, dmaTx); // Wait for all receivers to be ready (mcastVolume - 1, excluding // sender). - builder.create(loc, receiversReadySemaphore, - numReceiversVal, zero); + SemaphoreWaitOp::create(builder, loc, receiversReadySemaphore, + numReceiversVal, zero); // Perform shard-level multicast DMA write: from local CB to local CB // with multicast parameters. The multicast parameters specify that // the data should be sent to other cores. We use localMemref (from // ReserveOp) as both source and destination - this is the Producer // buffer that was just filled by the DMA read above. - Value mcastTx = builder.create( - loc, localMemref, localMemref, remoteLoad.getMcastStartIndex(), - remoteLoad.getMcastShape()); - builder.create(loc, mcastTx); + Value mcastTx = DMAWriteOp::create( + builder, loc, localMemref, localMemref, + remoteLoad.getMcastStartIndex(), remoteLoad.getMcastShape()); + DMAWaitOp::create(builder, loc, mcastTx); // Signal receivers that sender is finished. - builder.create(loc, senderFinishedSemaphore, one, - remoteLoad.getMcastStartIndex(), - remoteLoad.getMcastShape()); + SemaphoreSetOp::create(builder, loc, senderFinishedSemaphore, one, + remoteLoad.getMcastStartIndex(), + remoteLoad.getMcastShape()); - builder.create(loc); + scf::YieldOp::create(builder, loc); }, [&](OpBuilder &builder, Location loc) { // Receiver: signal ready and wait for sender to finish. SmallVector senderCoreIndex; - Value zeroIdx = builder.create( - loc, builder.getIndexType(), builder.getIndexAttr(0)); + Value zeroIdx = arith::ConstantOp::create( + builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); // Build sender core index by reading actual core positions. // For dimensions that are multicast, sender is at mcastStartIndex. @@ -184,23 +185,23 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { senderCoreIndex.push_back(mcastStartIndex[i]); } else { // Non-multicast dimension - use current core's position. - Value currentCoreIdx = builder.create( - loc, static_cast(i), gridMapping); + Value currentCoreIdx = CoreIndexOp::create( + builder, loc, static_cast(i), gridMapping); senderCoreIndex.push_back(currentCoreIdx); } } - builder.create(loc, receiversReadySemaphore, one, - senderCoreIndex); - builder.create(loc, senderFinishedSemaphore, one, - zeroIdx); + SemaphoreIncOp::create(builder, loc, receiversReadySemaphore, one, + senderCoreIndex); + SemaphoreWaitOp::create(builder, loc, senderFinishedSemaphore, one, + zeroIdx); // Note: CB already reserved before the if/else, so receiver has // proper access to the multicast data. - builder.create(loc); + scf::YieldOp::create(builder, loc); }); - rewriter.create(loc, cb); + PushOp::create(rewriter, loc, cb); rewriter.eraseOp(remoteLoad); return success(); @@ -237,15 +238,15 @@ class D2MLowerRemoteLoadRewritePattern : public OpRewritePattern { Value remoteMemref = remoteLoad.getMemref(); SmallVector gridIndices = remoteLoad.getIndices(); - Value localMemref = rewriter.create(loc, cb).getResult(); - Value dmaTx = - rewriter.create(loc, remoteMemref, gridIndices, localMemref); + Value localMemref = ReserveOp::create(rewriter, loc, cb).getResult(); + Value dmaTx = DMAReadOp::create(rewriter, loc, remoteMemref, gridIndices, + localMemref); rewriter.eraseOp(remoteLoad); // Wait for DMA to complete. - rewriter.create(loc, dmaTx); - rewriter.create(loc, cb); + DMAWaitOp::create(rewriter, loc, dmaTx); + PushOp::create(rewriter, loc, cb); return success(); } }; @@ -282,16 +283,16 @@ class D2MLowerRemoteStoreRewritePattern SmallVector gridIndices = remoteStore.getIndices(); // Wait on CB, emit shard-level dma_write, wait, pop - Value localMemref = rewriter.create(loc, cb).getResult(); - Value dmaTx = rewriter.create(loc, localMemref, remoteMemref, - gridIndices); + Value localMemref = WaitOp::create(rewriter, loc, cb).getResult(); + Value dmaTx = DMAWriteOp::create(rewriter, loc, localMemref, remoteMemref, + gridIndices); rewriter.eraseOp(remoteStore); // Wait for DMA to complete. - rewriter.create(loc, dmaTx); + DMAWaitOp::create(rewriter, loc, dmaTx); // Pop the circular buffer to signal consumption. - rewriter.create(loc, cb); + PopOp::create(rewriter, loc, cb); return success(); } }; diff --git a/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToExplicitCBForm.cpp b/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToExplicitCBForm.cpp index 6fa8b94bb5d..97434033e32 100644 --- a/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToExplicitCBForm.cpp +++ b/lib/Dialect/D2M/Transforms/LowerLoadStoreOpsToExplicitCBForm.cpp @@ -153,12 +153,12 @@ static void simplifyLoadStorePairs(ModuleOp moduleOp, IRRewriter &rewriter, if (!isRemoteStore) { // Create the explicit CB form of remote_load (no localBuffer, has CB // operand) - rewriter.create(loc, loadMemref, loadOp.getIndices(), - outputCB, loadOp.getMcastStartIndex(), - loadOp.getMcastShape()); + RemoteLoadOp::create(rewriter, loc, loadMemref, loadOp.getIndices(), + outputCB, loadOp.getMcastStartIndex(), + loadOp.getMcastShape()); } else { - rewriter.create(loc, storeMemref, loadOp.getIndices(), - inputCB); + RemoteStoreOp::create(rewriter, loc, storeMemref, loadOp.getIndices(), + inputCB); } // Get the shared localBuffer before erasing operations @@ -240,13 +240,13 @@ static PushPopInfo convertToExplicitCBForm(ModuleOp moduleOp, // Create the explicit CB form of remote_load (no localBuffer, no result, // has CB operand) d2m.remote_load %memref[indices] into %cb - rewriter.create(loc, memref, remoteLoad.getIndices(), assocCb, - remoteLoad.getMcastStartIndex(), - remoteLoad.getMcastShape()); + RemoteLoadOp::create(rewriter, loc, memref, remoteLoad.getIndices(), + assocCb, remoteLoad.getMcastStartIndex(), + remoteLoad.getMcastShape()); // Create wait operation to produce the result value // %in = d2m.wait %cb - auto waitOp = rewriter.create(loc, assocCb); + auto waitOp = WaitOp::create(rewriter, loc, assocCb); // Move any operations that use localBuffer and come before remoteLoad // to after waitOp. This handles cases like collapse_shape ops that were @@ -344,7 +344,7 @@ static PushPopInfo convertToExplicitCBForm(ModuleOp moduleOp, // Create reserve operation // %out = d2m.reserve %cb - auto reserveOp = rewriter.create(loc, assocCb); + auto reserveOp = ReserveOp::create(rewriter, loc, assocCb); // Replace all uses of memref.alloc result with reserve result rewriter.replaceAllUsesWith(allocOp.getResult(), reserveOp.getResult()); @@ -404,8 +404,8 @@ static PushPopInfo convertToExplicitCBForm(ModuleOp moduleOp, // Create the explicit CB form of remote_store (no local buffer, has CB) // d2m.remote_store %memref[indices] from %cb - rewriter.create(loc, memref, remoteStore.getIndices(), - assocCb); + RemoteStoreOp::create(rewriter, loc, memref, remoteStore.getIndices(), + assocCb); // Track the reserve op for push insertion (avoid duplicates). if (reserveOp && cbsWithReserveOps.insert(assocCb).second) { @@ -415,7 +415,7 @@ static PushPopInfo convertToExplicitCBForm(ModuleOp moduleOp, // Create right after the get_cb so it dominates all uses. OpBuilder::InsertionGuard reserveGuard(rewriter); rewriter.setInsertionPointAfterValue(assocCb); - auto newReserve = rewriter.create(loc, assocCb); + auto newReserve = ReserveOp::create(rewriter, loc, assocCb); info.reserveOpsNeedingPush.push_back({newReserve, assocCb}); // Replace uses of the old localBuffer with the reserve result // inside the generic only. @@ -443,7 +443,7 @@ static PushPopInfo convertToExplicitCBForm(ModuleOp moduleOp, // Create reserve op before the store (avoid duplicates) if (cbsWithReserveOps.insert(assocCb).second) { rewriter.setInsertionPoint(remoteStore); - auto reserveOp = rewriter.create(loc, assocCb); + auto reserveOp = ReserveOp::create(rewriter, loc, assocCb); info.reserveOpsNeedingPush.push_back({reserveOp, assocCb}); } } @@ -480,7 +480,7 @@ static void insertPushAndPopOps(ModuleOp moduleOp, IRRewriter &rewriter, } else { rewriter.setInsertionPointToEnd(waitBlock); } - rewriter.create(loc, assocCb); + PopOp::create(rewriter, loc, assocCb); } // Insert push ops for each reserve op @@ -505,7 +505,7 @@ static void insertPushAndPopOps(ModuleOp moduleOp, IRRewriter &rewriter, } else { rewriter.setInsertionPointToEnd(topLevelBlock); } - rewriter.create(loc, assocCb); + PushOp::create(rewriter, loc, assocCb); } else { reserveOp.emitWarning( "could not find top-level region block for push insertion"); diff --git a/lib/Dialect/D2M/Transforms/LowerMulticastLoads.cpp b/lib/Dialect/D2M/Transforms/LowerMulticastLoads.cpp index 43649b07736..866ac9df92c 100644 --- a/lib/Dialect/D2M/Transforms/LowerMulticastLoads.cpp +++ b/lib/Dialect/D2M/Transforms/LowerMulticastLoads.cpp @@ -156,8 +156,8 @@ class LowerMulticastLoadsRewriter : public OpRewritePattern { } // Build low-level multicast arguments. - Value zero = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); SmallVector mcastStartIndex; SmallVector mcastShapeInt64; @@ -169,8 +169,8 @@ class LowerMulticastLoadsRewriter : public OpRewritePattern { auto dimPos = *maybeDimPos; if (mcastDimSet.contains(dimPos)) { // for parallel dim specified by multicast, extent is 0 - Value coreIdx = rewriter.create( - loc, static_cast(dim), grid.getMapping()); + Value coreIdx = CoreIndexOp::create( + rewriter, loc, static_cast(dim), grid.getMapping()); mcastStartIndex.push_back(coreIdx); mcastShapeInt64.push_back(1); } else { @@ -218,8 +218,9 @@ class LowerMulticastLoadsRewriter : public OpRewritePattern { SmallVector mcastShape; mcastShape.reserve(mcastShapeInt64.size()); for (int64_t dimSize : mcastShapeInt64) { - mcastShape.push_back(rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(dimSize))); + mcastShape.push_back( + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexType(), + rewriter.getIndexAttr(dimSize))); } // Create replacement RemoteLoadOp with low-level multicast form. diff --git a/lib/Dialect/D2M/Transforms/LowerScratchAllocate.cpp b/lib/Dialect/D2M/Transforms/LowerScratchAllocate.cpp index 51dd430a5a8..52f7d362412 100644 --- a/lib/Dialect/D2M/Transforms/LowerScratchAllocate.cpp +++ b/lib/Dialect/D2M/Transforms/LowerScratchAllocate.cpp @@ -111,7 +111,7 @@ class D2MLowerScratchAllocatePass if (mlir::isa(scratchValue.getType())) { // CB form: unwrap via get_scratch_from_cb. auto scratchFromCBOp = - builder.create(genericOp.getLoc(), scratchValue); + GetScratchFromCBOp::create(builder, genericOp.getLoc(), scratchValue); scratchMemRef = scratchFromCBOp.getResult(); } else { // New form: memref.alloc is already the scratch buffer. @@ -166,9 +166,9 @@ class D2MLowerScratchAllocatePass SmallVector strides = {builder.getIndexAttr(1), builder.getIndexAttr(1)}; - auto subviewOp = builder.create( - loc, mlir::cast(inferredType), scratchMemRef, offsets, - sizes, strides); + auto subviewOp = memref::SubViewOp::create( + builder, loc, mlir::cast(inferredType), scratchMemRef, + offsets, sizes, strides); allocOp.getResult().replaceAllUsesWith(subviewOp.getResult()); allocOp.erase(); diff --git a/lib/Dialect/D2M/Transforms/LowerToExplicitForm.cpp b/lib/Dialect/D2M/Transforms/LowerToExplicitForm.cpp index 75287023b45..0e12d8fff8b 100644 --- a/lib/Dialect/D2M/Transforms/LowerToExplicitForm.cpp +++ b/lib/Dialect/D2M/Transforms/LowerToExplicitForm.cpp @@ -133,9 +133,9 @@ static void lowerBlockOffsetOps(IRRewriter &rewriter, GenericOp generic, Value blockFactorConstant = arith::ConstantIndexOp::create( rewriter, op.getLoc(), blockFactors[dim]); Value coreIndex = - rewriter.create(op.getLoc(), gridDim, gridMapping); - Value blockOffset = rewriter.create( - op.getLoc(), blockFactorConstant, coreIndex); + CoreIndexOp::create(rewriter, op.getLoc(), gridDim, gridMapping); + Value blockOffset = arith::MulIOp::create(rewriter, op.getLoc(), + blockFactorConstant, coreIndex); rewriter.replaceOp(op, blockOffset); } } diff --git a/lib/Dialect/D2M/Transforms/LowerToLayout.cpp b/lib/Dialect/D2M/Transforms/LowerToLayout.cpp index f51f29987ce..cab09b1e3a2 100644 --- a/lib/Dialect/D2M/Transforms/LowerToLayout.cpp +++ b/lib/Dialect/D2M/Transforms/LowerToLayout.cpp @@ -127,10 +127,10 @@ static Value createRemoteLoad(OpBuilder &builder, Location loc, Type shardType, Value source, ArrayRef indices) { // Create a buffer for the load result auto tensorType = mlir::cast(shardType); - auto bufferOp = builder.create(loc, tensorType.getShape(), - tensorType.getElementType()); + auto bufferOp = tensor::EmptyOp::create(builder, loc, tensorType.getShape(), + tensorType.getElementType()); Value buffer = bufferOp.getResult(); - return builder.create(loc, shardType, buffer, source, indices) + return RemoteLoadOp::create(builder, loc, shardType, buffer, source, indices) .getResult(); } @@ -138,9 +138,8 @@ static Value createRemoteLoad(OpBuilder &builder, Location loc, Type shardType, static Value createTensorEmpty(OpBuilder &builder, Location loc, Type shardType) { auto tensorType = mlir::cast(shardType); - return builder - .create(loc, tensorType.getShape(), - tensorType.getElementType()) + return tensor::EmptyOp::create(builder, loc, tensorType.getShape(), + tensorType.getElementType()) .getResult(); } @@ -148,9 +147,8 @@ static Value createTensorEmpty(OpBuilder &builder, Location loc, static Value createRemoteStore(OpBuilder &builder, Location loc, Value destination, ArrayRef indices, Value localBuffer) { - return builder - .create(loc, destination.getType(), destination, indices, - localBuffer) + return RemoteStoreOp::create(builder, loc, destination.getType(), destination, + indices, localBuffer) .getResult(); } @@ -449,8 +447,8 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { outputInfo.type.getElementType(), newLayout); // Pass the transformation map via the remapping attribute. - Value viewOp = rewriter.create(loc, viewType, input, viewMap, - /*reinterpretLayout=*/false); + Value viewOp = ViewLayoutOp::create(rewriter, loc, viewType, input, viewMap, + /*reinterpretLayout=*/false); // Materialize L1→L1 transformations with a DMA generic that performs the // actual data movement according to the view's affine map. @@ -467,24 +465,26 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { auto indexingMapAttr = mlir::cast(indexingMaps[0]); AffineMap indexingMap = indexingMapAttr.getValue(); - return rewriter - .create( - loc, viewOp, output, /*additionalArgs=*/ValueRange(), - [&](OpBuilder &builder, Location innerLoc, ValueRange blockArgs) { - // Load from input, store to output (load+store pair for proper - // CB association) - Type inputShardType = getShardTypeFromCB(blockArgs[0]); - SmallVector indices = d2m::utils::buildGridIndices( - builder, innerLoc, indexingMap); - - // Load-store idiom - Value loadedData = createRemoteLoad( - builder, innerLoc, inputShardType, viewOp, indices); - Value storeResult = createRemoteStore(builder, innerLoc, output, - indices, loadedData); - builder.create(innerLoc, storeResult); - }, - ThreadType::Unified) + return GenericOp::create( + rewriter, + + loc, viewOp, output, /*additionalArgs=*/ValueRange(), + [&](OpBuilder &builder, Location innerLoc, + ValueRange blockArgs) { + // Load from input, store to output (load+store pair for + // proper CB association) + Type inputShardType = getShardTypeFromCB(blockArgs[0]); + SmallVector indices = d2m::utils::buildGridIndices( + builder, innerLoc, indexingMap); + + // Load-store idiom + Value loadedData = createRemoteLoad( + builder, innerLoc, inputShardType, viewOp, indices); + Value storeResult = createRemoteStore( + builder, innerLoc, output, indices, loadedData); + YieldOp::create(builder, innerLoc, storeResult); + }, + ThreadType::Unified) .getResult(0); } // DRAM operations use the view directly without immediate @@ -513,11 +513,11 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { // Emit dedicated host transfer ops based on direction. if (inputInfo.isSystem()) { // Host → Device: use ToDeviceOp. - return rewriter.create(loc, input, output, *deviceLayout) + return ToDeviceOp::create(rewriter, loc, input, output, *deviceLayout) .getResult(0); } // Device → Host: use ToHostOp. - return rewriter.create(loc, input, output, *deviceLayout) + return ToHostOp::create(rewriter, loc, input, output, *deviceLayout) .getResult(0); } @@ -574,9 +574,8 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { baseLayout.getMemorySpace(), baseLayout.getMemoryLayout()); auto resultTy = RankedTensorType::get(toTy.getShape(), toTy.getElementType(), enc); - return rewriter - .create(loc, resultTy, fromVal, map, - /*reinterpretLayout=*/false) + return ViewLayoutOp::create(rewriter, loc, resultTy, fromVal, map, + /*reinterpretLayout=*/false) .getResult(); }; @@ -604,23 +603,22 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { AffineMap indexingMap = indexingMapAttr.getValue(); auto result = - rewriter - .create( - loc, viewInput, viewOutput, /*additionalArgs=*/ValueRange(), - [&](OpBuilder &builder, Location innerLoc, - ValueRange blockArgs) { - Type inputShardType = getShardTypeFromCB(blockArgs[0]); - SmallVector indices = d2m::utils::buildGridIndices( - builder, innerLoc, indexingMap); - - // Use load+store idiom for proper CB association - Value loadedData = createRemoteLoad( - builder, innerLoc, inputShardType, viewInput, indices); - Value storeResult = createRemoteStore( - builder, innerLoc, viewOutput, indices, loadedData); - builder.create(innerLoc, storeResult); - }, - ThreadType::Unified, grid) + GenericOp::create( + rewriter, loc, viewInput, viewOutput, + /*additionalArgs=*/ValueRange(), + [&](OpBuilder &builder, Location innerLoc, ValueRange blockArgs) { + Type inputShardType = getShardTypeFromCB(blockArgs[0]); + SmallVector indices = + d2m::utils::buildGridIndices(builder, innerLoc, indexingMap); + + // Use load+store idiom for proper CB association + Value loadedData = createRemoteLoad( + builder, innerLoc, inputShardType, viewInput, indices); + Value storeResult = createRemoteStore( + builder, innerLoc, viewOutput, indices, loadedData); + YieldOp::create(builder, innerLoc, storeResult); + }, + ThreadType::Unified, grid) .getResult(0); return result; } @@ -634,32 +632,34 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { assert(inputTiled != outputTiled && "one of input or output must be tiled for now"); - return rewriter - .create( - loc, input, output, /*additionalArgs=*/ValueRange(), - [=](OpBuilder &builder, Location innerLoc, ValueRange blockArgs) { - auto [src, dst, indices] = - buildIdentityLoadStore(builder, innerLoc, blockArgs[0], - blockArgs[1], input, output, 1); - - Value result; - if (inputTiled) { - result = builder - .create( - innerLoc, dst.getType(), src, dst) - .getResult(); - } else { - result = builder - .create(innerLoc, dst.getType(), - src, dst) - .getResult(); - } + return GenericOp::create( + rewriter, - Value storeResult = - createRemoteStore(builder, innerLoc, output, indices, result); - builder.create(innerLoc, storeResult); - }, - ThreadType::Unified) + loc, input, output, /*additionalArgs=*/ValueRange(), + [=](OpBuilder &builder, Location innerLoc, + ValueRange blockArgs) { + auto [src, dst, indices] = + buildIdentityLoadStore(builder, innerLoc, blockArgs[0], + blockArgs[1], input, output, 1); + + Value result; + if (inputTiled) { + result = TileUntilizeBlockOp::create(builder, + + innerLoc, dst.getType(), + src, dst) + .getResult(); + } else { + result = TileTilizeBlockOp::create(builder, innerLoc, + dst.getType(), src, dst) + .getResult(); + } + + Value storeResult = createRemoteStore(builder, innerLoc, + output, indices, result); + YieldOp::create(builder, innerLoc, storeResult); + }, + ThreadType::Unified) .getResult(0); } @@ -720,10 +720,10 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { auto maskShape = maskLayout.getDeviceShape(unitGrid, tileShape); Value rowMaskTensor = - rewriter.create(loc, maskShape, elemType, maskLayout) + d2m::EmptyOp::create(rewriter, loc, maskShape, elemType, maskLayout) .getResult(); Value colMaskTensor = - rewriter.create(loc, maskShape, elemType, maskLayout) + d2m::EmptyOp::create(rewriter, loc, maskShape, elemType, maskLayout) .getResult(); // Input list includes scratch mask CBs. @@ -749,8 +749,8 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { ArrayAttr iteratorTypesAttr = rewriter.getArrayAttr(SmallVector(shardRank, parallel)); - auto genericOp = rewriter.create( - loc, ValueRange(allInputs), ValueRange(allOutputs), + auto genericOp = GenericOp::create( + rewriter, loc, ValueRange(allInputs), ValueRange(allOutputs), /*additionalArgs=*/ValueRange(), indexingMapsAttr, iteratorTypesAttr, [&](OpBuilder &builder, Location innerLoc, ValueRange blockArgs) { // blockArgs: [inputCB, rowMaskCB, colMaskCB, outputCB]. @@ -771,7 +771,7 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { Type rowMaskType = getShardTypeFromCB(blockArgs[1]); Type colMaskType = getShardTypeFromCB(blockArgs[2]); SmallVector zeroIndices( - gridRank, builder.create(innerLoc, 0)); + gridRank, arith::ConstantIndexOp::create(builder, innerLoc, 0)); Value rowMaskLocal = createRemoteLoad(builder, innerLoc, rowMaskType, rowMaskTensor, zeroIndices); Value colMaskLocal = createRemoteLoad(builder, innerLoc, colMaskType, @@ -781,23 +781,22 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { Value dst = createTensorEmpty(builder, innerLoc, outputShardType); Value logicalRowsVal = - builder.create(innerLoc, logicalRows); + arith::ConstantIndexOp::create(builder, innerLoc, logicalRows); Value logicalColsVal = - builder.create(innerLoc, logicalCols); + arith::ConstantIndexOp::create(builder, innerLoc, logicalCols); // BlockMaskOp with mask tensors - the mask writes will be handled // in DecomposeMasking, which runs after bufferization. - Value masked = builder - .create(innerLoc, dst.getType(), src, - dst, rowMaskLocal, - colMaskLocal, logicalRowsVal, - logicalColsVal, fillValue) - .getResult(); + Value masked = + BlockMaskOp::create(builder, innerLoc, dst.getType(), src, dst, + rowMaskLocal, colMaskLocal, logicalRowsVal, + logicalColsVal, fillValue) + .getResult(); // Store the masked result to output. Value storeResult = createRemoteStore(builder, innerLoc, output, indices, masked); - builder.create(innerLoc, storeResult); + YieldOp::create(builder, innerLoc, storeResult); }, ThreadType::Unified); @@ -811,9 +810,9 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { Value input, RankedTensorType desiredType) const { auto layout = mlir::cast(desiredType.getEncoding()); - auto output = rewriter.create( - loc, desiredType.getShape(), desiredType.getElementType(), layout); - return rewriter.create(loc, input, output); + auto output = d2m::EmptyOp::create(rewriter, loc, desiredType.getShape(), + desiredType.getElementType(), layout); + return d2m::ToLayoutOp::create(rewriter, loc, input, output); } Value bounce(PatternRewriter &rewriter, ToLayoutOp op, @@ -852,9 +851,9 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { } auto layout = mlir::dyn_cast(type.getEncoding()); - auto emptyOp = rewriter.create(op.getLoc(), type.getShape(), - type.getElementType(), - layout, targetGridShape); + auto emptyOp = + d2m::EmptyOp::create(rewriter, op.getLoc(), type.getShape(), + type.getElementType(), layout, targetGridShape); return emptyOp.getResult(); }; @@ -956,8 +955,8 @@ class D2MLowerToLayoutRewriter : public OpRewritePattern { // buffers via createEmpty(). auto layout = mlir::dyn_cast( currentInfo.type.getEncoding()); - auto maskedEmptyOp = rewriter.create( - op.getLoc(), currentInfo.type.getShape(), + auto maskedEmptyOp = d2m::EmptyOp::create( + rewriter, op.getLoc(), currentInfo.type.getShape(), currentInfo.type.getElementType(), layout, targetGridShape); auto maskedEmpty = maskedEmptyOp.getResult(); currentValue = diff --git a/lib/Dialect/D2M/Transforms/MaterializeViewReturns.cpp b/lib/Dialect/D2M/Transforms/MaterializeViewReturns.cpp index 63c5ea94b47..b71eade1b82 100644 --- a/lib/Dialect/D2M/Transforms/MaterializeViewReturns.cpp +++ b/lib/Dialect/D2M/Transforms/MaterializeViewReturns.cpp @@ -59,8 +59,8 @@ Value materializeView(OpBuilder &builder, Location loc, Value viewResult) { builder.getContext(), layout.getLogicalShape(), layout.getDimAlignments(), layout.getCollapsedIntervals(), layout.getOobVal(), layout.getMemorySpace(), layout.getMemoryLayout()); - auto emptyOp = builder.create( - loc, tensorType.getShape(), tensorType.getElementType(), newLayout); + auto emptyOp = d2m::EmptyOp::create(builder, loc, tensorType.getShape(), + tensorType.getElementType(), newLayout); // Extract the grid from the tensor's layout to determine core distribution. ttcore::GridAttr grid = getGridFromType(tensorType); @@ -76,8 +76,9 @@ Value materializeView(OpBuilder &builder, Location loc, Value viewResult) { // Create a datamovement generic op that materializes the view. auto indexingMapAttr = mlir::cast(indexingMaps[0]); AffineMap indexingMap = indexingMapAttr.getValue(); - auto genericOp = builder.create( - loc, viewResult, emptyOp.getResult(), /*additionalArgs=*/ValueRange(), + auto genericOp = GenericOp::create( + builder, loc, viewResult, emptyOp.getResult(), + /*additionalArgs=*/ValueRange(), [&](OpBuilder &builder, Location innerLoc, ValueRange blockArgs) { SmallVector indices = utils::buildGridIndices(builder, innerLoc, indexingMap); @@ -87,16 +88,14 @@ Value materializeView(OpBuilder &builder, Location loc, Value viewResult) { Value inputBuffer = blockArgs[0]; Value loadedData = - builder - .create(innerLoc, inputShardType, inputBuffer, - viewResult, indices) + RemoteLoadOp::create(builder, innerLoc, inputShardType, inputBuffer, + viewResult, indices) .getResult(); Value storeResult = - builder - .create(innerLoc, emptyOp.getType(), - emptyOp.getResult(), indices, loadedData) + RemoteStoreOp::create(builder, innerLoc, emptyOp.getType(), + emptyOp.getResult(), indices, loadedData) .getResult(); - builder.create(innerLoc, storeResult); + d2m::YieldOp::create(builder, innerLoc, storeResult); }, ThreadType::Unified, grid, SmallVector{1, 1}); diff --git a/lib/Dialect/D2M/Transforms/ScalarizeConstTensors.cpp b/lib/Dialect/D2M/Transforms/ScalarizeConstTensors.cpp index 1813a409a11..5e34ffe73bc 100644 --- a/lib/Dialect/D2M/Transforms/ScalarizeConstTensors.cpp +++ b/lib/Dialect/D2M/Transforms/ScalarizeConstTensors.cpp @@ -240,8 +240,8 @@ static linalg::GenericOp rebuildLinalgGenericWithoutScalarizedInputs( } rewriter.setInsertionPoint(linalgOp); - auto newLinalgOp = rewriter.create( - linalgOp.getLoc(), linalgOp.getResultTypes(), newInputs, + auto newLinalgOp = linalg::GenericOp::create( + rewriter, linalgOp.getLoc(), linalgOp.getResultTypes(), newInputs, linalgOp.getOutputs(), newIndexingMaps, linalgOp.getIteratorTypesArray()); Block *linalgBlock = linalgOp.getBody(); @@ -304,9 +304,9 @@ static GenericOp rebuildD2MGenericWithoutScalarizedInputs( } rewriter.setInsertionPoint(genericOp); - auto newGenericOp = rewriter.create( - genericOp.getLoc(), genericOp.getResultTypes(), newGenericInputs, - genericOp.getOutputs(), genericOp.getAdditionalArgs(), + auto newGenericOp = GenericOp::create( + rewriter, genericOp.getLoc(), genericOp.getResultTypes(), + newGenericInputs, genericOp.getOutputs(), genericOp.getAdditionalArgs(), genericOp.getGrid(), genericOp.getBlockFactors(), rewriter.getArrayAttr(newIndexingMaps), genericOp.getIteratorTypes(), genericOp.getThreads(), genericOp.getScratchInputsAttr(), @@ -396,10 +396,10 @@ class ScalarizeFullOpPattern : public OpRewritePattern { Value scalarConst; if (auto floatAttr = dyn_cast(splatValue)) { scalarConst = - rewriter.create(fullOp.getLoc(), floatAttr); + arith::ConstantOp::create(rewriter, fullOp.getLoc(), floatAttr); } else if (auto intAttr = dyn_cast(splatValue)) { scalarConst = - rewriter.create(fullOp.getLoc(), intAttr); + arith::ConstantOp::create(rewriter, fullOp.getLoc(), intAttr); } if (!scalarConst) { diff --git a/lib/Dialect/D2M/Transforms/ScheduleDMA.cpp b/lib/Dialect/D2M/Transforms/ScheduleDMA.cpp index 416ff89afbc..32f07add9a4 100644 --- a/lib/Dialect/D2M/Transforms/ScheduleDMA.cpp +++ b/lib/Dialect/D2M/Transforms/ScheduleDMA.cpp @@ -288,10 +288,10 @@ class D2MScheduleDMARewriter : public OpRewritePattern { threads.push_back(rewriter.getAttr(ThreadType::Compute)); // Create new generic op with N+1 regions. - auto newGeneric = rewriter.create( - generic.getLoc(), generic.getResultTypes(), generic.getInputs(), - generic.getOutputs(), generic.getAdditionalArgs(), generic.getGrid(), - generic.getBlockFactors(), generic.getIndexingMaps(), + auto newGeneric = GenericOp::create( + rewriter, generic.getLoc(), generic.getResultTypes(), + generic.getInputs(), generic.getOutputs(), generic.getAdditionalArgs(), + generic.getGrid(), generic.getBlockFactors(), generic.getIndexingMaps(), generic.getIteratorTypes(), rewriter.getArrayAttr(threads), generic.getScratchInputsAttr(), /*numRegions*/ numThreadsToUse + 1); diff --git a/lib/Dialect/D2M/Transforms/SplitUnifiedThread.cpp b/lib/Dialect/D2M/Transforms/SplitUnifiedThread.cpp index 14e55526992..d04d43facc7 100644 --- a/lib/Dialect/D2M/Transforms/SplitUnifiedThread.cpp +++ b/lib/Dialect/D2M/Transforms/SplitUnifiedThread.cpp @@ -39,10 +39,10 @@ class D2MSplitUnifiedThreadRewriter : public OpRewritePattern { threads.push_back(rewriter.getAttr(ThreadType::Datamovement)); threads.push_back(rewriter.getAttr(ThreadType::Compute)); - auto newGeneric = rewriter.create( - generic->getLoc(), generic.getResultTypes(), generic.getInputs(), - generic.getOutputs(), generic.getAdditionalArgs(), generic.getGrid(), - generic.getBlockFactors(), generic.getIndexingMaps(), + auto newGeneric = GenericOp::create( + rewriter, generic->getLoc(), generic.getResultTypes(), + generic.getInputs(), generic.getOutputs(), generic.getAdditionalArgs(), + generic.getGrid(), generic.getBlockFactors(), generic.getIndexingMaps(), generic.getIteratorTypes(), rewriter.getArrayAttr(threads), generic.getScratchInputsAttr(), /*numRegions*/ 2); diff --git a/lib/Dialect/D2M/Utils/CBUtils.cpp b/lib/Dialect/D2M/Utils/CBUtils.cpp index 1e115abfb8f..b1df10e4b91 100644 --- a/lib/Dialect/D2M/Utils/CBUtils.cpp +++ b/lib/Dialect/D2M/Utils/CBUtils.cpp @@ -93,8 +93,8 @@ Value getOrCreateCB(GenericOp generic, Region ®ion, unsigned operandIndex, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(®ion.front()); - auto getCBOp = rewriter.create( - generic.getLoc(), cbType, port, + auto getCBOp = GetCBOp::create( + rewriter, generic.getLoc(), cbType, port, rewriter.getI64IntegerAttr(static_cast(operandIndex))); Value result = getCBOp.getResult(); diff --git a/lib/Dialect/D2M/Utils/Utils.cpp b/lib/Dialect/D2M/Utils/Utils.cpp index 2f138e3113c..008329204e9 100644 --- a/lib/Dialect/D2M/Utils/Utils.cpp +++ b/lib/Dialect/D2M/Utils/Utils.cpp @@ -139,7 +139,7 @@ SmallVector buildGridIndices(OpBuilder &builder, Location loc, SmallVector dimValues; for (unsigned i = 0; i < indexingMap.getNumDims(); ++i) { dimValues.push_back( - builder.create(loc, static_cast(i))); + BlockIndexOp::create(builder, loc, static_cast(i))); } // For each result expression, use expandAffineExpr to translate to arith ops diff --git a/lib/Dialect/LLVM/Transforms/EmitWrapperFuncs.cpp b/lib/Dialect/LLVM/Transforms/EmitWrapperFuncs.cpp index e9b8925abf1..73a1d8cf7ec 100644 --- a/lib/Dialect/LLVM/Transforms/EmitWrapperFuncs.cpp +++ b/lib/Dialect/LLVM/Transforms/EmitWrapperFuncs.cpp @@ -70,8 +70,8 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { auto helperFuncType = LLVM::LLVMFunctionType::get( helperReturnType, {LLVM::LLVMPointerType::get(context)}, false); - auto helperFunc = builder.create( - func.getLoc(), helperName, helperFuncType); + auto helperFunc = LLVM::LLVMFuncOp::create(builder, func.getLoc(), + helperName, helperFuncType); Block *entryBlock = helperFunc.addEntryBlock(builder); builder.setInsertionPointToStart(entryBlock); @@ -93,59 +93,61 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { for (auto rankAttr : argRanksAttr) { // Compute the offset for the current tensor (as index * size of // wrapped_tensor). - Value tensorIndex = builder.create( - func.getLoc(), builder.getI64Type(), - builder.getI64IntegerAttr(tensorIdx++)); + Value tensorIndex = + LLVM::ConstantOp::create(builder, func.getLoc(), builder.getI64Type(), + builder.getI64IntegerAttr(tensorIdx++)); // Calculate the ptr-width offset for the tensor; 3 pointers and one i64 // = 4. constexpr auto wrappedTensorSize = 4; - Value offset = builder.create( - func.getLoc(), tensorIndex, - builder.create( - func.getLoc(), builder.getI64Type(), + Value offset = LLVM::MulOp::create( + builder, func.getLoc(), tensorIndex, + LLVM::ConstantOp::create( + builder, func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(wrappedTensorSize))); // Get pointer to the struct for this offset-th tensor in input array. - Value structPtr = builder.create( - func.getLoc(), ptrTy, ptrTy, structArrayPtr, ValueRange(offset), - LLVM::GEPNoWrapFlags::inbounds); + Value structPtr = LLVM::GEPOp::create( + builder, func.getLoc(), ptrTy, ptrTy, structArrayPtr, + ValueRange(offset), LLVM::GEPNoWrapFlags::inbounds); // Load actual tensor object from pointer so we can extract its members. - Value tensorStruct = builder.create( - func.getLoc(), wrappedTensorTy, structPtr); + Value tensorStruct = LLVM::LoadOp::create(builder, func.getLoc(), + wrappedTensorTy, structPtr); - Value tensorBase = builder.create( - func.getLoc(), ptrTy, tensorStruct, + Value tensorBase = LLVM::ExtractValueOp::create( + builder, func.getLoc(), ptrTy, tensorStruct, builder.getDenseI64ArrayAttr({0})); originalCallArgs.push_back(tensorBase); - Value alignedBase = builder.create( - func.getLoc(), LLVM::LLVMPointerType::get(context), tensorStruct, - builder.getDenseI64ArrayAttr({1})); + Value alignedBase = LLVM::ExtractValueOp::create( + builder, func.getLoc(), LLVM::LLVMPointerType::get(context), + tensorStruct, builder.getDenseI64ArrayAttr({1})); originalCallArgs.push_back(alignedBase); - Value startIdx = builder.create( - func.getLoc(), builder.getI64Type(), tensorStruct, + Value startIdx = LLVM::ExtractValueOp::create( + builder, func.getLoc(), builder.getI64Type(), tensorStruct, builder.getDenseI64ArrayAttr({2})); originalCallArgs.push_back(startIdx); - Value sizesAndStrides = builder.create( - func.getLoc(), LLVM::LLVMPointerType::get(context), tensorStruct, - builder.getDenseI64ArrayAttr({3})); + Value sizesAndStrides = LLVM::ExtractValueOp::create( + builder, func.getLoc(), LLVM::LLVMPointerType::get(context), + tensorStruct, builder.getDenseI64ArrayAttr({3})); // The sizesAndStrides field is an array itself, so we need to step into // it and extract elements. int64_t rank = mlir::cast(rankAttr).getInt(); for (int i = 0; i < 2 * rank; i++) { - Value idx = builder.create( - func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(i)); + Value idx = LLVM::ConstantOp::create(builder, func.getLoc(), + builder.getI64Type(), + builder.getI64IntegerAttr(i)); - Value elementPtr = builder.create( - func.getLoc(), ptrTy, ptrTy, sizesAndStrides, ValueRange{idx}); + Value elementPtr = + LLVM::GEPOp::create(builder, func.getLoc(), ptrTy, ptrTy, + sizesAndStrides, ValueRange{idx}); - Value strideOrSize = builder.create( - func.getLoc(), builder.getI64Type(), elementPtr); + Value strideOrSize = LLVM::LoadOp::create( + builder, func.getLoc(), builder.getI64Type(), elementPtr); originalCallArgs.push_back(strideOrSize); } @@ -157,15 +159,14 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { // We need to wrap these back into WrappedTensor structs for the caller. if (!hasOutputs) { - builder.create(func.getLoc(), TypeRange(), func.getName(), - originalCallArgs); - builder.create(func.getLoc(), ValueRange()); + LLVM::CallOp::create(builder, func.getLoc(), TypeRange(), func.getName(), + originalCallArgs); + LLVM::ReturnOp::create(builder, func.getLoc(), ValueRange()); } else { // Call original function and pack results into WrappedTensors. auto returnType = func.getFunctionType().getReturnType(); - Value result = builder - .create(func.getLoc(), returnType, - func.getName(), originalCallArgs) + Value result = LLVM::CallOp::create(builder, func.getLoc(), returnType, + func.getName(), originalCallArgs) .getResult(); auto mallocFunc = moduleOp.lookupSymbol("malloc"); @@ -173,8 +174,8 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { auto i64Ty = builder.getI64Type(); auto makeConst = [&](int64_t val) { - return builder.create(loc, i64Ty, - builder.getI64IntegerAttr(val)); + return LLVM::ConstantOp::create(builder, loc, i64Ty, + builder.getI64IntegerAttr(val)); }; // Allocate output array and sizesAndStrides buffer. @@ -186,16 +187,17 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { } Value outputArrayPtr = - builder - .create( - loc, ptrTy, mallocFunc.getName(), - ValueRange{makeConst(numOutputs * kWrappedTensorBytes)}) + LLVM::CallOp::create( + builder, + + loc, ptrTy, mallocFunc.getName(), + ValueRange{makeConst(numOutputs * kWrappedTensorBytes)}) .getResult(); Value sizesStridesBase = - builder - .create( - loc, ptrTy, mallocFunc.getName(), - ValueRange{makeConst(totalSizesStridesBytes)}) + LLVM::CallOp::create(builder, + + loc, ptrTy, mallocFunc.getName(), + ValueRange{makeConst(totalSizesStridesBytes)}) .getResult(); int64_t sizesStridesOffset = 0; @@ -203,57 +205,60 @@ void generateLLVMWrappersForArgRanks(ModuleOp moduleOp) { int64_t rank = resultRanks[outIdx]; // For single output, result is the descriptor; otherwise extract it. - Value desc = - (numOutputs == 1) - ? result - : builder.create( - loc, result, builder.getDenseI64ArrayAttr({outIdx})); + Value desc = (numOutputs == 1) + ? result + : LLVM::ExtractValueOp::create( + builder, loc, result, + builder.getDenseI64ArrayAttr({outIdx})); // Extract memref descriptor fields. - Value basePtr = builder.create( - loc, ptrTy, desc, builder.getDenseI64ArrayAttr({0})); - Value alignedPtr = builder.create( - loc, ptrTy, desc, builder.getDenseI64ArrayAttr({1})); - Value offset = builder.create( - loc, i64Ty, desc, builder.getDenseI64ArrayAttr({2})); + Value basePtr = LLVM::ExtractValueOp::create( + builder, loc, ptrTy, desc, builder.getDenseI64ArrayAttr({0})); + Value alignedPtr = LLVM::ExtractValueOp::create( + builder, loc, ptrTy, desc, builder.getDenseI64ArrayAttr({1})); + Value offset = LLVM::ExtractValueOp::create( + builder, loc, i64Ty, desc, builder.getDenseI64ArrayAttr({2})); // Get pointer to this output's sizesAndStrides array. - Value sizesStridesPtr = builder.create( - loc, ptrTy, builder.getI8Type(), sizesStridesBase, + Value sizesStridesPtr = LLVM::GEPOp::create( + builder, loc, ptrTy, builder.getI8Type(), sizesStridesBase, ValueRange{makeConst(sizesStridesOffset)}); // Copy sizes and strides from descriptor to sizesAndStrides array. for (int64_t i = 0; i < 2 * rank; ++i) { int64_t structIdx = (i < rank) ? 3 : 4; int64_t arrayIdx = (i < rank) ? i : i - rank; - Value val = builder.create( - loc, i64Ty, desc, + Value val = LLVM::ExtractValueOp::create( + builder, loc, i64Ty, desc, builder.getDenseI64ArrayAttr({structIdx, arrayIdx})); - Value destPtr = builder.create( - loc, ptrTy, i64Ty, sizesStridesPtr, ValueRange{makeConst(i)}); - builder.create(loc, val, destPtr); + Value destPtr = + LLVM::GEPOp::create(builder, loc, ptrTy, i64Ty, sizesStridesPtr, + ValueRange{makeConst(i)}); + LLVM::StoreOp::create(builder, loc, val, destPtr); } // Build and store WrappedTensor struct. - Value wrapped = builder.create(loc, wrappedTensorTy); - wrapped = builder.create( - loc, wrapped, basePtr, builder.getDenseI64ArrayAttr({0})); - wrapped = builder.create( - loc, wrapped, alignedPtr, builder.getDenseI64ArrayAttr({1})); - wrapped = builder.create( - loc, wrapped, offset, builder.getDenseI64ArrayAttr({2})); - wrapped = builder.create( - loc, wrapped, sizesStridesPtr, builder.getDenseI64ArrayAttr({3})); - - Value outPtr = builder.create( - loc, ptrTy, wrappedTensorTy, outputArrayPtr, - ValueRange{makeConst(outIdx)}); - builder.create(loc, wrapped, outPtr); + Value wrapped = LLVM::UndefOp::create(builder, loc, wrappedTensorTy); + wrapped = LLVM::InsertValueOp::create( + builder, loc, wrapped, basePtr, builder.getDenseI64ArrayAttr({0})); + wrapped = + LLVM::InsertValueOp::create(builder, loc, wrapped, alignedPtr, + builder.getDenseI64ArrayAttr({1})); + wrapped = LLVM::InsertValueOp::create( + builder, loc, wrapped, offset, builder.getDenseI64ArrayAttr({2})); + wrapped = + LLVM::InsertValueOp::create(builder, loc, wrapped, sizesStridesPtr, + builder.getDenseI64ArrayAttr({3})); + + Value outPtr = + LLVM::GEPOp::create(builder, loc, ptrTy, wrappedTensorTy, + outputArrayPtr, ValueRange{makeConst(outIdx)}); + LLVM::StoreOp::create(builder, loc, wrapped, outPtr); sizesStridesOffset += 2 * rank * 8; } - builder.create(loc, ValueRange{outputArrayPtr}); + LLVM::ReturnOp::create(builder, loc, ValueRange{outputArrayPtr}); } } diff --git a/lib/Dialect/StableHLO/Transforms/ComplexDataTypeConversion.cpp b/lib/Dialect/StableHLO/Transforms/ComplexDataTypeConversion.cpp index 36be0c9d341..013378b13dd 100644 --- a/lib/Dialect/StableHLO/Transforms/ComplexDataTypeConversion.cpp +++ b/lib/Dialect/StableHLO/Transforms/ComplexDataTypeConversion.cpp @@ -59,9 +59,8 @@ static Value transposeTrailingToLeading(Location loc, Value input, } auto newType = RankedTensorType::get(newShape, type.getElementType()); - return builder - .create(loc, newType, input, - builder.getDenseI64ArrayAttr(perm)) + return mlir::stablehlo::TransposeOp::create( + builder, loc, newType, input, builder.getDenseI64ArrayAttr(perm)) .getResult(); } @@ -85,9 +84,8 @@ static Value transposeLeadingToTrailing(Location loc, Value input, newShape.push_back(type.getShape()[0]); auto newType = RankedTensorType::get(newShape, type.getElementType()); - return builder - .create(loc, newType, input, - builder.getDenseI64ArrayAttr(perm)) + return mlir::stablehlo::TransposeOp::create( + builder, loc, newType, input, builder.getDenseI64ArrayAttr(perm)) .getResult(); } @@ -130,10 +128,10 @@ class StablehloComplexToDecomposedPattern } auto unsqueezedType = RankedTensorType::get(unsqueezedShape, lhsType.getElementType()); - auto reshapedLhs = rewriter.create( - loc, unsqueezedType, adaptor.getLhs()); - auto reshapedRhs = rewriter.create( - loc, unsqueezedType, adaptor.getRhs()); + auto reshapedLhs = mlir::stablehlo::ReshapeOp::create( + rewriter, loc, unsqueezedType, adaptor.getLhs()); + auto reshapedRhs = mlir::stablehlo::ReshapeOp::create( + rewriter, loc, unsqueezedType, adaptor.getRhs()); SmallVector concatShape; concatShape.push_back(2); @@ -142,8 +140,8 @@ class StablehloComplexToDecomposedPattern } auto concatType = RankedTensorType::get(concatShape, lhsType.getElementType()); - auto concatOp = rewriter.create( - loc, concatType, + auto concatOp = mlir::stablehlo::ConcatenateOp::create( + rewriter, loc, concatType, ValueRange{reshapedLhs.getResult(), reshapedRhs.getResult()}, /*dimension=*/0); @@ -198,8 +196,9 @@ class StablehloRealImagToDecomposedPattern : public OpConversionPattern { SmallVector sliceShape(transposedShape.begin(), transposedShape.end()); sliceShape[0] = 1; - auto sliceOp = rewriter.create( - loc, RankedTensorType::get(sliceShape, transposedType.getElementType()), + auto sliceOp = mlir::stablehlo::SliceOp::create( + rewriter, loc, + RankedTensorType::get(sliceShape, transposedType.getElementType()), transposed, rewriter.getDenseI64ArrayAttr(begins), rewriter.getDenseI64ArrayAttr(ends), rewriter.getDenseI64ArrayAttr(steps)); diff --git a/lib/Dialect/StableHLO/Transforms/ShardyCCLCanonicalization.cpp b/lib/Dialect/StableHLO/Transforms/ShardyCCLCanonicalization.cpp index 83c3ed08301..43d1187e9b8 100644 --- a/lib/Dialect/StableHLO/Transforms/ShardyCCLCanonicalization.cpp +++ b/lib/Dialect/StableHLO/Transforms/ShardyCCLCanonicalization.cpp @@ -81,9 +81,10 @@ class AllReduceAllSliceToReduceScatterPattern // All users validated - create reduce_scatter for each all_slice. for (mlir::sdy::AllSliceOp allSliceOp : allSliceUsers) { - auto reduceScatterOp = rewriter.create( - allReduceOp.getLoc(), allSliceOp.getType(), allReduceOp.getOperand(), - allSliceOp.getSlicingAxes(), allSliceOp.getOutSharding()); + auto reduceScatterOp = mlir::sdy::ReduceScatterOp::create( + rewriter, allReduceOp.getLoc(), allSliceOp.getType(), + allReduceOp.getOperand(), allSliceOp.getSlicingAxes(), + allSliceOp.getOutSharding()); rewriter.replaceOp(allSliceOp, reduceScatterOp.getResult()); } diff --git a/lib/Dialect/StableHLO/Transforms/ShardyCCLToStableHLOCCLPatterns.cpp b/lib/Dialect/StableHLO/Transforms/ShardyCCLToStableHLOCCLPatterns.cpp index 96673bddc74..610cada96af 100644 --- a/lib/Dialect/StableHLO/Transforms/ShardyCCLToStableHLOCCLPatterns.cpp +++ b/lib/Dialect/StableHLO/Transforms/ShardyCCLToStableHLOCCLPatterns.cpp @@ -130,9 +130,9 @@ static void addReductionBlock(PatternRewriter &rewriter, SrcOp &srcOp, mlir::Block *block = rewriter.createBlock(&srcOp.getRegion(), /*insertPt*/ {}, {reductionType, reductionType}, {loc, loc}); - ReductionOp reductionOp = rewriter.create( - loc, block->getArgument(0), block->getArgument(1)); - rewriter.create(loc, reductionOp.getResult()); + ReductionOp reductionOp = ReductionOp::create( + rewriter, loc, block->getArgument(0), block->getArgument(1)); + mlir::stablehlo::ReturnOp::create(rewriter, loc, reductionOp.getResult()); } // AllGatherOp @@ -182,8 +182,8 @@ class ShardyToStableHLOAllGatherOpRewritePattern newShape, prevOutputType.getElementType()); mlir::stablehlo::AllGatherOp allGatherOp = - rewriter.create( - srcOp.getLoc(), newOutputType, result, allGatherDim, + mlir::stablehlo::AllGatherOp::create( + rewriter, srcOp.getLoc(), newOutputType, result, allGatherDim, createDenseAttrFromReplicaGroups( context, populateReplicaGroups(meshMap, meshAxis)), channelHandleAttr); @@ -241,8 +241,9 @@ class ShardyToStableHLOReduceScatterOpRewritePattern newShape, prevOutputType.getElementType()); mlir::stablehlo::ReduceScatterOp reduceScatterOp = - rewriter.create( - srcOp.getLoc(), newOutputType, result, reduceScatterDim, + mlir::stablehlo::ReduceScatterOp::create( + rewriter, srcOp.getLoc(), newOutputType, result, + reduceScatterDim, createDenseAttrFromReplicaGroups( context, populateReplicaGroups(meshMap, meshAxis)), channelHandleAttr); @@ -292,8 +293,8 @@ class ShardyToStableHLOAllReduceOpRewritePattern mlir::RankedTensorType newOutputType = mlir::cast(result.getType()); mlir::stablehlo::AllReduceOp allReduceOp = - rewriter.create( - srcOp.getLoc(), newOutputType, result, + mlir::stablehlo::AllReduceOp::create( + rewriter, srcOp.getLoc(), newOutputType, result, createDenseAttrFromReplicaGroups( context, populateReplicaGroups(meshMap, meshAxis)), channelHandleAttr); @@ -360,9 +361,9 @@ class ShardyToStableHLOAllToAllOpRewritePattern mlir::RankedTensorType newOutputType = mlir::RankedTensorType::get( newShape, prevOutputType.getElementType()); mlir::stablehlo::AllToAllOp allToAllOp = - rewriter.create( - srcOp.getLoc(), newOutputType, result, sliceDim, concatDim, - meshMap[meshAxis.str()], + mlir::stablehlo::AllToAllOp::create( + rewriter, srcOp.getLoc(), newOutputType, result, sliceDim, + concatDim, meshMap[meshAxis.str()], createDenseAttrFromReplicaGroups( context, populateReplicaGroups(meshMap, meshAxis)), channelHandleAttr); @@ -465,8 +466,8 @@ class ShardyToStableHLOAllSliceOpRewritePattern } auto reshapedType = mlir::RankedTensorType::get( reshapeShape, prevType.getElementType()); - auto reshaped = rewriter.create( - srcOp.getLoc(), reshapedType, result); + auto reshaped = mlir::stablehlo::ReshapeOp::create( + rewriter, srcOp.getLoc(), reshapedType, result); ops_to_outline.push_back(reshaped); // 2) Replica groups for the target mesh axis. @@ -476,8 +477,8 @@ class ShardyToStableHLOAllSliceOpRewritePattern // 3) AllToAll along the inserted "parts" axis (split and concat on same // axis). int64_t partsDim = sliceDim; // "parts" axis is inserted at sliceDim - auto allToAll = rewriter.create( - srcOp.getLoc(), + auto allToAll = mlir::stablehlo::AllToAllOp::create( + rewriter, srcOp.getLoc(), reshapedType, // result type is identical to input reshaped.getResult(), /*split_dimension=*/partsDim, @@ -505,9 +506,9 @@ class ShardyToStableHLOAllSliceOpRewritePattern mlir::RankedTensorType::get(slicedShape, prevType.getElementType()); mlir::Value allToAllOut = allToAll.getResult(0); - auto slice = rewriter.create( - srcOp.getLoc(), slicedType, allToAllOut, startAttr, limitAttr, - stridesAttr); + auto slice = mlir::stablehlo::SliceOp::create( + rewriter, srcOp.getLoc(), slicedType, allToAllOut, startAttr, + limitAttr, stridesAttr); ops_to_outline.push_back(slice); // 5) Remove the singleton "parts" axis → final shape with chunkLen at @@ -515,8 +516,8 @@ class ShardyToStableHLOAllSliceOpRewritePattern shape[sliceDim] = chunkLen; auto finalType = mlir::RankedTensorType::get(shape, prevType.getElementType()); - auto squeezed = rewriter.create( - srcOp.getLoc(), finalType, slice.getResult()); + auto squeezed = mlir::stablehlo::ReshapeOp::create( + rewriter, srcOp.getLoc(), finalType, slice.getResult()); ops_to_outline.push_back(squeezed); // Thread through for the next axis (if compound) or next dimension. diff --git a/lib/Dialect/StableHLO/Transforms/StableHLOFusing.cpp b/lib/Dialect/StableHLO/Transforms/StableHLOFusing.cpp index c4c419a3e0d..0945634c78c 100644 --- a/lib/Dialect/StableHLO/Transforms/StableHLOFusing.cpp +++ b/lib/Dialect/StableHLO/Transforms/StableHLOFusing.cpp @@ -69,8 +69,8 @@ class ConcatenateToBroadcastInDimFusionPattern // Create broadcast_in_dim op with concatenate input and broadcast dims. // Replace reshape op with broadcast_in_dim op. ::mlir::stablehlo::BroadcastInDimOp broadcastInDimOp = - rewriter.create<::mlir::stablehlo::BroadcastInDimOp>( - concatOp.getLoc(), reshapeOp.getResult().getType(), + ::mlir::stablehlo::BroadcastInDimOp::create( + rewriter, concatOp.getLoc(), reshapeOp.getResult().getType(), concatOp.getInputs()[0], broadcastDims); rewriter.replaceOp(reshapeOp, broadcastInDimOp.getResult()); return success(); diff --git a/lib/Dialect/StableHLO/Transforms/WrapUnderManualComputation.cpp b/lib/Dialect/StableHLO/Transforms/WrapUnderManualComputation.cpp index 191081a2182..aa9e0e6ede0 100644 --- a/lib/Dialect/StableHLO/Transforms/WrapUnderManualComputation.cpp +++ b/lib/Dialect/StableHLO/Transforms/WrapUnderManualComputation.cpp @@ -32,9 +32,10 @@ static mlir::LogicalResult wrapFunctionBodyInManualComputationOp( mlir::Block &entryBlock = funcOp.getBody().front(); builder.setInsertionPointToStart(&entryBlock); mlir::sdy::ManualComputationOp manualComputationOp = - builder.create( - builder.getUnknownLoc(), funcType.getResults(), funcOp.getArguments(), - inShardings, outShardings, llvm::SmallVector()); + mlir::sdy::ManualComputationOp::create( + builder, builder.getUnknownLoc(), funcType.getResults(), + funcOp.getArguments(), inShardings, outShardings, + llvm::SmallVector()); // Determine the argumentTypes and argumentLocations that need to get // added to the new region in manualComputationOp. @@ -70,8 +71,8 @@ static mlir::LogicalResult wrapFunctionBodyInManualComputationOp( // Create a new func.ReturnOp in the original func.funcOp that takes the // manualComputationOp as it's operand. builder.setInsertionPointAfter(manualComputationOp); - builder.create(builder.getUnknownLoc(), - manualComputationOp->getResults()); + mlir::func::ReturnOp::create(builder, builder.getUnknownLoc(), + manualComputationOp->getResults()); // Update old arguments with new arguments inside of the // manualComputationBlock. @@ -98,8 +99,8 @@ static mlir::LogicalResult wrapFunctionBodyInManualComputationOp( mlir::func::ReturnOp returnOp = mlir::cast(op); builder.setInsertionPoint(returnOp); - builder.create(builder.getUnknownLoc(), - returnOp->getOperands()); + mlir::sdy::ReturnOp::create(builder, builder.getUnknownLoc(), + returnOp->getOperands()); returnOp->erase(); } diff --git a/lib/Dialect/StableHLO/Utils/ShardyUtils.cpp b/lib/Dialect/StableHLO/Utils/ShardyUtils.cpp index b19710c45b2..6ed9cd9821a 100644 --- a/lib/Dialect/StableHLO/Utils/ShardyUtils.cpp +++ b/lib/Dialect/StableHLO/Utils/ShardyUtils.cpp @@ -92,8 +92,8 @@ void addMeshToModule(mlir::ModuleOp &module, std::string meshName, mlir::sdy::MeshAttr sdyMeshAttr = shardy_utils::createMeshAttrFromMeshMap(context, meshMap); builder.setInsertionPoint(&(module.getBody()->front())); - builder.create( - builder.getUnknownLoc(), builder.getStringAttr(meshName), sdyMeshAttr); + mlir::sdy::MeshOp::create(builder, builder.getUnknownLoc(), + builder.getStringAttr(meshName), sdyMeshAttr); } // Create a TTMeshAttr from a sdy::meshOp. @@ -954,8 +954,8 @@ convertCustomCallToShardingConstraint(mlir::ModuleOp &rootModule, // Create sdy.sharding_constraint op and replace it in place of custom call. builder.setInsertionPointAfter(customCallOp); - auto shardingConstraintOp = builder.create( - customCallOp->getLoc(), customCallOp.getResult(0).getType(), + auto shardingConstraintOp = mlir::sdy::ShardingConstraintOp::create( + builder, customCallOp->getLoc(), customCallOp.getResult(0).getType(), customCallOp.getOperand(0), tensorShardingAttr); customCallOp.getResult(0).replaceAllUsesWith( shardingConstraintOp.getResult()); diff --git a/lib/Dialect/StableHLO/Utils/StableHLOUtils.cpp b/lib/Dialect/StableHLO/Utils/StableHLOUtils.cpp index 066c501795b..f5cc6105610 100644 --- a/lib/Dialect/StableHLO/Utils/StableHLOUtils.cpp +++ b/lib/Dialect/StableHLO/Utils/StableHLOUtils.cpp @@ -69,7 +69,7 @@ mlir::func::FuncOp createPrivateFunction( mlir::Value escVal = mapping.lookupOrNull(esc); retVals.push_back(escVal); } - internalBuilder.create(func.getLoc(), retVals); + mlir::func::ReturnOp::create(internalBuilder, func.getLoc(), retVals); return func; } diff --git a/lib/Dialect/TTCore/IR/Utils.cpp b/lib/Dialect/TTCore/IR/Utils.cpp index 367fe9b7e4d..995cd5a5ee0 100644 --- a/lib/Dialect/TTCore/IR/Utils.cpp +++ b/lib/Dialect/TTCore/IR/Utils.cpp @@ -94,8 +94,8 @@ mlir::memref::GlobalOp createGlobal(ModuleOp moduleOp, StringRef name, auto symbolName = getUniqueSymbolName(); OpBuilder builder(moduleOp.getRegion()); - auto global = builder.create( - moduleOp->getLoc(), symbolName, + auto global = memref::GlobalOp::create( + builder, moduleOp->getLoc(), symbolName, /*sym_visibility*/ builder.getStringAttr(privateVisibility ? "private" : "public"), type, value, constant, diff --git a/lib/Dialect/TTCore/Transforms/TTCoreModuleWrap.cpp b/lib/Dialect/TTCore/Transforms/TTCoreModuleWrap.cpp index dd44fe77ba7..d615c5aef51 100644 --- a/lib/Dialect/TTCore/Transforms/TTCoreModuleWrap.cpp +++ b/lib/Dialect/TTCore/Transforms/TTCoreModuleWrap.cpp @@ -42,7 +42,7 @@ class TTCoreWrapDeviceModulePass innerModule.getBodyRegion().takeBody(rootModule.getBodyRegion()); rootModule.getRegion().emplaceBlock(); builder.setInsertionPointToStart(&rootModule.getBodyRegion().front()); - auto deviceModule = builder.create(rootModule.getLoc()); + auto deviceModule = DeviceModuleOp::create(builder, rootModule.getLoc()); builder.setInsertionPointToStart(&deviceModule.getBodyRegion().front()); builder.clone(*innerModule); innerModule->erase(); diff --git a/lib/Dialect/TTCore/Transforms/TTCoreRegisterDevice.cpp b/lib/Dialect/TTCore/Transforms/TTCoreRegisterDevice.cpp index ef1ca0471bc..e11f5b30c74 100644 --- a/lib/Dialect/TTCore/Transforms/TTCoreRegisterDevice.cpp +++ b/lib/Dialect/TTCore/Transforms/TTCoreRegisterDevice.cpp @@ -35,8 +35,8 @@ registerDeviceInSymbolTable(ModuleOp moduleOp, ArrayRef meshShape, return failure(); } OpBuilder builder(moduleOp.getBodyRegion()); - symbolTable.insert(builder.create( - moduleOp.getLoc(), getDefaultDeviceName(), + symbolTable.insert(DeviceOp::create( + builder, moduleOp.getLoc(), getDefaultDeviceName(), DeviceAttr::get(context, systemDesc, *finalMeshShape, meshTopology))); } return success(); diff --git a/lib/Dialect/TTIR/IR/TTIRDialect.cpp b/lib/Dialect/TTIR/IR/TTIRDialect.cpp index 0e6b6e40bca..bcebc15a8b6 100644 --- a/lib/Dialect/TTIR/IR/TTIRDialect.cpp +++ b/lib/Dialect/TTIR/IR/TTIRDialect.cpp @@ -154,16 +154,16 @@ ::mlir::Operation *TTIRDialect::materializeConstant(OpBuilder &builder, llvm::to_vector_of(elementsAttr.getShapedType().getShape()); auto splatValue = elementsAttr.getSplatValue(); if (isZeroAttr(splatValue)) { - return builder.create(loc, type, shape); + return ttir::ZerosOp::create(builder, loc, type, shape); } if (isOneAttr(splatValue)) { - return builder.create(loc, type, shape); + return ttir::OnesOp::create(builder, loc, type, shape); } if (isValidFullValueType(elementsAttr.getElementType())) { - return builder.create(loc, type, shape, splatValue); + return ttir::FullOp::create(builder, loc, type, shape, splatValue); } } - return builder.create(loc, type, elementsAttr); + return ttir::ConstantOp::create(builder, loc, type, elementsAttr); } return {}; } diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index edbbed0659f..72a40414388 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -105,7 +105,7 @@ mlir::Operation *mlir::tt::ttir::AddOp::rewriteWithQuantizedInputs( quantElemQ, quantType.getEncoding()); auto quantizedInput = - rewriter.create(getLoc(), newType, dequantVal); + ttir::QuantizeOp::create(rewriter, getLoc(), newType, dequantVal); // Update operands. if (lhsElemQ) { @@ -121,7 +121,8 @@ mlir::Operation *mlir::tt::ttir::AddOp::rewriteWithQuantizedInputs( oldType.getShape(), lhsElemQ, oldType.getEncoding()); // Emit new AddOp with quantized types. - auto newAdd = rewriter.create(getLoc(), newResultType, lhs, rhs); + auto newAdd = + ttir::AddOp::create(rewriter, getLoc(), newResultType, lhs, rhs); return newAdd.getOperation(); } @@ -923,12 +924,12 @@ mlir::Operation *mlir::tt::ttir::Conv2dOp::rewriteWithQuantizedInputs( RankedTensorType quantBiasType = RankedTensorType::get( getBias().getType().getShape(), quantConvOutputType, getBias().getType().getEncoding()); - quantBias = rewriter.create( - getLoc(), quantBiasType, quantBias); + quantBias = mlir::tt::ttir::QuantizeOp::create(rewriter, getLoc(), + quantBiasType, quantBias); } - auto quantConv = rewriter.create( - getLoc(), newType, sourceOperands[0], sourceOperands[1], quantBias, - getStrideAttr(), getPaddingAttr(), getDilationAttr(), + auto quantConv = mlir::tt::ttir::Conv2dOp::create( + rewriter, getLoc(), newType, sourceOperands[0], sourceOperands[1], + quantBias, getStrideAttr(), getPaddingAttr(), getDilationAttr(), rewriter.getI32IntegerAttr(getGroups()), getBatchDimAttr(), getHeightDimAttr(), getWidthDimAttr(), getChannelDimAttr(), /*flattenedCompatInfo=*/nullptr); @@ -1368,9 +1369,9 @@ mlir::Operation *mlir::tt::ttir::ConvTranspose2dOp::rewriteWithQuantizedInputs( RankedTensorType newType = RankedTensorType::get(oldConvOutputType.getShape(), quantConvOutputType, oldConvOutputType.getEncoding()); - auto quantConv = rewriter.create( - getLoc(), newType, sourceOperands[0], sourceOperands[1], getBias(), - getStrideAttr(), getPaddingAttr(), getOutputPaddingAttr(), + auto quantConv = mlir::tt::ttir::ConvTranspose2dOp::create( + rewriter, getLoc(), newType, sourceOperands[0], sourceOperands[1], + getBias(), getStrideAttr(), getPaddingAttr(), getOutputPaddingAttr(), getDilationAttr(), getGroupsAttr(), /*flattenedCompatInfo=*/nullptr); return quantConv.getOperation(); @@ -1706,10 +1707,11 @@ ::mlir::Operation *mlir::tt::ttir::MaxPool2dOp::rewriteWithQuantizedInputs( mlir::cast(input.getType()).getElementType(), outType.getEncoding()); - return rewriter - .create( - getLoc(), newOutType, input, getKernelAttr(), getStrideAttr(), - getDilationAttr(), getPaddingAttr(), getCeilModeAttr()) + return mlir::tt::ttir::MaxPool2dOp::create( + rewriter, + + getLoc(), newOutType, input, getKernelAttr(), getStrideAttr(), + getDilationAttr(), getPaddingAttr(), getCeilModeAttr()) .getOperation(); // NOLINTEND(clang-analyzer-core.StackAddressEscape) } @@ -4432,9 +4434,9 @@ void mlir::tt::ttir::UpdateCacheOp::getCanonicalizationPatterns( auto newInputType = RankedTensorType::get( newInputShape, newInput.getType().getElementType(), newInput.getType().getEncoding()); - newInput = rewriter.create( - op.getLoc(), newInputType, newInput, - rewriter.getDenseI64ArrayAttr({0, 2, 1, 3})); + newInput = + PermuteOp::create(rewriter, op.getLoc(), newInputType, newInput, + rewriter.getDenseI64ArrayAttr({0, 2, 1, 3})); } // If the update index shape is [1] then repeat to num users @@ -4445,8 +4447,9 @@ void mlir::tt::ttir::UpdateCacheOp::getCanonicalizationPatterns( newUpdateIndexShape, newUpdateIndex.getType().getElementType(), newUpdateIndex.getType().getEncoding()); auto repeatDims = rewriter.getDenseI64ArrayAttr({numUsers}); - newUpdateIndex = rewriter.create( - op.getLoc(), newUpdateIndexType, newUpdateIndex, repeatDims); + newUpdateIndex = + RepeatOp::create(rewriter, op.getLoc(), newUpdateIndexType, + newUpdateIndex, repeatDims); } rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index 687f8716249..61791d9d0b3 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -43,7 +43,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" -#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/lib/Dialect/TTIR/Transforms/Broadcast.cpp b/lib/Dialect/TTIR/Transforms/Broadcast.cpp index 7920d9777da..1a7c83b2938 100644 --- a/lib/Dialect/TTIR/Transforms/Broadcast.cpp +++ b/lib/Dialect/TTIR/Transforms/Broadcast.cpp @@ -45,8 +45,8 @@ static bool addOutputBroadcastIfNeeded(Operation *op, rewriter.setInsertionPointAfter(op); auto broadcastDimensions = ttmlir::utils::getBroadcastDimensions( implicitShape, resultShape); - auto broadcastOp = rewriter.create( - op->getLoc(), + auto broadcastOp = ttir::BroadcastOp::create( + rewriter, op->getLoc(), RankedTensorType::get(resultShape, newResultType.getElementType(), newResultType.getEncoding()), op->getResult(0), broadcastDimensions); @@ -170,8 +170,8 @@ class FullToScalarRewriter : public CreationToScalarRewriter { Value convertToScalar(ttir::FullOp op, RankedTensorType scalarType, PatternRewriter &rewriter) const override { - return rewriter - .create(op.getLoc(), scalarType, op.getFillValueAttr()) + return ttir::FullOp::create(rewriter, op.getLoc(), scalarType, + op.getFillValueAttr()) .getResult(); } }; @@ -182,9 +182,8 @@ class NamedFullToScalarRewriter : public CreationToScalarRewriter { Value convertToScalar(OpTy op, RankedTensorType scalarType, PatternRewriter &rewriter) const override { - return rewriter - .create(op.getLoc(), scalarType, - SmallVector(scalarType.getRank(), 1)) + return OpTy::create(rewriter, op.getLoc(), scalarType, + SmallVector(scalarType.getRank(), 1)) .getResult(); } }; diff --git a/lib/Dialect/TTIR/Transforms/DecomposeComplexPermute.cpp b/lib/Dialect/TTIR/Transforms/DecomposeComplexPermute.cpp index 5d98e30ce9b..2273bed4e59 100644 --- a/lib/Dialect/TTIR/Transforms/DecomposeComplexPermute.cpp +++ b/lib/Dialect/TTIR/Transforms/DecomposeComplexPermute.cpp @@ -168,8 +168,8 @@ class TTIRDecomposeComplexPermute ttmlir::utils::applyPermutation(inputType.getShape(), perm); auto outputType = RankedTensorType::get( outputShape, inputType.getElementType(), inputType.getEncoding()); - auto permuteOp = rewriter.create(loc, outputType, - currentInput, perm); + auto permuteOp = ttir::PermuteOp::create(rewriter, loc, outputType, + currentInput, perm); permuteOp->setAttr("decomposed", rewriter.getUnitAttr()); currentInput = permuteOp.getResult(); inputType = mlir::cast(currentInput.getType()); diff --git a/lib/Dialect/TTIR/Transforms/DecomposeComplexReshape.cpp b/lib/Dialect/TTIR/Transforms/DecomposeComplexReshape.cpp index 2c547e114a7..0ff7a601e64 100644 --- a/lib/Dialect/TTIR/Transforms/DecomposeComplexReshape.cpp +++ b/lib/Dialect/TTIR/Transforms/DecomposeComplexReshape.cpp @@ -79,10 +79,11 @@ Value createReshape(IRRewriter &rewriter, Location loc, Value input, RankedTensorType refType) { auto resultType = RankedTensorType::get(newShape, refType.getElementType(), refType.getEncoding()); - return rewriter - .create( - loc, resultType, input, - rewriter.getI32ArrayAttr(llvm::to_vector_of(newShape))) + return ttir::ReshapeOp::create( + rewriter, + + loc, resultType, input, + rewriter.getI32ArrayAttr(llvm::to_vector_of(newShape))) .getResult(); } @@ -97,7 +98,7 @@ Value createPermuteSwapLastTwoDims(IRRewriter &rewriter, Location loc, } auto resultType = RankedTensorType::get(resultShape, refType.getElementType(), refType.getEncoding()); - return rewriter.create(loc, resultType, input, perm) + return ttir::PermuteOp::create(rewriter, loc, resultType, input, perm) .getResult(); } @@ -239,7 +240,7 @@ class TTIRDecomposeComplexReshape input = createReshape(rewriter, loc, input, outputSwapped, inputType); } auto perm = buildLastTwoSwapPerm(outputShape.size()); - return rewriter.create(loc, outputType, input, perm) + return ttir::PermuteOp::create(rewriter, loc, outputType, input, perm) .getResult(); } @@ -279,7 +280,7 @@ class TTIRDecomposeComplexReshape input = createReshape(rewriter, loc, input, outputSwapped, inputType); } auto perm = buildLastTwoSwapPerm(outputShape.size()); - return rewriter.create(loc, outputType, input, perm) + return ttir::PermuteOp::create(rewriter, loc, outputType, input, perm) .getResult(); } diff --git a/lib/Dialect/TTIR/Transforms/ElementTypeNormalization.cpp b/lib/Dialect/TTIR/Transforms/ElementTypeNormalization.cpp index f62a9f2bb35..142956a9703 100644 --- a/lib/Dialect/TTIR/Transforms/ElementTypeNormalization.cpp +++ b/lib/Dialect/TTIR/Transforms/ElementTypeNormalization.cpp @@ -210,7 +210,7 @@ class FuncBodyTypeCast : public mlir::ConversionPattern { mlir::Location loc) -> mlir::Value { mlir::RankedTensorType rankedType = mlir::cast(type); - return builder.create(loc, rankedType, inputs); + return ttir::TypecastOp::create(builder, loc, rankedType, inputs); }; addSourceMaterialization(materializeFunc); diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/BroadcastCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/BroadcastCommutePatterns.cpp index 5995a04cf59..a226d8d31d4 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/BroadcastCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/BroadcastCommutePatterns.cpp @@ -221,15 +221,16 @@ class TTIRCommuteReshapeThroughBroadcast RankedTensorType::get(newReshapeShape, tmResultType.getElementType(), tmResultType.getEncoding()); - auto newReshape = rewriter.create( - reshapeUser.getLoc(), newTMResultType, op.getInput(), + auto newReshape = ttir::ReshapeOp::create( + rewriter, reshapeUser.getLoc(), newTMResultType, op.getInput(), rewriter.getI32ArrayAttr(SmallVector(newReshapeShape.begin(), newReshapeShape.end()))); assert(newBroadcastDimensions.size() == static_cast(tmResultType.getRank())); - auto newBroadcast = rewriter.create( - op->getLoc(), tmResultType, newReshape, newBroadcastDimensions); + auto newBroadcast = + ttir::BroadcastOp::create(rewriter, op->getLoc(), tmResultType, + newReshape, newBroadcastDimensions); // All users must be identical TMs. // We must not reference `reshapeUser` during/after replacements, as it will @@ -331,16 +332,17 @@ class TTIRCommutePermuteThroughBroadcast ttmlir::utils::applyPermutation(op.getBroadcastDimensions(), permutation); - auto newPermute = rewriter.create( - permuteUser->getLoc(), + auto newPermute = ttir::PermuteOp::create( + rewriter, permuteUser->getLoc(), RankedTensorType::get(newShape, tmResultType.getElementType(), tmResultType.getEncoding()), operand, permutation); assert(newBroadcastDimensions.size() == static_cast(tmResultType.getRank())); - auto newBroadcast = rewriter.create( - op->getLoc(), tmResultType, newPermute, newBroadcastDimensions); + auto newBroadcast = + ttir::BroadcastOp::create(rewriter, op->getLoc(), tmResultType, + newPermute, newBroadcastDimensions); // All users must be identical TMs. // We must not reference `permuteUser` during/after replacements, as it will diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ConcatCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ConcatCommutePatterns.cpp index e12142f36dd..c25816aad6f 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ConcatCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ConcatCommutePatterns.cpp @@ -60,9 +60,8 @@ class TTIRCommutePermuteThroughConcat operandType.getElementType()); newConcatOperands.push_back( - rewriter - .create(op->getLoc(), permuteOperandType, operand, - permuteUser.getPermutation()) + PermuteOp::create(rewriter, op->getLoc(), permuteOperandType, operand, + permuteUser.getPermutation()) ->getResult(0)); } @@ -70,8 +69,8 @@ class TTIRCommutePermuteThroughConcat ttmlir::utils::applyPermutation(op.getType().getShape(), permuteUser.getPermutation()), op.getType().getElementType()); - ConcatOp newConcat = rewriter.create( - op->getLoc(), newConcatType, newConcatOperands, newConcatDim); + ConcatOp newConcat = ConcatOp::create(rewriter, op->getLoc(), newConcatType, + newConcatOperands, newConcatDim); // All users must be identical TMs. // We must not reference `permuteUser` during/after replacements, as it will @@ -131,16 +130,16 @@ class TTIRCommutePermuteThroughConcat ttmlir::utils::inversePermutation(permuteOperand.getPermutation())), op.getType().getElementType()); int64_t newConcatDim = permuteOperand.getPermutation()[currentConcatDim]; - ConcatOp newConcat = rewriter.create( - op->getLoc(), newConcatType, newConcatOperands, newConcatDim); + ConcatOp newConcat = ConcatOp::create(rewriter, op->getLoc(), newConcatType, + newConcatOperands, newConcatDim); RankedTensorType newPermuteType = RankedTensorType::get( ttmlir::utils::applyPermutation(newConcatType.getShape(), permuteOperand.getPermutation()), newConcatType.getElementType()); PermuteOp newPerm = - rewriter.create(op->getLoc(), newPermuteType, newConcat, - permuteOperand.getPermutation()); + PermuteOp::create(rewriter, op->getLoc(), newPermuteType, newConcat, + permuteOperand.getPermutation()); rewriter.replaceOp(op, newPerm); } @@ -229,15 +228,15 @@ class TTIRCommuteReshapeThroughConcat RankedTensorType newOperandType = RankedTensorType::get( SmallVector(newOperandShape.begin(), newOperandShape.end()), operandType.getElementType()); - newConcatOperands.push_back(rewriter.create( - op->getLoc(), newOperandType, operand, - rewriter.getI32ArrayAttr(newOperandShape))); + newConcatOperands.push_back( + ReshapeOp::create(rewriter, op->getLoc(), newOperandType, operand, + rewriter.getI32ArrayAttr(newOperandShape))); } RankedTensorType newConcatType = RankedTensorType::get(newConcatShape, op.getType().getElementType()); - ConcatOp newConcat = rewriter.create( - op->getLoc(), newConcatType, newConcatOperands, newConcatDim); + ConcatOp newConcat = ConcatOp::create(rewriter, op->getLoc(), newConcatType, + newConcatOperands, newConcatDim); // All users must be identical TMs. // We must not reference `reshapeUser` during/after replacements, as it will diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ElementwiseCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ElementwiseCommutePatterns.cpp index 8110867edd3..e5142c65985 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ElementwiseCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ElementwiseCommutePatterns.cpp @@ -49,9 +49,9 @@ class TTIRCommuteTmsThroughElementwiseRewriter mlir::Location newLoc = ttmlir::utils::appendLocationSuffix( tmUser->getLoc(), "_tm" + std::to_string(operandIdx)); - auto newTM = rewriter.create( - newLoc, newTMResultTypes[operandIdx], op->getOperand(operandIdx), - tmUser->getAttrs()); + auto newTM = + TMOpType::create(rewriter, newLoc, newTMResultTypes[operandIdx], + op->getOperand(operandIdx), tmUser->getAttrs()); newEltwiseOperands.push_back(newTM); } @@ -102,9 +102,9 @@ class TTIRCommuteTmsThroughElementwiseRewriter RankedTensorType newTMType = cast(op->getResult(0).getType()) .clone(tmOperand.getType().getShape()); - TMOpType newUserTM = rewriter.create(op->getLoc(), newTMType, - newEltwise->getResult(0), - tmOperand->getAttrs()); + TMOpType newUserTM = + TMOpType::create(rewriter, op->getLoc(), newTMType, + newEltwise->getResult(0), tmOperand->getAttrs()); rewriter.replaceOp(op, newUserTM); } diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/RMSNormCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/RMSNormCommutePatterns.cpp index cde12a07062..a0039a096e3 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/RMSNormCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/RMSNormCommutePatterns.cpp @@ -42,12 +42,13 @@ class TTIRCommuteReshapeThroughRMSNorm // Create reshape before rms_norm auto newInputReshape = - rewriter.create(op.getLoc(), newInputType, op.getInput(), - rewriter.getI32ArrayAttr(newInputShape)); + ReshapeOp::create(rewriter, op.getLoc(), newInputType, op.getInput(), + rewriter.getI32ArrayAttr(newInputShape)); - auto newRmsNorm = rewriter.create( - op.getLoc(), outputReshapeType, newInputReshape, op.getWeight(), - op.getBias(), op.getNormalizedShapeAttr(), op.getEpsilonAttr()); + auto newRmsNorm = + RMSNormOp::create(rewriter, op.getLoc(), outputReshapeType, + newInputReshape, op.getWeight(), op.getBias(), + op.getNormalizedShapeAttr(), op.getEpsilonAttr()); // All users must be identical TMs. We must not reference `reshapeUser` // during/after replacements, as it will be erased on its turn. diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ReduceCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ReduceCommutePatterns.cpp index a83f2f44abe..2d27a5304e2 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/ReduceCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/ReduceCommutePatterns.cpp @@ -97,7 +97,7 @@ class TTIRCommutePermuteThroughReduce outputShape, inputType.getElementType(), inputType.getEncoding()); // Create and return the new PermuteOp - return rewriter.create(loc, outputType, input, permutation); + return PermuteOp::create(rewriter, loc, outputType, input, permutation); } ArrayAttr permuteDims(ArrayAttr dimArg, ArrayRef permutation, @@ -134,8 +134,8 @@ class TTIRCommutePermuteThroughReduce RankedTensorType::get(newReduceShape, op.getType().getElementType(), op.getType().getEncoding()); - return rewriter.create(op->getLoc(), newReduceType, newInput, - op.getKeepDimAttr(), newDimArgAttrs); + return ReduceOpType::create(rewriter, op->getLoc(), newReduceType, newInput, + op.getKeepDimAttr(), newDimArgAttrs); } bool isCommuteUpwardsViable(ReduceOpType op, PermuteOp) const override { diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/SliceCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/SliceCommutePatterns.cpp index bc53490c073..f48a2312e6d 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/SliceCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/SliceCommutePatterns.cpp @@ -44,8 +44,8 @@ class TTIRCommutePermuteThroughSlice sliceOperandType.getElementType(), sliceOperandType.getEncoding()); PermuteOp newPerm = - rewriter.create(permuteUser->getLoc(), newPermuteType, - op.getInput(), permuteUser.getPermutation()); + PermuteOp::create(rewriter, permuteUser->getLoc(), newPermuteType, + op.getInput(), permuteUser.getPermutation()); SmallVector newSliceStarts = ttmlir::utils::applyPermutation( op.getBegins().getValue(), permuteUser.getPermutation()); @@ -60,10 +60,10 @@ class TTIRCommutePermuteThroughSlice op.getType().getElementType(), op.getType().getEncoding()); SliceStaticOp newSlice = - rewriter.create(op->getLoc(), newSliceType, newPerm, - rewriter.getArrayAttr(newSliceStarts), - rewriter.getArrayAttr(newSliceEnds), - rewriter.getArrayAttr(newSliceSteps)); + SliceStaticOp::create(rewriter, op->getLoc(), newSliceType, newPerm, + rewriter.getArrayAttr(newSliceStarts), + rewriter.getArrayAttr(newSliceEnds), + rewriter.getArrayAttr(newSliceSteps)); // All users must be identical TMs. // We must not reference `permuteUser` during/after replacements, as it will @@ -111,8 +111,8 @@ class TTIRCommutePermuteThroughSlice SmallVector newSliceSteps = ttmlir::utils::applyPermutation( op.getStep().getValue(), inversePermutation); - SliceStaticOp newSlice = rewriter.create( - op->getLoc(), newSliceType, permuteOperand.getInput(), + SliceStaticOp newSlice = SliceStaticOp::create( + rewriter, op->getLoc(), newSliceType, permuteOperand.getInput(), rewriter.getArrayAttr(newSliceStarts), rewriter.getArrayAttr(newSliceEnds), rewriter.getArrayAttr(newSliceSteps)); @@ -121,8 +121,8 @@ class TTIRCommutePermuteThroughSlice op.getType().getShape(), newSlice.getType().getElementType(), newSlice.getType().getEncoding()); PermuteOp newPerm = - rewriter.create(permuteOperand->getLoc(), newPermuteType, - newSlice, permuteOperand.getPermutation()); + PermuteOp::create(rewriter, permuteOperand->getLoc(), newPermuteType, + newSlice, permuteOperand.getPermutation()); rewriter.replaceOp(op, newPerm); } @@ -220,8 +220,8 @@ class TTIRCommuteReshapeThroughSlice // The reshape should produce the same output type as the original slice SmallVector reshapeTargetShape(op.getType().getShape()); - ReshapeOp newReshape = rewriter.create( - reshapeOperand->getLoc(), op.getType(), newSlice, + ReshapeOp newReshape = ReshapeOp::create( + rewriter, reshapeOperand->getLoc(), op.getType(), newSlice, rewriter.getI32ArrayAttr(reshapeTargetShape)); rewriter.replaceOp(op, newReshape); } @@ -336,9 +336,10 @@ class TTIRCommuteReshapeThroughSlice RankedTensorType::get(newOutputShape, op.getType().getElementType(), op.getType().getEncoding()); - SliceStaticOp newSlice = rewriter.create( - op->getLoc(), newSliceType, input, rewriter.getArrayAttr(newBegins), - rewriter.getArrayAttr(newEnds), rewriter.getArrayAttr(newSteps)); + SliceStaticOp newSlice = SliceStaticOp::create( + rewriter, op->getLoc(), newSliceType, input, + rewriter.getArrayAttr(newBegins), rewriter.getArrayAttr(newEnds), + rewriter.getArrayAttr(newSteps)); return newSlice; } @@ -374,9 +375,9 @@ class TTIRCommuteReshapeThroughSlice op.getInput().getType().getEncoding()); SmallVector targetShapeInt32(targetShape.begin(), targetShape.end()); - return rewriter.create( - reshapeUser->getLoc(), targetType, op.getInput(), - rewriter.getI32ArrayAttr(targetShapeInt32)); + return ReshapeOp::create(rewriter, reshapeUser->getLoc(), targetType, + op.getInput(), + rewriter.getI32ArrayAttr(targetShapeInt32)); } bool isCommuteUpwardsViable(SliceStaticOp op, diff --git a/lib/Dialect/TTIR/Transforms/EraseInverseOps/SoftmaxCommutePatterns.cpp b/lib/Dialect/TTIR/Transforms/EraseInverseOps/SoftmaxCommutePatterns.cpp index a14aa8aae49..1397c71baa7 100644 --- a/lib/Dialect/TTIR/Transforms/EraseInverseOps/SoftmaxCommutePatterns.cpp +++ b/lib/Dialect/TTIR/Transforms/EraseInverseOps/SoftmaxCommutePatterns.cpp @@ -104,15 +104,15 @@ class TTIRCommuteReshapeThroughSoftmax SmallVector newReshapeShape(outputReshapeShape.begin(), outputReshapeShape.end()); - auto newInputReshape = rewriter.create( - reshapeUser->getLoc(), newInputReshapeType, op.getInput(), + auto newInputReshape = ReshapeOp::create( + rewriter, reshapeUser->getLoc(), newInputReshapeType, op.getInput(), rewriter.getI32ArrayAttr(newReshapeShape)); auto newSoftmaxDimAttr = rewriter.getSI32IntegerAttr(newSoftmaxDim); // Create a new Softmax Op with updated dimension attribute - auto newSoftmaxOp = rewriter.create( - op->getLoc(), outputReshapeType, newInputReshape.getResult(), + auto newSoftmaxOp = SoftmaxOp::create( + rewriter, op->getLoc(), outputReshapeType, newInputReshape.getResult(), newSoftmaxDimAttr, op.getNumericStableAttr()); // All users must be identical TMs. @@ -159,8 +159,8 @@ class TTIRCommuteReshapeThroughSoftmax auto newSoftmaxDimAttr = rewriter.getSI32IntegerAttr(newSoftmaxDim); // Create new Softmax Op with updated dimension attribute - auto newSoftmaxOp = rewriter.create( - op->getLoc(), newSoftmaxInputType, reshapeOperand.getInput(), + auto newSoftmaxOp = SoftmaxOp::create( + rewriter, op->getLoc(), newSoftmaxInputType, reshapeOperand.getInput(), newSoftmaxDimAttr, op.getNumericStableAttr()); // Create new reshape Op @@ -169,9 +169,9 @@ class TTIRCommuteReshapeThroughSoftmax auto originalSoftmaxOpShape = originalSoftmaxOpType.getShape(); SmallVector reshapeTargetShape(originalSoftmaxOpShape.begin(), originalSoftmaxOpShape.end()); - auto newReshapeOp = rewriter.create( - reshapeOperand->getLoc(), op.getType(), newSoftmaxOp.getResult(), - rewriter.getI32ArrayAttr(reshapeTargetShape)); + auto newReshapeOp = ReshapeOp::create( + rewriter, reshapeOperand->getLoc(), op.getType(), + newSoftmaxOp.getResult(), rewriter.getI32ArrayAttr(reshapeTargetShape)); rewriter.replaceOp(op, newReshapeOp.getResult()); } diff --git a/lib/Dialect/TTIR/Transforms/ExplicateTMs.cpp b/lib/Dialect/TTIR/Transforms/ExplicateTMs.cpp index 1a61821e86f..d7c7720e05a 100644 --- a/lib/Dialect/TTIR/Transforms/ExplicateTMs.cpp +++ b/lib/Dialect/TTIR/Transforms/ExplicateTMs.cpp @@ -43,7 +43,8 @@ class ExplicateRankChangeRewriter operandType.getShape().end()); // Create a new reshape operation. - auto reshapeOp = rewriter.create( + auto reshapeOp = ttir::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_reshape"), RankedTensorType::get(newShape, operandType.getElementType(), operandType.getEncoding()), @@ -104,7 +105,8 @@ class ExplicateBroadcastsRewriter } // Create a new broadcast operation. - auto broadcastOp = rewriter.create( + auto broadcastOp = ttir::BroadcastOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_broadcast"), RankedTensorType::get(broadcastedShape, operandType.getElementType(), operandType.getEncoding()), diff --git a/lib/Dialect/TTIR/Transforms/FlattenSlidingWindow.cpp b/lib/Dialect/TTIR/Transforms/FlattenSlidingWindow.cpp index e013e14081d..8e6d43db81e 100644 --- a/lib/Dialect/TTIR/Transforms/FlattenSlidingWindow.cpp +++ b/lib/Dialect/TTIR/Transforms/FlattenSlidingWindow.cpp @@ -40,8 +40,8 @@ ttir::ReshapeOp generateReshape(mlir::TypedValue input, // We cannot pass the shape directly as the attribute as ttir::ReshapeOp // requires that the shape attribute is a 32-bit integer array attribute. // Construction the SmallVector allows us to cast it. - return rewriter.create( - ttmlir::utils::appendLocationSuffix(input.getLoc(), "_reshape"), + return ttir::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(input.getLoc(), "_reshape"), outputType, input, rewriter.getI32ArrayAttr(SmallVector( outputType.getShape().begin(), outputType.getShape().end()))); @@ -70,17 +70,17 @@ class ConvertToFlattenedConv2dPattern Conv2dOpType newConv; if constexpr (std::is_same_v) { - newConv = rewriter.create( - op.getLoc(), getNHWFlattenedType(outputType), flattenedInput, - adaptor.getWeight(), adaptor.getBias(), adaptor.getStride(), - adaptor.getPadding(), adaptor.getOutputPadding(), + newConv = ttir::ConvTranspose2dOp::create( + rewriter, op.getLoc(), getNHWFlattenedType(outputType), + flattenedInput, adaptor.getWeight(), adaptor.getBias(), + adaptor.getStride(), adaptor.getPadding(), adaptor.getOutputPadding(), adaptor.getDilation(), adaptor.getGroups(), flattenedCompatInfoAttr); } else if constexpr (std::is_same_v) { - newConv = rewriter.create( - op.getLoc(), getNHWFlattenedType(outputType), flattenedInput, - adaptor.getWeight(), adaptor.getBias(), adaptor.getStride(), - adaptor.getPadding(), adaptor.getDilation(), adaptor.getGroups(), - flattenedCompatInfoAttr); + newConv = ttir::Conv2dOp::create( + rewriter, op.getLoc(), getNHWFlattenedType(outputType), + flattenedInput, adaptor.getWeight(), adaptor.getBias(), + adaptor.getStride(), adaptor.getPadding(), adaptor.getDilation(), + adaptor.getGroups(), flattenedCompatInfoAttr); } else { static_assert(ttmlir::utils::always_false(), "Unsupported Conv2dOpType"); @@ -117,15 +117,16 @@ class Pooling2dFlattenedCompatOpConversionPattern Pooling2dOp newPool; if constexpr (std::is_same_v) { - newPool = rewriter.create( - op.getLoc(), getNHWFlattenedType(outputType), flattenedInput, - adaptor.getKernel(), adaptor.getStride(), adaptor.getDilation(), - adaptor.getPadding(), adaptor.getCeilMode(), flattenedCompatInfoAttr); + newPool = ttir::MaxPool2dOp::create( + rewriter, op.getLoc(), getNHWFlattenedType(outputType), + flattenedInput, adaptor.getKernel(), adaptor.getStride(), + adaptor.getDilation(), adaptor.getPadding(), adaptor.getCeilMode(), + flattenedCompatInfoAttr); } else if constexpr (std::is_same_v) { - newPool = rewriter.create( - op.getLoc(), getNHWFlattenedType(outputType), flattenedInput, - adaptor.getKernel(), adaptor.getStride(), adaptor.getDilation(), - adaptor.getPadding(), adaptor.getCeilMode(), + newPool = ttir::AvgPool2dOp::create( + rewriter, op.getLoc(), getNHWFlattenedType(outputType), + flattenedInput, adaptor.getKernel(), adaptor.getStride(), + adaptor.getDilation(), adaptor.getPadding(), adaptor.getCeilMode(), adaptor.getCountIncludePad(), flattenedCompatInfoAttr); } else if constexpr (std::is_same_v) { @@ -141,16 +142,17 @@ class Pooling2dFlattenedCompatOpConversionPattern // TODO (umales): Migrate to createDPSOp when it supports multiple // outputs. See https://github.com/tenstorrent/tt-mlir/issues/5497 auto resultEmpty = - rewriter.create(op.getLoc(), flattenedOutputType); + ttir::EmptyOp::create(rewriter, op.getLoc(), flattenedOutputType); auto indicesEmpty = - rewriter.create(op.getLoc(), flattenedIndicesType); + ttir::EmptyOp::create(rewriter, op.getLoc(), flattenedIndicesType); // Create the MaxPool2dWithIndicesOp - auto newPoolWithIndices = rewriter.create( - op.getLoc(), TypeRange{flattenedOutputType, flattenedIndicesType}, - flattenedInput, ValueRange{resultEmpty, indicesEmpty}, - adaptor.getKernel(), adaptor.getStride(), adaptor.getDilation(), - adaptor.getPadding(), adaptor.getCeilMode(), flattenedCompatInfoAttr); + auto newPoolWithIndices = ttir::MaxPool2dWithIndicesOp::create( + rewriter, op.getLoc(), + TypeRange{flattenedOutputType, flattenedIndicesType}, flattenedInput, + ValueRange{resultEmpty, indicesEmpty}, adaptor.getKernel(), + adaptor.getStride(), adaptor.getDilation(), adaptor.getPadding(), + adaptor.getCeilMode(), flattenedCompatInfoAttr); auto pooledOutputVal = newPoolWithIndices.getResult(); auto indicesOutputVal = newPoolWithIndices.getResultIndices(); diff --git a/lib/Dialect/TTIR/Transforms/GlobalDataFormatConversion.cpp b/lib/Dialect/TTIR/Transforms/GlobalDataFormatConversion.cpp index e4532eace61..3152023e5b0 100644 --- a/lib/Dialect/TTIR/Transforms/GlobalDataFormatConversion.cpp +++ b/lib/Dialect/TTIR/Transforms/GlobalDataFormatConversion.cpp @@ -52,7 +52,7 @@ struct GlobalDataFormatBodyConverter : mlir::TypeConverter { mlir::Location loc) -> mlir::Value { mlir::RankedTensorType rankedType = mlir::cast(type); - return builder.create(loc, rankedType, inputs); + return ttir::TypecastOp::create(builder, loc, rankedType, inputs); }; addSourceMaterialization(materializeFunc); // Input conversions diff --git a/lib/Dialect/TTIR/Transforms/HoistCPUOps/HoistCPUOps.cpp b/lib/Dialect/TTIR/Transforms/HoistCPUOps/HoistCPUOps.cpp index ac3f94b109d..481b9e332d0 100644 --- a/lib/Dialect/TTIR/Transforms/HoistCPUOps/HoistCPUOps.cpp +++ b/lib/Dialect/TTIR/Transforms/HoistCPUOps/HoistCPUOps.cpp @@ -182,12 +182,13 @@ performInputArgumentsConversion(mlir::OpBuilder &opBuilder, if (tensorType != convertedType) { // Create converted tensor value. - auto emptyTensor = opBuilder.create( - argument.getLoc(), tensorType.getShape(), + auto emptyTensor = mlir::tt::ttir::EmptyOp::create( + opBuilder, argument.getLoc(), tensorType.getShape(), convertedType.getElementType()); - auto convertedArgument = opBuilder - .create( - argument.getLoc(), argument, emptyTensor) + auto convertedArgument = mlir::tt::ttir::ToLayoutOp::create( + opBuilder, + + argument.getLoc(), argument, emptyTensor) ->getResult(0); convertedArguments.push_back(convertedArgument); @@ -230,11 +231,11 @@ convertResultsBackToOriginalTypes(mlir::OpBuilder &opBuilder, llvm::dyn_cast_or_null(callOpOutput.getType()); if (originalResultType != convertedResultType) { - auto emptyTensor = opBuilder.create( - sourceModule->getLoc(), originalResultType.getShape(), + auto emptyTensor = mlir::tt::ttir::EmptyOp::create( + opBuilder, sourceModule->getLoc(), originalResultType.getShape(), originalResultType.getElementType()); - auto toOriginal = opBuilder.create( - sourceModule->getLoc(), callOpOutput, emptyTensor); + auto toOriginal = mlir::tt::ttir::ToLayoutOp::create( + opBuilder, sourceModule->getLoc(), callOpOutput, emptyTensor); // Replace all uses of the output value with the converted one. originalOutput.replaceAllUsesWith(toOriginal->getResult(0)); } else { @@ -355,7 +356,7 @@ static func::FuncOp createCPUHoistedFunctionDefinition( returnValues.push_back(mapping.lookup(outputValue)); } - builder.create(loc, returnValues); + mlir::func::ReturnOp::create(builder, loc, returnValues); // Add bufferization access attributes to function arguments. for (auto [index, argument] : @@ -513,8 +514,9 @@ static void hoistOperationsToFunction(CPUHoistedOpsDescriptor &descriptor, } // Create the call using already converted inputs. - auto callOp = opBuilder.create( - deviceModule->getLoc(), funcDeclaration, convertedInputArguments); + auto callOp = + mlir::func::CallOp::create(opBuilder, deviceModule->getLoc(), + funcDeclaration, convertedInputArguments); // Add the hoisted_call attribute. callOp->setAttr(CPUHoistedCallAttr::name, UnitAttr::get(context)); @@ -592,7 +594,7 @@ bool canLowerTTIRToLinalg(mlir::Operation *op) { // Build the return op from the cloned op's results. llvm::SmallVector returnValues(clonedOp->getResults()); - builder.create(op->getLoc(), returnValues); + mlir::func::ReturnOp::create(builder, op->getLoc(), returnValues); // Run TTIRToTTIRDecomposition (CPUFallback mode) followed by // TTIRToLinalg conversion on the temporary module, mirroring the @@ -682,9 +684,9 @@ void runCPUHoist(mlir::ModuleOp rootModule, // If no CPU module exists, create one. if (!cpuModule) { rewriter.setInsertionPointToEnd(rootModule.getBody()); - cpuModule = rewriter.create(loc); + cpuModule = ttcore::CPUModuleOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&cpuModule.getBodyRegion().front()); - cpuInnerModule = rewriter.create(loc); + cpuInnerModule = mlir::ModuleOp::create(rewriter, loc); } // Hoist each set of ops into a new function in the CPU module. diff --git a/lib/Dialect/TTIR/Transforms/MoveReshapeToConstant.cpp b/lib/Dialect/TTIR/Transforms/MoveReshapeToConstant.cpp index 2017e932ceb..f0c71e38b2a 100644 --- a/lib/Dialect/TTIR/Transforms/MoveReshapeToConstant.cpp +++ b/lib/Dialect/TTIR/Transforms/MoveReshapeToConstant.cpp @@ -112,9 +112,9 @@ class MoveReshapeToConstantPattern auto newConstType = RankedTensorType::get(preReshapeType.getShape(), constType.getElementType()); - auto constReshape = rewriter.create( - constOperand.getLoc(), newConstType, constOperand, - rewriter.getI32ArrayAttr(newShape)); + auto constReshape = + ReshapeOp::create(rewriter, constOperand.getLoc(), newConstType, + constOperand, rewriter.getI32ArrayAttr(newShape)); // Create the new elementwise op with the original activation and reshaped // constant. Preserve operand order from the original op. @@ -134,9 +134,9 @@ class MoveReshapeToConstantPattern // Replace the original op result with the new op result. // We need a reshape to match the original output shape for downstream // users. - auto outputReshape = rewriter.create( - op->getLoc(), op->getResult(0).getType(), newOp->getResult(0), - reshapeOp.getShapeAttr()); + auto outputReshape = + ReshapeOp::create(rewriter, op->getLoc(), op->getResult(0).getType(), + newOp->getResult(0), reshapeOp.getShapeAttr()); rewriter.replaceOp(op, outputReshape.getResult()); diff --git a/lib/Dialect/TTIR/Transforms/QuantDequantConversion.cpp b/lib/Dialect/TTIR/Transforms/QuantDequantConversion.cpp index b7b7d7f0a4c..85f08ce405d 100644 --- a/lib/Dialect/TTIR/Transforms/QuantDequantConversion.cpp +++ b/lib/Dialect/TTIR/Transforms/QuantDequantConversion.cpp @@ -102,8 +102,8 @@ class CommuteDequantizeBelowQuantizableOpRewriter if (mlir::dyn_cast( newResultType.getElementType())) { // It's quantized, so insert a DequantizeOp. - auto newDequant = rewriter.create( - op->getLoc(), oldType, newResult); + auto newDequant = mlir::tt::ttir::DequantizeOp::create( + rewriter, op->getLoc(), oldType, newResult); newResults.push_back(newDequant); } else { // It's already floating-point or non-quantized. @@ -130,12 +130,12 @@ class CommuteDequantizeBelowQuantizableOpRewriter originalType.getShape(), quantType, originalType.getEncoding()); // Create quantize op. mlir::tt::ttir::QuantizeOp quantize = - rewriter.create(op->getLoc(), - quantizeType, result); + mlir::tt::ttir::QuantizeOp::create(rewriter, op->getLoc(), + quantizeType, result); // Now dequantize op, effectively commuting the original dequantize. mlir::tt::ttir::DequantizeOp dequantize = - rewriter.create( - op->getLoc(), originalType, quantize); + mlir::tt::ttir::DequantizeOp::create(rewriter, op->getLoc(), + originalType, quantize); newResults.push_back(dequantize); } } @@ -157,9 +157,8 @@ struct RewriteDQToRequantize auto dequantizeOp = mlir::dyn_cast( op.getInput().getDefiningOp()); if (dequantizeOp) { - ttir::RequantizeOp requantize = - rewriter.create( - op->getLoc(), op.getType(), dequantizeOp.getInput()); + ttir::RequantizeOp requantize = mlir::tt::ttir::RequantizeOp::create( + rewriter, op->getLoc(), op.getType(), dequantizeOp.getInput()); rewriter.replaceOp(op, requantize); return mlir::success(); } diff --git a/lib/Dialect/TTIR/Transforms/RankNormalization.cpp b/lib/Dialect/TTIR/Transforms/RankNormalization.cpp index a95ffcdd87d..0a0a15e9b55 100644 --- a/lib/Dialect/TTIR/Transforms/RankNormalization.cpp +++ b/lib/Dialect/TTIR/Transforms/RankNormalization.cpp @@ -71,7 +71,8 @@ class RankNormalizationTypeConverter : public TypeConverter { static Value materializeCast(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1 && "Expected single input."); - return builder.create(loc, type, inputs.front()) + return UnrealizedConversionCastOp::create(builder, loc, type, + inputs.front()) .getResult(0); } }; diff --git a/lib/Dialect/TTIR/Transforms/ReductionForceKeepDim.cpp b/lib/Dialect/TTIR/Transforms/ReductionForceKeepDim.cpp index 631135c113c..6c6ad3f3355 100644 --- a/lib/Dialect/TTIR/Transforms/ReductionForceKeepDim.cpp +++ b/lib/Dialect/TTIR/Transforms/ReductionForceKeepDim.cpp @@ -54,8 +54,8 @@ class ForceKeepDimPattern : public mlir::OpRewritePattern { RankedTensorType::get(keepDimShape, originalType.getElementType(), originalType.getEncoding()); - auto newReduction = rewriter.create( - reductionOp.getLoc(), decomposedType, reductionOp.getInput(), + auto newReduction = ReductionOpTy::create( + rewriter, reductionOp.getLoc(), decomposedType, reductionOp.getInput(), /*keep_dim=*/rewriter.getBoolAttr(true), reductionOp.getDimArgAttr()); llvm::SmallVector outputShapeI32(originalType.getShape().begin(), diff --git a/lib/Dialect/TTIR/Transforms/TTIRFusing.cpp b/lib/Dialect/TTIR/Transforms/TTIRFusing.cpp index 17ccdb632cf..75d7df2bf46 100644 --- a/lib/Dialect/TTIR/Transforms/TTIRFusing.cpp +++ b/lib/Dialect/TTIR/Transforms/TTIRFusing.cpp @@ -49,7 +49,8 @@ class ConvAddBias : public mlir::OpRewritePattern { // We can do it like this because we already checked that bias has valid // shape. if (convOp.getBias()) { - bias = rewriter.create( + bias = AddOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(bias.getLoc(), "_bias_add"), mlir::cast(bias.getType()), convOp.getBias(), bias); } @@ -506,7 +507,7 @@ class HardsigmoidFusionPattern : public mlir::OpRewritePattern { } auto hardsigmoidOp = - rewriter.create(divOp.getLoc(), input.getType(), input); + HardsigmoidOp::create(rewriter, divOp.getLoc(), input.getType(), input); rewriter.replaceAllOpUsesWith(divOp, hardsigmoidOp); @@ -561,13 +562,13 @@ class SiluFusionPattern : public mlir::OpRewritePattern { auto inputType = sigmoidInput.getType(); auto outputType = multiplyOp.getResult().getType(); auto siluOp = - rewriter.create(multiplyOp->getLoc(), inputType, otherOperand); + SiluOp::create(rewriter, multiplyOp->getLoc(), inputType, otherOperand); // If multiply inputs and output are typecasted, we need to add a typecast // after silu to convert back to the multiply output type. if (inputType != outputType) { - auto typecastOp = rewriter.create( - multiplyOp->getLoc(), outputType, siluOp.getResult()); + auto typecastOp = TypecastOp::create(rewriter, multiplyOp->getLoc(), + outputType, siluOp.getResult()); rewriter.replaceAllOpUsesWith(multiplyOp, typecastOp); } else { rewriter.replaceAllOpUsesWith(multiplyOp, siluOp); @@ -653,13 +654,13 @@ class MishFusingPattern : public mlir::OpRewritePattern { auto inputType = originalInput.getType(); auto outputType = multiplyOp.getResult().getType(); auto mishOp = - rewriter.create(multiplyOp.getLoc(), inputType, originalInput); + MishOp::create(rewriter, multiplyOp.getLoc(), inputType, originalInput); // If multiply inputs and output are typecasted, we need to add a typecast // after mish to convert it back to the original multiply Op's output type. if (inputType != outputType) { - auto typecastOp = rewriter.create( - multiplyOp->getLoc(), outputType, mishOp.getResult()); + auto typecastOp = TypecastOp::create(rewriter, multiplyOp->getLoc(), + outputType, mishOp.getResult()); rewriter.replaceOp(multiplyOp, typecastOp.getResult()); } else { rewriter.replaceOp(multiplyOp, mishOp.getResult()); @@ -909,8 +910,8 @@ class ConvWithMultiply : public mlir::OpRewritePattern { llvm::SmallVector newShapeI32(newShape.begin(), newShape.end()); // Create the reshape operation. - auto reshapedScale = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reshape"), + auto reshapedScale = ttir::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape"), RankedTensorType::get(newShape, scaleType.getElementType(), scaleType.getEncoding()), reshapeInput, rewriter.getI32ArrayAttr(newShapeI32)); @@ -925,8 +926,8 @@ class ConvWithMultiply : public mlir::OpRewritePattern { SmallVector broadcastDims = ttmlir::utils::getBroadcastDimensions( reshapedScale.getType().getShape(), weightType.getShape()); - return rewriter.create(scaleValue.getLoc(), weightType, - reshapedScale, broadcastDims); + return ttir::BroadcastOp::create(rewriter, scaleValue.getLoc(), weightType, + reshapedScale, broadcastDims); } /// Create pre-multiplied weights. @@ -934,8 +935,8 @@ class ConvWithMultiply : public mlir::OpRewritePattern { Location loc, Value weightValue, Value reshapedScale) { // Create a multiplication of the weights by the reshaped scale. - return rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_multiply"), + return MultiplyOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_multiply"), mlir::cast(weightValue.getType()), weightValue, reshapedScale); } @@ -944,11 +945,11 @@ class ConvWithMultiply : public mlir::OpRewritePattern { Value biasValue, ConvOpType convOp, Value scaleValue) { // Create a multiplication of the bias by the scale. - return rewriter.create( - ttmlir::utils::appendLocationSuffix(biasValue.getLoc(), - "_bias_multiply"), - mlir::cast(biasValue.getType()), biasValue, - scaleValue); + return MultiplyOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( + biasValue.getLoc(), "_bias_multiply"), + mlir::cast(biasValue.getType()), + biasValue, scaleValue); } }; @@ -1057,26 +1058,27 @@ class BatchNormDecomposition variance.getType().getEncoding()); // Convert epsilon to a tensor - auto epsilonTensor = rewriter.create( - loc, scalarType, rewriter.getF32FloatAttr(epsilon.convertToFloat())); + auto epsilonTensor = + FullOp::create(rewriter, loc, scalarType, + rewriter.getF32FloatAttr(epsilon.convertToFloat())); // variance + epsilon - auto variancePlusEpsilon = rewriter.create(loc, variance.getType(), - variance, epsilonTensor); + auto variancePlusEpsilon = AddOp::create(rewriter, loc, variance.getType(), + variance, epsilonTensor); // std = sqrt(variance + epsilon) auto std = - rewriter.create(loc, variance.getType(), variancePlusEpsilon); + SqrtOp::create(rewriter, loc, variance.getType(), variancePlusEpsilon); // alpha = scale / std - auto alpha = rewriter.create(loc, scale.getType(), scale, std); + auto alpha = DivOp::create(rewriter, loc, scale.getType(), scale, std); // alphaMean = alpha * mean auto alphaMean = - rewriter.create(loc, mean.getType(), mean, alpha); + MultiplyOp::create(rewriter, loc, mean.getType(), mean, alpha); // beta = offset - alphaMean auto beta = - rewriter.create(loc, offset.getType(), offset, alphaMean); + SubtractOp::create(rewriter, loc, offset.getType(), offset, alphaMean); // Reshape alpha and beta along the specified dimension to match the input // shape: for dimension = 3 and input shape (N, H, W, C), reshape from (C) @@ -1091,8 +1093,8 @@ class BatchNormDecomposition reshapeShape[dimension] = alpha.getType().getShape()[0]; SmallVector reshapeShapeI32(reshapeShape.begin(), reshapeShape.end()); - alphaReshaped = rewriter.create( - loc, + alphaReshaped = ReshapeOp::create( + rewriter, loc, RankedTensorType::get(reshapeShape, alpha.getType().getElementType(), alpha.getType().getEncoding()), alpha, rewriter.getI32ArrayAttr(reshapeShapeI32)); @@ -1103,18 +1105,19 @@ class BatchNormDecomposition reshapeShape[dimension] = beta.getType().getShape()[0]; SmallVector reshapeShapeI32(reshapeShape.begin(), reshapeShape.end()); - betaReshaped = rewriter.create( - loc, + betaReshaped = ReshapeOp::create( + rewriter, loc, RankedTensorType::get(reshapeShape, beta.getType().getElementType(), beta.getType().getEncoding()), beta, rewriter.getI32ArrayAttr(reshapeShapeI32)); } // alpha * x - auto scaled = - rewriter.create(loc, input.getType(), input, alphaReshaped); + auto scaled = MultiplyOp::create(rewriter, loc, input.getType(), input, + alphaReshaped); // alpha * x + beta - auto result = rewriter.create(loc, resultType, scaled, betaReshaped); + auto result = + AddOp::create(rewriter, loc, resultType, scaled, betaReshaped); rewriter.replaceOp(batchNormOp, result); @@ -1240,8 +1243,8 @@ class RepVGGConvSumFusionPattern : public mlir::OpRewritePattern { // Pad the 1x1 weight to 3x3 by adding zeros around it SmallVector paddingValues = {0, 0, 0, 0, 1, 1, 1, 1}; - return rewriter.create( - ttmlir::utils::appendLocationSuffix(conv.getLoc(), "_pad"), + return PadOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(conv.getLoc(), "_pad"), weight3x3Type, weight1x1, rewriter.getDenseI32ArrayAttr(paddingValues), rewriter.getF32FloatAttr(0.0)); } @@ -1257,7 +1260,8 @@ class RepVGGConvSumFusionPattern : public mlir::OpRewritePattern { // Move additional weight UD chain before conv to ensure it is before addOp. utils::moveUDChainBefore(additionalWeight, conv); - auto combinedWeight = rewriter.create( + auto combinedWeight = AddOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(conv.getLoc(), "_weight_add"), existingWeight.getType(), additionalWeight, existingWeight); rewriter.modifyOpInPlace( @@ -1277,7 +1281,8 @@ class RepVGGConvSumFusionPattern : public mlir::OpRewritePattern { // Move bias2 UD chain before conv1 to ensure it is before addOp. utils::moveUDChainBefore(bias2, conv1); - auto combinedBias = rewriter.create( + auto combinedBias = AddOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(conv1.getLoc(), "_bias_add"), bias1.getType(), bias1, bias2); rewriter.modifyOpInPlace( @@ -1331,8 +1336,8 @@ class MatmulWithBiasFusionPattern : public mlir::OpRewritePattern { auto scalarType = bias.getType(); auto reshapedType = RankedTensorType::get( {1}, scalarType.getElementType(), scalarType.getEncoding()); - bias = rewriter.create(bias.getLoc(), reshapedType, bias, - rewriter.getI32ArrayAttr({1})); + bias = ttir::ReshapeOp::create(rewriter, bias.getLoc(), reshapedType, + bias, rewriter.getI32ArrayAttr({1})); } ArrayRef biasShape = bias.getType().getShape(); @@ -1361,18 +1366,19 @@ class MatmulWithBiasFusionPattern : public mlir::OpRewritePattern { broadcastShape, matmulOp.getType().getElementType(), matmulOp.getType().getEncoding()); - LinearOp linearOp = rewriter.create( - addOp.getLoc(), linearOutputType, matmulOp.getA(), matmulOp.getB(), - bias, matmulOp.getTransposeA(), matmulOp.getTransposeB()); + LinearOp linearOp = ttir::LinearOp::create( + rewriter, addOp.getLoc(), linearOutputType, matmulOp.getA(), + matmulOp.getB(), bias, matmulOp.getTransposeA(), + matmulOp.getTransposeB()); Value result = linearOp.getResult(); if (!llvm::equal(broadcastShape, addOutputShape)) { llvm::SmallVector addShapeI32(addOutputShape.begin(), addOutputShape.end()); - result = rewriter.create( - addOp.getLoc(), addOp.getType(), result, - rewriter.getI32ArrayAttr(addShapeI32)); + result = ttir::ReshapeOp::create(rewriter, addOp.getLoc(), + addOp.getType(), result, + rewriter.getI32ArrayAttr(addShapeI32)); } rewriter.replaceOp(addOp, result); @@ -1551,8 +1557,8 @@ class ScaledSumToMeanPattern : public mlir::OpRewritePattern { auto loc = sumOp.getLoc(); - auto meanOp = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_mean"), outputType, + auto meanOp = MeanOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_mean"), outputType, sumOp.getInput(), /*keep_dim=*/rewriter.getBoolAttr(sumOp.getKeepDim()), /*dim_arg=*/reduceDims); @@ -1676,9 +1682,9 @@ class SpatialMeanOptimizationPattern : public mlir::OpRewritePattern { SmallVector newShapeI32(newShape.begin(), newShape.end()); auto outputType = RankedTensorType::get( newShape, inputType.getElementType(), inputType.getEncoding()); - return rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_input_reshape"), outputType, - input, rewriter.getI32ArrayAttr(newShapeI32)); + return ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_input_reshape"), + outputType, input, rewriter.getI32ArrayAttr(newShapeI32)); } static MeanOp createMean(mlir::PatternRewriter &rewriter, Location loc, @@ -1692,8 +1698,9 @@ class SpatialMeanOptimizationPattern : public mlir::OpRewritePattern { auto outputType = RankedTensorType::get( outputShape, reshapedType.getElementType(), reshapedType.getEncoding()); - return rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_mean"), outputType, reshaped, + return MeanOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_mean"), outputType, + reshaped, /*keep_dim=*/rewriter.getBoolAttr(true), /*dim_arg=*/ rewriter.getArrayAttr({rewriter.getI32IntegerAttr(SPATIAL_WIDTH_DIM)})); @@ -1715,9 +1722,9 @@ class SpatialMeanOptimizationPattern : public mlir::OpRewritePattern { SmallVector newShapeI32(newShape.begin(), newShape.end()); auto outputType = RankedTensorType::get(newShape, meanType.getElementType(), meanType.getEncoding()); - return rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_output_reshape"), outputType, - meanResult, rewriter.getI32ArrayAttr(newShapeI32)); + return ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_output_reshape"), + outputType, meanResult, rewriter.getI32ArrayAttr(newShapeI32)); } }; @@ -1750,8 +1757,8 @@ class ConcatenateHeadsUpdatePattern : public mlir::OpRewritePattern { // input shape: [batch_size, num_heads, sequence_size, head_size] // output shape: [batch_size, sequence_size, num_heads * head_size // (hidden)]. - Value concatHeadsOp = rewriter.create( - reshapeOp.getLoc(), reshapeResultType, inputTensor); + Value concatHeadsOp = ConcatenateHeadsOp::create( + rewriter, reshapeOp.getLoc(), reshapeResultType, inputTensor); rewriter.replaceOp(reshapeOp, concatHeadsOp); return mlir::success(); } @@ -1775,8 +1782,8 @@ class ConcatenateHeadsUpdatePattern : public mlir::OpRewritePattern { newReshapeShape, reshapeResultType.getElementType()); // Create ConcatenateHeadsOp with the new shape. - Value concatHeadsOp = rewriter.create( - reshapeOp.getLoc(), newReshapeType, inputTensor); + Value concatHeadsOp = ConcatenateHeadsOp::create( + rewriter, reshapeOp.getLoc(), newReshapeType, inputTensor); rewriter.modifyOpInPlace(reshapeOp, [&]() { reshapeOp.getInputMutable().assign(concatHeadsOp); @@ -2003,8 +2010,8 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern inputs = {reshapeOutput, concatenatedWeightMatrix}; } - MatMulOpType matrixMultOp = rewriter.create( - keyMatmulOp.getLoc(), TypeRange{linearOutputType}, inputs, + MatMulOpType matrixMultOp = MatMulOpType::create( + rewriter, keyMatmulOp.getLoc(), TypeRange{linearOutputType}, inputs, keyMatmulOp->getAttrs()); TT_assertv(matrixMultOp, "Expected valid matrix multiplication operation"); @@ -2040,9 +2047,10 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern queryReshapeShape.end()); RankedTensorType queryReshapeTy = RankedTensorType::get( queryReshapeShape, queryReshapeElementType, queryReshapeEncoding); - ReshapeOp queryReshapeOp = rewriter.create( - matrixMultOp.getLoc(), queryReshapeTy, queryMatmulOp.getResult(), - rewriter.getI32ArrayAttr(queryReshapeShapeI32)); + ReshapeOp queryReshapeOp = + ReshapeOp::create(rewriter, matrixMultOp.getLoc(), queryReshapeTy, + queryMatmulOp.getResult(), + rewriter.getI32ArrayAttr(queryReshapeShapeI32)); SmallVector kvReshapeShape = {batchSize, sequenceLength, KVHiddenSize * 2}; @@ -2052,15 +2060,16 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern kvReshapeShape.end()); RankedTensorType kvReshapeTy = RankedTensorType::get( kvReshapeShape, kvReshapeElementType, kvReshapeEncoding); - ReshapeOp kvReshapeOp = rewriter.create( - matrixMultOp.getLoc(), kvReshapeTy, matrixMultOp.getResult(), + ReshapeOp kvReshapeOp = ReshapeOp::create( + rewriter, matrixMultOp.getLoc(), kvReshapeTy, matrixMultOp.getResult(), rewriter.getI32ArrayAttr(kvReshapeShapeI32)); // Create split qkv op. // Determine if need to transpose key based on key and value. bool transposeKey = isKeyTransposed(keyShape, valueShape); - auto splitOp = rewriter.create( - matrixMultOp->getLoc(), ArrayRef{queryType, keyType, valueType}, + auto splitOp = SplitQueryKeyValueAndSplitHeadsOp::create( + rewriter, matrixMultOp->getLoc(), + ArrayRef{queryType, keyType, valueType}, queryReshapeOp.getResult(), kvReshapeOp.getResult(), rewriter.getUI32IntegerAttr(numQueryHeads), rewriter.getUI32IntegerAttr(numKVHeads), @@ -2134,8 +2143,8 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern inputs = {reshapeOutput, concatenatedWeightMatrix}; } - MatMulOpType matrixMultOp = rewriter.create( - queryMatmulOp.getLoc(), TypeRange{linearOutputType}, inputs, + MatMulOpType matrixMultOp = MatMulOpType::create( + rewriter, queryMatmulOp.getLoc(), TypeRange{linearOutputType}, inputs, queryMatmulOp->getAttrs()); TT_assertv(matrixMultOp, "Expected valid matrix multiplication operation"); @@ -2162,17 +2171,18 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern RankedTensorType reshapeTy = RankedTensorType::get( reshapeToSplitShape, reshapeElementType, reshapeEncoding); - ReshapeOp reshapeToSplit = rewriter.create( - matrixMultOp.getLoc(), reshapeTy, matrixMultOp, + ReshapeOp reshapeToSplit = ttir::ReshapeOp::create( + rewriter, matrixMultOp.getLoc(), reshapeTy, matrixMultOp, rewriter.getI32ArrayAttr(reshapeToSplitShapeI32)); // Determine if need to transpose key based on key and value. bool transposeKey = isKeyTransposed(keyShape, valueShape); - auto splitOp = rewriter.create( - matrixMultOp.getLoc(), ArrayRef{queryType, keyType, valueType}, - reshapeToSplit, Value(), rewriter.getUI32IntegerAttr(numHeads), - IntegerAttr(), rewriter.getBoolAttr(transposeKey) /*transpose_key*/); + auto splitOp = SplitQueryKeyValueAndSplitHeadsOp::create( + rewriter, matrixMultOp.getLoc(), + ArrayRef{queryType, keyType, valueType}, reshapeToSplit, Value(), + rewriter.getUI32IntegerAttr(numHeads), IntegerAttr(), + rewriter.getBoolAttr(transposeKey) /*transpose_key*/); rewriter.replaceOp(permuteOps[0], splitOp.getQuery()); rewriter.replaceOp(permuteOps[1], splitOp.getKey()); @@ -2294,9 +2304,8 @@ class SplitQueryKeyValueAndSplitHeadsUpdatePattern RankedTensorType::get(concatenatedShape, firstType.getElementType()); // Create concat op along given dimension. - return rewriter.create( - loc, concatenatedType, tensors, - rewriter.getSI32IntegerAttr(dim) /* axis */); + return ttir::ConcatOp::create(rewriter, loc, concatenatedType, tensors, + rewriter.getSI32IntegerAttr(dim) /* axis */); } void hoistPreprocessingOps(llvm::SmallVector permuteOps) const { @@ -3229,9 +3238,8 @@ class RMSNormFusionPattern : public mlir::OpRewritePattern { normalizedShape, gammaType.getElementType(), gammaType.getEncoding()); llvm::SmallVector targetShape(normalizedShape.begin(), normalizedShape.end()); - gamma = rewriter.create(outerMul.getLoc(), reshapedGammaType, - gamma, - rewriter.getI32ArrayAttr(targetShape)); + gamma = ReshapeOp::create(rewriter, outerMul.getLoc(), reshapedGammaType, + gamma, rewriter.getI32ArrayAttr(targetShape)); } // Although the op can work with different dtypes, this is @@ -3245,8 +3253,8 @@ class RMSNormFusionPattern : public mlir::OpRewritePattern { auto rmsNormOutputType = RankedTensorType::get(inputType.getShape(), inputType.getElementType(), inputType.getEncoding()); - auto rmsNorm = rewriter.create( - outerMul.getLoc(), rmsNormOutputType, x, gamma, + auto rmsNorm = RMSNormOp::create( + rewriter, outerMul.getLoc(), rmsNormOutputType, x, gamma, /*bias=*/nullptr, rewriter.getDenseI64ArrayAttr(normalizedShape), rewriter.getF32FloatAttr(epsAttr.getValue().convertToFloat())); diff --git a/lib/Dialect/TTIR/Utils/Utils.cpp b/lib/Dialect/TTIR/Utils/Utils.cpp index d779f1b3f9a..edd978cefab 100644 --- a/lib/Dialect/TTIR/Utils/Utils.cpp +++ b/lib/Dialect/TTIR/Utils/Utils.cpp @@ -24,8 +24,8 @@ llvm::SmallVector unsqueezeValue(mlir::PatternRewriter &rewriter, unsqueezeShape.end()); auto reshapeDimAttr = rewriter.getI32ArrayAttr(reshapeDim); - input = rewriter.create( - loc, + input = ttir::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(unsqueezeShape, desiredType.getElementType(), desiredType.getEncoding()), input, reshapeDimAttr); @@ -59,8 +59,8 @@ mlir::LogicalResult broadcastValue(mlir::PatternRewriter &rewriter, ttmlir::utils::getBroadcastDimensions(inputShape, desiredType.getShape()); - output = rewriter.create(loc, desiredType, input, - broadcastDims); + output = ttir::BroadcastOp::create(rewriter, loc, desiredType, input, + broadcastDims); return mlir::success(); } diff --git a/lib/Dialect/TTKernel/Transforms/ControlDstSection.cpp b/lib/Dialect/TTKernel/Transforms/ControlDstSection.cpp index 40eb751822c..c89ce4c508f 100644 --- a/lib/Dialect/TTKernel/Transforms/ControlDstSection.cpp +++ b/lib/Dialect/TTKernel/Transforms/ControlDstSection.cpp @@ -39,10 +39,10 @@ class TTKernelTileRegsRewriter : public OpRewritePattern { } rewriter.setInsertionPoint(parent); - rewriter.create(op->getLoc()); - rewriter.create(op->getLoc()); + ttkernel::TileRegsCommitOp::create(rewriter, op->getLoc()); + ttkernel::TileRegsWaitOp::create(rewriter, op->getLoc()); rewriter.setInsertionPointAfter(parent); - rewriter.create(op->getLoc()); + ttkernel::TileRegsReleaseOp::create(rewriter, op->getLoc()); return success(); }; diff --git a/lib/Dialect/TTKernel/Transforms/InsertDeviceZoneScopes.cpp b/lib/Dialect/TTKernel/Transforms/InsertDeviceZoneScopes.cpp index e1b2d17cf9c..abd3cefd677 100644 --- a/lib/Dialect/TTKernel/Transforms/InsertDeviceZoneScopes.cpp +++ b/lib/Dialect/TTKernel/Transforms/InsertDeviceZoneScopes.cpp @@ -29,15 +29,15 @@ class TTKernelInsertDeviceZoneScopes return; } OpBuilder builder(op); - builder.create(op->getLoc(), "{"); + emitc::VerbatimOp::create(builder, op->getLoc(), "{"); auto name = op->getName().getStringRef(); if (name.starts_with("ttkernel.")) { name = name.drop_front(9); } - builder.create(op->getLoc(), "DeviceZoneScopedN(\"" + - name.str() + "\");"); + emitc::VerbatimOp::create(builder, op->getLoc(), + "DeviceZoneScopedN(\"" + name.str() + "\");"); builder.setInsertionPointAfter(op); - builder.create(op->getLoc(), "}"); + emitc::VerbatimOp::create(builder, op->getLoc(), "}"); }); } }; diff --git a/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp b/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp index 1b3f235d1c0..f0df87e729a 100644 --- a/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp +++ b/lib/Dialect/TTMetal/Pipelines/TTMetalPipelines.cpp @@ -13,7 +13,7 @@ #include "ttmlir/Transforms/Passes.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/EmitC/Transforms/Passes.h" diff --git a/lib/Dialect/TTNN/Transforms/Fusing/RoPEFusingPattern.cpp b/lib/Dialect/TTNN/Transforms/Fusing/RoPEFusingPattern.cpp index 08c76b36821..b57cc5a4acd 100644 --- a/lib/Dialect/TTNN/Transforms/Fusing/RoPEFusingPattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Fusing/RoPEFusingPattern.cpp @@ -602,11 +602,12 @@ mlir::LogicalResult createFusedRoPEOp(mlir::PatternRewriter &rewriter, auto computeConfig = buildComputeConfig(rewriter.getContext(), components); - auto ropeOp = rewriter.create( - srcOp.getLoc(), inputs.x.getType(), inputs.x, inputs.cos, inputs.sin, - /*token_index=*/nullptr, - /*memory_config=*/nullptr, - /*compute_config=*/computeConfig); + auto ropeOp = + RotaryEmbeddingOp::create(rewriter, srcOp.getLoc(), inputs.x.getType(), + inputs.x, inputs.cos, inputs.sin, + /*token_index=*/nullptr, + /*memory_config=*/nullptr, + /*compute_config=*/computeConfig); // Validate the fused op. If validation fails, try the workaround-padded // version since the workaround pass (seq_len tile alignment) hasn't run yet. @@ -643,8 +644,8 @@ mlir::LogicalResult createFusedRoPEOp(mlir::PatternRewriter &rewriter, llvm::seq(0, inputs.outPermutation.size()))) { DenseI64ArrayAttr permutationAttr = rewriter.getDenseI64ArrayAttr(inputs.outPermutation); - auto permuted = rewriter.create( - srcOp.getLoc(), srcOp.getType(), result, permutationAttr, + auto permuted = ttnn::PermuteOp::create( + rewriter, srcOp.getLoc(), srcOp.getType(), result, permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); result = permuted.getResult(); } @@ -728,8 +729,8 @@ RoPEDecodeFusing::matchAndRewrite(PermuteOp permuteOp, auto tokenIndex = rewriter.getIntegerAttr( rewriter.getIntegerType(32, /*isSigned=*/false), 0); - auto newRope = rewriter.create( - ropeOp.getLoc(), prePermute.getType(), prePermute.getResult(), + auto newRope = RotaryEmbeddingOp::create( + rewriter, ropeOp.getLoc(), prePermute.getType(), prePermute.getResult(), ropeOp.getCosCache(), ropeOp.getSinCache(), tokenIndex, ropeOp.getMemoryConfigAttr(), ropeOp.getComputeConfigAttr()); diff --git a/lib/Dialect/TTNN/Transforms/Fusing/TopKFusingPattern.cpp b/lib/Dialect/TTNN/Transforms/Fusing/TopKFusingPattern.cpp index c5f8ff81c38..b38ff7693fa 100644 --- a/lib/Dialect/TTNN/Transforms/Fusing/TopKFusingPattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Fusing/TopKFusingPattern.cpp @@ -241,8 +241,8 @@ TopKFusing::matchAndRewrite(SortOp srcOp, } // Create the fused TopK operation (now that we know it's valid) - auto topkOp = rewriter.create( - srcOp.getLoc(), valuesResultType, indicesResultType, + auto topkOp = TopKOp::create( + rewriter, srcOp.getLoc(), valuesResultType, indicesResultType, srcOp.getInput(), // input tensor rewriter.getI32IntegerAttr(sliceResult->k), // k value rewriter.getI32IntegerAttr(sortDim), // dimension diff --git a/lib/Dialect/TTNN/Transforms/OptimizerPasses/DevicePassesWrapper.cpp b/lib/Dialect/TTNN/Transforms/OptimizerPasses/DevicePassesWrapper.cpp index 4c50f3890c7..91006c3b0b7 100644 --- a/lib/Dialect/TTNN/Transforms/OptimizerPasses/DevicePassesWrapper.cpp +++ b/lib/Dialect/TTNN/Transforms/OptimizerPasses/DevicePassesWrapper.cpp @@ -69,7 +69,7 @@ class DevicePassesWrapper populatePipeline(nestedPm); // Ensure closeInstance() gets called and backtrace env is unset. - auto guard = llvm::make_scope_exit([]() noexcept { + auto guard = llvm::scope_exit([]() noexcept { op_model::SingletonDeviceContext::closeInstance(); unsetenv("TT_METAL_DISABLE_BACKTRACE"); }); diff --git a/lib/Dialect/TTNN/Transforms/OptimizerPasses/OperationValidationAndFallback.cpp b/lib/Dialect/TTNN/Transforms/OptimizerPasses/OperationValidationAndFallback.cpp index 60e7ea5cb72..93f54d55825 100644 --- a/lib/Dialect/TTNN/Transforms/OptimizerPasses/OperationValidationAndFallback.cpp +++ b/lib/Dialect/TTNN/Transforms/OptimizerPasses/OperationValidationAndFallback.cpp @@ -905,8 +905,8 @@ ToLayoutOp createToLayoutOp(OpBuilder &builder, Location loc, RankedTensorType resultType = RankedTensorType::get( currentResultType.getShape(), scalarElementType, targetLayout); - return builder.create( - loc, resultType, inputValue, + return ToLayoutOp::create( + builder, loc, resultType, inputValue, LayoutAttr::get(builder.getContext(), targetLayout.getLayout()), ttcore::DataTypeAttr::get(builder.getContext(), targetLayout.getDataType()), diff --git a/lib/Dialect/TTNN/Transforms/OptimizerPasses/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/OptimizerPasses/Optimizer.cpp index 6764d828e1c..4d67824e972 100644 --- a/lib/Dialect/TTNN/Transforms/OptimizerPasses/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/OptimizerPasses/Optimizer.cpp @@ -132,9 +132,9 @@ void applyChosenLayoutToD2MSubgraphOp(D2MSubgraphOp dispatchOp, layoutAttr.getBufferType()), utils::createShardSpecIfNeeded(layoutAttr, deviceGrid)); Location loc = mainFunc.getLoc(); - ToLayoutOp toLayoutOp = - builder.create(loc, newTensorType, currentResultValue, - newLayout, dataType, memConfigAttr); + ToLayoutOp toLayoutOp = ToLayoutOp::create( + builder, loc, newTensorType, currentResultValue, newLayout, + dataType, memConfigAttr); returnOp.setOperand(0, toLayoutOp.getResult()); } } @@ -922,8 +922,8 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { OpBuilder builder(consumerOp); Location loc = ttmlir::utils::appendLocationSuffix(consumerOp->getLoc(), "_mem_reconfig"); - ToLayoutOp memoryReconfigOp = builder.create( - loc, + ToLayoutOp memoryReconfigOp = ToLayoutOp::create( + builder, loc, newTensorType, // output type consumerOp->getOperand(edge.operandIndex), // input value LayoutAttr::get(consumerOp->getContext(), @@ -999,9 +999,9 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { } // Step 2: Insert spilling to DRAM. - Operation *spillToDRAMOp = builder.create( - loc, newTensorType, spilledOp->getResult(0), newLayout, dataType, - memConfigAttr); + Operation *spillToDRAMOp = ToLayoutOp::create( + builder, loc, newTensorType, spilledOp->getResult(0), newLayout, + dataType, memConfigAttr); // Step 3: Reconnect uses. for (auto &use : uses) { @@ -1112,9 +1112,9 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { uses.emplace_back(use.getOwner(), use.getOperandNumber()); } - Operation *toLayoutOp = builder.create( - loc, newTensorType, spilledOp->getResult(0), newLayout, dataType, - memConfigAttr); + Operation *toLayoutOp = ToLayoutOp::create( + builder, loc, newTensorType, spilledOp->getResult(0), newLayout, + dataType, memConfigAttr); for (auto &[useOp, operandIdx] : uses) { useOp->setOperand(operandIdx, toLayoutOp->getResult(0)); diff --git a/lib/Dialect/TTNN/Transforms/OptimizerPasses/TTNNPrepareConv2dWeightsAndBias.cpp b/lib/Dialect/TTNN/Transforms/OptimizerPasses/TTNNPrepareConv2dWeightsAndBias.cpp index d2b69031540..dcba9ce7192 100644 --- a/lib/Dialect/TTNN/Transforms/OptimizerPasses/TTNNPrepareConv2dWeightsAndBias.cpp +++ b/lib/Dialect/TTNN/Transforms/OptimizerPasses/TTNNPrepareConv2dWeightsAndBias.cpp @@ -173,7 +173,8 @@ class TTNNPrepareConv2dWeightsAndBias mlir::tt::ttcore::DataTypeAttr outputDtypeAttr, ttnn::Conv2dConfigAttr conv2dConfig) { if constexpr (std::is_same_v) { - return rewriter.create( + return PrepareWeightsOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(convOp.getLoc(), "_prepare_conv2d_weight"), getPreparedWeightsType(convOp, conv2dConfig), convOp.getWeight(), @@ -190,7 +191,8 @@ class TTNNPrepareConv2dWeightsAndBias outputDtypeAttr, conv2dConfig, convOp.getComputeConfigAttr(), convOp.getConv2dSliceConfigAttr()); } else { - return rewriter.create( + return PrepareWeightsOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( convOp.getLoc(), "_prepare_conv_transpose2d_weight"), getPreparedWeightsType(convOp, conv2dConfig), convOp.getWeight(), @@ -219,7 +221,8 @@ class TTNNPrepareConv2dWeightsAndBias mlir::tt::ttcore::DataTypeAttr outputDtypeAttr, ttnn::Conv2dConfigAttr conv2dConfig) { if constexpr (std::is_same_v) { - return rewriter.create( + return PrepareBiasOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(convOp.getLoc(), "_prepare_conv2d_bias"), getPreparedBiasType(convOp.getBias(), inputElementType), @@ -233,7 +236,8 @@ class TTNNPrepareConv2dWeightsAndBias inputDtypeAttr, outputDtypeAttr, conv2dConfig, convOp.getComputeConfigAttr(), convOp.getConv2dSliceConfigAttr()); } else { - return rewriter.create( + return PrepareBiasOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(convOp.getLoc(), "_prepare_conv_transpose2d_bias"), getPreparedBiasType(convOp.getBias(), inputElementType), diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 9718eef7778..1fb15d90f99 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -183,7 +183,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } rewriter.setInsertionPointAfter(lastOp); - rewriter.create(lastOp->getLoc(), value); + DeallocateOp::create(rewriter, lastOp->getLoc(), value); return success(); } @@ -337,7 +337,7 @@ class TTNNInputFunctionCreatorBase { // Create the function. // func::FuncOp inputFuncOp = - rewriter.create(loc, inputFuncName, functionType); + mlir::func::FuncOp::create(rewriter, loc, inputFuncName, functionType); // Mark this function as an input generator function. // @@ -376,12 +376,12 @@ class TTNNInputFunctionCreatorBase { // Create a tuple from the tensors. // ttcore::TupleOp tuple = - rewriter.create(loc, returnTypes, tensors); + ttcore::TupleOp::create(rewriter, loc, returnTypes, tensors); // Create ReturnOp. // - rewriter.create(forwardFuncOp.getLoc(), - tuple->getResults()); + func::ReturnOp::create(rewriter, forwardFuncOp.getLoc(), + tuple->getResults()); return inputFuncOp; } @@ -405,8 +405,8 @@ class TTNNInputFunctionCreatorBase { // Create the main function. // - func::FuncOp mainFuncOp = rewriter.create( - moduleOp.getLoc(), mainFuncName, functionType); + func::FuncOp mainFuncOp = mlir::func::FuncOp::create( + rewriter, moduleOp.getLoc(), mainFuncName, functionType); // Mark this function as a main function. // @@ -426,8 +426,8 @@ class TTNNInputFunctionCreatorBase { // inputFuncOp will be null if the forward function has no inputs. // if (inputFuncOp) { - func::CallOp tensors = rewriter.create( - forwardFuncOp.getLoc(), inputFuncOp, + func::CallOp tensors = mlir::func::CallOp::create( + rewriter, forwardFuncOp.getLoc(), inputFuncOp, /*operands=*/ValueRange()); operands = tensors->getResults(); } @@ -435,8 +435,8 @@ class TTNNInputFunctionCreatorBase { // Call a forward function. If there are input tensors, pass them as // operands. // - rewriter.create(forwardFuncOp.getLoc(), forwardFuncOp, - operands); + mlir::func::CallOp::create(rewriter, forwardFuncOp.getLoc(), + forwardFuncOp, operands); } // Return 0 @@ -444,10 +444,10 @@ class TTNNInputFunctionCreatorBase { // func::ReturnOp requires a Value to be returned, which means that an SSA // needs to be returned, hence create a constant 0 via arith::ConstantOp. // - Value constantZero = rewriter.create( - rewriter.getUnknownLoc(), rewriter.getI32Type(), + Value constantZero = arith::ConstantOp::create( + rewriter, rewriter.getUnknownLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); - rewriter.create(mainFuncOp->getLoc(), constantZero); + func::ReturnOp::create(rewriter, mainFuncOp->getLoc(), constantZero); } }; @@ -501,9 +501,10 @@ class TTNNCreateInputGenerators } // Create a new tensor of ones. // - ttnn::OnesOp onesOp = rewriter.create( - loc, tensorType, device, shapeAttr, dTypeAttr, tensorLayoutAttr, - /*memory_config=*/nullptr); + ttnn::OnesOp onesOp = + ttnn::OnesOp::create(rewriter, loc, tensorType, device, shapeAttr, + dTypeAttr, tensorLayoutAttr, + /*memory_config=*/nullptr); return onesOp; } @@ -562,8 +563,8 @@ class TTNNLoadInputTensors } // Create LoadTensorOp to load tensor from disk. // - ttnn::LoadTensorOp loadTensorOp = rewriter.create( - loc, tensorType, filePathAttr, device); + ttnn::LoadTensorOp loadTensorOp = ttnn::LoadTensorOp::create( + rewriter, loc, tensorType, filePathAttr, device); return loadTensorOp; } @@ -681,9 +682,9 @@ class TTNNTuplifyTensors rewriter.setInsertionPointToStart(&entryBlock); for (size_t idx = 0; idx < originalFuncType.getNumInputs(); idx++) { ttcore::GetTupleElementOp getTupleElementOp = - rewriter.create( - targetFuncOpInput.getLoc(), targetFuncOpInput.getArgument(0), - idx); + ttcore::GetTupleElementOp::create( + rewriter, targetFuncOpInput.getLoc(), + targetFuncOpInput.getArgument(0), idx); // Replace all uses of the original tensor arguments with the // GetTupleElementOp results. @@ -727,8 +728,8 @@ class TTNNTuplifyTensors targetFuncOpResult.walk( [&](mlir::func::ReturnOp returnOp) { rewriter.setInsertionPoint(returnOp); - ttcore::TupleOp tupleOp = rewriter.create( - returnOp.getLoc(), returnOp.getOperands()); + ttcore::TupleOp tupleOp = ttcore::TupleOp::create( + rewriter, returnOp.getLoc(), returnOp.getOperands()); rewriter.modifyOpInPlace(returnOp, [&]() { returnOp.getOperandsMutable().assign(tupleOp); }); diff --git a/lib/Dialect/TTNN/Transforms/TTNNConstEvalInputsToSystemMemory.cpp b/lib/Dialect/TTNN/Transforms/TTNNConstEvalInputsToSystemMemory.cpp index b1e0c8f38e7..2fa0bc5f5de 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNConstEvalInputsToSystemMemory.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNConstEvalInputsToSystemMemory.cpp @@ -177,8 +177,8 @@ static void convertArgumentOfConstEvalFunc(func::FuncOp constEvalFuncOp, auto originalDataTypeAttr = mlir::tt::ttcore::DataTypeAttr::get( constEvalFuncOp.getContext(), deviceTensorLayout.getDataType()); - auto toLayoutOp = builder.create( - blockArgument.getLoc(), deviceTensorType, blockArgument, + auto toLayoutOp = ttnn::ToLayoutOp::create( + builder, blockArgument.getLoc(), deviceTensorType, blockArgument, deviceTensorLayout.getLayout(), originalDataTypeAttr, MemoryConfigAttr::get(deviceTensorLayout, deviceGrid)); diff --git a/lib/Dialect/TTNN/Transforms/TTNND2MFusing.cpp b/lib/Dialect/TTNN/Transforms/TTNND2MFusing.cpp index 7f47741d44f..56bd96b6ab9 100644 --- a/lib/Dialect/TTNN/Transforms/TTNND2MFusing.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNND2MFusing.cpp @@ -262,9 +262,9 @@ class TTNND2MFusingPass : public impl::TTNND2MFusingBase { LayoutAttr::get(rewriter.getContext(), layoutAttr.getLayout()); auto memoryConfigAttr = MemoryConfigAttr::get(layoutAttr, deviceGrid); - auto emptyOp = rewriter.create( - loc, tensorType, device, shapeAttr, dtypeAttr, tensorLayoutAttr, - memoryConfigAttr); + auto emptyOp = + EmptyOp::create(rewriter, loc, tensorType, device, shapeAttr, + dtypeAttr, tensorLayoutAttr, memoryConfigAttr); outputBuffers.push_back(emptyOp.getResult()); lastEmptyOp = emptyOp.getOperation(); } @@ -301,7 +301,7 @@ class TTNND2MFusingPass : public impl::TTNND2MFusingBase { llvm::SmallVector returnValues; llvm::transform(outputs, std::back_inserter(returnValues), [&](Value v) { return mapping.lookup(v); }); - rewriter.create(loc, returnValues); + func::ReturnOp::create(rewriter, loc, returnValues); // Place subgraph after all its operands: after last input definer and after // the empty output buffers we just created (so output buffers dominate). @@ -312,8 +312,8 @@ class TTNND2MFusingPass : public impl::TTNND2MFusingBase { } else { rewriter.setInsertionPoint(firstOp); } - auto dispatchOp = rewriter.create( - loc, outputTypes, inputs.getArrayRef(), outputBuffers, + auto dispatchOp = D2MSubgraphOp::create( + rewriter, loc, outputTypes, inputs.getArrayRef(), outputBuffers, SymbolRefAttr::get(rewriter.getContext(), funcName)); for (auto [origOutput, dispatchResult] : diff --git a/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp b/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp index 7dbc18e758b..fdf6d618ea2 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNDecomposeLayouts.cpp @@ -269,8 +269,8 @@ class TTNNDecomposeLayouts mlir::Value currentInput, Args &&...args) const { rewriter.setInsertionPoint(op); - return rewriter.create(op.getLoc(), op.getType(), currentInput, - std::forward(args)...); + return OpType::create(rewriter, op.getLoc(), op.getType(), currentInput, + std::forward(args)...); } template @@ -278,8 +278,8 @@ class TTNNDecomposeLayouts RankedTensorType newResultType, mlir::Value currentInput, Args &&...args) const { rewriter.setInsertionPoint(op); - return rewriter.create(op.getLoc(), newResultType, currentInput, - std::forward(args)...); + return OpType::create(rewriter, op.getLoc(), newResultType, currentInput, + std::forward(args)...); } mlir::Value createToDeviceOpIfNeeded(ttnn::ToLayoutOp op, diff --git a/lib/Dialect/TTNN/Transforms/TTNNFileSplit.cpp b/lib/Dialect/TTNN/Transforms/TTNNFileSplit.cpp index f8bb7ed1dcd..26083ea6964 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNFileSplit.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNFileSplit.cpp @@ -127,9 +127,9 @@ class TTNNFileSplit : public impl::TTNNFileSplitBase { // Create the file containers. builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front()); - auto mainFile = builder.create(moduleOp.getLoc(), kMainFileName); + auto mainFile = FileOpTy::create(builder, moduleOp.getLoc(), kMainFileName); auto constevalFile = - builder.create(moduleOp.getLoc(), kConstevalFileName); + FileOpTy::create(builder, moduleOp.getLoc(), kConstevalFileName); // Move const-eval functions to the consteval file. Clone // CPU-hoisted declarations into both files so that func.call ops @@ -199,8 +199,8 @@ class TTNNFileSplit : public impl::TTNNFileSplitBase { auto wrapperFuncType = builder.getFunctionType(wrapperArgTypes, {dictType}); builder.setInsertionPointToEnd(&constevalFile.getBodyRegion().front()); - auto wrapperFunc = builder.create( - forwardFunc.getLoc(), wrapperName, wrapperFuncType); + auto wrapperFunc = func::FuncOp::create(builder, forwardFunc.getLoc(), + wrapperName, wrapperFuncType); wrapperFunc.addEntryBlock(); Block &forwardBody = forwardFunc.getBody().front(); @@ -240,9 +240,8 @@ class TTNNFileSplit : public impl::TTNNFileSplitBase { } // Create the return operation in the wrapper function. - builder.create( - forwardFunc.getLoc(), - ValueRange{mapping.lookup(cacheDict.getResult())}); + func::ReturnOp::create(builder, forwardFunc.getLoc(), + ValueRange{mapping.lookup(cacheDict.getResult())}); // Insert a call to the wrapper function after the cache dictionary // retrieval in the forward function. @@ -251,8 +250,9 @@ class TTNNFileSplit : public impl::TTNNFileSplitBase { callArgs.push_back(cacheDict.getResult()); callArgs.append(forwardBody.getArguments().begin(), forwardBody.getArguments().end()); - auto callOp = builder.create( - forwardFunc.getLoc(), wrapperName, TypeRange{dictType}, callArgs); + auto callOp = + func::CallOp::create(builder, forwardFunc.getLoc(), wrapperName, + TypeRange{dictType}, callArgs); // Update dictionary lookup operations to use the dictionary that wrapper // function returns. @@ -271,8 +271,8 @@ class TTNNFileSplit : public impl::TTNNFileSplitBase { // Create a declaration of the wrapper function in the main file so that // func.call op can resolve the symbol. builder.setInsertionPointToEnd(&mainFile.getBodyRegion().front()); - auto privateDecl = builder.create( - forwardFunc.getLoc(), wrapperName, wrapperFuncType); + auto privateDecl = func::FuncOp::create(builder, forwardFunc.getLoc(), + wrapperName, wrapperFuncType); privateDecl.setPrivate(); return success(); diff --git a/lib/Dialect/TTNN/Transforms/TTNNFusing.cpp b/lib/Dialect/TTNN/Transforms/TTNNFusing.cpp index ccd21190909..7b94cd42f0d 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNFusing.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNFusing.cpp @@ -550,8 +550,8 @@ class SDPAFusing : public mlir::OpRewritePattern { auto dataType = ttcore::DataType::BFloat16; auto castType = utils::RankedTensorTypeFactory::create(vType, dataType); - return rewriter.create( - v.getLoc(), castType, v, + return TypecastOp::create( + rewriter, v.getLoc(), castType, v, ttcore::DataTypeAttr::get(rewriter.getContext(), dataType)); } @@ -567,8 +567,8 @@ class SDPAFusing : public mlir::OpRewritePattern { // Create new tensor type with correctly updated encoding. auto castType = utils::RankedTensorTypeFactory::create(vType, dataType); - return rewriter.create( - v.getLoc(), castType, v, + return TypecastOp::create( + rewriter, v.getLoc(), castType, v, ttcore::DataTypeAttr::get(rewriter.getContext(), dataType)); } @@ -732,8 +732,8 @@ class SDPAFusing : public mlir::OpRewritePattern { auto resultType = utils::RankedTensorTypeFactory::create(maskType, resultShape); - return rewriter.create( - loc, resultType, mask, rewriter.getI32ArrayAttr(begins), + return SliceStaticOp::create( + rewriter, loc, resultType, mask, rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); } @@ -992,7 +992,7 @@ class SDPAFusing : public mlir::OpRewritePattern { maskType.getShape(), targetShape); auto shapeAttr = ShapeAttr::get(rewriter.getContext(), broadcastDims); - return rewriter.create(loc, broadcastType, mask, shapeAttr); + return RepeatOp::create(rewriter, loc, broadcastType, mask, shapeAttr); } mlir::LogicalResult createSDPAOp(mlir::PatternRewriter &rewriter, @@ -1038,9 +1038,9 @@ class SDPAFusing : public mlir::OpRewritePattern { llvm::to_vector(kToDecodePermutation), rewriter, c.attentionMatmul.getLoc()); - auto decodeOp = rewriter.create( - c.attentionMatmul.getLoc(), permutedQuery.getType(), permutedQuery, - c.key, c.value, + auto decodeOp = ScaledDotProductAttentionDecodeOp::create( + rewriter, c.attentionMatmul.getLoc(), permutedQuery.getType(), + permutedQuery, c.key, c.value, /*is_causal=*/rewriter.getBoolAttr(false), attentionMask, /*cur_pos_tensor=*/Value(), /*attention_sink=*/Value(), scaleAttr, @@ -1074,9 +1074,9 @@ class SDPAFusing : public mlir::OpRewritePattern { rewriter.replaceOp(c.attentionMatmul, finalResult); } else { - auto sdpaOp = rewriter.create( - c.attentionMatmul.getLoc(), c.query.getType(), c.query, c.key, - c.value, attentionMask, + auto sdpaOp = ScaledDotProductAttentionOp::create( + rewriter, c.attentionMatmul.getLoc(), c.query.getType(), c.query, + c.key, c.value, attentionMask, /*is_causal=*/rewriter.getBoolAttr(false), scaleAttr, /*sliding_window_size=*/IntegerAttr(), /*memory_config=*/MemoryConfigAttr()); @@ -1180,8 +1180,8 @@ class NLPConcatHeadsDecodeFusing : public mlir::OpRewritePattern { op_model::ScopedSingletonDeviceGuard deviceGuard(reshapeOp); - auto nlpConcatHeadsDecodeOp = rewriter.create( - reshapeOp.getLoc(), concatHeadsResultType, input, + auto nlpConcatHeadsDecodeOp = NLPConcatHeadsDecodeOp::create( + rewriter, reshapeOp.getLoc(), concatHeadsResultType, input, rewriter.getUI32IntegerAttr(static_cast(numHeads)), /*memory_config=*/MemoryConfigAttr()); @@ -1196,8 +1196,9 @@ class NLPConcatHeadsDecodeFusing : public mlir::OpRewritePattern { auto shardedResultType = utils::RankedTensorTypeFactory::create( shardedInputType, concatHeadsOutputShape); - auto validationOp = rewriter.create( - reshapeOp.getLoc(), shardedResultType, workaround->getResult(), + auto validationOp = NLPConcatHeadsDecodeOp::create( + rewriter, reshapeOp.getLoc(), shardedResultType, + workaround->getResult(), rewriter.getUI32IntegerAttr(static_cast(numHeads)), /*memory_config=*/MemoryConfigAttr()); @@ -1219,8 +1220,8 @@ class NLPConcatHeadsDecodeFusing : public mlir::OpRewritePattern { rewriter.setInsertionPointAfter(nlpConcatHeadsDecodeOp); - auto newReshapeOp = rewriter.create( - reshapeOp.getLoc(), reshapeOp.getType(), + auto newReshapeOp = ReshapeOp::create( + rewriter, reshapeOp.getLoc(), reshapeOp.getType(), nlpConcatHeadsDecodeOp.getResult(), reshapeOp.getShapeAttr(), /*memory_config=*/MemoryConfigAttr()); diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index 4c0de93dec3..8fbac00fe52 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -327,9 +327,9 @@ class TTNNLayoutHoistedFuncCallRewriter func::FuncOp funcOp = dyn_cast( SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr())); // Create the original CallOp with the new inputs on host. - auto newCallOp = rewriter.create( - callOp.getLoc(), callOp.getCallee(), funcOp.getResultTypes(), - fromDeviceOperands); + auto newCallOp = + func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(), + funcOp.getResultTypes(), fromDeviceOperands); newCallOp->setAttr(ttmlir::utils::g_cpuHoistFuncCallAttrName, mlir::UnitAttr::get(rewriter.getContext())); diff --git a/lib/Dialect/TTNN/Transforms/TTNNPrepareConstEvalCaching.cpp b/lib/Dialect/TTNN/Transforms/TTNNPrepareConstEvalCaching.cpp index c2e08fe99b9..db349623a37 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNPrepareConstEvalCaching.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNPrepareConstEvalCaching.cpp @@ -52,15 +52,15 @@ class TTNNPrepareConstEvalCaching // Create the global caching dictionary before the function. builder.setInsertionPoint(funcOp); - builder.create(funcOp.getLoc(), - llvm::StringRef(cacheName), dictType, - /*index=*/IntegerAttr()); + ttcore::GlobalOp::create(builder, funcOp.getLoc(), + llvm::StringRef(cacheName), dictType, + /*index=*/IntegerAttr()); // Retrieve the caching dictionary at the top of the function body. Block &entryBlock = funcOp.getBody().front(); builder.setInsertionPointToStart(&entryBlock); - auto dictVal = builder.create(funcOp.getLoc(), - dictType, cacheName); + auto dictVal = ttcore::GetGlobalOp::create(builder, funcOp.getLoc(), + dictType, cacheName); dictVal->setDiscardableAttr(kCachingDictAttr, builder.getUnitAttr()); // For each LoadCachedOp, store its results under one key in the caching @@ -68,12 +68,12 @@ class TTNNPrepareConstEvalCaching // Replace LoadCachedOp results with dictionary lookups. for (auto loadCachedOp : loadCachedOps) { builder.setInsertionPointAfter(loadCachedOp); - auto setKVOp = builder.create( - loadCachedOp.getLoc(), dictVal.getResult(), + auto setKVOp = ttcore::SetKeyValueOp::create( + builder, loadCachedOp.getLoc(), dictVal.getResult(), builder.getStringAttr(loadCachedOp.getCallee()), loadCachedOp.getResults()); - auto getKVOp = builder.create( - loadCachedOp.getLoc(), loadCachedOp.getResultTypes(), + auto getKVOp = ttcore::GetKeyValueOp::create( + builder, loadCachedOp.getLoc(), loadCachedOp.getResultTypes(), dictVal.getResult(), builder.getStringAttr(loadCachedOp.getCallee())); for (unsigned i = 0; i < loadCachedOp->getNumResults(); ++i) { diff --git a/lib/Dialect/TTNN/Transforms/TTNNRecoverStructure.cpp b/lib/Dialect/TTNN/Transforms/TTNNRecoverStructure.cpp index e237457a463..8f2c74a17cf 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNRecoverStructure.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNRecoverStructure.cpp @@ -572,8 +572,8 @@ class TTNNRecoverStructure // Create the new function. // - func::FuncOp newFunc = rewriter.create( - candidateFn.getLoc(), uniqueName, funcType); + func::FuncOp newFunc = func::FuncOp::create( + rewriter, candidateFn.getLoc(), uniqueName, funcType); // Mark as private for now. // @@ -651,7 +651,7 @@ class TTNNRecoverStructure // Add return statement. // - rewriter.create(funcOp.getLoc(), returnValues); + func::ReturnOp::create(rewriter, funcOp.getLoc(), returnValues); } } @@ -841,8 +841,8 @@ class TTNNRecoverStructure // Create the call. // Operation *lastOp = opsInFunction.back(); - func::CallOp callOp = rewriter.create( - lastOp->getLoc(), targetFunc, callOperands); + func::CallOp callOp = func::CallOp::create(rewriter, lastOp->getLoc(), + targetFunc, callOperands); // Track this call for the next insertion. // diff --git a/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp b/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp index cbcb39ff8ec..f0605b62463 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp @@ -17,7 +17,7 @@ namespace mlir::tt::ttnn { LogicalResult emitTTNNAsCpp(ModuleOp origOp, llvm::raw_ostream &os) { ModuleOp op = cast(origOp->clone()); - auto cleanupDispatchClone = llvm::make_scope_exit([&op] { op->erase(); }); + auto cleanupDispatchClone = llvm::scope_exit([&op] { op->erase(); }); auto pm = PassManager::on(op.getContext()); pm.addPass(createConvertTTNNToEmitCPass()); diff --git a/lib/Dialect/TTNN/Transforms/TTNNToPython.cpp b/lib/Dialect/TTNN/Transforms/TTNNToPython.cpp index 51a6acd4f8f..42d891da49a 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNToPython.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNToPython.cpp @@ -17,7 +17,7 @@ namespace mlir::tt::ttnn { LogicalResult emitTTNNAsPython(ModuleOp origOp, llvm::raw_ostream &os) { ModuleOp op = cast(origOp->clone()); - auto cleanupDispatchClone = llvm::make_scope_exit([&op] { op->erase(); }); + auto cleanupDispatchClone = llvm::scope_exit([&op] { op->erase(); }); auto pm = PassManager::on(op.getContext()); pm.addPass(createConvertTTNNToEmitPyPass()); diff --git a/lib/Dialect/TTNN/Transforms/TTNNTraceHoistTransform.cpp b/lib/Dialect/TTNN/Transforms/TTNNTraceHoistTransform.cpp index 55f98d35ad4..3071f1def0d 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNTraceHoistTransform.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNTraceHoistTransform.cpp @@ -231,8 +231,8 @@ class TTNNTraceHoistTransform // Create the function builder.setInsertionPoint(funcOp); - auto traceFuncOp = builder.create( - funcOp.getLoc(), traceFuncName, traceFuncType); + auto traceFuncOp = func::FuncOp::create(builder, funcOp.getLoc(), + traceFuncName, traceFuncType); ttmlir::utils::setFunctionType(traceFuncOp, ttmlir::utils::FunctionType::TraceMain); @@ -291,7 +291,7 @@ class TTNNTraceHoistTransform "Could not map output value in hoisted function"); } } - builder.create(funcOp.getLoc(), returnValues); + func::ReturnOp::create(builder, funcOp.getLoc(), returnValues); return mlir::success(); } @@ -343,8 +343,8 @@ class TTNNTraceHoistTransform getCaptureTraceFuncName(funcOp, traceFuncIndex); builder.setInsertionPoint(funcOp); - auto runAndCaptureTraceFunc = builder.create( - funcOp.getLoc(), runAndCaptureTraceFuncName, + auto runAndCaptureTraceFunc = func::FuncOp::create( + builder, funcOp.getLoc(), runAndCaptureTraceFuncName, runAndCaptureTraceFuncType); ttmlir::utils::setFunctionType( runAndCaptureTraceFunc, @@ -390,8 +390,8 @@ class TTNNTraceHoistTransform utils::createShardSpecIfNeeded(ttnnLayoutAttr, device.getWorkerGrid())); - auto emptyOp = builder.create( - runAndCaptureTraceFunc.getLoc(), inputTensorType, deviceOp, + auto emptyOp = ttnn::EmptyOp::create( + builder, runAndCaptureTraceFunc.getLoc(), inputTensorType, deviceOp, ttnn::ShapeAttr::get(context, inputTensorType.getShape()), ttcore::DataTypeAttr::get(context, ttnnLayoutAttr.getDataType()), ttnn::LayoutAttr::get(context, ttnnLayoutAttr.getLayout()), @@ -412,30 +412,30 @@ class TTNNTraceHoistTransform RankedTensorType newResultType = utils::RankedTensorTypeFactory::create( currentInputType, ttnn::BufferType::SystemMemory); - auto fromDeviceOp = builder.create( - runAndCaptureTraceFunc.getLoc(), newResultType, input); + auto fromDeviceOp = ttnn::FromDeviceOp::create( + builder, runAndCaptureTraceFunc.getLoc(), newResultType, input); - builder.create(runAndCaptureTraceFunc.getLoc(), - fromDeviceOp, inputSlots[i], - /*blocking=*/false, /*cq_id=*/0); + ttnn::WriteTensorOp::create(builder, runAndCaptureTraceFunc.getLoc(), + fromDeviceOp, inputSlots[i], + /*blocking=*/false, /*cq_id=*/0); } // call the trace function on the input slots - auto traceFuncCall = builder.create( - runAndCaptureTraceFunc.getLoc(), traceFunc, inputSlots); + auto traceFuncCall = func::CallOp::create( + builder, runAndCaptureTraceFunc.getLoc(), traceFunc, inputSlots); // now, we can capture the trace - auto beginTraceCaptureOp = builder.create( - runAndCaptureTraceFunc.getLoc(), utils::getTraceIdType(context), - deviceOp, + auto beginTraceCaptureOp = ttnn::BeginTraceCaptureOp::create( + builder, runAndCaptureTraceFunc.getLoc(), + utils::getTraceIdType(context), deviceOp, /*cq_id=*/0); - auto captureTraceCall = builder.create( - runAndCaptureTraceFunc.getLoc(), traceFunc, inputSlots); + auto captureTraceCall = func::CallOp::create( + builder, runAndCaptureTraceFunc.getLoc(), traceFunc, inputSlots); - builder.create(runAndCaptureTraceFunc.getLoc(), - deviceOp, beginTraceCaptureOp, - /*cq_id=*/0); + ttnn::EndTraceCaptureOp::create(builder, runAndCaptureTraceFunc.getLoc(), + deviceOp, beginTraceCaptureOp, + /*cq_id=*/0); // create the return op llvm::SmallVector returnValues; @@ -451,8 +451,8 @@ class TTNNTraceHoistTransform returnValues.push_back(outputSlot); } - builder.create(runAndCaptureTraceFunc.getLoc(), - returnValues); + func::ReturnOp::create(builder, runAndCaptureTraceFunc.getLoc(), + returnValues); return mlir::success(); } @@ -484,8 +484,8 @@ class TTNNTraceHoistTransform getExecuteTraceFuncName(funcOp, traceFuncIndex); builder.setInsertionPoint(funcOp); - auto executeTraceFunc = builder.create( - funcOp.getLoc(), executeTraceFuncName, executeTraceFuncType); + auto executeTraceFunc = func::FuncOp::create( + builder, funcOp.getLoc(), executeTraceFuncName, executeTraceFuncType); ttmlir::utils::setFunctionType(executeTraceFunc, ttmlir::utils::FunctionType::TraceExecute); executeTraceFunc.setPrivate(); @@ -497,10 +497,10 @@ class TTNNTraceHoistTransform auto deviceOp = utils::getOrInsertDevice(rewriter, executeTraceFuncEntryBlock); mlir::Value traceId = executeTraceFunc.getArgument(0); - builder.create(funcOp.getLoc(), deviceOp, traceId, - /*cq_id=*/0, /*blocking=*/false); + ttnn::ExecuteTraceOp::create(builder, funcOp.getLoc(), deviceOp, traceId, + /*cq_id=*/0, /*blocking=*/false); - builder.create(funcOp.getLoc()); + func::ReturnOp::create(builder, funcOp.getLoc()); return mlir::success(); } @@ -555,8 +555,8 @@ class TTNNTraceHoistTransform auto device = utils::getOrInsertDevice(rewriter, firstOp); - auto traceOp = builder.create( - funcOp.getLoc(), outputTypes, device, captureTraceSymbolAttr, + auto traceOp = ttnn::CaptureOrExecuteTraceOp::create( + builder, funcOp.getLoc(), outputTypes, device, captureTraceSymbolAttr, executeTraceSymbolAttr, inputs); // Replace uses of original outputs with the output of the trace op function diff --git a/lib/Dialect/TTNN/Transforms/TTNNWeightDtypeConversion.cpp b/lib/Dialect/TTNN/Transforms/TTNNWeightDtypeConversion.cpp index 7590103bdba..11325481b60 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNWeightDtypeConversion.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNWeightDtypeConversion.cpp @@ -62,8 +62,8 @@ class WeightDtypeConversionPattern : public mlir::OpRewritePattern { ttnn::utils::RankedTensorTypeFactory::create(weightType, targetDtype); // Insert typecast operation to convert weight to target dtype. - auto typecastOp = rewriter.create( - op.getLoc(), newWeightType, weight, + auto typecastOp = TypecastOp::create( + rewriter, op.getLoc(), newWeightType, weight, ttcore::DataTypeAttr::get(rewriter.getContext(), targetDtype)); // Update op to use the typecast result. diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/AllGatherOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/AllGatherOpRewritePattern.cpp index a17e1e82f3e..92102f7f6b0 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/AllGatherOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/AllGatherOpRewritePattern.cpp @@ -49,7 +49,8 @@ TTNNAllGatherWorkarounds::matchAndRewrite(ttnn::AllGatherOp op, paddedInputShape.end()); RankedTensorType reshapeInputType = ttnn::utils::RankedTensorTypeFactory::create(inputType, paddedInputShape); - auto reshapeInput = rewriter.create( + auto reshapeInput = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_reshape_to_4d"), reshapeInputType, op.getInput(), rewriter.getI32ArrayAttr(paddedShapeI32), ttnn::MemoryConfigAttr()); @@ -61,7 +62,8 @@ TTNNAllGatherWorkarounds::matchAndRewrite(ttnn::AllGatherOp op, // Create the reduce gather operation on 4D tensors with adjusted // all_gather_dim - auto allGather4D = rewriter.create( + auto allGather4D = ttnn::AllGatherOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_all_gather_4d"), paddedOutputType, reshapeInput.getResult(), adjustedGatherDim, op.getClusterAxis(), diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp index 43d407b7a10..5420063ba6f 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp @@ -62,8 +62,9 @@ ArgMaxOpRewritePattern::matchAndRewrite(ttnn::ArgMaxOp srcOp, mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 32), dim); } // Create new ttnn.argmax op with updated input tensor, dimension, etc. - ArgMaxOp argMaxOp = rewriter.create( - srcOp->getLoc(), newOutputType, preReshapeOp, dimAttr, srcOp.getKeepDim(), + ArgMaxOp argMaxOp = mlir::tt::ttnn::ArgMaxOp::create( + rewriter, srcOp->getLoc(), newOutputType, preReshapeOp, dimAttr, + srcOp.getKeepDim(), /*use_multicore=*/false, /*memoryConfig=*/nullptr); // Create ttnn.reshape op after performing ttnn.argmax op. diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ConcatenateHeadsOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ConcatenateHeadsOpRewritePattern.cpp index 9a1bda3bea6..4b4a699ce2c 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ConcatenateHeadsOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ConcatenateHeadsOpRewritePattern.cpp @@ -44,7 +44,8 @@ LogicalResult ConcatenateHeadsOpRewritePattern::matchAndRewrite( utils::RankedTensorTypeFactory::create(outputType, permutedShape); auto input = srcOp.getInput(); - PermuteOp permuteOp = rewriter.create( + PermuteOp permuteOp = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_concat_heads"), permutedType, input, permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dDepthPaddingRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dDepthPaddingRewritePattern.cpp index 86817643a18..ae1f84ef20c 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dDepthPaddingRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dDepthPaddingRewritePattern.cpp @@ -57,12 +57,13 @@ LogicalResult Conv3dDepthPaddingRewritePattern::matchAndRewrite( .withTensorShape(paddedInputShape)); auto paddedInput = - rewriter.create(ttmlir::utils::appendLocationSuffix( - srcOp.getInput().getLoc(), "_pad_conv3d"), - paddedInputType, srcOp.getInput(), inputPadding, - /*pad_value=*/mlir::APFloat(0.0f), - /*use_multicore=*/false, - /*memory_config=*/nullptr); + PadOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( + srcOp.getInput().getLoc(), "_pad_conv3d"), + paddedInputType, srcOp.getInput(), inputPadding, + /*pad_value=*/mlir::APFloat(0.0f), + /*use_multicore=*/false, + /*memory_config=*/nullptr); rewriter.modifyOpInPlace(srcOp, [&]() { srcOp.getInputMutable().assign(paddedInput); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dPadOutputChannelsRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dPadOutputChannelsRewritePattern.cpp index d5bafacb2f0..9a21e6b95d7 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dPadOutputChannelsRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/Conv3dPadOutputChannelsRewritePattern.cpp @@ -42,12 +42,13 @@ LogicalResult Conv3dPadOutputChannelsRewritePattern::matchAndRewrite( utils::RankedTensorTypeFactory::create(weightType, paddedWeightShape); auto paddedWeight = - rewriter.create(ttmlir::utils::appendLocationSuffix( - srcOp.getWeight().getLoc(), "_pad_out_ch"), - paddedWeightType, srcOp.getWeight(), weightPadding, - /*pad_value=*/mlir::APFloat(0.0f), - /*use_multicore=*/false, - /*memory_config=*/nullptr); + PadOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( + srcOp.getWeight().getLoc(), "_pad_out_ch"), + paddedWeightType, srcOp.getWeight(), weightPadding, + /*pad_value=*/mlir::APFloat(0.0f), + /*use_multicore=*/false, + /*memory_config=*/nullptr); // Pad bias tensor if present: (1, O) -> (1, O_padded) Value paddedBias = srcOp.getBias(); @@ -63,8 +64,8 @@ LogicalResult Conv3dPadOutputChannelsRewritePattern::matchAndRewrite( auto paddedBiasType = utils::RankedTensorTypeFactory::create(biasType, paddedBiasShape); - paddedBias = - rewriter.create(ttmlir::utils::appendLocationSuffix( + paddedBias = PadOp::create(rewriter, + ttmlir::utils::appendLocationSuffix( srcOp.getBias().getLoc(), "_pad_out_ch"), paddedBiasType, srcOp.getBias(), biasPadding, /*pad_value=*/mlir::APFloat(0.0f), @@ -93,9 +94,9 @@ LogicalResult Conv3dPadOutputChannelsRewritePattern::matchAndRewrite( } // Create new conv3d with padded output channels. - auto paddedConvOp = rewriter.create( - srcOp.getLoc(), paddedOutputType, srcOp.getInput(), paddedWeight, - paddedBias, srcOp.getDevice(), + auto paddedConvOp = Conv3dOp::create( + rewriter, srcOp.getLoc(), paddedOutputType, srcOp.getInput(), + paddedWeight, paddedBias, srcOp.getDevice(), rewriter.getI32IntegerAttr(srcOp.getInChannels()), rewriter.getI32IntegerAttr(paddedOutChannels), rewriter.getI32IntegerAttr(srcOp.getBatchSize()), @@ -113,7 +114,8 @@ LogicalResult Conv3dPadOutputChannelsRewritePattern::matchAndRewrite( SmallVector ends(outputType.getShape()); SmallVector steps(rank, 1); - auto sliceOp = rewriter.create( + auto sliceOp = SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_slice_out_ch"), outputType, paddedConvOp, rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpDimRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpDimRewritePattern.cpp index bff55439397..4783cf73734 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpDimRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpDimRewritePattern.cpp @@ -36,8 +36,8 @@ CumSumOpDimRewritePattern::matchAndRewrite(ttnn::MorehCumSumOp srcOp, ttmlir::utils::applyPermutation(originalShape, permutation); RankedTensorType adaptedInputType = utils::RankedTensorTypeFactory::create(inputType, adaptedShape); - auto adaptedInput = rewriter.create( - ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_permute"), + auto adaptedInput = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_permute"), adaptedInputType, srcOp.getInput(), rewriter.getDenseI64ArrayAttr(permutation), /*memory_config=*/ttnn::MemoryConfigAttr(), @@ -46,8 +46,8 @@ CumSumOpDimRewritePattern::matchAndRewrite(ttnn::MorehCumSumOp srcOp, mlir::RankedTensorType outputType = srcOp.getResult().getType(); RankedTensorType adaptedOutputType = utils::RankedTensorTypeFactory::create(outputType, adaptedShape); - auto adaptedCumSumOp = rewriter.create( - srcOp->getLoc(), adaptedOutputType, adaptedInput, /*dim=*/0, + auto adaptedCumSumOp = mlir::tt::ttnn::MorehCumSumOp::create( + rewriter, srcOp->getLoc(), adaptedOutputType, adaptedInput, /*dim=*/0, /*memory_config=*/nullptr); auto permute = rewriter.replaceOpWithNewOp( diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRankRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRankRewritePattern.cpp index eda24a4b53e..f0bd632396c 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRankRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRankRewritePattern.cpp @@ -39,10 +39,10 @@ CumSumOpRankRewritePattern::matchAndRewrite(ttnn::MorehCumSumOp srcOp, RankedTensorType outputType = srcOp.getResult().getType(); RankedTensorType adaptedOutputType = utils::RankedTensorTypeFactory::create(outputType, adaptedShape); - MorehCumSumOp adaptedCumSumOp = - rewriter.create( - srcOp->getLoc(), adaptedOutputType, adaptedInput, srcOp.getDim(), - /*memory_config=*/nullptr); + MorehCumSumOp adaptedCumSumOp = mlir::tt::ttnn::MorehCumSumOp::create( + rewriter, srcOp->getLoc(), adaptedOutputType, adaptedInput, + srcOp.getDim(), + /*memory_config=*/nullptr); ReshapeOp cumsumOutput = ttir_to_ttnn::utils::generateReshape( adaptedCumSumOp, srcOp.getResult().getType().getShape(), rewriter, diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/DistributedRMSNormWidthShardInputRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/DistributedRMSNormWidthShardInputRewritePattern.cpp index 93557b5eb67..f1bc26d510b 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/DistributedRMSNormWidthShardInputRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/DistributedRMSNormWidthShardInputRewritePattern.cpp @@ -106,8 +106,8 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( ttnn::MemoryConfigAttr::get(desiredInputLayout, grid); RankedTensorType memoryConfigedInputType = inputType.cloneWithEncoding(desiredInputLayout); - auto inputToLayoutOp = rewriter.create( - op.getLoc(), memoryConfigedInputType, op.getInput(), + auto inputToLayoutOp = ttnn::ToLayoutOp::create( + rewriter, op.getLoc(), memoryConfigedInputType, op.getInput(), tt::ttnn::Layout::Tile, ttcore::DataTypeAttr::get( rewriter.getContext(), @@ -139,8 +139,8 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( reshapedShape.end()); RankedTensorType reshapedWeightType = ttnn::utils::RankedTensorTypeFactory::create(weightType, reshapedShape); - auto reshapeOp = rewriter.create( - op.getLoc(), reshapedWeightType, weight, + auto reshapeOp = ttnn::ReshapeOp::create( + rewriter, op.getLoc(), reshapedWeightType, weight, rewriter.getI32ArrayAttr(reshapedShapeI32), ttnn::MemoryConfigAttr()); weight = reshapeOp.getResult(); @@ -160,8 +160,9 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( ttnn::BufferTypeAttr::get(rewriter.getContext(), weightLayout.getBufferType()), /*shardSpec=*/std::nullopt); - auto weightToLayoutOp = rewriter.create( - op.getLoc(), rowMajorWeightType, weight, tt::ttnn::Layout::RowMajor, + auto weightToLayoutOp = ttnn::ToLayoutOp::create( + rewriter, op.getLoc(), rowMajorWeightType, weight, + tt::ttnn::Layout::RowMajor, ttcore::DataTypeAttr::get( rewriter.getContext(), ttcore::elementTypeToDataType(weightElementType)), @@ -181,8 +182,9 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( if (residualLayout != desiredInputLayout) { RankedTensorType shardedResidualType = residualType.cloneWithEncoding(desiredInputLayout); - auto residualToLayoutOp = rewriter.create( - op.getLoc(), shardedResidualType, residual, tt::ttnn::Layout::Tile, + auto residualToLayoutOp = ttnn::ToLayoutOp::create( + rewriter, op.getLoc(), shardedResidualType, residual, + tt::ttnn::Layout::Tile, ttcore::DataTypeAttr::get( rewriter.getContext(), ttcore::elementTypeToDataType(inputElementType)), @@ -243,9 +245,9 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( RankedTensorType::get(statsShape, statsElementType, statsLayout); auto device = ttnn::utils::getOrInsertDevice(rewriter, op); - auto statsEmptyOp = rewriter.create( - op.getLoc(), statsResultType, device, statsShapeAttr, statsDtypeAttr, - statsLayoutAttr, statsMemConfig); + auto statsEmptyOp = ttnn::EmptyOp::create( + rewriter, op.getLoc(), statsResultType, device, statsShapeAttr, + statsDtypeAttr, statsLayoutAttr, statsMemConfig); // The fused kernel output shape == input shape (only stats are all-gathered, // not data). Use the input's width-sharded memory config for the output too @@ -267,9 +269,9 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( ttnn::CoreCoordAttr::get(rewriter.getContext(), gridW, gridH), /*subblock_w=*/1, blockH, blockW, /*inplace=*/false); - auto newOp = rewriter.create( - op.getLoc(), shardedOutputType, inputToLayoutOp.getResult(), weight, - residual, statsEmptyOp.getResult(), op.getDevice(), + auto newOp = ttnn::DistributedRMSNormOp::create( + rewriter, op.getLoc(), shardedOutputType, inputToLayoutOp.getResult(), + weight, residual, statsEmptyOp.getResult(), op.getDevice(), static_cast(op.getClusterAxis()), op.getEpsilon(), op.getSubDeviceIdAttr(), inputMemoryConfig, op.getNumLinksAttr(), op.getTopologyAttr(), computeConfigAttr, programConfigAttr); @@ -283,8 +285,9 @@ LogicalResult DistributedRMSNormWidthShardInputRewritePattern::matchAndRewrite( if (originalOutputLayout && originalOutputLayout != desiredInputLayout) { auto originalMemConfig = ttnn::MemoryConfigAttr::get( originalOutputLayout, originalOutputLayout.getGrid()); - auto toMemConfigOp = rewriter.create( - op.getLoc(), originalOutputType, newOp.getResult(), originalMemConfig); + auto toMemConfigOp = ttnn::ToMemoryConfigOp::create( + rewriter, op.getLoc(), originalOutputType, newOp.getResult(), + originalMemConfig); rewriter.replaceOp(op, toMemConfigOp.getResult()); } else { rewriter.replaceOp(op, newOp.getResult()); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ExplicateOperandBroadcastsRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ExplicateOperandBroadcastsRewritePattern.cpp index 59c3116ce2d..f7a57b12a19 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ExplicateOperandBroadcastsRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ExplicateOperandBroadcastsRewritePattern.cpp @@ -32,8 +32,8 @@ LogicalResult ExplicateOperandBroadcastsRewritePattern::matchAndRewrite( auto broadcastDims = ttmlir::utils::getBroadcastDimensions( operandShape, resultShape); auto shapeAttr = ttnn::ShapeAttr::get(rewriter.getContext(), broadcastDims); - auto repeatOp = rewriter.create( - srcOp->getLoc(), newOutputType, operand, shapeAttr); + auto repeatOp = ttnn::RepeatOp::create(rewriter, srcOp->getLoc(), + newOutputType, operand, shapeAttr); rewriter.modifyOpInPlace(srcOp, [&]() { srcOp->setOperand(i, repeatOp); }); hasChanged = true; diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpOutputShapeRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpOutputShapeRewritePattern.cpp index 0af31cfe55b..8f0537bf678 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpOutputShapeRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpOutputShapeRewritePattern.cpp @@ -105,8 +105,8 @@ LogicalResult LinearOpOutputShapeRewritePattern::matchAndRewrite( auto matmulOutputType = utils::RankedTensorTypeFactory::create(currentOutputType, matmulShape); - auto newLinearOp = rewriter.create( - srcOp.getLoc(), matmulOutputType, srcOp.getA(), srcOp.getB(), + auto newLinearOp = ttnn::LinearOp::create( + rewriter, srcOp.getLoc(), matmulOutputType, srcOp.getA(), srcOp.getB(), srcOp.getBias(), srcOp.getTransposeA(), srcOp.getTransposeB(), /*matmul_program_config=*/nullptr, srcOp.getActivationAttr(), /*compute_config=*/srcOp.getComputeConfigAttr()); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpRewritePattern.cpp index f5225c07d86..dad07d98435 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/LinearOpRewritePattern.cpp @@ -163,7 +163,8 @@ LinearOpRewritePattern::matchAndRewrite(ttnn::LinearOp srcOp, auto dataTypeAttr = mlir::tt::ttcore::DataTypeAttr::get( rewriter.getContext(), outputEncoding.getDataType()); - MatmulOp matmulOp = rewriter.create( + MatmulOp matmulOp = ttnn::MatmulOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_decomp_matmul"), matmulOutputType, srcOp.getA(), srcOp.getB(), srcOp.getTransposeA(), srcOp.getTransposeB(), /*matmul_program_config=*/nullptr, @@ -177,7 +178,8 @@ LinearOpRewritePattern::matchAndRewrite(ttnn::LinearOp srcOp, addShape); auto addOutputType = utils::RankedTensorTypeFactory::create(outputType, addShape); - AddOp addOp = rewriter.create( + AddOp addOp = ttnn::AddOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_decomp_add"), addOutputType, matmulOp.getResult(), srcOp.getBias(), /*dtype=*/dataTypeAttr, diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/MultiplyOpDecompositionRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/MultiplyOpDecompositionRewritePattern.cpp index f4e5749f344..cda64ab527a 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/MultiplyOpDecompositionRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/MultiplyOpDecompositionRewritePattern.cpp @@ -57,9 +57,10 @@ LogicalResult MultiplyOpDecompositionRewritePattern::matchAndRewrite( RankedTensorType lhsPermutedType = utils::RankedTensorTypeFactory::create(lhsType, lhsPermutedShape); - PermuteOp lhsPermuted = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_lhs_permute"), lhsPermutedType, - lhs, permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); + PermuteOp lhsPermuted = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_lhs_permute"), + lhsPermutedType, lhs, permutationAttr, ttnn::MemoryConfigAttr(), + mlir::FloatAttr()); // Apply permutation to rhs input llvm::SmallVector rhsPermutedShape = @@ -67,9 +68,10 @@ LogicalResult MultiplyOpDecompositionRewritePattern::matchAndRewrite( RankedTensorType rhsPermutedType = utils::RankedTensorTypeFactory::create(rhsType, rhsPermutedShape); - PermuteOp rhsPermuted = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_rhs_permute"), rhsPermutedType, - rhs, permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); + PermuteOp rhsPermuted = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_rhs_permute"), + rhsPermutedType, rhs, permutationAttr, ttnn::MemoryConfigAttr(), + mlir::FloatAttr()); // Create the multiply operation on permuted inputs llvm::SmallVector permutedOutputShape = @@ -77,8 +79,8 @@ LogicalResult MultiplyOpDecompositionRewritePattern::matchAndRewrite( RankedTensorType permutedOutputType = utils::RankedTensorTypeFactory::create(outputType, permutedOutputShape); - MultiplyOp permutedMultiply = rewriter.create( - loc, permutedOutputType, lhsPermuted.getResult(), + MultiplyOp permutedMultiply = ttnn::MultiplyOp::create( + rewriter, loc, permutedOutputType, lhsPermuted.getResult(), rhsPermuted.getResult()); // Apply reverse permutation to output (which is the same as forward: (2, 3, diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PagedUpdateCacheOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PagedUpdateCacheOpRewritePattern.cpp index ac18f89045c..7e20807d20f 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PagedUpdateCacheOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PagedUpdateCacheOpRewritePattern.cpp @@ -71,8 +71,8 @@ LogicalResult PagedUpdateCacheOpRewritePattern::matchAndRewrite( ttnn::MemoryConfigAttr::get(desiredInputLayout, grid); RankedTensorType memoryConfigedInputType = inputType.cloneWithEncoding(desiredInputLayout); - auto toLayoutOp = rewriter.create( - op.getLoc(), memoryConfigedInputType, op.getInput(), + auto toLayoutOp = ttnn::ToLayoutOp::create( + rewriter, op.getLoc(), memoryConfigedInputType, op.getInput(), tt::ttnn::Layout::Tile, ttcore::DataTypeAttr::get( rewriter.getContext(), @@ -81,9 +81,9 @@ LogicalResult PagedUpdateCacheOpRewritePattern::matchAndRewrite( // Replace the original PagedUpdateCacheOp with one which takes our properly // configured input tensor. - auto pagedUpdateCacheOp = rewriter.create( - op.getLoc(), op.getCache(), toLayoutOp.getResult(), op.getUpdateIndex(), - op.getShareCache(), op.getPageTable()); + auto pagedUpdateCacheOp = ttnn::PagedUpdateCacheOp::create( + rewriter, op.getLoc(), op.getCache(), toLayoutOp.getResult(), + op.getUpdateIndex(), op.getShareCache(), op.getPageTable()); rewriter.replaceOp(op, pagedUpdateCacheOp); return success(); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PointToPointOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PointToPointOpRewritePattern.cpp index 158a0cc79ff..ef6bb86e45f 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PointToPointOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/PointToPointOpRewritePattern.cpp @@ -35,13 +35,15 @@ PointToPointOpRewritePattern::matchAndRewrite(ttnn::PointToPointOp srcOp, receiverCoord[1]}; ::llvm::ArrayRef intermediateCoord = intermediateCoordVec; - auto p2pOp1 = rewriter.create( + auto p2pOp1 = ttnn::PointToPointOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_p2p_to_intermediate"), srcOp.getResult().getType(), srcOp.getInput(), senderCoord, intermediateCoord, /*optional_output_tensor=*/nullptr); auto optionalOutputTensor = srcOp.getOptionalOutputTensor(); - auto p2pOp2 = rewriter.create( + auto p2pOp2 = ttnn::PointToPointOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_p2p_from_intermediate"), srcOp.getResult().getType(), p2pOp1.getResult(), intermediateCoord, diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RMSNormConfigRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RMSNormConfigRewritePattern.cpp index 91bb69263e8..2ce9c0dc7cc 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RMSNormConfigRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RMSNormConfigRewritePattern.cpp @@ -36,8 +36,8 @@ RMSNormConfigRewritePattern::matchAndRewrite(RMSNormOp srcOp, /*dstFullSyncEn=*/nullptr); // Create a new operation with the compute config set - auto newOp = rewriter.create( - srcOp.getLoc(), srcOp.getResult().getType(), srcOp.getInput(), + auto newOp = RMSNormOp::create( + rewriter, srcOp.getLoc(), srcOp.getResult().getType(), srcOp.getInput(), srcOp.getWeight(), srcOp.getBias(), srcOp.getEpsilonAttr(), srcOp.getMemoryConfigAttr(), computeConfig); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceScatterOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceScatterOpRewritePattern.cpp index 76053ca4b4f..eb7c72c1e0e 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceScatterOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceScatterOpRewritePattern.cpp @@ -49,7 +49,8 @@ TTNNReduceScatterWorkarounds::matchAndRewrite(ttnn::ReduceScatterOp op, paddedInputShape.end()); RankedTensorType reshapeInputType = ttnn::utils::RankedTensorTypeFactory::create(inputType, paddedInputShape); - auto reshapeInput = rewriter.create( + auto reshapeInput = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_reshape_to_4d"), reshapeInputType, op.getInput(), rewriter.getI32ArrayAttr(paddedShapeI32), ttnn::MemoryConfigAttr()); @@ -61,7 +62,8 @@ TTNNReduceScatterWorkarounds::matchAndRewrite(ttnn::ReduceScatterOp op, // Create the reduce scatter operation on 4D tensors with adjusted // scatter_dim - auto reduceScatter4D = rewriter.create( + auto reduceScatter4D = ttnn::ReduceScatterOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(op.getLoc(), "_reduce_scatter_4d"), paddedOutputType, reshapeInput.getResult(), op.getReduceType(), adjustedScatterDim, op.getClusterAxis(), diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RotaryEmbeddingOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RotaryEmbeddingOpRewritePattern.cpp index 465ac03acdc..16c507986b5 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RotaryEmbeddingOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/RotaryEmbeddingOpRewritePattern.cpp @@ -32,9 +32,9 @@ getWorkaroundedOp(RotaryEmbeddingOp ropeOp, PatternRewriter &rewriter) { auto paddedType = utils::RankedTensorTypeFactory::create(resultType, paddedResultShape); - auto paddedOp = rewriter.create( - ropeOp.getLoc(), paddedType, ropeOp.getInput(), ropeOp.getCosCache(), - ropeOp.getSinCache(), ropeOp.getTokenIndexAttr(), + auto paddedOp = RotaryEmbeddingOp::create( + rewriter, ropeOp.getLoc(), paddedType, ropeOp.getInput(), + ropeOp.getCosCache(), ropeOp.getSinCache(), ropeOp.getTokenIndexAttr(), ropeOp.getMemoryConfigAttr(), ropeOp.getComputeConfigAttr()); // Slice to original shape. @@ -43,8 +43,8 @@ getWorkaroundedOp(RotaryEmbeddingOp ropeOp, PatternRewriter &rewriter) { SmallVector steps(resultShape.size(), 1); ends[ends.size() - 2] = originalSeqLen; - auto sliceOp = rewriter.create( - ropeOp.getLoc(), resultType, paddedOp.getResult(), + auto sliceOp = ttnn::SliceStaticOp::create( + rewriter, ropeOp.getLoc(), resultType, paddedOp.getResult(), rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScaledDotProductAttentionPadTileDimsRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScaledDotProductAttentionPadTileDimsRewritePattern.cpp index 13eae2b1f93..272068ac894 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScaledDotProductAttentionPadTileDimsRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScaledDotProductAttentionPadTileDimsRewritePattern.cpp @@ -32,12 +32,11 @@ Value padDimension(Value tensor, int64_t targetLen, int64_t dim, auto paddedType = utils::RankedTensorTypeFactory::create(tensorType, paddedShape); - return rewriter - .create(loc, paddedType, tensor, - rewriter.getDenseI32ArrayAttr(padding), - rewriter.getF32FloatAttr(padValue), - /*use_multicore=*/rewriter.getBoolAttr(true), - /*memory_config=*/nullptr) + return PadOp::create(rewriter, loc, paddedType, tensor, + rewriter.getDenseI32ArrayAttr(padding), + rewriter.getF32FloatAttr(padValue), + /*use_multicore=*/rewriter.getBoolAttr(true), + /*memory_config=*/nullptr) .getResult(); } @@ -62,10 +61,11 @@ Value sliceDimension(Value tensor, int64_t originalLen, int64_t dim, auto slicedType = utils::RankedTensorTypeFactory::create(tensorType, slicedShape); - return rewriter - .create( - loc, slicedType, tensor, rewriter.getI32ArrayAttr(begins), - rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)) + return SliceStaticOp::create( + rewriter, + + loc, slicedType, tensor, rewriter.getI32ArrayAttr(begins), + rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)) .getResult(); } @@ -157,8 +157,8 @@ ScaledDotProductAttentionPadTileDimsRewritePattern::matchAndRewrite( } auto resultType = paddedQuery.getType(); - auto sdpaOp = rewriter.create( - srcOp.getLoc(), resultType, paddedQuery, paddedKey, paddedValue, + auto sdpaOp = ScaledDotProductAttentionOp::create( + rewriter, srcOp.getLoc(), resultType, paddedQuery, paddedKey, paddedValue, paddedMask, srcOp.getIsCausal(), srcOp.getScaleAttr(), srcOp.getSlidingWindowSizeAttr(), srcOp.getMemoryConfigAttr()); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScatterOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScatterOpRewritePattern.cpp index c381dba1c3c..28bf3546f18 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScatterOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ScatterOpRewritePattern.cpp @@ -82,7 +82,8 @@ TTNNScatterWorkarounds::matchAndRewrite(ttnn::ScatterOp op, // Slice index tensor for this chunk. RankedTensorType chunkIndexType = ttnn::utils::RankedTensorTypeFactory::create(indexType, chunkShape); - auto chunkIndex = rewriter.create( + auto chunkIndex = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( op.getLoc(), "_chunk_" + std::to_string(chunkIdx) + "_index"), chunkIndexType, op.getIndex(), rewriter.getI32ArrayAttr(begins), @@ -96,14 +97,16 @@ TTNNScatterWorkarounds::matchAndRewrite(ttnn::ScatterOp op, RankedTensorType chunkSourceType = ttnn::utils::RankedTensorTypeFactory::create(sourceType, chunkSourceShape); - auto chunkSource = rewriter.create( + auto chunkSource = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( op.getLoc(), "_chunk_" + std::to_string(chunkIdx) + "_source"), chunkSourceType, op.getSource(), rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); // Perform scatter operation for this chunk. - auto chunkScatter = rewriter.create( + auto chunkScatter = ttnn::ScatterOp::create( + rewriter, ttmlir::utils::appendLocationSuffix( op.getLoc(), "_chunk_" + std::to_string(chunkIdx) + "_scatter"), currentResult.getType(), currentResult, chunkIndex.getResult(), diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/SplitQueryKeyValueAndSplitHeadsOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/SplitQueryKeyValueAndSplitHeadsOpRewritePattern.cpp index 808f6a52ae0..d8456ed6e20 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/SplitQueryKeyValueAndSplitHeadsOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/SplitQueryKeyValueAndSplitHeadsOpRewritePattern.cpp @@ -69,7 +69,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( RankedTensorType reshapedQueryType = utils::RankedTensorTypeFactory::create(inputType, reshapedQueryShape); - auto reshapeQuery = rewriter.create( + auto reshapeQuery = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_reshape_query"), reshapedQueryType, srcOp.getInputTensor(), reshapedQueryShapeAttr, ttnn::MemoryConfigAttr()); @@ -83,7 +84,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( llvm::ArrayRef(reshapedQueryShape), permutation); RankedTensorType queryOutputType = utils::RankedTensorTypeFactory::create(queryType, permutedQueryShape); - auto permuteQ = rewriter.create( + auto permuteQ = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_permute_query"), queryOutputType, reshapeQuery.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); @@ -105,7 +107,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( mlir::cast(srcOp.getKvInputTensor().getType()), kvIntermediateShape); - auto sliceK = rewriter.create( + auto sliceK = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_slice_k"), kvIntermediateType, srcOp.getKvInputTensor(), rewriter.getI32ArrayAttr(beginsK), rewriter.getI32ArrayAttr(endsK), @@ -116,7 +119,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( SmallVector endsV = {static_cast(batchSize), static_cast(sequenceSize), static_cast(2 * kvHiddenSize)}; - auto sliceV = rewriter.create( + auto sliceV = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_slice_v"), kvIntermediateType, srcOp.getKvInputTensor(), rewriter.getI32ArrayAttr(beginsV), rewriter.getI32ArrayAttr(endsV), @@ -133,7 +137,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( rewriter.getI32ArrayAttr(reshapedKShapeI32); RankedTensorType reshapedKType = utils::RankedTensorTypeFactory::create( kvIntermediateType, reshapedKShape); - auto reshapeK = rewriter.create( + auto reshapeK = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_reshape_k"), reshapedKType, sliceK.getResult(), reshapedKShapeAttr, ttnn::MemoryConfigAttr()); @@ -144,7 +149,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( llvm::ArrayRef(reshapedKShape), permutation); RankedTensorType keyOutputType = utils::RankedTensorTypeFactory::create( srcOp.getKey().getType(), permutedKShape); - auto permuteK = rewriter.create( + auto permuteK = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_permute_k"), keyOutputType, reshapeK.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); @@ -160,7 +166,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( rewriter.getI32ArrayAttr(reshapedVShapeI32); RankedTensorType reshapedVType = utils::RankedTensorTypeFactory::create( kvIntermediateType, reshapedVShape); - auto reshapeV = rewriter.create( + auto reshapeV = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_reshape_v"), reshapedVType, sliceV.getResult(), reshapedVShapeAttr, ttnn::MemoryConfigAttr()); @@ -171,7 +178,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( llvm::ArrayRef(reshapedVShape), permutation); RankedTensorType valueOutputType = utils::RankedTensorTypeFactory::create( srcOp.getValue().getType(), permutedVShape); - auto permuteV = rewriter.create( + auto permuteV = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_permute_v"), valueOutputType, reshapeV.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); @@ -187,7 +195,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( llvm::ArrayRef(permutedKShape), transposePermutation); RankedTensorType transposedKType = utils::RankedTensorTypeFactory::create( srcOp.getKey().getType(), transposedKShape); - auto transposeK = rewriter.create( + auto transposeK = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_transpose_k"), transposedKType, permuteK.getResult(), transposePermutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); @@ -227,8 +236,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( utils::RankedTensorTypeFactory::create(queryType, qkvIntermediateShape); // Slice for Q - auto sliceQ = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_split_q"), + auto sliceQ = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_split_q"), qkvIntermediateType, input, rewriter.getI32ArrayAttr(begins_q), rewriter.getI32ArrayAttr(ends_q), rewriter.getI32ArrayAttr(step)); @@ -237,8 +246,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( SmallVector ends_k = {static_cast(batchSize), static_cast(sequenceSize), static_cast(hiddenSize * 2)}; - auto sliceK = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_split_k"), + auto sliceK = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_split_k"), qkvIntermediateType, input, rewriter.getI32ArrayAttr(begins_k), rewriter.getI32ArrayAttr(ends_k), rewriter.getI32ArrayAttr(step)); @@ -248,8 +257,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( SmallVector ends_v = {static_cast(batchSize), static_cast(sequenceSize), static_cast(hiddenSize * 3)}; - auto sliceV = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_split_v"), + auto sliceV = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_split_v"), qkvIntermediateType, input, rewriter.getI32ArrayAttr(begins_v), rewriter.getI32ArrayAttr(ends_v), rewriter.getI32ArrayAttr(step)); @@ -265,17 +274,20 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( RankedTensorType reshapedType = utils::RankedTensorTypeFactory::create(queryType, reshapedShape); - auto reshapeQ = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reshape_q"), reshapedType, - sliceQ.getResult(), reshapedShapeAttr, ttnn::MemoryConfigAttr()); + auto reshapeQ = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape_q"), + reshapedType, sliceQ.getResult(), reshapedShapeAttr, + ttnn::MemoryConfigAttr()); - auto reshapeK = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reshape_k"), reshapedType, - sliceK.getResult(), reshapedShapeAttr, ttnn::MemoryConfigAttr()); + auto reshapeK = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape_k"), + reshapedType, sliceK.getResult(), reshapedShapeAttr, + ttnn::MemoryConfigAttr()); - auto reshapeV = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reshape_v"), reshapedType, - sliceV.getResult(), reshapedShapeAttr, ttnn::MemoryConfigAttr()); + auto reshapeV = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape_v"), + reshapedType, sliceV.getResult(), reshapedShapeAttr, + ttnn::MemoryConfigAttr()); // Step 3: Permute from [batch, seq, num_heads, head_size] to // [batch, num_heads, seq, head_size]. @@ -287,20 +299,20 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( RankedTensorType queryOutputType = utils::RankedTensorTypeFactory::create(queryType, permutedShape); - auto permuteQ = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_permute_q"), queryOutputType, - reshapeQ.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), - mlir::FloatAttr()); + auto permuteQ = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_permute_q"), + queryOutputType, reshapeQ.getResult(), permutationAttr, + ttnn::MemoryConfigAttr(), mlir::FloatAttr()); - auto permuteK = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_permute_k"), queryOutputType, - reshapeK.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), - mlir::FloatAttr()); + auto permuteK = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_permute_k"), + queryOutputType, reshapeK.getResult(), permutationAttr, + ttnn::MemoryConfigAttr(), mlir::FloatAttr()); - auto permuteV = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_permute_v"), queryOutputType, - reshapeV.getResult(), permutationAttr, ttnn::MemoryConfigAttr(), - mlir::FloatAttr()); + auto permuteV = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_permute_v"), + queryOutputType, reshapeV.getResult(), permutationAttr, + ttnn::MemoryConfigAttr(), mlir::FloatAttr()); // Step 4: If transpose_key is true, additionally permute K // from [batch, num_heads, seq, head_size] to [batch, num_heads, head_size, @@ -315,8 +327,8 @@ LogicalResult SplitQueryKeyValueAndSplitHeadsOpRewritePattern::matchAndRewrite( RankedTensorType keyOutputType = utils::RankedTensorTypeFactory::create( srcOp.getKey().getType(), transposedShape); - auto transposeK = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_transpose_k"), + auto transposeK = ttnn::PermuteOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_transpose_k"), keyOutputType, permuteK.getResult(), transposePermutationAttr, ttnn::MemoryConfigAttr(), mlir::FloatAttr()); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/UpsampleOpRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/UpsampleOpRewritePattern.cpp index 6c374f17235..65a41b586fe 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/UpsampleOpRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/UpsampleOpRewritePattern.cpp @@ -63,7 +63,8 @@ LogicalResult UpsampleOpBilinearPaddingRewritePattern::matchAndRewrite( mlir::cast(inputType.getEncoding()) .withTensorShape(paddedShape)); - auto padOp = rewriter.create( + auto padOp = ttnn::PadOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getInput().getLoc(), "pad"), paddedType, srcOp.getInput(), padding, /*pad_value=*/mlir::APFloat(0.0f), /*use_multicore=*/false, @@ -78,18 +79,19 @@ LogicalResult UpsampleOpBilinearPaddingRewritePattern::matchAndRewrite( mlir::cast(outputType.getEncoding()) .withTensorShape(upsamplePaddedShape)); - auto paddedUpsampleOp = rewriter.create( - srcOp.getLoc(), upsampledPaddedType, padOp, srcOp.getScaleFactorAttr(), - srcOp.getModeAttr(), /*memory_config=*/nullptr); + auto paddedUpsampleOp = + ttnn::UpsampleOp::create(rewriter, srcOp.getLoc(), upsampledPaddedType, + padOp, srcOp.getScaleFactorAttr(), + srcOp.getModeAttr(), /*memory_config=*/nullptr); // Create SliceStaticOp to remove padding from the upsampled result. SmallVector begins(/*size=*/DIM_COUNT, /*value=*/0); SmallVector ends(outputType.getShape()); SmallVector steps(/*size=*/DIM_COUNT, /*value=*/1); - auto sliceOp = rewriter.create( - ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "slice"), outputType, - paddedUpsampleOp, rewriter.getI32ArrayAttr(begins), + auto sliceOp = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "slice"), + outputType, paddedUpsampleOp, rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), rewriter.getI32ArrayAttr(steps)); rewriter.replaceOp(srcOp, sliceOp); diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkaroundsPatterns.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkaroundsPatterns.cpp index 388e1d10ac9..f3a1920f9b0 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkaroundsPatterns.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkaroundsPatterns.cpp @@ -443,7 +443,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { auto paddedType = ttnn::utils::RankedTensorTypeFactory::create(inputType, paddedShape); - reduceScatterInput = rewriter.create( + reduceScatterInput = ttnn::PadOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_pad_for_reduce_scatter"), paddedType, op.getInput(), padding, /*pad_value=*/mlir::APFloat(0.0f), /*use_multicore=*/false, @@ -465,18 +466,16 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { reduceScatterInputType, reduceScatterShape); // Create a new reducer scatter op. - ttnn::ReduceScatterOp reduceScatterOp = - rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reduce_scatter"), - reduceScatterOutputType, reduceScatterInput, op.getReduceType(), - selectedDim, clusterAxis, nullptr, nullptr, nullptr, nullptr, - nullptr); + ttnn::ReduceScatterOp reduceScatterOp = ttnn::ReduceScatterOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reduce_scatter"), + reduceScatterOutputType, reduceScatterInput, op.getReduceType(), + selectedDim, clusterAxis, nullptr, nullptr, nullptr, nullptr, nullptr); // all_gather restores the reduce_scatter input shape. auto allGatherOutputType = ttnn::utils::RankedTensorTypeFactory::create( reduceScatterInputType, reduceScatterInputType.getShape()); - ttnn::AllGatherOp allGatherOp = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_all_gather"), + ttnn::AllGatherOp allGatherOp = ttnn::AllGatherOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_all_gather"), allGatherOutputType, reduceScatterOp.getResult(), selectedDim, clusterAxis, nullptr /*sub_device_id*/, nullptr /*memory_config*/, nullptr /*num_links*/, nullptr /*topology*/); @@ -486,7 +485,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { llvm::SmallVector begins(inputShape.size(), 0); llvm::SmallVector ends(inputShape.begin(), inputShape.end()); llvm::SmallVector steps(inputShape.size(), 1); - auto sliceOp = rewriter.create( + auto sliceOp = ttnn::SliceStaticOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_slice_for_reduce_scatter_pad"), op.getType(), allGatherOp.getResult(), @@ -520,9 +520,9 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { ttnn::utils::RankedTensorTypeFactory::create(inputType, expandedInputShape); - ttnn::ReshapeOp leadingReshapeOp = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_reshape"), reshapedInputType, - op.getInput(), reshapedInputShapeAttr, + ttnn::ReshapeOp leadingReshapeOp = ttnn::ReshapeOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_reshape"), + reshapedInputType, op.getInput(), reshapedInputShapeAttr, /* memory_config */ nullptr); // Create a new all gather op. @@ -530,8 +530,8 @@ class TTNNAllReduceWorkarounds : public OpRewritePattern { RankedTensorType allGatherOutputType = ttnn::utils::RankedTensorTypeFactory::create(reshapedInputType, expandedInputShape); - ttnn::AllGatherOp allGatherOp = rewriter.create( - ttmlir::utils::appendLocationSuffix(loc, "_allGather"), + ttnn::AllGatherOp allGatherOp = ttnn::AllGatherOp::create( + rewriter, ttmlir::utils::appendLocationSuffix(loc, "_allGather"), allGatherOutputType, leadingReshapeOp.getResult(), 0, clusterAxis, nullptr /*sub_device_id*/, nullptr /*memory_config*/, nullptr /*num_links*/, nullptr /*topology*/); diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp index 6d31164eb93..60b58b713b7 100644 --- a/lib/Dialect/TTNN/Utils/PassOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -8,6 +8,7 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/Support/raw_ostream.h" #include diff --git a/lib/Dialect/TTNN/Utils/TransformUtils.cpp b/lib/Dialect/TTNN/Utils/TransformUtils.cpp index eadfd81e9ca..f484d8e7a70 100644 --- a/lib/Dialect/TTNN/Utils/TransformUtils.cpp +++ b/lib/Dialect/TTNN/Utils/TransformUtils.cpp @@ -21,8 +21,8 @@ static GetDeviceOp insertGetDeviceOp(RewriterBase &rewriter, // TODO (jnie): Currently hardcoding the mesh offset to 0x0 // Need a proper plan to dynamically determine this. llvm::SmallVector meshOffset{0, 0}; - return rewriter.create( - loc, rewriter.getType(), + return ttnn::GetDeviceOp::create( + rewriter, loc, rewriter.getType(), ttnn::MeshShapeAttr::get(rewriter.getContext(), meshShape[0], meshShape[1]), ttnn::MeshOffsetAttr::get(rewriter.getContext(), meshOffset[0], @@ -118,8 +118,8 @@ ToLayoutOp createToLayoutOp(Operation *op, Location loc = ttmlir::utils::appendLocationSuffix(op->getLoc(), locSuffix); // Create a ToLayoutOp to convert the input operand to the desired // tensor layout, buffer type and memory layout. - return rewriter.create( - loc, toLayoutOpResultType, inputValue, + return ttnn::ToLayoutOp::create( + rewriter, loc, toLayoutOpResultType, inputValue, LayoutAttr::get(rewriter.getContext(), targetTensorLayout), ttcore::DataTypeAttr::get(rewriter.getContext(), targetTensorDataType), outputMemConfigAttr); diff --git a/lib/RegisterAll.cpp b/lib/RegisterAll.cpp index 5dfb9410648..36f891cbd38 100644 --- a/lib/RegisterAll.cpp +++ b/lib/RegisterAll.cpp @@ -57,8 +57,6 @@ #include "mlir/Transforms/Passes.h" #if TTMLIR_ENABLE_STABLEHLO -#include "shardy/dialect/mpmd/ir/register.h" -#include "shardy/dialect/mpmd/transforms/passes.h" #include "shardy/dialect/sdy/ir/register.h" #include "shardy/dialect/sdy/transforms/passes.h" #include "stablehlo/dialect/Register.h" @@ -100,7 +98,6 @@ void mlir::tt::registerAllDialects(mlir::DialectRegistry ®istry) { #if TTMLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); mlir::sdy::registerAllDialects(registry); - mlir::mpmd::registerAllDialects(registry); #endif } diff --git a/lib/Target/TTKernel/TTKernelToCpp.cpp b/lib/Target/TTKernel/TTKernelToCpp.cpp index ffb74d4cadb..e23f80eb95c 100644 --- a/lib/Target/TTKernel/TTKernelToCpp.cpp +++ b/lib/Target/TTKernel/TTKernelToCpp.cpp @@ -44,138 +44,146 @@ class ScopedModuleHelper { if (!originalSymbolName.empty()) { emitComment(originalSymbolName); } - builder->create(loc, "cstdint", - /*isStandard=*/true); + emitc::IncludeOp::create(*builder, loc, "cstdint", + /*isStandard=*/true); - builder->create(loc, "tools/profiler/kernel_profiler.hpp", - /*isStandard=*/false); - builder->create(loc, "internal/firmware_common.h", - /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "tools/profiler/kernel_profiler.hpp", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "internal/firmware_common.h", + /*isStandard=*/false); if (threadType == ThreadType::Noc) { - builder->create(loc, "api/dataflow/dataflow_api.h", - /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/dataflow/dataflow_api.h", + /*isStandard=*/false); emitExperimentalLLKs(); emitDebugPrint(threadType); } if (threadType == ThreadType::Compute) { - builder->create(loc, "llk_defs.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/binary_max_min.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/common.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/matmul.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/bcast.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/tilize.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/untilize.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/transpose_wh.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_binary.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_binary_sfpu.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/add_int_sfpu.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/sub_int_sfpu.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/mul_int_sfpu.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/compute_kernel_api.h", // max ops - /*isStandard=*/false); - builder->create(loc, "api/compute/copy_dest_values.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/tile_move_copy.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/activations.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/eltwise_unary.h", - /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "llk_defs.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/binary_max_min.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/common.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/matmul.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/bcast.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/tilize.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/untilize.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/transpose_wh.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/eltwise_binary.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_binary_sfpu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/add_int_sfpu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/sub_int_sfpu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/mul_int_sfpu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/compute_kernel_api.h", // max ops + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/copy_dest_values.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/tile_move_copy.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/activations.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/eltwise_unary.h", + /*isStandard=*/false); // TODO (kmitrovic) exp.h is an ExpOp-specific include. Every op has one, // should be handled in general, not like this. // Issue: https://github.com/tenstorrent/tt-mlir/issues/772 - builder->create(loc, "api/compute/eltwise_unary/exp.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/sfpu_split_includes.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/recip.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_unary/fill.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/negative.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_unary/sqrt.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/rounding.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/trigonometry.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_unary/gelu.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/erf_erfc.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/logical_not.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_unary/comp.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/rsqrt.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/typecast.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/binary_bitwise_sfpu.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/bitwise_not.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/eltwise_unary/relu.h", - /*isStandard=*/false); - builder->create( - loc, "api/compute/eltwise_unary/binop_with_scalar.h", + emitc::IncludeOp::create(*builder, loc, "api/compute/eltwise_unary/exp.h", + /*isStandard=*/false); + emitc::IncludeOp::create( + *builder, loc, "api/compute/eltwise_unary/sfpu_split_includes.h", /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/where.h", - /*isStandard=*/false); - builder->create(loc, - "api/compute/eltwise_unary/clamp.h", - /*isStandard=*/false); - builder->create(loc, "api/compute/pack_untilize.h", - /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/recip.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/fill.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/negative.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/sqrt.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/rounding.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/trigonometry.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/gelu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/erf_erfc.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/logical_not.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/comp.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/rsqrt.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/typecast.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/binary_bitwise_sfpu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/bitwise_not.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/relu.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/binop_with_scalar.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/where.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, + "api/compute/eltwise_unary/clamp.h", + /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/compute/pack_untilize.h", + /*isStandard=*/false); // Helper for float-to-uint32 bit reinterpretation (used by scalar tile // ops). - builder->create( - loc, "inline uint32_t float_to_bits(float f) { " - "uint32_t r; __builtin_memcpy(&r, &f, sizeof(r)); return r; }"); + emitc::VerbatimOp::create( + *builder, loc, + "inline uint32_t float_to_bits(float f) { " + "uint32_t r; __builtin_memcpy(&r, &f, sizeof(r)); return r; }"); // Define INFINITY if not available (needed for OOB masking with inf // fill). - builder->create( - loc, "#ifndef INFINITY\n#define INFINITY __builtin_inff()\n#endif"); + emitc::VerbatimOp::create( + *builder, loc, + "#ifndef INFINITY\n#define INFINITY __builtin_inff()\n#endif"); // Must define macros REDUCE_OP and REDUCE_DIM before including reduce.h // because they are default template parameters values in reduce api. - builder->create(loc, - "#define REDUCE_OP PoolType::SUM"); - builder->create( - loc, "#define REDUCE_DIM ReduceDim::REDUCE_COL"); - builder->create(loc, "api/compute/reduce.h", - /*isStandard=*/false); + emitc::VerbatimOp::create(*builder, loc, + "#define REDUCE_OP PoolType::SUM"); + emitc::VerbatimOp::create(*builder, loc, + "#define REDUCE_DIM ReduceDim::REDUCE_COL"); + emitc::IncludeOp::create(*builder, loc, "api/compute/reduce.h", + /*isStandard=*/false); emitExperimentalLLKs(); emitDebugPrint(threadType); } @@ -184,7 +192,7 @@ class ScopedModuleHelper { ~ScopedModuleHelper() = default; void emitComment(StringRef str) { - builder->create(loc, (Twine("// ") + str).str()); + emitc::VerbatimOp::create(*builder, loc, (Twine("// ") + str).str()); } void emitDebugPrint(ThreadType threadType) { @@ -194,10 +202,10 @@ class ScopedModuleHelper { return; } - builder->create(loc, "api/debug/dprint.h", - /*isStandard=*/false); + emitc::IncludeOp::create(*builder, loc, "api/debug/dprint.h", + /*isStandard=*/false); - builder->create(loc, R""""( + emitc::VerbatimOp::create(*builder, loc, R""""( namespace ttmlir { template void dprint(Arg &&arg) { @@ -214,7 +222,7 @@ void dprint(Arg &&arg, ArgV&&... argv) { )""""); if (threadType == ThreadType::Compute) { - builder->create(loc, R""""( + emitc::VerbatimOp::create(*builder, loc, R""""( namespace ttmlir { inline void print_cb_details_(DebugPrinter dp, uint32_t cb_id) { dp << "cb_id " << cb_id << ": { "; @@ -250,28 +258,28 @@ void dprint(Arg &&arg, ArgV&&... argv) { auto experimentalTilizeLLKs = StringRef(experimental_tilize_llks_generated, experimental_tilize_llks_generated_len); - builder->create(loc, experimentalTilizeLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalTilizeLLKs); } if (hasCall("experimental::untilize")) { auto experimentalUntilizeLLKs = StringRef(experimental_untilize_llks_generated, experimental_untilize_llks_generated_len); - builder->create(loc, experimentalUntilizeLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalUntilizeLLKs); } if (hasCall("experimental::pack_untilize_block")) { auto experimentalPackUntilizeLLKs = StringRef(experimental_pack_untilize_llks_generated, experimental_pack_untilize_llks_generated_len); - builder->create(loc, experimentalPackUntilizeLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalPackUntilizeLLKs); } if (hasCall("experimental::get_noc_multicast_addr")) { auto experimentalDataflowLLKs = StringRef(experimental_dataflow_api_generated, experimental_dataflow_api_generated_len); - builder->create(loc, experimentalDataflowLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalDataflowLLKs); } if (hasCall("experimental::convert_logical_x_to_translated") || @@ -279,7 +287,8 @@ void dprint(Arg &&arg, ArgV&&... argv) { auto experimentalCoordTranslationLLKs = StringRef(experimental_coord_translation_generated, experimental_coord_translation_generated_len); - builder->create(loc, experimentalCoordTranslationLLKs); + emitc::VerbatimOp::create(*builder, loc, + experimentalCoordTranslationLLKs); } if (hasCall("experimental::close_fabric_connections") || @@ -295,32 +304,32 @@ void dprint(Arg &&arg, ArgV&&... argv) { auto experimentalFabricTopologyInfoLLKs = StringRef(experimental_fabric_topology_info_generated, experimental_fabric_topology_info_generated_len); - builder->create(loc, - experimentalFabricTopologyInfoLLKs); + emitc::VerbatimOp::create(*builder, loc, + experimentalFabricTopologyInfoLLKs); // 2. Routing functions auto experimentalFabric1DRoutingLLKs = StringRef(experimental_fabric_1d_routing_generated, experimental_fabric_1d_routing_generated_len); - builder->create(loc, experimentalFabric1DRoutingLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalFabric1DRoutingLLKs); auto experimentalFabric2DRoutingLLKs = StringRef(experimental_fabric_2d_routing_generated, experimental_fabric_2d_routing_generated_len); - builder->create(loc, experimentalFabric2DRoutingLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalFabric2DRoutingLLKs); // 3. Fabric APIs auto experimentalFabricAPILLKs = StringRef(experimental_fabric_api_generated, experimental_fabric_api_generated_len); - builder->create(loc, experimentalFabricAPILLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalFabricAPILLKs); } if (hasCall("experimental::matmul_block")) { auto experimentalMatmulLLKs = StringRef(experimental_matmul_llks_generated, experimental_matmul_llks_generated_len); - builder->create(loc, experimentalMatmulLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalMatmulLLKs); } if (hasCall("experimental::tile_fill") || @@ -330,13 +339,14 @@ void dprint(Arg &&arg, ArgV&&... argv) { auto experimentalPaddingLLKs = StringRef(experimental_padding_llks_generated, experimental_padding_llks_generated_len); - builder->create(loc, experimentalPaddingLLKs); + emitc::VerbatimOp::create(*builder, loc, experimentalPaddingLLKs); } if (hasVerbatim("experimental::invoke_sfpi")) { - builder->create( - loc, StringRef(experimental_invoke_sfpi_llks_generated, - experimental_invoke_sfpi_llks_generated_len)); + emitc::VerbatimOp::create( + *builder, loc, + StringRef(experimental_invoke_sfpi_llks_generated, + experimental_invoke_sfpi_llks_generated_len)); } } @@ -384,7 +394,7 @@ cloneEntryIntoStandaloneModule(func::FuncOp origEntry, ThreadType threadType) { // We will wrap everything in a standalone module op so that we can run the // translation. - auto moduleWrapper = builder.create(loc, "module_wrapper"); + auto moduleWrapper = mlir::ModuleOp::create(builder, loc, "module_wrapper"); builder.setInsertionPointToStart(moduleWrapper.getBody()); Region *kernelMainRegion; @@ -393,8 +403,8 @@ cloneEntryIntoStandaloneModule(func::FuncOp origEntry, ThreadType threadType) { origEntry.getName()); // Clone 'region' into a new func op nested inside 'moduleWrapper': - auto kernelMain = builder.create( - loc, "kernel_main", + auto kernelMain = func::FuncOp::create( + builder, loc, "kernel_main", builder.getType(region->getArgumentTypes(), TypeRange())); kernelMainRegion = &kernelMain.getBody(); } @@ -417,7 +427,7 @@ LogicalResult translateKernelFuncToCpp(func::FuncOp entry, if (failed(kernelModule)) { return failure(); } - auto moduleCleanup = llvm::make_scope_exit([&]() { kernelModule->erase(); }); + auto moduleCleanup = llvm::scope_exit([&]() { kernelModule->erase(); }); return emitc::translateToCpp(*kernelModule, os); } diff --git a/lib/Transforms/CollapseParallelLoops.cpp b/lib/Transforms/CollapseParallelLoops.cpp index 08eea74bc64..448ad658feb 100644 --- a/lib/Transforms/CollapseParallelLoops.cpp +++ b/lib/Transforms/CollapseParallelLoops.cpp @@ -45,10 +45,10 @@ class CollapseParallelLoopPattern : public OpRewritePattern { SmallVector newLowerBounds, newUpperBounds, newSteps; - newLowerBounds.push_back(rewriter.create(loc, 0)); + newLowerBounds.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); newUpperBounds.push_back( - rewriter.create(loc, collapsedSize)); - newSteps.push_back(rewriter.create(loc, 1)); + arith::ConstantIndexOp::create(rewriter, loc, collapsedSize)); + newSteps.push_back(arith::ConstantIndexOp::create(rewriter, loc, 1)); newLowerBounds.push_back(lowerBounds[1]); newUpperBounds.push_back(upperBounds[1]); @@ -58,8 +58,8 @@ class CollapseParallelLoopPattern : public OpRewritePattern { newUpperBounds.push_back(upperBounds[0]); newSteps.push_back(steps[0]); - auto newParallelOp = rewriter.create( - loc, newLowerBounds, newUpperBounds, newSteps, initVals); + auto newParallelOp = scf::ParallelOp::create( + rewriter, loc, newLowerBounds, newUpperBounds, newSteps, initVals); return newParallelOp; } @@ -147,12 +147,13 @@ class CollapseParallelLoopPattern : public OpRewritePattern { int64_t productOfRemainingDims = calculateProductOfRemainingDims(i + 1, endIdx, dimSizes); Value divisor = - rewriter.create(loc, productOfRemainingDims); - Value quotient = rewriter.create(loc, remaining, divisor); + arith::ConstantIndexOp::create(rewriter, loc, productOfRemainingDims); + Value quotient = + arith::DivUIOp::create(rewriter, loc, remaining, divisor); decomposedVars.push_back(quotient); - Value product = rewriter.create(loc, quotient, divisor); - remaining = rewriter.create(loc, remaining, product); + Value product = arith::MulIOp::create(rewriter, loc, quotient, divisor); + remaining = arith::SubIOp::create(rewriter, loc, remaining, product); } decomposedVars.push_back(remaining); diff --git a/lib/Transforms/ConstEvalHoist.cpp b/lib/Transforms/ConstEvalHoist.cpp index 681abbcaa08..51946bb21a7 100644 --- a/lib/Transforms/ConstEvalHoist.cpp +++ b/lib/Transforms/ConstEvalHoist.cpp @@ -544,8 +544,8 @@ class ConstEvalHoistTransform // Create the const-eval function before the parent function // This ensures proper ordering in the generated EmitC code. builder.setInsertionPoint(originalFunc); - auto newFuncOp = builder.create(originalFunc.getLoc(), - newFuncName, funcType); + auto newFuncOp = func::FuncOp::create(builder, originalFunc.getLoc(), + newFuncName, funcType); // Mark the new function as const-eval and private. ttmlir::utils::setFunctionType(newFuncOp, ttmlir::utils::FunctionType::ConstEval); @@ -601,7 +601,7 @@ class ConstEvalHoistTransform returnValues.push_back(it->second); } - builder.create(originalFunc.getLoc(), returnValues); + func::ReturnOp::create(builder, originalFunc.getLoc(), returnValues); auto &originalEntryBlock = originalFunc.getBody().front(); // Manually order LoadCachedOp as first n ops in original func--we may @@ -615,8 +615,9 @@ class ConstEvalHoistTransform mlir::SymbolRefAttr::get(builder.getContext(), newFuncName); // Create the LoadCachedOp with the correct argument order - auto callOp = builder.create( - originalFunc.getLoc(), outputTypes, calleeAttr, ValueRange(inputs)); + auto callOp = ttcore::LoadCachedOp::create(builder, originalFunc.getLoc(), + outputTypes, calleeAttr, + ValueRange(inputs)); // Replace uses of original outputs with call results. for (size_t i = 0; i < outputs.size(); ++i) { diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index d2aeede9fed..4a4aca0102a 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -144,7 +144,7 @@ declare_mlir_python_extension(TTMLIRPythonExtensions.Main MLIRTestToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation MLIRX86VectorTransforms - PYTHON_BINDINGS_LIBRARY nanobind + ) # Create a separate library for PassTracker compiled with -fno-rtti to match MLIR @@ -210,8 +210,6 @@ if(TTMLIR_ENABLE_STABLEHLO) StablehloPythonExtensions SdyPythonSources SdyPythonExtensions - MpmdPythonSources - MpmdPythonExtensions ) endif() diff --git a/test/unittests/Allocation/TestGenericOpBufferAnalysis.cpp b/test/unittests/Allocation/TestGenericOpBufferAnalysis.cpp index 279eb125659..8ce0b9c9033 100644 --- a/test/unittests/Allocation/TestGenericOpBufferAnalysis.cpp +++ b/test/unittests/Allocation/TestGenericOpBufferAnalysis.cpp @@ -39,8 +39,8 @@ class GenericOpBufferAnalysisTest : public ::testing::Test { builder.setInsertionPointToEnd(module->getBody()); auto funcType = builder.getFunctionType({}, {}); - func = builder.create(builder.getUnknownLoc(), - "test_func", funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), + "test_func", funcType); mlir::Block *entryBlock = func.addEntryBlock(); builder.setInsertionPointToStart(entryBlock); @@ -102,13 +102,13 @@ GenericOp createGenericOp(mlir::OpBuilder &builder, mlir::MLIRContext &context, // Use memref alloc to create concrete memref values from types. Value input1 = - builder.create(builder.getUnknownLoc(), memrefType1) + ttir::EmptyOp::create(builder, builder.getUnknownLoc(), memrefType1) ->getResult(0); Value input2 = - builder.create(builder.getUnknownLoc(), memrefType2) + ttir::EmptyOp::create(builder, builder.getUnknownLoc(), memrefType2) ->getResult(0); Value output = - builder.create(builder.getUnknownLoc(), memrefType3) + ttir::EmptyOp::create(builder, builder.getUnknownLoc(), memrefType3) ->getResult(0); ValueRange additionalArgs = {}; @@ -116,9 +116,9 @@ GenericOp createGenericOp(mlir::OpBuilder &builder, mlir::MLIRContext &context, SmallVector outputs = {output}; // Create the GenericOp. - auto genericOp = builder.create( - builder.getUnknownLoc(), inputs, outputs, additionalArgs, indexingMaps, - iteratorTypes); + auto genericOp = mlir::tt::d2m::GenericOp::create( + builder, builder.getUnknownLoc(), inputs, outputs, additionalArgs, + indexingMaps, iteratorTypes); return genericOp; } diff --git a/test/unittests/OpModel/TTNN/Op/TestD2MOpCostModel.cpp b/test/unittests/OpModel/TTNN/Op/TestD2MOpCostModel.cpp index 7c6c997d5a2..798b7d01083 100644 --- a/test/unittests/OpModel/TTNN/Op/TestD2MOpCostModel.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestD2MOpCostModel.cpp @@ -39,9 +39,9 @@ class D2MOpCostModelTest : public OpModelFixture { } auto rankedTensorType = createRankedTensorType(tensorShape, elementType, layout); - return builder.create( - builder.getUnknownLoc(), rankedTensorType, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + return OnesOp::create(builder, builder.getUnknownLoc(), rankedTensorType, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); } std::vector getInputLayoutsFromOperands(mlir::Operation *op) { @@ -75,8 +75,8 @@ TEST_F(D2MOpCostModelTest, AddOp) { auto rhs = createEmptyTensor(shape, builder.getBF16Type(), layout); auto outputType = createRankedTensorType(shape, builder.getBF16Type(), layout); - auto addOp = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{lhs, rhs}); + auto addOp = AddOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{lhs, rhs}); std::vector inputs = getInputLayoutsFromOperands(addOp); OpConfig opConfig = getOutputConfig(addOp); @@ -102,8 +102,8 @@ TEST_F(D2MOpCostModelTest, ReluOp) { auto input = createEmptyTensor(shape, builder.getBF16Type(), layout); auto outputType = createRankedTensorType(shape, builder.getBF16Type(), layout); - auto reluOp = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input}); + auto reluOp = ReluOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input}); std::vector inputs = getInputLayoutsFromOperands(reluOp); OpConfig opConfig = getOutputConfig(reluOp); @@ -132,8 +132,8 @@ TEST_F(D2MOpCostModelTest, SumOp) { auto input = createEmptyTensor(inputShape, builder.getBF16Type(), layoutIn); auto outputType = createRankedTensorType(outputShape, builder.getBF16Type(), layoutOut); - auto sumOp = builder.create( - builder.getUnknownLoc(), outputType, input, /*keep_dim=*/true, + auto sumOp = SumOp::create( + builder, builder.getUnknownLoc(), outputType, input, /*keep_dim=*/true, builder.getArrayAttr( llvm::SmallVector{builder.getI64IntegerAttr(1)})); @@ -168,8 +168,8 @@ TEST_F(D2MOpCostModelTest, MatmulOp) { auto inputB = createEmptyTensor(shapeB, builder.getBF16Type(), layoutB); auto outputType = createRankedTensorType(shapeO, builder.getBF16Type(), layoutO); - auto matmulOp = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmulOp = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); std::vector inputs = getInputLayoutsFromOperands(matmulOp); OpConfig opConfig = getOutputConfig(matmulOp); diff --git a/test/unittests/OpModel/TTNN/Op/TestMatmulBlockShardedConstraint.cpp b/test/unittests/OpModel/TTNN/Op/TestMatmulBlockShardedConstraint.cpp index 2573b3fb335..0b4f5c6740e 100644 --- a/test/unittests/OpModel/TTNN/Op/TestMatmulBlockShardedConstraint.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestMatmulBlockShardedConstraint.cpp @@ -35,9 +35,9 @@ class OpModelTest : public OpModelFixture { } RankedTensorType rankedTensorType = createRankedTensorType(tensorShape, elementType, layout); - return builder.create( - builder.getUnknownLoc(), rankedTensorType, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + return OnesOp::create(builder, builder.getUnknownLoc(), rankedTensorType, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); } }; @@ -67,8 +67,8 @@ TEST_F(OpModelTest, MatmulBlockShardedInputWithPadding) { auto outputType = createRankedTensorType(outputShape, builder.getBF16Type(), nullptr); - auto matmul = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmul = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); auto deviceGrid = CreateWorkerGrid(); @@ -110,8 +110,8 @@ TEST_F(OpModelTest, MatmulActivationWithIgnorePhysicalLayout) { createRankedTensorType(outputShape, builder.getBF16Type(), nullptr); // Create matmul with activation attribute - auto matmul = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmul = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); matmul.setActivation(llvm::StringRef("sigmoid")); // Create sharded output layout with IgnorePhysicalLayout=true diff --git a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp index 6b3cdeeda17..6ccddb69ad1 100644 --- a/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp +++ b/test/unittests/OpModel/TTNN/Op/TestOpModelInterface.cpp @@ -109,9 +109,9 @@ class OpModelBase : public OpModelFixture { } RankedTensorType rankedTensorType = createRankedTensorType(tensorShape, elementType, layout); - return builder.create( - builder.getUnknownLoc(), rankedTensorType, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + return OnesOp::create(builder, builder.getUnknownLoc(), rankedTensorType, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); } }; struct ExpectedResult { @@ -201,107 +201,107 @@ const ExpectedResult expected{true}; //===---------------------------------------------------------=== const auto createRelu = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return ReluOp::create(b, loc, type, ops).getOperation(); }; const auto createRelu6 = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return Relu6Op::create(b, loc, type, ops).getOperation(); }; const auto createHardsigmoid = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return HardsigmoidOp::create(b, loc, type, ops).getOperation(); }; const auto createSilu = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return SiluOp::create(b, loc, type, ops).getOperation(); }; const auto createSin = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return SinOp::create(b, loc, type, ops).getOperation(); }; const auto createCos = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return CosOp::create(b, loc, type, ops).getOperation(); }; const auto createExp = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return ExpOp::create(b, loc, type, ops).getOperation(); }; const auto createTanh = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return TanhOp::create(b, loc, type, ops).getOperation(); }; const auto createLog = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return LogOp::create(b, loc, type, ops).getOperation(); }; const auto createAbs = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return AbsOp::create(b, loc, type, ops).getOperation(); }; const auto createCeil = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return CeilOp::create(b, loc, type, ops).getOperation(); }; const auto createSign = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return SignOp::create(b, loc, type, ops).getOperation(); }; const auto createErf = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return ErfOp::create(b, loc, type, ops).getOperation(); }; const auto createErfc = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return ErfcOp::create(b, loc, type, ops).getOperation(); }; const auto createFloor = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return FloorOp::create(b, loc, type, ops).getOperation(); }; const auto createGelu = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return GeluOp::create(b, loc, type, ops).getOperation(); }; const auto createIsFinite = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return IsFiniteOp::create(b, loc, type, ops).getOperation(); }; const auto createLogicalNot = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return LogicalNotOp::create(b, loc, type, ops).getOperation(); }; const auto createNeg = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return NegOp::create(b, loc, type, ops).getOperation(); }; const auto createTan = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return TanOp::create(b, loc, type, ops).getOperation(); }; const auto createAtan = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return AtanOp::create(b, loc, type, ops).getOperation(); }; const auto createRsqrt = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return RsqrtOp::create(b, loc, type, ops).getOperation(); }; const auto createLog1p = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return Log1pOp::create(b, loc, type, ops).getOperation(); }; const auto createExpm1 = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return Expm1Op::create(b, loc, type, ops).getOperation(); }; const auto createReciprocal = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return ReciprocalOp::create(b, loc, type, ops).getOperation(); }; const auto createCbrt = [](OpBuilder &b, Location loc, Type type, ValueRange ops) { - return b.create(loc, type, ops).getOperation(); + return CbrtOp::create(b, loc, type, ops).getOperation(); }; //===---------------------------------------------------------=== @@ -483,71 +483,71 @@ const ExpectedResult binaryExpected{true}; //===---------------------------------------------------------=== // Lambda functions for creating binary operations const auto createAdd = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return AddOp::create(b, l, t, r).getOperation(); }; const auto createSubtract = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return SubtractOp::create(b, l, t, r).getOperation(); }; const auto createMultiply = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return MultiplyOp::create(b, l, t, r).getOperation(); }; const auto createDivide = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return DivideOp::create(b, l, t, r).getOperation(); }; const auto createEqual = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return EqualOp::create(b, l, t, r).getOperation(); }; const auto createNotEqual = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return NotEqualOp::create(b, l, t, r).getOperation(); }; const auto createGE = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return GreaterEqualOp::create(b, l, t, r).getOperation(); }; const auto createGT = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return GreaterThanOp::create(b, l, t, r).getOperation(); }; const auto createLE = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return LessEqualOp::create(b, l, t, r).getOperation(); }; const auto createLT = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return LessThanOp::create(b, l, t, r).getOperation(); }; const auto createAnd = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return LogicalAndOp::create(b, l, t, r).getOperation(); }; const auto createOr = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return LogicalOrOp::create(b, l, t, r).getOperation(); }; const auto createXor = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return LogicalXorOp::create(b, l, t, r).getOperation(); }; const auto createMax = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return MaximumOp::create(b, l, t, r).getOperation(); }; const auto createMin = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return MinimumOp::create(b, l, t, r).getOperation(); }; const auto createPow = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return PowTensorOp::create(b, l, t, r).getOperation(); }; const auto createBitwiseAnd = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return BitwiseAndOp::create(b, l, t, r).getOperation(); }; const auto createBitwiseOr = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return BitwiseOrOp::create(b, l, t, r).getOperation(); }; const auto createBitwiseXor = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return BitwiseXorOp::create(b, l, t, r).getOperation(); }; const auto createRemainder = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return RemainderOp::create(b, l, t, r).getOperation(); }; const auto createAtan2 = [](OpBuilder &b, Location l, Type t, ValueRange r) { - return b.create(l, t, r).getOperation(); + return Atan2Op::create(b, l, t, r).getOperation(); }; //===---------------------------------------------------------=== @@ -607,8 +607,8 @@ TEST_F(OpModelBase, PowScalarOp) { // Input params const auto exponent = builder.getF32FloatAttr(2.0f); - PowScalarOp powScalarOp = builder.create( - builder.getUnknownLoc(), outputType, input, exponent); + PowScalarOp powScalarOp = PowScalarOp::create( + builder, builder.getUnknownLoc(), outputType, input, exponent); powScalarOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(powScalarOp.getOperation()); @@ -654,12 +654,12 @@ TEST_F(OpModelBase, BitwiseNotOpInterface) { auto outputType = createRankedTensorType(tensorShape, intType, int32Layout); // Create input tensor using OnesOp with Int32 layout - auto input = builder.create( - builder.getUnknownLoc(), inputType, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + auto input = OnesOp::create(builder, builder.getUnknownLoc(), inputType, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); - auto bitwiseNot = builder.create( - builder.getUnknownLoc(), outputType, ::mlir::ValueRange{input}); + auto bitwiseNot = BitwiseNotOp::create(builder, builder.getUnknownLoc(), + outputType, ::mlir::ValueRange{input}); // Test BitwiseNot interface auto constraintsExp = getOpConstraints(bitwiseNot.getOperation()); @@ -707,15 +707,16 @@ TEST_F(OpModelBase, LogicalRightShiftOpInterface) { auto outputType = createRankedTensorType(tensorShape, intType, int32Layout); // Create input tensors using OnesOp with Int32 layout - auto input1 = builder.create( - builder.getUnknownLoc(), input1Type, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); - auto input2 = builder.create( - builder.getUnknownLoc(), input2Type, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + auto input1 = OnesOp::create(builder, builder.getUnknownLoc(), input1Type, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); + auto input2 = OnesOp::create(builder, builder.getUnknownLoc(), input2Type, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); - auto logicalRightShift = builder.create( - builder.getUnknownLoc(), outputType, ::mlir::ValueRange{input1, input2}); + auto logicalRightShift = + LogicalRightShiftOp::create(builder, builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{input1, input2}); // Test LogicalRightShift interface auto constraintsExp = getOpConstraints(logicalRightShift.getOperation()); @@ -762,15 +763,16 @@ TEST_F(OpModelBase, LogicalLeftShiftOpInterface) { auto outputType = createRankedTensorType(tensorShape, intType, int32Layout); // Create input tensors using OnesOp with Int32 layout - auto input1 = builder.create( - builder.getUnknownLoc(), input1Type, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); - auto input2 = builder.create( - builder.getUnknownLoc(), input2Type, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, nullptr, nullptr); + auto input1 = OnesOp::create(builder, builder.getUnknownLoc(), input1Type, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); + auto input2 = OnesOp::create(builder, builder.getUnknownLoc(), input2Type, + nullptr, ShapeAttr::get(&context, tensorShape), + nullptr, nullptr, nullptr); - auto logicalLeftShift = builder.create( - builder.getUnknownLoc(), outputType, ::mlir::ValueRange{input1, input2}); + auto logicalLeftShift = + LogicalLeftShiftOp::create(builder, builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{input1, input2}); // Test LogicalLeftShift interface auto constraintsExp = getOpConstraints(logicalLeftShift.getOperation()); @@ -802,8 +804,8 @@ TEST_F(OpModelBase, SqrtOpInterface) { auto input = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - auto sqrt = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input}); + auto sqrt = SqrtOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input}); // test SqrtOp interface auto constraintsExp = getOpConstraints(sqrt.getOperation()); @@ -834,8 +836,8 @@ TEST_F(OpModelBase, SigmoidOpInterface) { auto input = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - auto sigmoid = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input}); + auto sigmoid = SigmoidOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input}); // test SigmoidOp interface auto constraintsExp = getOpConstraints(sigmoid.getOperation()); @@ -867,7 +869,7 @@ TEST_F(OpModelBase, SoftmaxOpInterface) { auto output = createRankedTensorType(tensorShape); auto softmax = - builder.create(builder.getUnknownLoc(), output, input, -1); + SoftmaxOp::create(builder, builder.getUnknownLoc(), output, input, -1); // test SoftmaxOp interface auto constraintsExp = getOpConstraints(softmax.getOperation()); @@ -903,9 +905,8 @@ TEST_F(OpModelBase, LinearOpInterface) { auto bias = createEmptyTensor(biasShape); auto outputType = createRankedTensorType(tensorShapeO); - auto linear = - builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB, bias}); + auto linear = LinearOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB, bias}); // test LinearOp interface auto constraintsExp = getOpConstraints(linear.getOperation()); @@ -941,9 +942,8 @@ TEST_F(OpModelBase, LinearOpInterfaceNullOutput) { auto bias = createEmptyTensor(biasShape); auto outputType = createRankedTensorType(tensorShapeO); - auto linear = - builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB, bias}); + auto linear = LinearOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB, bias}); // test LinearOp interface OpModel backend = dyn_cast(linear.getOperation()); @@ -977,9 +977,8 @@ TEST_F(OpModelBase, LinearOpInterfacePartialOutput) { auto outputLayout = CreateTiledLayout(tensorShapeO, BufferType::L1, TensorMemoryLayout::BlockSharded) .withIgnorePhysicalLayout(true); - auto linear = - builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB, bias}); + auto linear = LinearOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB, bias}); // test LinearOp interface OpModel backend = dyn_cast(linear.getOperation()); @@ -1008,8 +1007,8 @@ TEST_F(OpModelBase, MatmulOpInterface) { auto inputB = createEmptyTensor(tensorShapeB); auto outputType = createRankedTensorType(tensorShapeO); - auto matmul = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmul = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); // test MatmulOp interface auto constraintsExp = getOpConstraints(matmul.getOperation()); @@ -1043,8 +1042,8 @@ TEST_F(OpModelBase, MatmulOpInterfaceNullOutput) { auto inputB = createEmptyTensor(tensorShapeB); auto outputType = createRankedTensorType(tensorShapeO); - auto matmul = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmul = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); // test MatmulOp interface OpModel backend = dyn_cast(matmul.getOperation()); @@ -1076,8 +1075,8 @@ TEST_F(OpModelBase, MatmulOpInterfacePartialOutput) { auto outputLayout = CreateTiledLayout(tensorShapeO, BufferType::L1, TensorMemoryLayout::BlockSharded) .withIgnorePhysicalLayout(true); - auto matmul = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{inputA, inputB}); + auto matmul = MatmulOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{inputA, inputB}); // test MatmulOp interface OpModel backend = dyn_cast(matmul.getOperation()); @@ -1108,8 +1107,8 @@ void testReductionOp(OpModelBase *testFixture, mlir::OpBuilder &builder, OpConstraintsFn getOpConstraintsFn, OpRuntimeFn getOpRuntimeFn) { // Create the reduction operation - auto op = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input}); + auto op = OpType::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input}); op.setKeepDim(true); op.setDimArgAttr(builder.getArrayAttr( llvm::SmallVector{builder.getI64IntegerAttr(1)})); @@ -1197,9 +1196,9 @@ TEST_F(OpModelBase, ArgMaxOpInterface) { auto outputType = createRankedTensorType(tensorShapeA, builder.getBF16Type(), outputLayout); - auto argMax = builder.create(builder.getUnknownLoc(), outputType, - input, builder.getI32IntegerAttr(1), - false, false, nullptr); + auto argMax = + ArgMaxOp::create(builder, builder.getUnknownLoc(), outputType, input, + builder.getI32IntegerAttr(1), false, false, nullptr); // getOutputLayout() hardcodes tiled L1 layout, so we cannot use it OpModel backend = dyn_cast(argMax.getOperation()); @@ -1233,9 +1232,9 @@ TEST_F(OpModelBase, ProdOpInterface) { auto input = createEmptyTensor(tensorShapeA); auto output = createEmptyTensor(tensorShapeA); - auto prod = builder.create(builder.getUnknownLoc(), output.getType(), - input, builder.getI64IntegerAttr(0), - builder.getBoolAttr(false), nullptr); + auto prod = ProdOp::create(builder, builder.getUnknownLoc(), output.getType(), + input, builder.getI64IntegerAttr(0), + builder.getBoolAttr(false), nullptr); // test prod Op interface auto constraintsExp = getOpConstraints(prod.getOperation()); @@ -1269,8 +1268,8 @@ TEST_F(OpModelBase, ScatterOpInterface) { ttcore::ReduceType::Sum); const int32_t dim = 0; - auto scatter = builder.create( - builder.getUnknownLoc(), output.getType(), input, index, source, + auto scatter = ScatterOp::create( + builder, builder.getUnknownLoc(), output.getType(), input, index, source, builder.getI32IntegerAttr(dim), reduceType, nullptr); // test ScatterOp interface @@ -1304,8 +1303,8 @@ TEST_F(OpModelBase, ReshapeOpInterface) { auto input = createEmptyTensor(tensorShapeA); auto output = createEmptyTensor(tensorShapeO); - auto reshape = builder.create( - builder.getUnknownLoc(), output.getType(), mlir::ValueRange{input}); + auto reshape = ReshapeOp::create(builder, builder.getUnknownLoc(), + output.getType(), mlir::ValueRange{input}); reshape.setShapeAttr(builder.getArrayAttr(llvm::SmallVector{ builder.getI64IntegerAttr(64 * 4), builder.getI64IntegerAttr(1024 / 4)})); @@ -1374,8 +1373,9 @@ TEST_F(OpModelBase, SliceStaticOpInterface) { llvm::SmallVector endsArray = {1, 56, 56, 95}; llvm::SmallVector stepArray = {1, 2, 1, 1}; - auto sliceStaticOp = builder.create( - builder.getUnknownLoc(), output.getType(), mlir::ValueRange{input}); + auto sliceStaticOp = + SliceStaticOp::create(builder, builder.getUnknownLoc(), output.getType(), + mlir::ValueRange{input}); sliceStaticOp.setBeginsAttr(builder.getI64ArrayAttr(beginsArray)); sliceStaticOp.setEndsAttr(builder.getI64ArrayAttr(endsArray)); @@ -1423,8 +1423,8 @@ TEST_F(OpModelBase, SliceDynamicOpInterface) { builder.getI32IntegerAttr(1)}; auto sliceDynamicOp = - builder.create(builder.getUnknownLoc(), output.getType(), - mlir::ValueRange{input, begins, ends}); + SliceDynamicOp::create(builder, builder.getUnknownLoc(), output.getType(), + mlir::ValueRange{input, begins, ends}); sliceDynamicOp.setStepAttr(builder.getArrayAttr(stepAttrs)); // test SliceDynamicOp interface @@ -1453,14 +1453,14 @@ TEST_F(OpModelBase, SliceDynamicOpInterface) { TEST_F(OpModelBase, toLayoutOp) { llvm::SmallVector tensorShape = {64, 1024}; RankedTensorType rankedTensorType = createRankedTensorType(tensorShape); - auto tensor = builder.create( - builder.getUnknownLoc(), rankedTensorType, nullptr, - ShapeAttr::get(&context, tensorShape), nullptr, - LayoutAttr::get(&context, Layout::RowMajor), nullptr); + auto tensor = + OnesOp::create(builder, builder.getUnknownLoc(), rankedTensorType, + nullptr, ShapeAttr::get(&context, tensorShape), nullptr, + LayoutAttr::get(&context, Layout::RowMajor), nullptr); ToLayoutOp toLayout = - builder.create(builder.getUnknownLoc(), tensor.getType(), - tensor, Layout::Tile, nullptr, nullptr); + ToLayoutOp::create(builder, builder.getUnknownLoc(), tensor.getType(), + tensor, Layout::Tile, nullptr, nullptr); // Manually create the operand layouts for calling the backend to make sure // the layouts are propagated all the way @@ -1510,8 +1510,8 @@ TEST_F(OpModelBase, toMemoryConfigOp) { auto inputTensor = createEmptyTensor(tensorShape, nullptr, inputLayout_L1Tiled); - ToMemoryConfigOp toMemoryConfig = builder.create( - builder.getUnknownLoc(), inputTensor.getType(), inputTensor, + ToMemoryConfigOp toMemoryConfig = ToMemoryConfigOp::create( + builder, builder.getUnknownLoc(), inputTensor.getType(), inputTensor, memoryConfig); OpModel backend = dyn_cast(toMemoryConfig.getOperation()); @@ -1554,8 +1554,8 @@ TEST_F(OpModelBase, concatOp) { mlir::Value inputTensor3 = createEmptyTensor(tensorShape3); mlir::Value output = createEmptyTensor(tensorShapeO); - auto concatOp = builder.create( - builder.getUnknownLoc(), output.getType(), + auto concatOp = ConcatOp::create( + builder, builder.getUnknownLoc(), output.getType(), mlir::ValueRange{inputTensor1, inputTensor2, inputTensor3}, 2, nullptr); // test concat Op interface @@ -1588,8 +1588,8 @@ TEST_F(OpModelBase, transposeOp) { auto input = createEmptyTensor(tensorShapeA); auto output = createEmptyTensor(tensorShapeO); - auto transpose = builder.create(builder.getUnknownLoc(), - output.getType(), input, 0, 1); + auto transpose = TransposeOp::create(builder, builder.getUnknownLoc(), + output.getType(), input, 0, 1); // test transpose Op interface auto constraintsExp = getOpConstraints(transpose.getOperation()); @@ -1621,9 +1621,9 @@ TEST_F(OpModelBase, morehCumSumOp) { auto input = createEmptyTensor(tensorShapeA); auto output = createEmptyTensor(tensorShapeO); - auto morehCumSum = builder.create( - builder.getUnknownLoc(), output.getType(), input, - builder.getI64IntegerAttr(0), nullptr); + auto morehCumSum = + MorehCumSumOp::create(builder, builder.getUnknownLoc(), output.getType(), + input, builder.getI64IntegerAttr(0), nullptr); // test morehCumSum Op interface auto constraintsExp = getOpConstraints(morehCumSum.getOperation()); @@ -1657,8 +1657,8 @@ TEST_F(OpModelBase, TopKOp) { auto topKValues = createEmptyTensor(tensorShapeO); auto indices = createEmptyTensor(tensorShapeO); // TopKOp returns 2 tensors: top k values and their indices - auto topK = builder.create( - builder.getUnknownLoc(), + auto topK = TopKOp::create( + builder, builder.getUnknownLoc(), mlir::TypeRange{topKValues.getType(), indices.getType()}, // 2 result types input, k, /*dim=*/-1, /*largest=*/false, /*sorted=*/true, nullptr); @@ -1704,8 +1704,8 @@ TEST_F(OpModelBase, ConcatenateHeadsOpInterface) { auto input = createEmptyTensor(inputShape); auto outputType = createRankedTensorType(outputShape); - auto concatenateHeads = builder.create( - builder.getUnknownLoc(), outputType, input); + auto concatenateHeads = ConcatenateHeadsOp::create( + builder, builder.getUnknownLoc(), outputType, input); // test ConcatenateHeadsOp interface auto constraintsExp = getOpConstraints(concatenateHeads.getOperation()); @@ -1749,8 +1749,8 @@ TEST_F(OpModelBase, RotaryEmbeddingLlamaOpInterface) { auto outputType = createRankedTensorType(shape); bool isDecodeMode = false; - auto rotaryEmbeddingLlama = builder.create( - builder.getUnknownLoc(), outputType, input, cos, sin, transMat, + auto rotaryEmbeddingLlama = RotaryEmbeddingLlamaOp::create( + builder, builder.getUnknownLoc(), outputType, input, cos, sin, transMat, isDecodeMode, /*memory_config=*/nullptr, /*compute_config=*/nullptr); auto constraintsExp = getOpConstraints(rotaryEmbeddingLlama.getOperation()); @@ -1791,8 +1791,8 @@ TEST_F(OpModelBase, RotaryEmbeddingOpInterface) { auto sin = createEmptyTensor(rotationShape); auto outputType = createRankedTensorType(inputShape); - auto rotaryEmbedding = builder.create( - builder.getUnknownLoc(), outputType, input, cos, sin, + auto rotaryEmbedding = RotaryEmbeddingOp::create( + builder, builder.getUnknownLoc(), outputType, input, cos, sin, /*tokenIndex=*/nullptr, /*memory_config=*/nullptr, /*compute_config=*/nullptr); @@ -1847,8 +1847,8 @@ TEST_F(OpModelBase, NLPCreateQKVHeadsDecodeOpInterface) { IntegerAttr numKVHeadsAttr = builder.getUI32IntegerAttr(numHeads); BoolAttr overlapQKCoregridAttr = builder.getBoolAttr(overlapQKCoregrid); - auto nlpCreateQKVHeadsDecode = builder.create( - builder.getUnknownLoc(), TypeRange(returnTypes), input, + auto nlpCreateQKVHeadsDecode = NLPCreateQKVHeadsDecodeOp::create( + builder, builder.getUnknownLoc(), TypeRange(returnTypes), input, /*batchOffset=*/nullptr, numHeads, numKVHeadsAttr, overlapQKCoregridAttr, /*sliceSize=*/nullptr, /*memory_config=*/nullptr); @@ -1896,8 +1896,8 @@ TEST_F(OpModelBase, SplitQueryKeyValueAndSplitHeadsOpInterface) { BoolAttr transposeKeyAttr = builder.getBoolAttr(false); auto splitQueryKeyValueAndSplitHeads = - builder.create( - builder.getUnknownLoc(), + SplitQueryKeyValueAndSplitHeadsOp::create( + builder, builder.getUnknownLoc(), TypeRange({outputQuery, outputKey, outputValue}), input, /*kv_input_tensor=*/nullptr, numHeadsAttr, /*num_kv_heads*/ nullptr, transposeKeyAttr, /*memory_config=*/nullptr); @@ -1974,8 +1974,8 @@ TEST_F(OpModelBase, ScaledDotProductAttentionDecodeOpInterface) { auto outputType = createRankedTensorType(queryShape, tiledElemType, queryLayout); - auto sdpAttentionDecode = builder.create( - builder.getUnknownLoc(), outputType, query, key, value, + auto sdpAttentionDecode = ScaledDotProductAttentionDecodeOp::create( + builder, builder.getUnknownLoc(), outputType, query, key, value, /*is_causal=*/false, /*attention_mask=*/attentionMask, /*cur_pos_tensor=*/curPos, @@ -2065,15 +2065,15 @@ TEST_F(OpModelBase, DISABLED_PagedScaledDotProductAttentionDecodeOpInterface) { auto outputType = createRankedTensorType(queryShape, tiledElemType, queryLayout); - auto sdpAttentionDecode = - builder.create( - builder.getUnknownLoc(), outputType, query, key, value, pageTable, - /*is_causal=*/true, - /*attention_mask*/ nullptr, - /*cur_pos_tensor=*/curPos, - /*attention_sink=*/nullptr, - /*scale=*/builder.getF32FloatAttr(0.125f), - /*memory_config=*/nullptr); + auto sdpAttentionDecode = PagedScaledDotProductAttentionDecodeOp::create( + builder, builder.getUnknownLoc(), outputType, query, key, value, + pageTable, + /*is_causal=*/true, + /*attention_mask*/ nullptr, + /*cur_pos_tensor=*/curPos, + /*attention_sink=*/nullptr, + /*scale=*/builder.getF32FloatAttr(0.125f), + /*memory_config=*/nullptr); OpModel backend = dyn_cast(sdpAttentionDecode.getOperation()); auto constraintsExp = backend.getOpConstraints( @@ -2156,8 +2156,8 @@ TEST_F(OpModelBase, ScaledDotProductAttentionOpInterface) { auto outputType = createRankedTensorType(queryShape, tiledElemType, queryLayout); - auto sdpAttention = builder.create( - builder.getUnknownLoc(), outputType, query, key, value, + auto sdpAttention = ScaledDotProductAttentionOp::create( + builder, builder.getUnknownLoc(), outputType, query, key, value, /*attention_mask=*/attentionMask, /*is_causal=*/false, /*scale=*/nullptr, @@ -2208,8 +2208,8 @@ TEST_F(OpModelBase, NLPConcatHeadsOpInterface) { auto input = createEmptyTensor(inputShape); auto outputType = createRankedTensorType(outputShape); - auto nlpConcatHeads = builder.create( - builder.getUnknownLoc(), outputType, input); + auto nlpConcatHeads = NLPConcatHeadsOp::create( + builder, builder.getUnknownLoc(), outputType, input); auto constraintsExp = getOpConstraints(nlpConcatHeads.getOperation()); if (constraintsExp) { @@ -2242,8 +2242,8 @@ TEST_F(OpModelBase, repeatInterleaveOp) { auto input = createEmptyTensor(tensorShapeA); auto output = createEmptyTensor(tensorShapeO); - auto repeatInterleave = builder.create( - builder.getUnknownLoc(), output.getType(), input, 2, 0, nullptr); + auto repeatInterleave = RepeatInterleaveOp::create( + builder, builder.getUnknownLoc(), output.getType(), input, 2, 0, nullptr); // test repeatInterleave Op interface auto constraintsExp = getOpConstraints(repeatInterleave.getOperation()); @@ -2279,8 +2279,8 @@ TEST_F(OpModelBase, repeatOp) { llvm::ArrayRef repeatDims(repeatDimsVec); auto repeatDimsAttr = ShapeAttr::get(&context, repeatDims); - auto repeat = builder.create( - builder.getUnknownLoc(), output.getType(), input, repeatDimsAttr); + auto repeat = RepeatOp::create(builder, builder.getUnknownLoc(), + output.getType(), input, repeatDimsAttr); // test repeat Op interface auto constraintsExp = getOpConstraints(repeat.getOperation()); @@ -2315,9 +2315,8 @@ TEST_F(OpModelBase, padOp) { std::vector paddingVec = {0, 2, 0, 2}; llvm::ArrayRef padding(paddingVec); - auto pad = - builder.create(builder.getUnknownLoc(), output.getType(), input, - padding, llvm::APFloat(0.0f), false, nullptr); + auto pad = PadOp::create(builder, builder.getUnknownLoc(), output.getType(), + input, padding, llvm::APFloat(0.0f), false, nullptr); // test pad Op interface auto constraintsExp = getOpConstraints(pad.getOperation()); @@ -2350,11 +2349,11 @@ TEST_F(OpModelBase, sortOp) { auto indices = createEmptyTensor(tensorShapeA); // SortOp returns 2 tensors: sorted values and indices - auto sort = builder.create( - builder.getUnknownLoc(), - mlir::TypeRange{sortedValues.getType(), - indices.getType()}, // 2 result types - input, 0, false, false, nullptr); + auto sort = + SortOp::create(builder, builder.getUnknownLoc(), + mlir::TypeRange{sortedValues.getType(), + indices.getType()}, // 2 result types + input, 0, false, false, nullptr); // test sort Op interface auto constraintsExp = getOpConstraints(sort.getOperation()); @@ -2422,8 +2421,8 @@ TEST_F(OpModelBase, maxPool2dWithIndicesOp) { llvm::SmallVector dilation = {dilationHeight, dilationWidth}; // MaxPool2dWithIndicesOp returns 2 tensors: pooled values and indices - auto maxPool2dWithIndices = builder.create( - builder.getUnknownLoc(), + auto maxPool2dWithIndices = MaxPool2dWithIndicesOp::create( + builder, builder.getUnknownLoc(), mlir::TypeRange{pooledValues.getType(), indices.getType()}, input, batchSize, inputHeight, inputWidth, numChannels, kernelSize, stride, padding, dilation, memoryConfigAttr, appliedShardScheme, ceilMode, @@ -2452,16 +2451,16 @@ TEST_F(OpModelBase, typecastOp) { RankedTensorType rankedTensorTypeBF16 = RankedTensorType::get(tensorShape, builder.getBF16Type()); - auto input = builder.create( - builder.getUnknownLoc(), rankedTensorTypeBF16, nullptr, + auto input = OnesOp::create( + builder, builder.getUnknownLoc(), rankedTensorTypeBF16, nullptr, ShapeAttr::get(&context, tensorShape), ttcore::DataTypeAttr::get(&context, ttcore::DataType::BFloat16), nullptr, nullptr); RankedTensorType rankedTensorTypeF32 = RankedTensorType::get(tensorShape, builder.getF32Type()); - auto typecast = builder.create( - builder.getUnknownLoc(), rankedTensorTypeF32, input, + auto typecast = TypecastOp::create( + builder, builder.getUnknownLoc(), rankedTensorTypeF32, input, ttcore::DataTypeAttr::get(&context, ttcore::DataType::Float32)); auto constraintsExp = getOpConstraints(typecast.getOperation()); @@ -2501,33 +2500,34 @@ TEST_F(OpModelBase, Conv2dInterface) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), // Location - outputType, // Output type - input, // Input tensor - weight, // Weight tensor - nullptr, // Bias tensor (optional) - deviceOp, // Device operation - 3, // Input channels - 64, // Output channels - 1, // Batch size - 224, // Input height - 224, // Input width - llvm::ArrayRef({7, 7}), // Kernel size [H, W] - llvm::ArrayRef({2, 2}), // Stride [H, W] - llvm::ArrayRef({3, 3}), // Padding [H, W] - llvm::ArrayRef({1, 1}), // Dilation [H, W] - 1, // Groups - outputDtype, // OutputDtype - nullptr, // Conv2dConfig (optional) - nullptr, // ComputeKernelConfig (optional) - nullptr // Conv2dSliceConfig (optional) - ); + Conv2dOp conv2d = + Conv2dOp::create(builder, + builder.getUnknownLoc(), // Location + outputType, // Output type + input, // Input tensor + weight, // Weight tensor + nullptr, // Bias tensor (optional) + deviceOp, // Device operation + 3, // Input channels + 64, // Output channels + 1, // Batch size + 224, // Input height + 224, // Input width + llvm::ArrayRef({7, 7}), // Kernel size [H, W] + llvm::ArrayRef({2, 2}), // Stride [H, W] + llvm::ArrayRef({3, 3}), // Padding [H, W] + llvm::ArrayRef({1, 1}), // Dilation [H, W] + 1, // Groups + outputDtype, // OutputDtype + nullptr, // Conv2dConfig (optional) + nullptr, // ComputeKernelConfig (optional) + nullptr // Conv2dSliceConfig (optional) + ); // test Conv2dOp interface auto constraintsExp = getOpConstraints(conv2d.getOperation()); @@ -2563,33 +2563,34 @@ TEST_F(OpModelBase, Conv2dInterfaceNullOutput) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), // Location - outputType, // Output type - input, // Input tensor - weight, // Weight tensor - nullptr, // Bias tensor (optional) - deviceOp, // Device operation - 3, // Input channels - 64, // Output channels - 1, // Batch size - 224, // Input height - 224, // Input width - llvm::ArrayRef({7, 7}), // Kernel size [H, W] - llvm::ArrayRef({2, 2}), // Stride [H, W] - llvm::ArrayRef({3, 3}), // Padding [H, W] - llvm::ArrayRef({1, 1}), // Dilation [H, W] - 1, // Groups - outputDtype, // OutputDtype - nullptr, // Conv2dConfig (optional) - nullptr, // ComputeKernelConfig (optional) - nullptr // Conv2dSliceConfig (optional) - ); + Conv2dOp conv2d = + Conv2dOp::create(builder, + builder.getUnknownLoc(), // Location + outputType, // Output type + input, // Input tensor + weight, // Weight tensor + nullptr, // Bias tensor (optional) + deviceOp, // Device operation + 3, // Input channels + 64, // Output channels + 1, // Batch size + 224, // Input height + 224, // Input width + llvm::ArrayRef({7, 7}), // Kernel size [H, W] + llvm::ArrayRef({2, 2}), // Stride [H, W] + llvm::ArrayRef({3, 3}), // Padding [H, W] + llvm::ArrayRef({1, 1}), // Dilation [H, W] + 1, // Groups + outputDtype, // OutputDtype + nullptr, // Conv2dConfig (optional) + nullptr, // ComputeKernelConfig (optional) + nullptr // Conv2dSliceConfig (optional) + ); // test Conv2dOp interface OpModel backend = dyn_cast(conv2d.getOperation()); @@ -2632,14 +2633,14 @@ TEST_F(OpModelBase, PrepareConv2dWeightsOutput) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, nullptr, deviceOp, 3, - 64, 1, 224, 224, llvm::ArrayRef({7, 7}), + Conv2dOp conv2d = Conv2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, nullptr, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({1, 1}), 1, outputDtype, nullptr, nullptr, nullptr); @@ -2686,14 +2687,14 @@ TEST_F(OpModelBase, Conv2dInterfaceConfigs) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, nullptr, deviceOp, 3, - 64, 1, 224, 224, llvm::ArrayRef({7, 7}), + Conv2dOp conv2d = Conv2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, nullptr, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({1, 1}), 1, outputDtype, nullptr, nullptr, nullptr); @@ -2796,14 +2797,14 @@ TEST_F(OpModelBase, conv2dInterfaceComputeKernelConfig) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, nullptr, deviceOp, 3, - 64, 1, 224, 224, llvm::ArrayRef({7, 7}), + Conv2dOp conv2d = Conv2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, nullptr, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({1, 1}), 1, outputDtype, nullptr, nullptr, nullptr); @@ -2859,12 +2860,13 @@ TEST_F(OpModelBase, Conv3dInterface) { createEmptyTensor(weightShape, builder.getBF16Type(), weightLayout); auto outputType = createRankedTensorType(outputShape); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv3dOp conv3d = builder.create( + Conv3dOp conv3d = Conv3dOp::create( + builder, builder.getUnknownLoc(), // Location outputType, // Output type input, // Input tensor @@ -2937,14 +2939,14 @@ TEST_F(OpModelBase, ConvTranspose2dInterfaceConfigs) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - ConvTranspose2dOp convTranspose2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, nullptr, deviceOp, 3, - 64, 1, 224, 224, llvm::ArrayRef({7, 7}), + ConvTranspose2dOp convTranspose2d = ConvTranspose2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, nullptr, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({0, 0}), llvm::ArrayRef({1, 1}), 1, outputDtype, nullptr, nullptr, nullptr, nullptr); @@ -3016,14 +3018,14 @@ TEST_F(OpModelBase, PrepareConv2dWeightsTest) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, nullptr, deviceOp, 3, - 64, 1, 224, 224, llvm::ArrayRef({7, 7}), + Conv2dOp conv2d = Conv2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, nullptr, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({1, 1}), 1, outputDtype, nullptr, nullptr, nullptr); @@ -3048,32 +3050,32 @@ TEST_F(OpModelBase, PrepareConv2dWeightsTest) { auto preparedWeightOutputType = op_model::getPreparedConv2dWeightsOutputTensor(&conv2d, conv2dConfig); - PrepareConv2dWeightsOp prepareConv2dWeights = - builder.create( - builder.getUnknownLoc(), // Location - preparedWeightOutputType, // Output type (derived from conv2d) - conv2d.getWeight(), // Weight tensor from conv2d - inputMemConfigAttr, // Input memory config - inputLayoutAttr, // Input tensor layout - builder.getStringAttr("OIHW"), // Weights format - conv2d.getInChannelsAttr(), // Input channels from conv2d - conv2d.getOutChannelsAttr(), // Output channels from conv2d - conv2d.getBatchSizeAttr(), // Batch size from conv2d - conv2d.getInputHeightAttr(), // Input height from conv2d - conv2d.getInputWidthAttr(), // Input width from conv2d - conv2d.getKernelSizeAttr(), // Kernel size from conv2d - conv2d.getStrideAttr(), // Stride from conv2d - conv2d.getPaddingAttr(), // Padding from conv2d - conv2d.getDilationAttr(), // Dilation from conv2d - builder.getBoolAttr(conv2d.getBias() != nullptr), // has_bias - conv2d.getGroupsAttr(), // Groups from conv2d - conv2d.getDevice(), // Device from conv2d - inputDtypeAttr, // Input dtype - outputDtype, // Output dtype - conv2d.getConv2dConfigAttr(), // Conv2dConfig from conv2d - conv2d.getComputeConfigAttr(), // ComputeKernelConfig from conv2d - conv2d.getConv2dSliceConfigAttr() // Conv2dSliceConfig from conv2d - ); + PrepareConv2dWeightsOp prepareConv2dWeights = PrepareConv2dWeightsOp::create( + builder, + builder.getUnknownLoc(), // Location + preparedWeightOutputType, // Output type (derived from conv2d) + conv2d.getWeight(), // Weight tensor from conv2d + inputMemConfigAttr, // Input memory config + inputLayoutAttr, // Input tensor layout + builder.getStringAttr("OIHW"), // Weights format + conv2d.getInChannelsAttr(), // Input channels from conv2d + conv2d.getOutChannelsAttr(), // Output channels from conv2d + conv2d.getBatchSizeAttr(), // Batch size from conv2d + conv2d.getInputHeightAttr(), // Input height from conv2d + conv2d.getInputWidthAttr(), // Input width from conv2d + conv2d.getKernelSizeAttr(), // Kernel size from conv2d + conv2d.getStrideAttr(), // Stride from conv2d + conv2d.getPaddingAttr(), // Padding from conv2d + conv2d.getDilationAttr(), // Dilation from conv2d + builder.getBoolAttr(conv2d.getBias() != nullptr), // has_bias + conv2d.getGroupsAttr(), // Groups from conv2d + conv2d.getDevice(), // Device from conv2d + inputDtypeAttr, // Input dtype + outputDtype, // Output dtype + conv2d.getConv2dConfigAttr(), // Conv2dConfig from conv2d + conv2d.getComputeConfigAttr(), // ComputeKernelConfig from conv2d + conv2d.getConv2dSliceConfigAttr() // Conv2dSliceConfig from conv2d + ); auto constraintsExp = getOpConstraints(prepareConv2dWeights.getOperation()); ASSERT_TRUE(static_cast(constraintsExp)); @@ -3126,8 +3128,8 @@ TEST_F(OpModelBase, PrepareConv2dBiasTest) { auto outputDtype = ttcore::DataTypeAttr::get( &context, ttcore::elementTypeToDataType(outputType.getElementType())); - GetDeviceOp deviceOp = builder.create( - builder.getUnknownLoc(), builder.getType(), + GetDeviceOp deviceOp = GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), MeshShapeAttr::get(builder.getContext(), 1, 1), MeshOffsetAttr::get(builder.getContext(), 0, 0)); Conv2dConfigAttr configAttr = Conv2dConfigAttr::get(&context); @@ -3135,9 +3137,9 @@ TEST_F(OpModelBase, PrepareConv2dBiasTest) { // get_cb_info expects conv_config.weights_dtype to be set otherwise it // issues an error. See conv2d_op_program_factory_common.cpp in tt-metal. - Conv2dOp conv2d = builder.create( - builder.getUnknownLoc(), outputType, input, weight, bias, deviceOp, 3, 64, - 1, 224, 224, llvm::ArrayRef({7, 7}), + Conv2dOp conv2d = Conv2dOp::create( + builder, builder.getUnknownLoc(), outputType, input, weight, bias, + deviceOp, 3, 64, 1, 224, 224, llvm::ArrayRef({7, 7}), llvm::ArrayRef({2, 2}), llvm::ArrayRef({3, 3}), llvm::ArrayRef({1, 1}), 1, outputDtype, configAttr, nullptr, nullptr); @@ -3166,7 +3168,8 @@ TEST_F(OpModelBase, PrepareConv2dBiasTest) { auto preparedBiasOutputType = mlir::RankedTensorType::get( oldBiasType.getShape(), oldBiasType.getElementType(), newBiasLayout); - PrepareConv2dBiasOp prepareConv2dBias = builder.create( + PrepareConv2dBiasOp prepareConv2dBias = PrepareConv2dBiasOp::create( + builder, builder.getUnknownLoc(), // Location preparedBiasOutputType, // Output type (derived from bias) conv2d.getBias(), // Bias tensor from conv2d @@ -3250,10 +3253,11 @@ TEST_F(OpModelBase, maxPool2DOp) { llvm::SmallVector padding = {paddingHeight, paddingWidth}; llvm::SmallVector dilation = {dilationHeight, dilationWidth}; - auto maxPool2DOp = builder.create( - builder.getUnknownLoc(), output.getType(), input, batchSize, inputHeight, - inputWidth, numChannels, kernelSize, stride, padding, dilation, - memoryConfigAttr, appliedShardScheme, ceilMode, reallocateHaloOutput, + auto maxPool2DOp = MaxPool2dOp::create( + builder, builder.getUnknownLoc(), output.getType(), input, batchSize, + inputHeight, inputWidth, numChannels, kernelSize, stride, padding, + dilation, memoryConfigAttr, appliedShardScheme, ceilMode, + reallocateHaloOutput, /*config_tensors_in_dram=*/nullptr); maxPool2DOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -3321,10 +3325,11 @@ TEST_F(OpModelBase, avgPool2DOp) { llvm::SmallVector padding = {paddingHeight, paddingWidth}; llvm::SmallVector dilation = {dilationHeight, dilationWidth}; - auto avgPool2DOp = builder.create( - builder.getUnknownLoc(), output.getType(), input, batchSize, inputHeight, - inputWidth, numChannels, kernelSize, stride, padding, dilation, - memoryConfigAttr, appliedShardScheme, ceilMode, reallocateHaloOutput, + auto avgPool2DOp = AvgPool2dOp::create( + builder, builder.getUnknownLoc(), output.getType(), input, batchSize, + inputHeight, inputWidth, numChannels, kernelSize, stride, padding, + dilation, memoryConfigAttr, appliedShardScheme, ceilMode, + reallocateHaloOutput, /*count_include_pad=*/true, /*config_tensors_in_dram=*/nullptr); avgPool2DOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -3366,7 +3371,7 @@ TEST_F(OpModelBase, globalAvgPool2dOp) { CreateRowMajorLayout(tensorShapeO, BufferType::DRAM, TensorMemoryLayout::Interleaved)); - auto globalAvgPool2dOp = builder.create( + auto globalAvgPool2dOp = GlobalAvgPool2dOp::create(builder, builder.getUnknownLoc(), output.getType(), input); globalAvgPool2dOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -3407,8 +3412,8 @@ TEST_F(OpModelBase, LeakyReluOp) { // Convert float value to APFloat object llvm::APFloat slopeAPF(slope); - LeakyReluOp leakyReluOp = builder.create( - builder.getUnknownLoc(), outputType, input, slopeAPF); + LeakyReluOp leakyReluOp = LeakyReluOp::create( + builder, builder.getUnknownLoc(), outputType, input, slopeAPF); leakyReluOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(leakyReluOp.getOperation()); @@ -3438,9 +3443,9 @@ TEST_F(OpModelBase, GeluBackwardOp) { auto gradNone = createEmptyTensor(tensorShape); auto outputTypeNone = createRankedTensorType(tensorShape); - GeluBackwardOp geluBackwardOpNone = builder.create( - builder.getUnknownLoc(), outputTypeNone, gradNone, inputNone, nullptr, - nullptr, builder.getStringAttr("none")); + GeluBackwardOp geluBackwardOpNone = GeluBackwardOp::create( + builder, builder.getUnknownLoc(), outputTypeNone, gradNone, inputNone, + nullptr, nullptr, builder.getStringAttr("none")); geluBackwardOpNone->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExpNone = getOpConstraints(geluBackwardOpNone.getOperation()); @@ -3465,9 +3470,9 @@ TEST_F(OpModelBase, GeluBackwardOp) { auto gradTanh = createEmptyTensor(tensorShape); auto outputTypeTanh = createRankedTensorType(tensorShape); - GeluBackwardOp geluBackwardOpTanh = builder.create( - builder.getUnknownLoc(), outputTypeTanh, gradTanh, inputTanh, nullptr, - nullptr, builder.getStringAttr("tanh")); + GeluBackwardOp geluBackwardOpTanh = GeluBackwardOp::create( + builder, builder.getUnknownLoc(), outputTypeTanh, gradTanh, inputTanh, + nullptr, nullptr, builder.getStringAttr("tanh")); geluBackwardOpTanh->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExpTanh = getOpConstraints(geluBackwardOpTanh.getOperation()); @@ -3504,8 +3509,9 @@ TEST_F(OpModelBase, clampScalarOp) { llvm::APFloat minValAPF(minVal); llvm::APFloat maxValAPF(maxVal); - ClampScalarOp clampScalarOp = builder.create( - builder.getUnknownLoc(), outputType, input, minValAPF, maxValAPF); + ClampScalarOp clampScalarOp = + ClampScalarOp::create(builder, builder.getUnknownLoc(), outputType, input, + minValAPF, maxValAPF); clampScalarOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(clampScalarOp.getOperation()); @@ -3536,8 +3542,8 @@ TEST_F(OpModelBase, clampTensorOp) { auto max = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - ClampTensorOp clampTensorOp = builder.create( - builder.getUnknownLoc(), outputType, input, min, max); + ClampTensorOp clampTensorOp = ClampTensorOp::create( + builder, builder.getUnknownLoc(), outputType, input, min, max); clampTensorOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(clampTensorOp.getOperation()); @@ -3566,8 +3572,8 @@ TEST_F(OpModelBase, permuteOp) { auto input = createEmptyTensor(inputShape); auto outputType = createRankedTensorType(outputShape); - PermuteOp permuteOp = builder.create( - builder.getUnknownLoc(), outputType, input, + PermuteOp permuteOp = PermuteOp::create( + builder, builder.getUnknownLoc(), outputType, input, llvm::ArrayRef({0, 3, 1, 2}), nullptr, llvm::APFloat(0.0f)); permuteOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -3614,8 +3620,8 @@ TEST_F(OpModelBase, upsampleOp) { mlir::StringAttr modeAttr = builder.getStringAttr(mode); UpsampleOp upsampleOp = - builder.create(builder.getUnknownLoc(), outputType, input, - scaleFactorAttr, modeAttr, nullptr); + UpsampleOp::create(builder, builder.getUnknownLoc(), outputType, input, + scaleFactorAttr, modeAttr, nullptr); upsampleOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); // getOutputLayout() hardcodes L1, so we cannot use it @@ -3654,8 +3660,9 @@ TEST_F(OpModelBase, EmbeddingOpInterface) { auto outputType = createRankedTensorType(outputShape); // Create EmbeddingOp - auto embedding = builder.create( - builder.getUnknownLoc(), outputType, mlir::ValueRange{input, weight}); + auto embedding = + EmbeddingOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input, weight}); // Test EmbeddingOp interface constraints auto constraintsExp = getOpConstraints(embedding.getOperation()); @@ -3694,8 +3701,9 @@ TEST_F(OpModelBase, EmbeddingOpNullOutputLayout) { auto outputType = createRankedTensorType(outputShape); // Create EmbeddingOp - auto embedding = builder.create( - builder.getUnknownLoc(), outputType, ::mlir::ValueRange{input, weight}); + auto embedding = + EmbeddingOp::create(builder, builder.getUnknownLoc(), outputType, + ::mlir::ValueRange{input, weight}); // Test EmbeddingOp interface constraints auto constraintsExp = embedding.getOpConstraints( @@ -3741,8 +3749,8 @@ TEST_F(OpModelBase, EmbeddingBackwardOp) { CreateTiledLayout(inGradientShape, BufferType::L1, TensorMemoryLayout::Interleaved)); - auto embeddingBackward = builder.create( - builder.getUnknownLoc(), outputType, + auto embeddingBackward = EmbeddingBackwardOp::create( + builder, builder.getUnknownLoc(), outputType, ::mlir::ValueRange{input, weight, inGradient}); auto constraintsExp = getOpConstraints(embeddingBackward.getOperation()); @@ -3777,8 +3785,8 @@ TEST_F(OpModelBase, CacheOpConstraintsTest) { auto input2 = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - auto sub = builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input1, input2}); + auto sub = SubtractOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input1, input2}); // test SubtractOp interface auto constraintsExp = getOpConstraints(sub.getOperation()); @@ -3836,15 +3844,15 @@ TEST_F(OpModelBase, CacheOpConstraintsMissesTest) { auto input1 = createEmptyTensor(tensorShape1); auto input2 = createEmptyTensor(tensorShape1); auto outputType1 = createRankedTensorType(tensorShape1); - auto add1 = builder.create(builder.getUnknownLoc(), outputType1, - mlir::ValueRange{input1, input2}); + auto add1 = AddOp::create(builder, builder.getUnknownLoc(), outputType1, + mlir::ValueRange{input1, input2}); llvm::SmallVector tensorShape2 = {workerCoresN300, 512}; auto input3 = createEmptyTensor(tensorShape2); auto input4 = createEmptyTensor(tensorShape2); auto outputType2 = createRankedTensorType(tensorShape2); - auto add2 = builder.create(builder.getUnknownLoc(), outputType2, - mlir::ValueRange{input3, input4}); + auto add2 = AddOp::create(builder, builder.getUnknownLoc(), outputType2, + mlir::ValueRange{input3, input4}); // test AddOp interface auto constraintsExp1 = getOpConstraints(add1.getOperation()); @@ -3867,9 +3875,8 @@ TEST_F(OpModelBase, WhereOpInterface) { auto input2 = createEmptyTensor(tensorShape); auto input3 = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - auto where = - builder.create(builder.getUnknownLoc(), outputType, - mlir::ValueRange{input1, input2, input3}); + auto where = WhereOp::create(builder, builder.getUnknownLoc(), outputType, + mlir::ValueRange{input1, input2, input3}); // test WhereOp interface auto constraintsExp = getOpConstraints(where.getOperation()); @@ -3911,9 +3918,9 @@ TEST_F(OpModelBase, batchNormOp) { // BatchNormInference parameters llvm::APFloat epsilon(1e-05f); - BatchNormInferenceOp batchNormOp = builder.create( - builder.getUnknownLoc(), outputType, input, runningMean, runningVar, - epsilon, weight, bias, nullptr); + BatchNormInferenceOp batchNormOp = BatchNormInferenceOp::create( + builder, builder.getUnknownLoc(), outputType, input, runningMean, + runningVar, epsilon, weight, bias, nullptr); batchNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(batchNormOp.getOperation()); @@ -3965,9 +3972,9 @@ TEST_F(OpModelBase, batchNormOpL1Memory) { // BatchNorm parameters llvm::APFloat epsilon(1e-05f); - BatchNormInferenceOp batchNormOp = builder.create( - builder.getUnknownLoc(), outputType, input, runningMean, runningVar, - epsilon, weight, bias, nullptr); + BatchNormInferenceOp batchNormOp = BatchNormInferenceOp::create( + builder, builder.getUnknownLoc(), outputType, input, runningMean, + runningVar, epsilon, weight, bias, nullptr); batchNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(batchNormOp.getOperation()); @@ -4009,9 +4016,9 @@ TEST_F(OpModelBase, batchNormOpTraining) { llvm::APFloat epsilon(1e-05f); llvm::APFloat momentum(0.1f); - BatchNormTrainingOp batchNormTrainingOp = builder.create( - builder.getUnknownLoc(), outputType, input, runningMean, runningVar, - epsilon, momentum, weight, bias, nullptr); + BatchNormTrainingOp batchNormTrainingOp = BatchNormTrainingOp::create( + builder, builder.getUnknownLoc(), outputType, input, runningMean, + runningVar, epsilon, momentum, weight, bias, nullptr); batchNormTrainingOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(batchNormTrainingOp.getOperation()); @@ -4044,9 +4051,9 @@ TEST_F(OpModelBase, batchNormOpTrainingMinimal) { llvm::APFloat epsilon(1e-05f); llvm::APFloat momentum(0.1f); - BatchNormTrainingOp batchNormTrainingOp = builder.create( - builder.getUnknownLoc(), outputType, input, nullptr, nullptr, epsilon, - momentum, nullptr, nullptr, nullptr); + BatchNormTrainingOp batchNormTrainingOp = BatchNormTrainingOp::create( + builder, builder.getUnknownLoc(), outputType, input, nullptr, nullptr, + epsilon, momentum, nullptr, nullptr, nullptr); batchNormTrainingOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(batchNormTrainingOp.getOperation()); @@ -4099,9 +4106,9 @@ TEST_F(OpModelBase, batchNormOpTrainingL1Memory) { llvm::APFloat epsilon(1e-05f); llvm::APFloat momentum(0.1f); - BatchNormTrainingOp batchNormTrainingOp = builder.create( - builder.getUnknownLoc(), outputType, input, runningMean, runningVar, - epsilon, momentum, weight, bias, nullptr); + BatchNormTrainingOp batchNormTrainingOp = BatchNormTrainingOp::create( + builder, builder.getUnknownLoc(), outputType, input, runningMean, + runningVar, epsilon, momentum, weight, bias, nullptr); batchNormTrainingOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(batchNormTrainingOp.getOperation()); @@ -4141,8 +4148,8 @@ TEST_F(OpModelBase, rmsNormOp) { llvm::APFloat epsilon(1e-12f); RMSNormOp rmsNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - weight, bias, epsilon, nullptr, nullptr); + RMSNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + weight, bias, epsilon, nullptr, nullptr); rmsNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(rmsNormOp.getOperation()); @@ -4176,8 +4183,8 @@ TEST_F(OpModelBase, rmsNormOpMinimal) { llvm::APFloat epsilon(1e-12f); RMSNormOp rmsNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - nullptr, nullptr, epsilon, nullptr, nullptr); + RMSNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + nullptr, nullptr, epsilon, nullptr, nullptr); rmsNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(rmsNormOp.getOperation()); @@ -4226,8 +4233,8 @@ TEST_F(OpModelBase, rmsNormOpL1Memory) { llvm::APFloat epsilon(1e-12f); RMSNormOp rmsNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - weight, bias, epsilon, nullptr, nullptr); + RMSNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + weight, bias, epsilon, nullptr, nullptr); rmsNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(rmsNormOp.getOperation()); @@ -4267,8 +4274,8 @@ TEST_F(OpModelBase, layerNormOp) { llvm::APFloat epsilon(1e-12f); LayerNormOp layerNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - weight, bias, epsilon, nullptr); + LayerNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + weight, bias, epsilon, nullptr); layerNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(layerNormOp.getOperation()); @@ -4302,8 +4309,8 @@ TEST_F(OpModelBase, layerNormOpMinimal) { llvm::APFloat epsilon(1e-12f); LayerNormOp layerNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - nullptr, nullptr, epsilon, nullptr); + LayerNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + nullptr, nullptr, epsilon, nullptr); layerNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(layerNormOp.getOperation()); @@ -4352,8 +4359,8 @@ TEST_F(OpModelBase, layerNormOpL1Memory) { llvm::APFloat epsilon(1e-12f); LayerNormOp layerNormOp = - builder.create(builder.getUnknownLoc(), outputType, input, - weight, bias, epsilon, nullptr); + LayerNormOp::create(builder, builder.getUnknownLoc(), outputType, input, + weight, bias, epsilon, nullptr); layerNormOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); auto constraintsExp = getOpConstraints(layerNormOp.getOperation()); @@ -4398,14 +4405,14 @@ TEST_F(OpModelBase, EmptyOpInterface) { std::nullopt); // No sharding for this test // Create a device value (required for EmptyOp) - auto device = builder.create( - builder.getUnknownLoc(), builder.getType(), + auto device = ttnn::GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), ttnn::MeshShapeAttr::get(&context, 1, 1), ttnn::MeshOffsetAttr::get(&context, 0, 0)); // Create the EmptyOp with all required parameters - auto empty = builder.create( - builder.getUnknownLoc(), inputType, device, + auto empty = ttnn::EmptyOp::create( + builder, builder.getUnknownLoc(), inputType, device, ttnn::ShapeAttr::get(&context, inputTensorType.getShape()), ttcore::DataTypeAttr::get(&context, ttnnLayoutAttr.getDataType()), ttnn::LayoutAttr::get(&context, ttnnLayoutAttr.getLayout()), @@ -4438,9 +4445,9 @@ TEST_F(OpModelBase, ArangeOpInterface) { auto endAttr = builder.getI64IntegerAttr(10); auto stepAttr = builder.getI64IntegerAttr(2); - auto arange = builder.create( - builder.getUnknownLoc(), resultType, /*device=*/nullptr, startAttr, - endAttr, stepAttr, /*dtype=*/nullptr, /*layout=*/nullptr, + auto arange = ArangeOp::create( + builder, builder.getUnknownLoc(), resultType, /*device=*/nullptr, + startAttr, endAttr, stepAttr, /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); // test ArangeOp interface @@ -4507,16 +4514,16 @@ TEST_P(NamedFullOpModelTest, TestOpInterface) { const auto createZeros = [](OpBuilder &b, Location loc, Type type, ttnn::ShapeAttr shape) { - return b - .create(loc, type, /*device=*/nullptr, shape, /*dtype=*/nullptr, - /*layout=*/nullptr, /*memoryConfig=*/nullptr) + return ZerosOp::create(b, loc, type, /*device=*/nullptr, shape, + /*dtype=*/nullptr, + /*layout=*/nullptr, /*memoryConfig=*/nullptr) .getOperation(); }; const auto createOnes = [](OpBuilder &b, Location loc, Type type, ttnn::ShapeAttr shape) { - return b - .create(loc, type, /*device=*/nullptr, shape, /*dtype=*/nullptr, - /*layout=*/nullptr, /*memoryConfig=*/nullptr) + return OnesOp::create(b, loc, type, /*device=*/nullptr, shape, + /*dtype=*/nullptr, + /*layout=*/nullptr, /*memoryConfig=*/nullptr) .getOperation(); }; @@ -4539,8 +4546,8 @@ TEST_F(OpModelBase, FullOpInterface) { TensorMemoryLayout::Interleaved); auto outputType = createRankedTensorType(tensorShape, builder.getBF16Type(), layout); - auto fullInt = builder.create( - builder.getUnknownLoc(), outputType, /*device=*/nullptr, + auto fullInt = FullOp::create( + builder, builder.getUnknownLoc(), outputType, /*device=*/nullptr, ttnn::ShapeAttr::get(&context, tensorShape), builder.getI32IntegerAttr(42), /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); @@ -4562,8 +4569,8 @@ TEST_F(OpModelBase, FullOpInterface) { } // test FullOp interface with float fill value: - auto fullF = builder.create( - builder.getUnknownLoc(), outputType, /*device=*/nullptr, + auto fullF = FullOp::create( + builder, builder.getUnknownLoc(), outputType, /*device=*/nullptr, ttnn::ShapeAttr::get(&context, tensorShape), builder.getF32FloatAttr(0.5), /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); auto backendF = dyn_cast(fullF.getOperation()); @@ -4597,8 +4604,8 @@ TEST_F(OpModelBase, ConstantOpInterface) { mlir::DenseElementsAttr attr = mlir::DenseElementsAttr::get(tensorType, dataRef); - auto constant = builder.create( - builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, + auto constant = ConstantOp::create( + builder, builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); auto backend = dyn_cast(constant.getOperation()); @@ -4636,8 +4643,8 @@ TEST_F(OpModelBase, ConstantOpInterfaceBF16) { mlir::DenseElementsAttr attr = mlir::DenseElementsAttr::get( tensorType, llvm::ArrayRef(bfloats)); - auto constant = builder.create( - builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, + auto constant = ConstantOp::create( + builder, builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); auto backend = dyn_cast(constant.getOperation()); @@ -4672,8 +4679,8 @@ TEST_F(OpModelBase, ConstantOpInterfaceNullOutputLayout) { mlir::DenseElementsAttr attr = mlir::DenseElementsAttr::get(tensorType, dataRef); - auto constant = builder.create( - builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, + auto constant = ConstantOp::create( + builder, builder.getUnknownLoc(), outputType, /*device=*/nullptr, attr, /*dtype=*/nullptr, /*layout=*/nullptr, /*memoryConfig=*/nullptr); auto backend = dyn_cast(constant.getOperation()); @@ -4701,14 +4708,14 @@ TEST_F(OpModelBase, RandOpInterface) { createRankedTensorType(tensorShape, builder.getBF16Type(), layout); // Create device value using GetDeviceOp - auto device = builder.create( - builder.getUnknownLoc(), builder.getType(), + auto device = ttnn::GetDeviceOp::create( + builder, builder.getUnknownLoc(), builder.getType(), ttnn::MeshShapeAttr::get(&context, 1, 1), ttnn::MeshOffsetAttr::get(&context, 0, 0)); // Create RandOp with default parameters (low=0.0, high=1.0, seed=0) - auto randOp = builder.create( - builder.getUnknownLoc(), outputType, device, + auto randOp = RandOp::create( + builder, builder.getUnknownLoc(), outputType, device, ttnn::ShapeAttr::get(&context, tensorShape), /*low=*/nullptr, /*high=*/nullptr, /*seed=*/nullptr, /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); @@ -4730,8 +4737,8 @@ TEST_F(OpModelBase, RandOpInterface) { } // Test RandOp with custom parameters - auto randOpCustom = builder.create( - builder.getUnknownLoc(), outputType, device, + auto randOpCustom = RandOp::create( + builder, builder.getUnknownLoc(), outputType, device, ttnn::ShapeAttr::get(&context, tensorShape), builder.getF32FloatAttr(-1.0), // low builder.getF32FloatAttr(2.0), // high @@ -4762,8 +4769,8 @@ TEST_F(OpModelBase, DISABLED_DeallocateOpInterface) { llvm::SmallVector tensorShape = {workerCoresN300, 1024}; auto inputTensor = createEmptyTensor(tensorShape); auto deallocate = - builder.create(builder.getUnknownLoc(), inputTensor, - /*force=*/false); + DeallocateOp::create(builder, builder.getUnknownLoc(), inputTensor, + /*force=*/false); auto backend = dyn_cast(deallocate.getOperation()); auto constraintsExp = backend.getOpConstraints( @@ -4791,8 +4798,8 @@ TEST_F(OpModelBase, FillCacheOpInterface) { // Create FillCacheOp with batch_offset = 0 (no result type - it's in-place) auto fillCache = - builder.create(builder.getUnknownLoc(), cacheTensor, - inputTensor, builder.getI32IntegerAttr(0)); + FillCacheOp::create(builder, builder.getUnknownLoc(), cacheTensor, + inputTensor, builder.getI32IntegerAttr(0)); // Test OpModel interface auto backend = dyn_cast(fillCache.getOperation()); @@ -4840,9 +4847,9 @@ TEST_F(OpModelBase, UpdateCacheOpInterface) { TensorMemoryLayout::Interleaved)); // Create UpdateCacheOp with batch_offset = 0 (no result type - it's in-place) - auto updateCache = builder.create( - builder.getUnknownLoc(), cacheTensor, inputTensor, updateIndexTensor, - builder.getI32IntegerAttr(0)); + auto updateCache = UpdateCacheOp::create( + builder, builder.getUnknownLoc(), cacheTensor, inputTensor, + updateIndexTensor, builder.getI32IntegerAttr(0)); // Test OpModel interface auto backend = dyn_cast(updateCache.getOperation()); @@ -4902,9 +4909,9 @@ TEST_F(OpModelBase, PagedUpdateCacheOpInterface) { pageTableShape, mlir::IntegerType::get(&context, 32, IntegerType::Signed), pageTableLayout); - auto pagedUpdateCacheOp = builder.create( - builder.getUnknownLoc(), cacheTensor, inputTensor, updateIndexTensor, - false, pageTableTensor); + auto pagedUpdateCacheOp = PagedUpdateCacheOp::create( + builder, builder.getUnknownLoc(), cacheTensor, inputTensor, + updateIndexTensor, false, pageTableTensor); auto backend = dyn_cast(pagedUpdateCacheOp.getOperation()); ASSERT_TRUE(backend); @@ -4961,9 +4968,9 @@ TEST_F(OpModelBase, PagedFillCacheOpInterface) { pageTableShape, mlir::IntegerType::get(&context, 32, IntegerType::Signed), pageTableLayout); - auto pagedFillCacheOp = builder.create( - builder.getUnknownLoc(), cacheTensor, inputTensor, pageTableTensor, - batchOffsetTensor); + auto pagedFillCacheOp = + PagedFillCacheOp::create(builder, builder.getUnknownLoc(), cacheTensor, + inputTensor, pageTableTensor, batchOffsetTensor); auto backend = dyn_cast(pagedFillCacheOp.getOperation()); ASSERT_TRUE(backend); @@ -5015,8 +5022,8 @@ TEST_F(OpModelBase, QuantizeOpInterface) { auto outputType = mlir::RankedTensorType::get(outputShape, intType, int32Layout); - auto quantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, scale, zeroPoint, + auto quantizeOp = QuantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, scale, zeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( &context, @@ -5072,8 +5079,8 @@ TEST_F(OpModelBase, QuantizeOpInterfaceNullOutput) { auto outputType = mlir::RankedTensorType::get(outputShape, intType, int32Layout); - auto quantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, scale, zeroPoint, + auto quantizeOp = QuantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, scale, zeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( &context, @@ -5141,8 +5148,8 @@ TEST_F(OpModelBase, RequantizeOpInterface) { auto outputType = mlir::RankedTensorType::get(outputShape, intType, int32Layout); - auto requantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, inScale, inZeroPoint, + auto requantizeOp = RequantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, inScale, inZeroPoint, outScale, outZeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( @@ -5212,8 +5219,8 @@ TEST_F(OpModelBase, RequantizeOpInterfaceNullOutput) { auto outputType = mlir::RankedTensorType::get(outputShape, intType, int32Layout); - auto requantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, inScale, inZeroPoint, + auto requantizeOp = RequantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, inScale, inZeroPoint, outScale, outZeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( @@ -5270,8 +5277,8 @@ TEST_F(OpModelBase, DequantizeOpInterface) { // Create output type with BF16 data type auto outputType = createRankedTensorType(outputShape, builder.getBF16Type()); - auto dequantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, scale, zeroPoint, + auto dequantizeOp = DequantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, scale, zeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( &context, @@ -5328,8 +5335,8 @@ TEST_F(OpModelBase, DequantizeOpInterfaceNullOutput) { auto outputType = createRankedTensorType(outputShape, builder.getBF16Type()); - auto dequantizeOp = builder.create( - builder.getUnknownLoc(), outputType, input, scale, zeroPoint, + auto dequantizeOp = DequantizeOp::create( + builder, builder.getUnknownLoc(), outputType, input, scale, zeroPoint, builder.getI32IntegerAttr(1), // axis = 1 ttcore::DataTypeAttr::get( &context, @@ -5385,8 +5392,8 @@ TEST_F(OpModelBase, AssignOpInterface) { BufferTypeAttr::get(&context, outputLayout.getBufferType()), std::nullopt /*shardSpec*/); - auto assign = builder.create(builder.getUnknownLoc(), outputType, - input, memoryConfig, nullptr); + auto assign = AssignOp::create(builder, builder.getUnknownLoc(), outputType, + input, memoryConfig, nullptr); OpModel backend = dyn_cast(assign.getOperation()); auto constraintsExp = @@ -5431,8 +5438,8 @@ TEST_F(OpModelBase, AssignOpInterfaceL1Output) { BufferTypeAttr::get(&context, outputLayout.getBufferType()), std::nullopt /*shardSpec*/); - auto assign = builder.create(builder.getUnknownLoc(), outputType, - input, memoryConfig, nullptr); + auto assign = AssignOp::create(builder, builder.getUnknownLoc(), outputType, + input, memoryConfig, nullptr); OpModel backend = dyn_cast(assign.getOperation()); auto constraintsExp = @@ -5480,8 +5487,8 @@ TEST_F(OpModelBase, AssignOpInterfaceWithOutputDtype) { auto outputDtype = ttcore::DataTypeAttr::get(&context, ttcore::DataType::BFloat16); - auto assign = builder.create(builder.getUnknownLoc(), outputType, - input, memoryConfig, outputDtype); + auto assign = AssignOp::create(builder, builder.getUnknownLoc(), outputType, + input, memoryConfig, outputDtype); OpModel backend = dyn_cast(assign.getOperation()); auto constraintsExp = @@ -5512,8 +5519,8 @@ TEST_F(OpModelBase, DropoutOpInterface) { auto input = createEmptyTensor(tensorShape); auto outputType = createRankedTensorType(tensorShape); - auto dropoutOp = builder.create( - builder.getUnknownLoc(), outputType, input, + auto dropoutOp = DropoutOp::create( + builder, builder.getUnknownLoc(), outputType, input, /*prob=*/nullptr, /*scale=*/nullptr, /*seed=*/nullptr, /*use_per_device_seed=*/nullptr, /*memory_config=*/nullptr); dropoutOp->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -5534,10 +5541,10 @@ TEST_F(OpModelBase, DropoutOpInterface) { } // Test DropoutOp with custom parameters - auto dropoutOpCustom = builder.create( - builder.getUnknownLoc(), outputType, input, builder.getF32FloatAttr(0.2), - builder.getF32FloatAttr(1.25), builder.getUI32IntegerAttr(21), - builder.getBoolAttr(true), + auto dropoutOpCustom = DropoutOp::create( + builder, builder.getUnknownLoc(), outputType, input, + builder.getF32FloatAttr(0.2), builder.getF32FloatAttr(1.25), + builder.getUI32IntegerAttr(21), builder.getBoolAttr(true), /*memory_config=*/nullptr); dropoutOpCustom->setAttr(ttcore::DeviceAttr::name, getFakeDeviceAttr()); @@ -5611,8 +5618,8 @@ TEST_P(OpModelMeshPartitionInterfaceRuntimeTest, createEmptyTensor(inputShape, builder.getBF16Type(), inputLayout); auto outputType = createRankedTensorType(outputShape); - auto meshPartitionOp = builder.create( - builder.getUnknownLoc(), outputType, input, + auto meshPartitionOp = MeshPartitionOp::create( + builder, builder.getUnknownLoc(), outputType, input, /*dim=*/builder.getSI32IntegerAttr(p.dim), /*cluster_axis=*/builder.getUI32IntegerAttr(p.clusterAxis), /*memory_config=*/nullptr); diff --git a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp index e5525c789d8..0aa3d75d81f 100644 --- a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp +++ b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp @@ -61,9 +61,9 @@ class GreedyL1InterleavedPolicyBase : public ::testing::Test { mlir::Value createEmptyTensor() { ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); - return builder.create(builder.getUnknownLoc(), - getTensorRankedType(), nullptr, shapeAttr, - nullptr, nullptr, nullptr); + return OnesOp::create(builder, builder.getUnknownLoc(), + getTensorRankedType(), nullptr, shapeAttr, nullptr, + nullptr, nullptr); } mlir::func::FuncOp createFuncOp() { @@ -75,8 +75,8 @@ class GreedyL1InterleavedPolicyBase : public ::testing::Test { auto funcType = builder.getType( mlir::TypeRange(input), mlir::TypeRange(output)); - func = builder.create(builder.getUnknownLoc(), "test", - funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), "test", + funcType); mlir::Block *block = func.addEntryBlock(); block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); @@ -149,7 +149,7 @@ TEST_F(GreedyL1InterleavedPolicyBase, VerifyGreedyPolicy) { mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); mlir::Operation *opA = - builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); uint64_t outputL1Usage = 2; uint64_t requiredL1Usage = 8; prepareOpForGreedyConfigPicker(opA, outputL1Usage, requiredL1Usage, @@ -159,7 +159,7 @@ TEST_F(GreedyL1InterleavedPolicyBase, VerifyGreedyPolicy) { lhs = func.getBody().getBlocks().front().getArgument(0); rhs = func.getBody().getBlocks().front().getArgument(1); mlir::Operation *opB = - builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); outputL1Usage = 3; requiredL1Usage = 7; prepareOpForGreedyConfigPicker(opB, outputL1Usage, requiredL1Usage, @@ -169,7 +169,7 @@ TEST_F(GreedyL1InterleavedPolicyBase, VerifyGreedyPolicy) { lhs = func.getBody().getBlocks().front().getArgument(0); rhs = func.getBody().getBlocks().front().getArgument(1); mlir::Operation *opC = - builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); outputL1Usage = 1; requiredL1Usage = 9; prepareOpForGreedyConfigPicker(opC, outputL1Usage, requiredL1Usage, @@ -179,7 +179,7 @@ TEST_F(GreedyL1InterleavedPolicyBase, VerifyGreedyPolicy) { lhs = func.getBody().getBlocks().front().getArgument(0); rhs = func.getBody().getBlocks().front().getArgument(1); mlir::Operation *opD = - builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); outputL1Usage = 4; requiredL1Usage = 0; prepareOpForGreedyConfigPicker(opD, outputL1Usage, requiredL1Usage, diff --git a/test/unittests/Optimizer/TestLegalLayoutAnalysis.cpp b/test/unittests/Optimizer/TestLegalLayoutAnalysis.cpp index 20bd48c1fef..a765158b477 100644 --- a/test/unittests/Optimizer/TestLegalLayoutAnalysis.cpp +++ b/test/unittests/Optimizer/TestLegalLayoutAnalysis.cpp @@ -53,8 +53,8 @@ class LegalLayoutAnalysisTest // Create a function auto funcType = builder.getFunctionType({}, {}); - func = builder.create(builder.getUnknownLoc(), - "test_func", funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), + "test_func", funcType); // Create a basic block in the function mlir::Block *entryBlock = func.addEntryBlock(); @@ -98,8 +98,9 @@ class LegalLayoutAnalysisTest // Create function with a test tensor type of the parameterized shape auto tensorType = createTensorType(getTensorShape(), f32Type); - auto device = builder.create( - builder.getUnknownLoc(), builder.getType(), + auto device = mlir::tt::ttnn::GetDeviceOp::create( + builder, builder.getUnknownLoc(), + builder.getType(), mlir::tt::ttnn::MeshShapeAttr::get(builder.getContext(), 1, 1), mlir::tt::ttnn::MeshOffsetAttr::get(builder.getContext(), 0, 0)); @@ -113,20 +114,20 @@ class LegalLayoutAnalysisTest std::nullopt); // Create an empty tensor - auto empty = builder.create( - builder.getUnknownLoc(), tensorType, device, + auto empty = mlir::tt::ttnn::EmptyOp::create( + builder, builder.getUnknownLoc(), tensorType, device, mlir::tt::ttnn::ShapeAttr::get(&context, getTensorShape()), mlir::tt::ttcore::DataTypeAttr::get( &context, mlir::tt::ttcore::DataType::Float32), mlir::tt::ttnn::LayoutAttr::get(&context, Layout::Tile), memConfig); // Use that tensor in a ReluOp so we have a relevant op with a tensor result - auto relu = builder.create(builder.getUnknownLoc(), - empty.getResult()); + auto relu = mlir::tt::ttnn::ReluOp::create(builder, builder.getUnknownLoc(), + empty.getResult()); // Add return op - builder.create(builder.getUnknownLoc(), - relu.getResult()); + mlir::func::ReturnOp::create(builder, builder.getUnknownLoc(), + relu.getResult()); } }; diff --git a/test/unittests/Optimizer/TestLegalTensorLayoutAnalysis.cpp b/test/unittests/Optimizer/TestLegalTensorLayoutAnalysis.cpp index 636b5f6182f..af52849cced 100644 --- a/test/unittests/Optimizer/TestLegalTensorLayoutAnalysis.cpp +++ b/test/unittests/Optimizer/TestLegalTensorLayoutAnalysis.cpp @@ -50,8 +50,8 @@ class LegalTensorLayoutAnalysisTest // Create a function auto funcType = builder.getFunctionType({}, {}); - func = builder.create(builder.getUnknownLoc(), - "test_func", funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), + "test_func", funcType); // Create a basic block in the function mlir::Block *entryBlock = func.addEntryBlock(); @@ -91,8 +91,9 @@ class LegalTensorLayoutAnalysisTest // Create function with a test tensor type of the parameterized shape auto tensorType = createTensorType(getTensorShape(), f32Type); - auto device = builder.create( - builder.getUnknownLoc(), builder.getType(), + auto device = mlir::tt::ttnn::GetDeviceOp::create( + builder, builder.getUnknownLoc(), + builder.getType(), mlir::tt::ttnn::MeshShapeAttr::get(builder.getContext(), 1, 1), mlir::tt::ttnn::MeshOffsetAttr::get(builder.getContext(), 0, 0)); @@ -106,15 +107,15 @@ class LegalTensorLayoutAnalysisTest std::nullopt); // Create an empty op with all required parameters - builder.create( - builder.getUnknownLoc(), tensorType, device, + mlir::tt::ttnn::EmptyOp::create( + builder, builder.getUnknownLoc(), tensorType, device, mlir::tt::ttnn::ShapeAttr::get(&context, getTensorShape()), mlir::tt::ttcore::DataTypeAttr::get( &context, mlir::tt::ttcore::DataType::Float32), mlir::tt::ttnn::LayoutAttr::get(&context, Layout::Tile), memConfig); // Add return op - builder.create(builder.getUnknownLoc()); + mlir::func::ReturnOp::create(builder, builder.getUnknownLoc()); } }; diff --git a/test/unittests/Optimizer/TestShardSolver.cpp b/test/unittests/Optimizer/TestShardSolver.cpp index 05e6e351e9d..18058316e2b 100644 --- a/test/unittests/Optimizer/TestShardSolver.cpp +++ b/test/unittests/Optimizer/TestShardSolver.cpp @@ -58,9 +58,9 @@ class ShardSolverBase : public ::testing::Test { mlir::Value createEmptyTensor() { ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); - return builder.create(builder.getUnknownLoc(), - getTensorRankedType(), nullptr, shapeAttr, - nullptr, nullptr, nullptr); + return OnesOp::create(builder, builder.getUnknownLoc(), + getTensorRankedType(), nullptr, shapeAttr, nullptr, + nullptr, nullptr); } mlir::func::FuncOp createFuncOp() { @@ -72,8 +72,8 @@ class ShardSolverBase : public ::testing::Test { auto funcType = builder.getType( mlir::TypeRange(input), mlir::TypeRange(output)); - func = builder.create(builder.getUnknownLoc(), "test", - funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), "test", + funcType); mlir::Block *block = func.addEntryBlock(); block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); @@ -157,7 +157,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); mlir::Operation *op = - builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); mlir::Operation *firstOp = op; prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); @@ -169,7 +169,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { TensorMemoryLayout::BlockSharded, 2, 2); rhs = op->getResult(0); - op = builder.create(builder.getUnknownLoc(), rhs.getType(), rhs); + op = ReluOp::create(builder, builder.getUnknownLoc(), rhs.getType(), rhs); prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::WidthSharded, 1, 8); @@ -181,7 +181,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { lhs = func.getBody().getBlocks().front().getArgument(0); rhs = op->getResult(0); - op = builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + op = AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::WidthSharded, 1, 4); @@ -190,7 +190,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::BlockSharded, 1, 1); - op = builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + op = AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::WidthSharded, 1, 4); @@ -201,7 +201,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { lhs = opL1MemSpecs[opL1MemSpecs.size() - 2].op->getResult(0); rhs = opL1MemSpecs[opL1MemSpecs.size() - 1].op->getResult(0); - op = builder.create(builder.getUnknownLoc(), lhs.getType(), lhs, rhs); + op = AddOp::create(builder, builder.getUnknownLoc(), lhs.getType(), lhs, rhs); prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::WidthSharded, 1, 2); @@ -211,7 +211,7 @@ TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { TensorMemoryLayout::BlockSharded, 1, 1); rhs = op->getResult(0); - op = builder.create(builder.getUnknownLoc(), rhs.getType(), rhs); + op = ReluOp::create(builder, builder.getUnknownLoc(), rhs.getType(), rhs); prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); addConfigForOp(op, legalConfigs, BufferType::L1, TensorMemoryLayout::WidthSharded, 1, 2); diff --git a/test/unittests/Support/TTPrintIRInstrumentationTest.cpp b/test/unittests/Support/TTPrintIRInstrumentationTest.cpp index 00b72198f5f..71c8662e25e 100644 --- a/test/unittests/Support/TTPrintIRInstrumentationTest.cpp +++ b/test/unittests/Support/TTPrintIRInstrumentationTest.cpp @@ -80,7 +80,7 @@ class Utils { auto i32Type = builder.getI32Type(); auto funcType = builder.getFunctionType({i32Type, i32Type}, i32Type); auto funcOp = - builder.create(loc, "test_func", funcType); + mlir::func::FuncOp::create(builder, loc, "test_func", funcType); funcOp.setPrivate(); auto &block = funcOp.getBody().emplaceBlock(); @@ -92,8 +92,8 @@ class Utils { auto arg0 = block.getArgument(0); auto arg1 = block.getArgument(1); - auto addResult = builder.create(loc, arg0, arg1); - builder.create(loc, addResult.getResult()); + auto addResult = mlir::arith::AddIOp::create(builder, loc, arg0, arg1); + mlir::func::ReturnOp::create(builder, loc, addResult.getResult()); module.push_back(funcOp); return module; diff --git a/test/unittests/TestScheduler/TestScheduler.cpp b/test/unittests/TestScheduler/TestScheduler.cpp index 8545a17d8b1..8d352949a26 100644 --- a/test/unittests/TestScheduler/TestScheduler.cpp +++ b/test/unittests/TestScheduler/TestScheduler.cpp @@ -68,8 +68,9 @@ class SchedulerBase : public ::testing::Test { } mlir::Value createEmptyTensor() { - return builder.create( - builder.getUnknownLoc(), getTensorShape(), builder.getF32Type()); + return mlir::tt::ttir::EmptyOp::create(builder, builder.getUnknownLoc(), + getTensorShape(), + builder.getF32Type()); } mlir::func::FuncOp createFuncOp() { @@ -81,8 +82,8 @@ class SchedulerBase : public ::testing::Test { auto funcType = builder.getType( mlir::TypeRange(input), mlir::TypeRange(output)); - func = builder.create(builder.getUnknownLoc(), "test", - funcType); + func = mlir::func::FuncOp::create(builder, builder.getUnknownLoc(), "test", + funcType); mlir::Block *block = func.addEntryBlock(); block->addArgument(getTensorType(), builder.getUnknownLoc()); @@ -104,8 +105,8 @@ TEST_F(SchedulerBase, FixedSchedule) { mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); // First operation has arg1 and arg2 (no DPS operand needed) - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), - getTensorType(), lhs, rhs); + ttir::TTIROp op = ttir::AddOp::create(builder, builder.getUnknownLoc(), + getTensorType(), lhs, rhs); // Create a chain of operations by using the result of the previous operation llvm::SmallVector operands = {rhs, @@ -118,8 +119,8 @@ TEST_F(SchedulerBase, FixedSchedule) { for (std::size_t i = 1; i < NumberOfOps; i++) { mlir::Value lhs = operands[operands.size() - 2]; mlir::Value rhs = operands[operands.size() - 1]; - op = builder.create(builder.getUnknownLoc(), getTensorType(), - lhs, rhs); + op = ttir::AddOp::create(builder, builder.getUnknownLoc(), getTensorType(), + lhs, rhs); operands.push_back(op.getOperation()->getResult(0)); ops.push_back(op); } @@ -153,8 +154,8 @@ TEST_F(SchedulerBase, SingleOp) { mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); // First operation has arg1 and arg2 (no DPS operand needed) - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), - getTensorType(), lhs, rhs); + ttir::TTIROp op = ttir::AddOp::create(builder, builder.getUnknownLoc(), + getTensorType(), lhs, rhs); mlir::tt::scheduler::Scheduler scheduler(&func); ASSERT_TRUE(scheduler.hasUnscheduledOps()); @@ -176,8 +177,8 @@ TEST_F(SchedulerBase, VerifyFork) { // Create the first operation which works on arg1 and arg2 mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), - getTensorType(), lhs, rhs); + ttir::TTIROp op = ttir::AddOp::create(builder, builder.getUnknownLoc(), + getTensorType(), lhs, rhs); std::vector ops; ops.push_back(op); @@ -187,19 +188,19 @@ TEST_F(SchedulerBase, VerifyFork) { // Create the second operation which works on the result of the first // operation and arg1 - op = builder.create(builder.getUnknownLoc(), getTensorType(), - lhs, rhs); + op = ttir::AddOp::create(builder, builder.getUnknownLoc(), getTensorType(), + lhs, rhs); ops.push_back(op); - op = builder.create(builder.getUnknownLoc(), getTensorType(), - lhs, rhs); + op = ttir::AddOp::create(builder, builder.getUnknownLoc(), getTensorType(), + lhs, rhs); ops.push_back(op); // Create the third operation which works on the result of the second and // third operation lhs = ops[ops.size() - 2].getOperation()->getResult(0); rhs = ops[ops.size() - 1].getOperation()->getResult(0); - op = builder.create(builder.getUnknownLoc(), getTensorType(), - lhs, rhs); + op = ttir::AddOp::create(builder, builder.getUnknownLoc(), getTensorType(), + lhs, rhs); ops.push_back(op); mlir::tt::scheduler::Scheduler scheduler(&func); @@ -241,8 +242,8 @@ TEST_F(SchedulerBase, SplitQueryKeyValueAndSplitHeadsOp) { llvm::SmallVector outputShape{batchSize, numHeads, sequenceSize, headDim}; - mlir::Value inputTensor = builder.create( - builder.getUnknownLoc(), inputShape, builder.getF32Type()); + mlir::Value inputTensor = ttir::EmptyOp::create( + builder, builder.getUnknownLoc(), inputShape, builder.getF32Type()); mlir::Type queryType = mlir::RankedTensorType::get(outputShape, builder.getF32Type()); @@ -251,20 +252,21 @@ TEST_F(SchedulerBase, SplitQueryKeyValueAndSplitHeadsOp) { mlir::Type valueType = mlir::RankedTensorType::get(outputShape, builder.getF32Type()); - auto splitOp = builder.create( - builder.getUnknownLoc(), queryType, keyType, valueType, inputTensor, + auto splitOp = ttir::SplitQueryKeyValueAndSplitHeadsOp::create( + builder, builder.getUnknownLoc(), queryType, keyType, valueType, + inputTensor, /*kv_input_tensor=*/nullptr, builder.getUI32IntegerAttr(numHeads), /*num_kv_heads=*/nullptr, builder.getBoolAttr(false)); auto outputType = mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type()); mlir::Value arg0 = func.getBody().getBlocks().front().getArgument(0); - auto queryConsumerOp = builder.create( - builder.getUnknownLoc(), outputType, splitOp.getQuery(), arg0); - auto keyConsumerOp = builder.create( - builder.getUnknownLoc(), outputType, splitOp.getKey(), arg0); - auto valueConsumerOp = builder.create( - builder.getUnknownLoc(), outputType, splitOp.getValue(), arg0); + auto queryConsumerOp = ttir::AddOp::create( + builder, builder.getUnknownLoc(), outputType, splitOp.getQuery(), arg0); + auto keyConsumerOp = ttir::AddOp::create(builder, builder.getUnknownLoc(), + outputType, splitOp.getKey(), arg0); + auto valueConsumerOp = ttir::AddOp::create( + builder, builder.getUnknownLoc(), outputType, splitOp.getValue(), arg0); mlir::tt::scheduler::Scheduler scheduler(&func); llvm::SmallVector schedulableOps = diff --git a/test/unittests/Validation/TestOpConstraintValidation.cpp b/test/unittests/Validation/TestOpConstraintValidation.cpp index 0024d9bb7d8..ac1438bfada 100644 --- a/test/unittests/Validation/TestOpConstraintValidation.cpp +++ b/test/unittests/Validation/TestOpConstraintValidation.cpp @@ -95,19 +95,19 @@ class OpConstraintValidationTest : public ::testing::Test { mlir::RankedTensorType::get(inputShape, builder.getBF16Type(), layout); // Create two input tensors using OnesOp (simpler than EmptyOp) - auto input1 = builder.create( - builder.getUnknownLoc(), tensorType, + auto input1 = OnesOp::create( + builder, builder.getUnknownLoc(), tensorType, /*device=*/nullptr, ShapeAttr::get(&context, inputShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); - auto input2 = builder.create( - builder.getUnknownLoc(), tensorType, + auto input2 = OnesOp::create( + builder, builder.getUnknownLoc(), tensorType, /*device=*/nullptr, ShapeAttr::get(&context, inputShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); // Create AddOp - return builder.create(builder.getUnknownLoc(), tensorType, - input1.getResult(), input2.getResult()); + return AddOp::create(builder, builder.getUnknownLoc(), tensorType, + input1.getResult(), input2.getResult()); } // Helper to create OpConfig for testing @@ -174,8 +174,8 @@ TEST_F(OpConstraintValidationTest, UpdateCacheOpWithInvalidUpdateIndexType) { TensorMemoryLayout::Interleaved); auto cacheTensorType = mlir::RankedTensorType::get( cacheShape, builder.getBF16Type(), cacheLayout); - auto cacheOp = builder.create( - builder.getUnknownLoc(), cacheTensorType, + auto cacheOp = OnesOp::create( + builder, builder.getUnknownLoc(), cacheTensorType, /*device=*/nullptr, ShapeAttr::get(&context, cacheShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); @@ -185,8 +185,8 @@ TEST_F(OpConstraintValidationTest, UpdateCacheOpWithInvalidUpdateIndexType) { TensorMemoryLayout::Interleaved); auto inputTensorType = mlir::RankedTensorType::get( inputShape, builder.getBF16Type(), inputLayout); - auto inputOp = builder.create( - builder.getUnknownLoc(), inputTensorType, + auto inputOp = OnesOp::create( + builder, builder.getUnknownLoc(), inputTensorType, /*device=*/nullptr, ShapeAttr::get(&context, inputShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); @@ -197,15 +197,15 @@ TEST_F(OpConstraintValidationTest, UpdateCacheOpWithInvalidUpdateIndexType) { TensorMemoryLayout::Interleaved); auto updateIndexTensorType = mlir::RankedTensorType::get( updateIndexShape, builder.getBF16Type(), updateIndexLayout); - auto updateIndexOp = builder.create( - builder.getUnknownLoc(), updateIndexTensorType, + auto updateIndexOp = OnesOp::create( + builder, builder.getUnknownLoc(), updateIndexTensorType, /*device=*/nullptr, ShapeAttr::get(&context, updateIndexShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); // Create UpdateCacheOp (inplace operation, no result type) - auto updateCacheOp = builder.create( - builder.getUnknownLoc(), cacheOp.getResult(), inputOp.getResult(), - updateIndexOp.getResult(), /*batch_offset=*/0); + auto updateCacheOp = ttnn::UpdateCacheOp::create( + builder, builder.getUnknownLoc(), cacheOp.getResult(), + inputOp.getResult(), updateIndexOp.getResult(), /*batch_offset=*/0); // Extract layouts and create config auto layouts = ttnn::utils::extractInputLayouts(updateCacheOp); @@ -228,15 +228,15 @@ TEST_F(OpConstraintValidationTest, UpdateCacheOpWithInvalidUpdateIndexType) { TensorMemoryLayout::Interleaved); auto uint32UpdateIndexTensorType = mlir::RankedTensorType::get( updateIndexShape, uint32Type, uint32UpdateIndexLayout); - auto uint32UpdateIndexOp = builder.create( - builder.getUnknownLoc(), uint32UpdateIndexTensorType, + auto uint32UpdateIndexOp = OnesOp::create( + builder, builder.getUnknownLoc(), uint32UpdateIndexTensorType, /*device=*/nullptr, ShapeAttr::get(&context, updateIndexShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); // Create UpdateCacheOp with correct uint32 type - auto validUpdateCacheOp = builder.create( - builder.getUnknownLoc(), cacheOp.getResult(), inputOp.getResult(), - uint32UpdateIndexOp.getResult(), /*batch_offset=*/0); + auto validUpdateCacheOp = ttnn::UpdateCacheOp::create( + builder, builder.getUnknownLoc(), cacheOp.getResult(), + inputOp.getResult(), uint32UpdateIndexOp.getResult(), /*batch_offset=*/0); // Extract layouts and validate auto validLayouts = ttnn::utils::extractInputLayouts(validUpdateCacheOp); @@ -256,10 +256,10 @@ TEST_F(OpConstraintValidationTest, ValidationStatusNotImplemented) { auto tensorType = mlir::RankedTensorType::get(tensorShape, builder.getBF16Type(), layout); - auto allocOp = builder.create( - builder.getUnknownLoc(), tensorType, builder.getI64IntegerAttr(0), - builder.getI64IntegerAttr(2048), - BufferTypeAttr::get(&context, BufferType::L1)); + auto allocOp = AllocOp::create(builder, builder.getUnknownLoc(), tensorType, + builder.getI64IntegerAttr(0), + builder.getI64IntegerAttr(2048), + BufferTypeAttr::get(&context, BufferType::L1)); auto layouts = ttnn::utils::extractInputLayouts(allocOp); OpConfig config = createTestConfig(); @@ -303,8 +303,8 @@ TEST_F(OpConstraintValidationTest, ValidationStatusMetalBackendError) { auto inputTensorType = mlir::RankedTensorType::get( tensorShape, builder.getBF16Type(), inputLayout); - auto input = builder.create( - builder.getUnknownLoc(), inputTensorType, + auto input = OnesOp::create( + builder, builder.getUnknownLoc(), inputTensorType, /*device=*/nullptr, ShapeAttr::get(&context, tensorShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); @@ -315,8 +315,8 @@ TEST_F(OpConstraintValidationTest, ValidationStatusMetalBackendError) { tensorShape, builder.getBF16Type(), outputLayout); // Create ToLayoutOp with incompatible input/output layouts - auto toLayoutOp = builder.create( - builder.getUnknownLoc(), outputTensorType, input.getResult(), + auto toLayoutOp = ToLayoutOp::create( + builder, builder.getUnknownLoc(), outputTensorType, input.getResult(), LayoutAttr::get(&context, Layout::RowMajor), // ttcore::DataTypeAttr::get(&context, ttcore::DataType::BFloat16), /*dtype=*/nullptr, @@ -351,18 +351,18 @@ TEST_F(OpConstraintValidationTest, ValidationStatusOutOfMemoryError) { auto tensorType = mlir::RankedTensorType::get(largeShape, builder.getBF16Type(), layout); - auto input1 = builder.create( - builder.getUnknownLoc(), tensorType, + auto input1 = OnesOp::create( + builder, builder.getUnknownLoc(), tensorType, /*device=*/nullptr, ShapeAttr::get(&context, largeShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); - auto input2 = builder.create( - builder.getUnknownLoc(), tensorType, + auto input2 = OnesOp::create( + builder, builder.getUnknownLoc(), tensorType, /*device=*/nullptr, ShapeAttr::get(&context, largeShape), /*dtype=*/nullptr, /*layout=*/nullptr, /*memory_config=*/nullptr); - auto addOp = builder.create(builder.getUnknownLoc(), tensorType, - input1.getResult(), input2.getResult()); + auto addOp = AddOp::create(builder, builder.getUnknownLoc(), tensorType, + input1.getResult(), input2.getResult()); auto layouts = ttnn::utils::extractInputLayouts(addOp); OpConfig config(layout, OpConfig::OpSpecificAttrs{}); diff --git a/tools/builder/stablehlo/stablehlo_builder.py b/tools/builder/stablehlo/stablehlo_builder.py index 26c161b9f4c..20751a3d1a1 100644 --- a/tools/builder/stablehlo/stablehlo_builder.py +++ b/tools/builder/stablehlo/stablehlo_builder.py @@ -15,7 +15,7 @@ import math from ttmlir.ir import * -from ttmlir.dialects import stablehlo, sdy, mpmd, func +from ttmlir.dialects import stablehlo, sdy, func from builder.base.builder import * from builder.base.builder_utils import * @@ -8010,36 +8010,6 @@ def sdy_all_gather( return op_result - # ----- Experimental Mpmd Attribute Generators ---- - - def experimental_named_mesh_attr( - self, - name: str, - mesh_attr: sdy.MeshAttr, - ) -> mpmd.NamedMeshAttr: - return mpmd.NamedMeshAttr.get(name, mesh_attr) - - def experimental_topology_attr( - self, - meshes: List[mpmd.NamedMeshAttr], - ) -> mpmd.TopologyAttr: - return mpmd.TopologyAttr.get(meshes) - - def experimental_user_origin_attr( - self, - user_name: str, - transpose_count: int = 0, - ) -> mpmd.UserOriginAttr: - return mpmd.UserOriginAttr.get( - user_name=user_name, transpose_count=transpose_count - ) - - def experimental_origin_attr( - self, - origin_label: str, - ) -> mpmd.OriginAttr: - return mpmd.OriginAttr.get(origin_label=origin_label) - # ----- Parse stablehlo module ---- @staticmethod