@@ -37,67 +37,111 @@ class ConvertSimulatedQuantPass
3737
3838} // end anonymous namespace
3939
40- // / Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
41- class ConstFakeQuantRewrite : public RewritePattern {
40+ // / Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
41+ template <typename ConcretRewriteClass, typename FakeQuantOp>
42+ class FakeQuantRewrite : public OpRewritePattern <FakeQuantOp> {
4243public:
43- bool *hadFailure ;
44+ using OpRewritePattern<FakeQuantOp>::OpRewritePattern ;
4445
45- ConstFakeQuantRewrite (MLIRContext *context, bool *hadFailure)
46- : RewritePattern(ConstFakeQuant::getOperationName(), 1 , context),
47- hadFailure (hadFailure) {}
46+ FakeQuantRewrite (MLIRContext *ctx, bool *hadFailure)
47+ : OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
4848
49- PatternMatchResult matchAndRewrite (Operation * op,
49+ PatternMatchResult matchAndRewrite (FakeQuantOp op,
5050 PatternRewriter &rewriter) const override {
5151 // TODO: If this pattern comes up more frequently, consider adding core
5252 // support for failable rewrites.
5353 if (failableRewrite (op, rewriter)) {
5454 *hadFailure = true ;
55- return matchFailure ();
55+ return Pattern:: matchFailure ();
5656 }
5757
58- return matchSuccess ();
58+ return Pattern:: matchSuccess ();
5959 }
6060
61- bool failableRewrite (Operation *op, PatternRewriter &rewriter) const {
62- auto fqOp = cast<ConstFakeQuant>(op) ;
61+ private:
62+ bool *hadFailure ;
6363
64- auto converter =
65- ExpressedToQuantizedConverter::forInputType (fqOp .getType ());
64+ bool failableRewrite (FakeQuantOp op, PatternRewriter &rewriter) const {
65+ auto converter = ExpressedToQuantizedConverter::forInputType (op .getType ());
6666 if (!converter) {
67- return (op-> emitError (" unsupported quantized type conversion" ), true );
67+ return (op. emitError (" unsupported quantized type conversion" ), true );
6868 }
6969
70- UniformQuantizedType uniformElementType = fakeQuantAttrsToType (
71- fqOp.getLoc (), fqOp.num_bits ().getSExtValue (),
72- fqOp.min ().convertToFloat (), fqOp.max ().convertToFloat (),
73- fqOp.narrow_range (), converter.expressedType , fqOp.is_signed ());
70+ QuantizedType elementType =
71+ static_cast <const ConcretRewriteClass *>(this )
72+ ->convertFakeQuantAttrsToType (op, converter.expressedType );
7473
75- if (!uniformElementType ) {
74+ if (!elementType ) {
7675 // Note that the fakeQuantAttrsToType will have emitted the error.
7776 return true ;
7877 }
7978
80- Type quantizedType = converter.convert (uniformElementType );
79+ Type quantizedType = converter.convert (elementType );
8180 assert (quantizedType &&
8281 " Converter accepted a type that it did not convert" );
8382
8483 // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
8584 // this is a forced/hard-coded constraint.
86- auto qbarrier = rewriter.create <QuantizeCastOp>(op-> getLoc (), quantizedType,
87- fqOp .inputs ());
85+ auto qbarrier = rewriter.create <QuantizeCastOp>(op. getLoc (), quantizedType,
86+ op .inputs ());
8887 rewriter.replaceOpWithNewOp <DequantizeCastOp>(op, converter.inputType ,
8988 qbarrier.getResult ());
9089
9190 return false ;
9291 }
9392};
9493
94+ class ConstFakeQuantRewrite
95+ : public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
96+ public:
97+ using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
98+
99+ ConstFakeQuantRewrite (MLIRContext *ctx, bool *hadFailure)
100+ : BaseRewrite(ctx, hadFailure) {}
101+
102+ QuantizedType convertFakeQuantAttrsToType (ConstFakeQuant fqOp,
103+ Type expressedType) const {
104+ return fakeQuantAttrsToType (
105+ fqOp.getLoc (), fqOp.num_bits ().getSExtValue (),
106+ fqOp.min ().convertToFloat (), fqOp.max ().convertToFloat (),
107+ fqOp.narrow_range (), expressedType, fqOp.is_signed ());
108+ }
109+ };
110+
111+ class ConstFakeQuantPerAxisRewrite
112+ : public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
113+ ConstFakeQuantPerAxis> {
114+ public:
115+ using BaseRewrite =
116+ FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
117+
118+ ConstFakeQuantPerAxisRewrite (MLIRContext *ctx, bool *hadFailure)
119+ : BaseRewrite(ctx, hadFailure) {}
120+
121+ QuantizedType convertFakeQuantAttrsToType (ConstFakeQuantPerAxis fqOp,
122+ Type expressedType) const {
123+ SmallVector<double , 4 > min, max;
124+ min.reserve (fqOp.min ().size ());
125+ max.reserve (fqOp.max ().size ());
126+ for (auto m : fqOp.min ())
127+ min.push_back (m.cast <FloatAttr>().getValueAsDouble ());
128+ for (auto m : fqOp.max ())
129+ max.push_back (m.cast <FloatAttr>().getValueAsDouble ());
130+
131+ return fakeQuantAttrsToType (fqOp.getLoc (), fqOp.num_bits ().getSExtValue (),
132+ fqOp.axis ().getSExtValue (), min, max,
133+ fqOp.narrow_range (), expressedType,
134+ fqOp.is_signed ());
135+ }
136+ };
137+
95138void ConvertSimulatedQuantPass::runOnFunction () {
96139 bool hadFailure = false ;
97140 OwningRewritePatternList patterns;
98141 auto func = getFunction ();
99- auto *context = &getContext ();
100- patterns.insert <ConstFakeQuantRewrite>(context, &hadFailure);
142+ auto ctx = func.getContext ();
143+ patterns.insert <ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
144+ ctx, &hadFailure);
101145 applyPatternsGreedily (func, patterns);
102146 if (hadFailure)
103147 signalPassFailure ();
0 commit comments