Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/get-docker-tag.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

# Calculate hash from the following files. This hash is used to tag the docker images.
# Any change in these files will result in a new docker image build
DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci .github/Dockerfile.ird .github/Dockerfile.cibuildwheel env/CMakeLists.txt env/init_venv.sh env/build-requirements.txt env/ttnn-requirements.txt env/patches/shardy.patch env/patches/shardy_mpmd_pybinds.patch test/python/requirements.txt env/install-tt-triage.sh"
DOCKERFILE_HASH_FILES=".github/Dockerfile.base .github/Dockerfile.ci .github/Dockerfile.ird .github/Dockerfile.cibuildwheel env/CMakeLists.txt env/init_venv.sh env/build-requirements.txt env/ttnn-requirements.txt env/patches/shardy.patch test/python/requirements.txt env/install-tt-triage.sh"
DOCKERFILE_HASH=$(sha256sum $DOCKERFILE_HASH_FILES | sha256sum | cut -d ' ' -f 1)
echo dt-$DOCKERFILE_HASH
10 changes: 5 additions & 5 deletions env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ cmake_minimum_required(VERSION 3.20.0)
project(ttmlir-toolchain LANGUAGES CXX C)

set(FLATBUFFERS_VERSION "fb9afbafc7dfe226b9db54d4923bfb8839635274")
set(LLVM_PROJECT_VERSION "4efe170d858eb54432f520abb4e7f0086236748b")
set(STABLEHLO_VERSION "0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d")
set(SHARDY_VERSION "edfd6730ddfc39da5fbea8b6b202357fdf1cdb90")
set(LLVM_PROJECT_VERSION "1053047a4be7d1fece3adaf5e7597f838058c947")
set(STABLEHLO_VERSION "43550117525e77084ac1f83ba50febc0688f7958")
set(SHARDY_VERSION "4069e470c94d7b08f129ef1607aa2be8f4c06b53")
set(LLVM_BUILD_TYPE MinSizeRel CACHE STRING "Build type for LLVM")

include(ExternalProject)
Expand Down Expand Up @@ -52,7 +52,7 @@ if(TTMLIR_BUILD_LLVM)
llvm-project
# Super hacky way to install the python dependencies before the build
# Sync nanobind with tt-metal's pinned version to avoid ODR violations
PATCH_COMMAND bash -c "source ${CMAKE_CURRENT_SOURCE_DIR}/activate && pip install -r mlir/python/requirements.txt && pip install --force-reinstall 'nanobind==2.10.2' && git config user.email \"tt-mlir@tenstorrent.com\" && git config user.name \"tenstorrent\" && git apply --index \"${CMAKE_CURRENT_LIST_DIR}/patches/affine-allow-symbol-vars.patch\" && git commit -m \"tt-mlir related patch\""
PATCH_COMMAND bash -c "source ${CMAKE_CURRENT_SOURCE_DIR}/activate && pip install -r mlir/python/requirements.txt && pip install --force-reinstall 'nanobind==2.10.2' && git config user.email \"tt-mlir@tenstorrent.com\" && git config user.name \"tenstorrent\" "
CMAKE_GENERATOR Ninja
CMAKE_ARGS
-DCMAKE_BUILD_TYPE=${LLVM_BUILD_TYPE}
Expand Down Expand Up @@ -97,7 +97,7 @@ ExternalProject_Add(shardy
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
PATCH_COMMAND git config user.email "tt-mlir@tenstorrent.com" && git config user.name "tenstorrent" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy.patch" && git commit -m "tt-mlir related patch" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy_mpmd_pybinds.patch" && git commit -m "tt-mlir related patch"
PATCH_COMMAND git config user.email "tt-mlir@tenstorrent.com" && git config user.name "tenstorrent" && git apply --index "${CMAKE_CURRENT_LIST_DIR}/patches/shardy.patch" && git commit -m "tt-mlir shardy.patch"
)

if(TTMLIR_BUILD_LLVM)
Expand Down
128 changes: 15 additions & 113 deletions env/patches/shardy.patch
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ new file mode 100755
index 0000000..f23330d
--- /dev/null
+++ b/shardy/dialect/sdy/transforms/export/CMakeLists.txt
@@ -0,0 +1,94 @@
@@ -0,0 +1,96 @@
+# Shardy MLIR Transform Export Passes and Pipeline
+
+set(LLVM_TARGET_DEFINITIONS passes.td)
Expand Down Expand Up @@ -318,6 +318,8 @@ index 0000000..f23330d
+ sharding_constraint_to_reshard.cc
+ sink_data_flow_edges.cc
+ update_non_divisible_input_output_shardings.cc
+ convert_global_to_local.cc
+ remove_sub_axes_in_input_output_shardings.cc
+
+ DEPENDS
+ SdyExplicitReshardsUtil
Expand Down Expand Up @@ -635,106 +637,6 @@ index 0000000..dc33501
+ SdyDialect
+ SdyTransformsPropagationShardingProjection
+)
diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
index a73917b..78cc827 100644
--- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
+++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
@@ -346,6 +346,7 @@ void saveShardingOriginsOnModule(
// the case for the target of the edge, because if the source appears multiple
// times, then it's because it effects multiple other operands/results in the
// op.
+[[maybe_unused]]
bool insertSeenValue(Operation* op, const PropagationEdge& edge,
llvm::SmallDenseSet<Value>& seenValues) {
EdgeNode target = edge.target;
diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
index ab93067..b181797 100644
--- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
@@ -187,12 +187,12 @@ void addGatherScatterFactors(

// We add factors for all collapsed slice dimensions.
for (int64_t collapsedSliceDim : collapsedSliceDims) {
- // To keep the operand dim sharded for gather, we need an all-reduce on the
- // result.
+ // For gather: use kNeedReplication to all-gather operand before gather,
+ // instead of kReduction which would insert all-reduce after gather.
addUnblockedFactorFn(
collapsedSliceDim, /*indicesDim=*/kNullDim,
/*slicesDim=*/kNullDim, inputType.getDimSize(collapsedSliceDim),
- isScatter ? FactorType::kPassThrough : FactorType::kReduction);
+ isScatter ? FactorType::kPassThrough : FactorType::kNeedReplication);
}

// Add a factor for the index-vector-dim, if it's present.
@@ -303,6 +303,37 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
}
return builder.build();
})
+ .Case<stablehlo::BatchNormInferenceOp>(
+ [conservativePropagation](stablehlo::BatchNormInferenceOp bn) {
+ auto inTy = llvm::cast<mlir::RankedTensorType>(bn.getOperand().getType());
+ auto outTy = llvm::cast<mlir::RankedTensorType>(bn.getResult().getType());
+
+ OpShardingRuleBuilder builder(bn);
+
+ const int64_t numOperands = static_cast<int64_t>(bn->getNumOperands());
+ llvm::SmallVector<int64_t> opDims(numOperands, kNullDim);
+
+ for (auto [dU, dimSize] : llvm::enumerate(inTy.getShape())) {
+ const int64_t d = static_cast<int64_t>(dU);
+ std::fill(opDims.begin(), opDims.end(), kNullDim);
+ opDims[0] = d;
+ builder.addFactor(opDims, d, dimSize);
+ }
+
+ const int64_t featAxis = static_cast<int64_t>(bn.getFeatureIndex());
+ const int64_t C = outTy.getDimSize(featAxis);
+
+ for (int64_t paramIdx : {1LL, 2LL, 3LL, 4LL}) {
+ std::fill(opDims.begin(), opDims.end(), kNullDim);
+ opDims[paramIdx] = 0;
+ auto factorType = conservativePropagation ? FactorType::kNeedReplication
+ : FactorType::kPassThrough;
+ builder.addFactor(opDims, kNullDim, C,
+ factorType, true);
+ }
+
+ return builder.build();
+ })
.Case<stablehlo::BitcastConvertOp>(
[](stablehlo::BitcastConvertOp bitcastConvert) {
ArrayRef<int64_t> inShape =
@@ -685,6 +716,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
/*isBlocked=*/usedByRngBitGenerator)
.build();
}
+ // Check if the custom call implements the ShardingRuleOpInterface.
+ if (auto shardingRuleOp =
+ llvm::dyn_cast<ShardingRuleOpInterface>(customCall.getOperation())) {
+ return shardingRuleOp.getShardingRule();
+ }
+
// TODO(b/327191011): output unregistered op stats instead.
static llvm::once_flag onceFlag;
emitOpWarningOnce(
@@ -1093,6 +1130,16 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
return builder.build();
})
.Case<stablehlo::ScatterOp>([](stablehlo::ScatterOp scatter) {
+ // Check if the scatter op implements the ShardingRuleOpInterface.
+ if (auto shardingRuleOp =
+ llvm::dyn_cast<ShardingRuleOpInterface>(scatter.getOperation())) {
+ // Try to get custom rule - if it returns non-null, use it.
+ if (auto customRule = shardingRuleOp.getShardingRule()) {
+ return customRule;
+ }
+ // If custom rule returns null, fall through to default.
+ }
+
OpShardingRuleBuilder builder(scatter);

// Since all inputs and results have compatible shapes, we can look at
diff --git a/shardy/integrations/c/CMakeLists.txt b/shardy/integrations/c/CMakeLists.txt
new file mode 100644
index 0000000..fdd50c4
Expand Down Expand Up @@ -1191,10 +1093,10 @@ index 0000000..30a5cf9
+
+#endif
diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc
index da451fa..44c0ea2 100644
index cd7fdc8..1b5aa5b 100644
--- a/shardy/integrations/python/ir/sdy_module.cc
+++ b/shardy/integrations/python/ir/sdy_module.cc
@@ -109,7 +109,15 @@ NB_MODULE(_sdy, m) {
@@ -110,7 +110,15 @@ NB_MODULE(_sdy, m) {
})
.def_property_readonly("size", [](MlirAttribute self) {
return sdyMeshAxisAttrGetSize(self);
Expand All @@ -1211,7 +1113,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "MeshAttr", sdyAttributeIsAMeshAttr)
@@ -133,7 +141,15 @@ NB_MODULE(_sdy, m) {
@@ -134,7 +142,15 @@ NB_MODULE(_sdy, m) {
.def_property_readonly("axes", [](MlirAttribute self) {
return propertyVector<MlirAttribute>(self, sdyMeshAttrGetAxesSize,
sdyMeshAttrGetAxesElem);
Expand All @@ -1228,7 +1130,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "SubAxisInfoAttr", sdyAttributeIsASubAxisInfoAttr)
@@ -150,7 +166,15 @@ NB_MODULE(_sdy, m) {
@@ -151,7 +167,15 @@ NB_MODULE(_sdy, m) {
[](MlirAttribute self) { return sdySubAxisInfoAttrGetPreSize(self); })
.def_property_readonly("size", [](MlirAttribute self) {
return sdySubAxisInfoAttrGetSize(self);
Expand All @@ -1245,7 +1147,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "AxisRefAttr", sdyAttributeIsAnAxisRefAttr)
@@ -175,7 +199,15 @@ NB_MODULE(_sdy, m) {
@@ -176,7 +200,15 @@ NB_MODULE(_sdy, m) {
MlirAttribute subAxisInfo = sdyAxisRefAttrGetSubAxisInfo(self);
return subAxisInfo.ptr == nullptr ? std::nullopt
: std::optional(subAxisInfo);
Expand All @@ -1262,7 +1164,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "DimensionShardingAttr", sdyAttributeIsADimensionShardingAttr)
@@ -205,7 +237,15 @@ NB_MODULE(_sdy, m) {
@@ -206,7 +238,15 @@ NB_MODULE(_sdy, m) {
.def_property_readonly("priority", [](MlirAttribute self) {
int64_t priority = sdyDimensionShardingAttrGetPriority(self);
return priority == -1 ? std::nullopt : std::optional(priority);
Expand All @@ -1279,7 +1181,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "TensorShardingAttr", sdyAttributeIsATensorShardingAttr)
@@ -251,7 +291,15 @@ NB_MODULE(_sdy, m) {
@@ -252,7 +292,15 @@ NB_MODULE(_sdy, m) {
return propertyVector<MlirAttribute>(
self, sdyTensorShardingAttrGetUnreducedAxesSize,
sdyTensorShardingAttrGetUnreducedAxesElem);
Expand All @@ -1296,7 +1198,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "TensorShardingPerValueAttr",
@@ -270,7 +318,15 @@ NB_MODULE(_sdy, m) {
@@ -271,7 +319,15 @@ NB_MODULE(_sdy, m) {
return propertyVector<MlirAttribute>(
self, sdyTensorShardingPerValueAttrGetShardingsSize,
sdyTensorShardingPerValueAttrGetShardingsElem);
Expand All @@ -1313,7 +1215,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "DimMappingAttr", sdyAttributeIsADimMappingAttr)
@@ -288,7 +344,15 @@ NB_MODULE(_sdy, m) {
@@ -289,7 +345,15 @@ NB_MODULE(_sdy, m) {
return propertyVector<intptr_t>(self,
sdyDimMappingAttrGetFactorIndicesSize,
sdyDimMappingAttrGetFactorIndicesElem);
Expand All @@ -1330,7 +1232,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "TensorMappingAttr", sdyAttributeIsATensorMappingAttr)
@@ -310,7 +374,15 @@ NB_MODULE(_sdy, m) {
@@ -311,7 +375,15 @@ NB_MODULE(_sdy, m) {
})
.def_property_readonly("rank", [](MlirAttribute self) {
return sdyTensorMappingAttrGetRank(self);
Expand All @@ -1347,7 +1249,7 @@ index da451fa..44c0ea2 100644

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "OpShardingRuleAttr", sdyAttributeIsAOpShardingRuleAttr)
@@ -394,6 +466,14 @@ NB_MODULE(_sdy, m) {
@@ -395,6 +467,14 @@ NB_MODULE(_sdy, m) {
return propertyVector<intptr_t>(
self, sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize,
sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem);
Expand All @@ -1362,7 +1264,7 @@ index da451fa..44c0ea2 100644
});

mlir::python::nanobind_adaptors::mlir_attribute_subclass(
@@ -417,7 +497,67 @@ NB_MODULE(_sdy, m) {
@@ -418,7 +498,67 @@ NB_MODULE(_sdy, m) {
})
.def("__len__", [](MlirAttribute& self) {
return sdyManualAxesAttrGetAxesSize(self);
Expand Down
Loading
Loading