Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 8066c22

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Convert ConstFakeQuantPerAxis to qcast and dcast pair
This is also to add the test to the fakeQuantAttrsToType for per-channel fake quant. PiperOrigin-RevId: 268260032
1 parent 083a0ae commit 8066c22

File tree

3 files changed

+89
-27
lines changed

3 files changed

+89
-27
lines changed

lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,67 +37,111 @@ class ConvertSimulatedQuantPass
3737

3838
} // end anonymous namespace
3939

40-
/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
41-
class ConstFakeQuantRewrite : public RewritePattern {
40+
/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
41+
template <typename ConcretRewriteClass, typename FakeQuantOp>
42+
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
4243
public:
43-
bool *hadFailure;
44+
using OpRewritePattern<FakeQuantOp>::OpRewritePattern;
4445

45-
ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
46-
: RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
47-
hadFailure(hadFailure) {}
46+
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
47+
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}
4848

49-
PatternMatchResult matchAndRewrite(Operation *op,
49+
PatternMatchResult matchAndRewrite(FakeQuantOp op,
5050
PatternRewriter &rewriter) const override {
5151
// TODO: If this pattern comes up more frequently, consider adding core
5252
// support for failable rewrites.
5353
if (failableRewrite(op, rewriter)) {
5454
*hadFailure = true;
55-
return matchFailure();
55+
return Pattern::matchFailure();
5656
}
5757

58-
return matchSuccess();
58+
return Pattern::matchSuccess();
5959
}
6060

61-
bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
62-
auto fqOp = cast<ConstFakeQuant>(op);
61+
private:
62+
bool *hadFailure;
6363

64-
auto converter =
65-
ExpressedToQuantizedConverter::forInputType(fqOp.getType());
64+
bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
65+
auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
6666
if (!converter) {
67-
return (op->emitError("unsupported quantized type conversion"), true);
67+
return (op.emitError("unsupported quantized type conversion"), true);
6868
}
6969

70-
UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
71-
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
72-
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
73-
fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
70+
QuantizedType elementType =
71+
static_cast<const ConcretRewriteClass *>(this)
72+
->convertFakeQuantAttrsToType(op, converter.expressedType);
7473

75-
if (!uniformElementType) {
74+
if (!elementType) {
7675
// Note that the fakeQuantAttrsToType will have emitted the error.
7776
return true;
7877
}
7978

80-
Type quantizedType = converter.convert(uniformElementType);
79+
Type quantizedType = converter.convert(elementType);
8180
assert(quantizedType &&
8281
"Converter accepted a type that it did not convert");
8382

8483
// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
8584
// this is a forced/hard-coded constraint.
86-
auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
87-
fqOp.inputs());
85+
auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
86+
op.inputs());
8887
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
8988
qbarrier.getResult());
9089

9190
return false;
9291
}
9392
};
9493

94+
class ConstFakeQuantRewrite
95+
: public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
96+
public:
97+
using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;
98+
99+
ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
100+
: BaseRewrite(ctx, hadFailure) {}
101+
102+
QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
103+
Type expressedType) const {
104+
return fakeQuantAttrsToType(
105+
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
106+
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
107+
fqOp.narrow_range(), expressedType, fqOp.is_signed());
108+
}
109+
};
110+
111+
class ConstFakeQuantPerAxisRewrite
112+
: public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
113+
ConstFakeQuantPerAxis> {
114+
public:
115+
using BaseRewrite =
116+
FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;
117+
118+
ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
119+
: BaseRewrite(ctx, hadFailure) {}
120+
121+
QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
122+
Type expressedType) const {
123+
SmallVector<double, 4> min, max;
124+
min.reserve(fqOp.min().size());
125+
max.reserve(fqOp.max().size());
126+
for (auto m : fqOp.min())
127+
min.push_back(m.cast<FloatAttr>().getValueAsDouble());
128+
for (auto m : fqOp.max())
129+
max.push_back(m.cast<FloatAttr>().getValueAsDouble());
130+
131+
return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
132+
fqOp.axis().getSExtValue(), min, max,
133+
fqOp.narrow_range(), expressedType,
134+
fqOp.is_signed());
135+
}
136+
};
137+
95138
void ConvertSimulatedQuantPass::runOnFunction() {
96139
bool hadFailure = false;
97140
OwningRewritePatternList patterns;
98141
auto func = getFunction();
99-
auto *context = &getContext();
100-
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
142+
auto ctx = func.getContext();
143+
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
144+
ctx, &hadFailure);
101145
applyPatternsGreedily(func, patterns);
102146
if (hadFailure)
103147
signalPassFailure();

lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
136136
loc);
137137
}
138138

139-
// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
140139
UniformQuantizedPerAxisType
141140
fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
142141
ArrayRef<double> rmins, ArrayRef<double> rmaxs,
@@ -180,8 +179,8 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
180179

181180
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
182181
return UniformQuantizedPerAxisType::getChecked(
183-
flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
184-
quantizedDimension, loc);
182+
flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
183+
qmin, qmax, loc);
185184
}
186185

187186
} // namespace quant

test/Dialect/QuantOps/convert-fakequant.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,22 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
180180
} : (tensor<f32>) -> tensor<f32>
181181
return %0 : tensor<f32>
182182
}
183+
184+
// -----
185+
// Verifies a qint8 per axis
186+
// CHECK_LABEL: fakeQuantPerAxis
187+
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
188+
^bb0(%arg0: tensor<8x4x3xf32>):
189+
190+
// CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
191+
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>
192+
// CHECK: %[[d:.*]] = "quant.dcast"(%[[q]])
193+
// CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>)
194+
195+
%0 = "quant.const_fake_quant_per_axis"(%arg0) {
196+
min = [-1.0 : f32, 0.0 : f32, 0.0 : f32],
197+
max = [0.9921875 : f32, 0.0: f32, 1.0 : f32],
198+
num_bits = 8, narrow_range = false, is_signed = true, axis = 2
199+
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
200+
return %0 : tensor<8x4x3xf32>
201+
}

0 commit comments

Comments
 (0)