@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16+ #include < optional>
1617#include < utility>
1718
1819#include " mhlo/IR/hlo_ops.h"
@@ -22,9 +23,12 @@ limitations under the License.
2223#include " mlir/Dialect/Func/IR/FuncOps.h"
2324#include " mlir/Dialect/Shape/IR/Shape.h"
2425#include " mlir/Dialect/Tensor/IR/Tensor.h"
26+ #include " mlir/IR/BuiltinAttributes.h"
27+ #include " mlir/IR/BuiltinOps.h"
2528#include " mlir/IR/MLIRContext.h"
2629#include " mlir/IR/PatternMatch.h"
2730#include " mlir/Pass/PassManager.h"
31+ #include " mlir/Support/LLVM.h"
2832#include " mlir/Support/LogicalResult.h"
2933#include " mlir/Transforms/DialectConversion.h"
3034#include " stablehlo/dialect/ChloOps.h"
@@ -56,7 +60,8 @@ struct ChloLegalizeToHighLevelMhloPass
5660 // Consider the mhlo dialect legal for tests. Also add helper dialects
5761 // that are needed by the patterns.
5862 conversionTarget.addLegalDialect <chlo::ChloDialect, mhlo::MhloDialect>();
59- conversionTarget.addIllegalOp <chlo::TopKOp, chlo::ErfOp>();
63+ conversionTarget
64+ .addIllegalOp <chlo::TopKOp, chlo::ErfOp, chlo::RaggedDotOp>();
6065
6166 if (failed (applyPartialConversion (getOperation (), conversionTarget,
6267 std::move (conversionPatterns)))) {
@@ -93,6 +98,64 @@ struct ChloLegalizeToHloPass
9398 }
9499};
95100
101+ struct RaggedDotChloToMhlo : public OpRewritePattern <chlo::RaggedDotOp> {
102+ using OpRewritePattern<chlo::RaggedDotOp>::OpRewritePattern;
103+
104+ LogicalResult matchAndRewrite (chlo::RaggedDotOp raggedDotOp,
105+ PatternRewriter &rewriter) const override {
106+ auto moduleOp = raggedDotOp->getParentOfType <ModuleOp>();
107+
108+ OpBuilder builder (moduleOp.getBodyRegion ());
109+ builder.setInsertionPointToStart (&moduleOp.getBodyRegion ().front ());
110+
111+ auto chloRaggedDotDimNums = raggedDotOp.getRaggedDotDimensionNumbers ();
112+ auto dotDimNums = mhlo::DotDimensionNumbersAttr::get (
113+ builder.getContext (), chloRaggedDotDimNums.getLhsBatchingDimensions (),
114+ chloRaggedDotDimNums.getRhsBatchingDimensions (),
115+ chloRaggedDotDimNums.getLhsContractingDimensions (),
116+ chloRaggedDotDimNums.getRhsContractingDimensions ());
117+ auto raggedDotDimNums = mhlo::RaggedDotDimensionNumbersAttr::get (
118+ builder.getContext (), dotDimNums,
119+ chloRaggedDotDimNums.getLhsRaggedDimensions (),
120+ chloRaggedDotDimNums.getRhsGroupDimensions ());
121+
122+ auto mhloPrecision =
123+ [](chlo::Precision precision) -> std::optional<mhlo::Precision> {
124+ switch (precision) {
125+ case chlo::Precision::DEFAULT:
126+ return mhlo::Precision::DEFAULT;
127+ case chlo::Precision::HIGH:
128+ return mhlo::Precision::HIGH;
129+ case chlo::Precision::HIGHEST:
130+ return mhlo::Precision::HIGHEST;
131+ }
132+ };
133+ ArrayAttr precisionConfig = rewriter.getArrayAttr ({});
134+ if (raggedDotOp.getPrecisionConfig ().has_value ()) {
135+ SmallVector<Attribute> vector;
136+ for (auto configValue : raggedDotOp.getPrecisionConfig ()
137+ .value ()
138+ .getAsRange <chlo::PrecisionAttr>()) {
139+ vector.push_back (
140+ PrecisionAttr::get (raggedDotOp.getContext (),
141+ mhloPrecision (configValue.getValue ()).value ()));
142+ }
143+ precisionConfig = rewriter.getArrayAttr (vector);
144+ }
145+
146+ rewriter.replaceOp (
147+ raggedDotOp,
148+ rewriter
149+ .create <mhlo::RaggedDotOp>(
150+ raggedDotOp.getLoc (), raggedDotOp.getResult ().getType (),
151+ raggedDotOp.getLhs (), raggedDotOp.getRhs (),
152+ raggedDotOp.getGroupSizes (), raggedDotDimNums, precisionConfig)
153+ .getOperation ());
154+
155+ return success ();
156+ }
157+ };
158+
96159} // namespace
97160
98161} // namespace mhlo
@@ -105,6 +168,7 @@ namespace {
105168
106169void populateChloToHighLevelMhloOpPatterns (MLIRContext *,
107170 RewritePatternSet *patterns) {
171+ patterns->add <mhlo::RaggedDotChloToMhlo>(patterns->getContext ());
108172 populateWithGenerated (*patterns);
109173}
110174
0 commit comments