@@ -635,22 +635,37 @@ index 0000000..dc33501
635635+ SdyDialect
636636+ SdyTransformsPropagationShardingProjection
637637+ )
638- diff --git a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
639- index a73917b..78cc827 100644
640- --- a/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
641- +++ b/shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc
642- @@ -346,6 +346,7 @@ void saveShardingOriginsOnModule(
643- // the case for the target of the edge, because if the source appears multiple
644- // times, then it's because it effects multiple other operands/results in the
645- // op.
646- + [[maybe_unused]]
647- bool insertSeenValue(Operation* op, const PropagationEdge& edge,
648- llvm::SmallDenseSet<Value>& seenValues) {
649- EdgeNode target = edge.target;
650638diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
651- index ab93067..b181797 100644
639+ index e77a8ee..4779d09 100644
652640--- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
653641+++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc
642+ @@ -51,7 +51,7 @@ namespace {
643+
644+ // Helper to check if reduction window dim can be a passthrough dim.
645+ // When window size is 1, stride is 1 and there is no padding on the operand, it
646+ - // is a 1-1 mapping between operand and result.
647+ + // is a 1-1 mapping between operand and instead of kReduction which would insert all-reduce after gather.
648+ bool isWindowPassthroughDim(std::optional<DenseIntElementsAttr> operandPadding,
649+ ArrayRef<int64_t> windowDimensions,
650+ ArrayRef<int64_t> windowStrides, int64_t dim) {
651+ @@ -146,7 +146,7 @@ void addGatherScatterFactors(
652+ slicesDim, slicesDimSize,
653+ FactorType::kNeedReplication);
654+ } else if (slicesDimSize == 1) {
655+ - // To keep the operand dim sharded, we need an all-reduce on the result.
656+ + // To keep the operand dim sharded, we need an all-reduce on the instead of kReduction which would insert all-reduce after gather.
657+ addUnblockedFactorFn(inputDim, /*indicesDim=*/kNullDim,
658+ /*slicesDim=*/kNullDim, inputDimSize,
659+ FactorType::kReduction);
660+ @@ -166,7 +166,7 @@ void addGatherScatterFactors(
661+ assert(indicesDim < startIndices.getRank());
662+
663+ // If `indicesDim` is in `indicesBatchingDims`, This is an explicit batch
664+ - // dimension across input, indices, and result. Otherwise, it is an
665+ + // dimension across input, indices, and instead of kReduction which would insert all-reduce after gather. Otherwise, it is an
666+ // implicit batch dimension across input and result only.
667+ const auto* batchingDimIt = llvm::find(indicesBatchingDims, indicesDim);
668+ bool isExplicitBatchDim = batchingDimIt != indicesBatchingDims.end();
654669@@ -187,12 +187,12 @@ void addGatherScatterFactors(
655670
656671 // We add factors for all collapsed slice dimensions.
@@ -667,58 +682,28 @@ index ab93067..b181797 100644
667682 }
668683
669684 // Add a factor for the index-vector-dim, if it's present.
670- @@ -303,6 +303,37 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
671- }
672- return builder.build();
673- })
674- + .Case<stablehlo::BatchNormInferenceOp>(
675- + [conservativePropagation](stablehlo::BatchNormInferenceOp bn) {
676- + auto inTy = llvm::cast<mlir::RankedTensorType>(bn.getOperand().getType());
677- + auto outTy = llvm::cast<mlir::RankedTensorType>(bn.getResult().getType());
678- +
679- + OpShardingRuleBuilder builder(bn);
680- +
681- + const int64_t numOperands = static_cast<int64_t>(bn->getNumOperands());
682- + llvm::SmallVector<int64_t> opDims(numOperands, kNullDim);
683- +
684- + for (auto [dU, dimSize] : llvm::enumerate(inTy.getShape())) {
685- + const int64_t d = static_cast<int64_t>(dU);
686- + std::fill(opDims.begin(), opDims.end(), kNullDim);
687- + opDims[0] = d;
688- + builder.addFactor(opDims, d, dimSize);
689- + }
690- +
691- + const int64_t featAxis = static_cast<int64_t>(bn.getFeatureIndex());
692- + const int64_t C = outTy.getDimSize(featAxis);
693- +
694- + for (int64_t paramIdx : {1LL, 2LL, 3LL, 4LL}) {
695- + std::fill(opDims.begin(), opDims.end(), kNullDim);
696- + opDims[paramIdx] = 0;
697- + auto factorType = conservativePropagation ? FactorType::kNeedReplication
698- + : FactorType::kPassThrough;
699- + builder.addFactor(opDims, kNullDim, C,
700- + factorType, true);
701- + }
702- +
703- + return builder.build();
704- + })
705- .Case<stablehlo::BitcastConvertOp>(
706- [](stablehlo::BitcastConvertOp bitcastConvert) {
707- ArrayRef<int64_t> inShape =
708- @@ -685,6 +716,12 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
709- /*isBlocked=*/usedByRngBitGenerator)
710- .build();
685+ @@ -707,6 +707,11 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
686+ }
687+ return OpShardingRuleBuilder::buildPointwise(customCall);
711688 }
712689+ // Check if the custom call implements the ShardingRuleOpInterface.
713690+ if (auto shardingRuleOp =
714691+ llvm::dyn_cast<ShardingRuleOpInterface>(customCall.getOperation())) {
715692+ return shardingRuleOp.getShardingRule();
716693+ }
717- +
718694 // TODO(b/327191011): output unregistered op stats instead.
719695 static llvm::once_flag onceFlag;
720696 emitOpWarningOnce(
721- @@ -1093,6 +1130,16 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
697+ @@ -921,7 +926,7 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
698+ FactorType::kReduction);
699+ } else {
700+ // Not a reduced dimension. So have a mapping b/w the operand and
701+ - // result.
702+ + // instead of kReduction which would insert all-reduce after gather.
703+ assert(resultType.getDimSize(outDim) == dimSize);
704+ resultDims.assign(numInputs, outDim++);
705+ builder.addFactor(operandDims, resultDims, dimSize);
706+ @@ -1115,6 +1120,15 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
722707 return builder.build();
723708 })
724709 .Case<stablehlo::ScatterOp>([](stablehlo::ScatterOp scatter) {
@@ -731,7 +716,6 @@ index ab93067..b181797 100644
731716+ }
732717+ // If custom rule returns null, fall through to default.
733718+ }
734- +
735719 OpShardingRuleBuilder builder(scatter);
736720
737721 // Since all inputs and results have compatible shapes, we can look at
@@ -1191,10 +1175,10 @@ index 0000000..30a5cf9
11911175+
11921176+ #endif
11931177diff --git a/shardy/integrations/python/ir/sdy_module.cc b/shardy/integrations/python/ir/sdy_module.cc
1194- index da451fa..44c0ea2 100644
1178+ index cd7fdc8..1b5aa5b 100644
11951179--- a/shardy/integrations/python/ir/sdy_module.cc
11961180+++ b/shardy/integrations/python/ir/sdy_module.cc
1197- @@ -109 ,7 +109 ,15 @@ NB_MODULE(_sdy, m) {
1181+ @@ -110 ,7 +110 ,15 @@ NB_MODULE(_sdy, m) {
11981182 })
11991183 .def_property_readonly("size", [](MlirAttribute self) {
12001184 return sdyMeshAxisAttrGetSize(self);
@@ -1211,7 +1195,7 @@ index da451fa..44c0ea2 100644
12111195
12121196 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12131197 m, "MeshAttr", sdyAttributeIsAMeshAttr)
1214- @@ -133 ,7 +141 ,15 @@ NB_MODULE(_sdy, m) {
1198+ @@ -134 ,7 +142 ,15 @@ NB_MODULE(_sdy, m) {
12151199 .def_property_readonly("axes", [](MlirAttribute self) {
12161200 return propertyVector<MlirAttribute>(self, sdyMeshAttrGetAxesSize,
12171201 sdyMeshAttrGetAxesElem);
@@ -1228,7 +1212,7 @@ index da451fa..44c0ea2 100644
12281212
12291213 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12301214 m, "SubAxisInfoAttr", sdyAttributeIsASubAxisInfoAttr)
1231- @@ -150 ,7 +166 ,15 @@ NB_MODULE(_sdy, m) {
1215+ @@ -151 ,7 +167 ,15 @@ NB_MODULE(_sdy, m) {
12321216 [](MlirAttribute self) { return sdySubAxisInfoAttrGetPreSize(self); })
12331217 .def_property_readonly("size", [](MlirAttribute self) {
12341218 return sdySubAxisInfoAttrGetSize(self);
@@ -1245,7 +1229,7 @@ index da451fa..44c0ea2 100644
12451229
12461230 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12471231 m, "AxisRefAttr", sdyAttributeIsAnAxisRefAttr)
1248- @@ -175 ,7 +199 ,15 @@ NB_MODULE(_sdy, m) {
1232+ @@ -176 ,7 +200 ,15 @@ NB_MODULE(_sdy, m) {
12491233 MlirAttribute subAxisInfo = sdyAxisRefAttrGetSubAxisInfo(self);
12501234 return subAxisInfo.ptr == nullptr ? std::nullopt
12511235 : std::optional(subAxisInfo);
@@ -1262,7 +1246,7 @@ index da451fa..44c0ea2 100644
12621246
12631247 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12641248 m, "DimensionShardingAttr", sdyAttributeIsADimensionShardingAttr)
1265- @@ -205 ,7 +237 ,15 @@ NB_MODULE(_sdy, m) {
1249+ @@ -206 ,7 +238 ,15 @@ NB_MODULE(_sdy, m) {
12661250 .def_property_readonly("priority", [](MlirAttribute self) {
12671251 int64_t priority = sdyDimensionShardingAttrGetPriority(self);
12681252 return priority == -1 ? std::nullopt : std::optional(priority);
@@ -1279,7 +1263,7 @@ index da451fa..44c0ea2 100644
12791263
12801264 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12811265 m, "TensorShardingAttr", sdyAttributeIsATensorShardingAttr)
1282- @@ -251 ,7 +291 ,15 @@ NB_MODULE(_sdy, m) {
1266+ @@ -252 ,7 +292 ,15 @@ NB_MODULE(_sdy, m) {
12831267 return propertyVector<MlirAttribute>(
12841268 self, sdyTensorShardingAttrGetUnreducedAxesSize,
12851269 sdyTensorShardingAttrGetUnreducedAxesElem);
@@ -1296,7 +1280,7 @@ index da451fa..44c0ea2 100644
12961280
12971281 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
12981282 m, "TensorShardingPerValueAttr",
1299- @@ -270 ,7 +318 ,15 @@ NB_MODULE(_sdy, m) {
1283+ @@ -271 ,7 +319 ,15 @@ NB_MODULE(_sdy, m) {
13001284 return propertyVector<MlirAttribute>(
13011285 self, sdyTensorShardingPerValueAttrGetShardingsSize,
13021286 sdyTensorShardingPerValueAttrGetShardingsElem);
@@ -1313,7 +1297,7 @@ index da451fa..44c0ea2 100644
13131297
13141298 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
13151299 m, "DimMappingAttr", sdyAttributeIsADimMappingAttr)
1316- @@ -288 ,7 +344 ,15 @@ NB_MODULE(_sdy, m) {
1300+ @@ -289 ,7 +345 ,15 @@ NB_MODULE(_sdy, m) {
13171301 return propertyVector<intptr_t>(self,
13181302 sdyDimMappingAttrGetFactorIndicesSize,
13191303 sdyDimMappingAttrGetFactorIndicesElem);
@@ -1330,7 +1314,7 @@ index da451fa..44c0ea2 100644
13301314
13311315 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
13321316 m, "TensorMappingAttr", sdyAttributeIsATensorMappingAttr)
1333- @@ -310 ,7 +374 ,15 @@ NB_MODULE(_sdy, m) {
1317+ @@ -311 ,7 +375 ,15 @@ NB_MODULE(_sdy, m) {
13341318 })
13351319 .def_property_readonly("rank", [](MlirAttribute self) {
13361320 return sdyTensorMappingAttrGetRank(self);
@@ -1347,7 +1331,7 @@ index da451fa..44c0ea2 100644
13471331
13481332 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
13491333 m, "OpShardingRuleAttr", sdyAttributeIsAOpShardingRuleAttr)
1350- @@ -394 ,6 +466 ,14 @@ NB_MODULE(_sdy, m) {
1334+ @@ -395 ,6 +467 ,14 @@ NB_MODULE(_sdy, m) {
13511335 return propertyVector<intptr_t>(
13521336 self, sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize,
13531337 sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem);
@@ -1362,7 +1346,7 @@ index da451fa..44c0ea2 100644
13621346 });
13631347
13641348 mlir::python::nanobind_adaptors::mlir_attribute_subclass(
1365- @@ -417 ,7 +497 ,67 @@ NB_MODULE(_sdy, m) {
1349+ @@ -418 ,7 +498 ,67 @@ NB_MODULE(_sdy, m) {
13661350 })
13671351 .def("__len__", [](MlirAttribute& self) {
13681352 return sdyManualAxesAttrGetAxesSize(self);
0 commit comments