Skip to content

Commit e7b3c16

Browse files
pravnarTensorFlow MLIR Team
authored andcommitted
CHLO defns for a ragged dot that permits ragged batch and contraction.
PiperOrigin-RevId: 706767602
1 parent f32ade0 commit e7b3c16

File tree

7 files changed

+794
-1
lines changed

7 files changed

+794
-1
lines changed

mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <optional>
1617
#include <utility>
1718

1819
#include "mhlo/IR/hlo_ops.h"
@@ -22,9 +23,12 @@ limitations under the License.
2223
#include "mlir/Dialect/Func/IR/FuncOps.h"
2324
#include "mlir/Dialect/Shape/IR/Shape.h"
2425
#include "mlir/Dialect/Tensor/IR/Tensor.h"
26+
#include "mlir/IR/BuiltinAttributes.h"
27+
#include "mlir/IR/BuiltinOps.h"
2528
#include "mlir/IR/MLIRContext.h"
2629
#include "mlir/IR/PatternMatch.h"
2730
#include "mlir/Pass/PassManager.h"
31+
#include "mlir/Support/LLVM.h"
2832
#include "mlir/Support/LogicalResult.h"
2933
#include "mlir/Transforms/DialectConversion.h"
3034
#include "stablehlo/dialect/ChloOps.h"
@@ -56,7 +60,8 @@ struct ChloLegalizeToHighLevelMhloPass
5660
// Consider the mhlo dialect legal for tests. Also add helper dialects
5761
// that are needed by the patterns.
5862
conversionTarget.addLegalDialect<chlo::ChloDialect, mhlo::MhloDialect>();
59-
conversionTarget.addIllegalOp<chlo::TopKOp, chlo::ErfOp>();
63+
conversionTarget
64+
.addIllegalOp<chlo::TopKOp, chlo::ErfOp, chlo::RaggedDotOp>();
6065

6166
if (failed(applyPartialConversion(getOperation(), conversionTarget,
6267
std::move(conversionPatterns)))) {
@@ -93,6 +98,64 @@ struct ChloLegalizeToHloPass
9398
}
9499
};
95100

101+
struct RaggedDotChloToMhlo : public OpRewritePattern<chlo::RaggedDotOp> {
102+
using OpRewritePattern<chlo::RaggedDotOp>::OpRewritePattern;
103+
104+
LogicalResult matchAndRewrite(chlo::RaggedDotOp raggedDotOp,
105+
PatternRewriter &rewriter) const override {
106+
auto moduleOp = raggedDotOp->getParentOfType<ModuleOp>();
107+
108+
OpBuilder builder(moduleOp.getBodyRegion());
109+
builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
110+
111+
auto chloRaggedDotDimNums = raggedDotOp.getRaggedDotDimensionNumbers();
112+
auto dotDimNums = mhlo::DotDimensionNumbersAttr::get(
113+
builder.getContext(), chloRaggedDotDimNums.getLhsBatchingDimensions(),
114+
chloRaggedDotDimNums.getRhsBatchingDimensions(),
115+
chloRaggedDotDimNums.getLhsContractingDimensions(),
116+
chloRaggedDotDimNums.getRhsContractingDimensions());
117+
auto raggedDotDimNums = mhlo::RaggedDotDimensionNumbersAttr::get(
118+
builder.getContext(), dotDimNums,
119+
chloRaggedDotDimNums.getLhsRaggedDimensions(),
120+
chloRaggedDotDimNums.getRhsGroupDimensions());
121+
122+
auto mhloPrecision =
123+
[](chlo::Precision precision) -> std::optional<mhlo::Precision> {
124+
switch (precision) {
125+
case chlo::Precision::DEFAULT:
126+
return mhlo::Precision::DEFAULT;
127+
case chlo::Precision::HIGH:
128+
return mhlo::Precision::HIGH;
129+
case chlo::Precision::HIGHEST:
130+
return mhlo::Precision::HIGHEST;
131+
}
132+
};
133+
ArrayAttr precisionConfig = rewriter.getArrayAttr({});
134+
if (raggedDotOp.getPrecisionConfig().has_value()) {
135+
SmallVector<Attribute> vector;
136+
for (auto configValue : raggedDotOp.getPrecisionConfig()
137+
.value()
138+
.getAsRange<chlo::PrecisionAttr>()) {
139+
vector.push_back(
140+
PrecisionAttr::get(raggedDotOp.getContext(),
141+
mhloPrecision(configValue.getValue()).value()));
142+
}
143+
precisionConfig = rewriter.getArrayAttr(vector);
144+
}
145+
146+
rewriter.replaceOp(
147+
raggedDotOp,
148+
rewriter
149+
.create<mhlo::RaggedDotOp>(
150+
raggedDotOp.getLoc(), raggedDotOp.getResult().getType(),
151+
raggedDotOp.getLhs(), raggedDotOp.getRhs(),
152+
raggedDotOp.getGroupSizes(), raggedDotDimNums, precisionConfig)
153+
.getOperation());
154+
155+
return success();
156+
}
157+
};
158+
96159
} // namespace
97160

98161
} // namespace mhlo
@@ -105,6 +168,7 @@ namespace {
105168

106169
void populateChloToHighLevelMhloOpPatterns(MLIRContext *,
107170
RewritePatternSet *patterns) {
171+
patterns->add<mhlo::RaggedDotChloToMhlo>(patterns->getContext());
108172
populateWithGenerated(*patterns);
109173
}
110174

stablehlo/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ cc_library(
315315
":chlo_attrs_inc_gen",
316316
":chlo_enums_inc_gen",
317317
":chlo_ops_inc_gen",
318+
":stablehlo_assembly_format",
318319
":stablehlo_type_inference",
319320
"@llvm-project//llvm:Support",
320321
"@llvm-project//mlir:BytecodeOpInterface",

stablehlo/stablehlo/dialect/ChloEnums.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,29 @@ def CHLO_ComparisonType : I32EnumAttr<"ComparisonType",
7070

7171
def CHLO_ComparisonTypeAttr : EnumAttr<CHLO_Dialect, CHLO_ComparisonType, "comparison_type">;
7272

73+
//===----------------------------------------------------------------------===//
74+
// Ragged dot op definitions.
75+
//===----------------------------------------------------------------------===//
76+
77+
// These mirror the XLA PrecisionConfig proto enum.
78+
def CHLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
79+
def CHLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>;
80+
def CHLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>;
81+
82+
def CHLO_Precision : I32EnumAttr<"Precision",
83+
"XLA precision for an operand. Has backend specific meaning.",
84+
[
85+
CHLO_PRECISION_DEFAULT,
86+
CHLO_PRECISION_HIGH,
87+
CHLO_PRECISION_HIGHEST
88+
]> {
89+
let genSpecializedAttr = 0;
90+
let cppNamespace = "::mlir::chlo";
91+
}
92+
93+
def CHLO_PrecisionAttr : EnumAttr<CHLO_Dialect, CHLO_Precision, "precision">;
94+
95+
def CHLO_PrecisionConfigAttr:
96+
TypedArrayAttrBase<CHLO_PrecisionAttr, "Precision Config attribute">;
97+
7398
#endif // STABLEHLO_DIALECT_CHLO_ENUMS

0 commit comments

Comments
 (0)