Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 797757b

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Remove the constraint that min / max should stride zero
Since we apply nudging for the zero point to make sure the nudged zerop points can be in the range of [qmin, qmax], the constraint that rmin / rmax should stride zero isn't necessary. This also matches the documentation of tensorflow's FakeQuantWithMinMaxArgs op, where min and max don't need to stride zero: https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args PiperOrigin-RevId: 268296285
1 parent 8066c22 commit 797757b

File tree

3 files changed

+50
-39
lines changed

3 files changed

+50
-39
lines changed

lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
5454
return false;
5555
}
5656

57-
void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
58-
double &scale, int64_t &nudgedZeroPoint) {
57+
// This is a specific implementation of nudging:
58+
// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
59+
// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
60+
// point is derived from the shifted range, and the scale isn't changed. As
61+
// a consequence some values, which are supposeed in the original [rmin, rmax]
62+
// range will be outside the shifted range and be clamped during quantization.
63+
// TODO(fengliuai): we should nudge the scale as well, but that requires the
64+
// fake quant op used in the training to use the nudged scale as well.
65+
void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
66+
double rmax, double &scale,
67+
int64_t &nudgedZeroPoint) {
5968
// Determine the scale.
6069
const double qminDouble = qmin;
6170
const double qmaxDouble = qmax;
@@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
100109
double rmin, double rmax,
101110
bool narrowRange, Type expressedType,
102111
bool isSigned) {
103-
// Range must straddle zero.
104-
// TODO(b/140641593): remove this constraint.
105-
if (rmin > 0.0 || rmax < 0.0) {
106-
return (emitError(loc, "FakeQuant range must straddle zero: [")
107-
<< rmin << "," << rmax << "]",
108-
nullptr);
109-
}
110-
111112
MLIRContext *ctx = expressedType.getContext();
112113
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113114
Type storageType;
@@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
129130

130131
double scale;
131132
int64_t nudgedZeroPoint;
132-
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133+
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133134

134135
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
135136
scale, nudgedZeroPoint, qmin, qmax,
@@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
172173

173174
double scale;
174175
int64_t nudgedZeroPoint;
175-
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176+
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176177
scales.push_back(scale);
177178
zeroPoints.push_back(nudgedZeroPoint);
178179
}

test/Dialect/QuantOps/convert-fakequant-invalid.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,5 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization
22

3-
// -----
4-
// Verify that a mismatched range errors.
5-
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
6-
^bb0(%arg0: tensor<8x4x3xf32>):
7-
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}}
8-
%0 = "quant.const_fake_quant"(%arg0) {
9-
min = 1.1 : f32, max = 1.5 : f32, num_bits = 8
10-
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
11-
return %0 : tensor<8x4x3xf32>
12-
}
13-
14-
// -----
15-
// Verify that a valid range errors.
16-
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
17-
^bb0(%arg0: tensor<8x4x3xf32>):
18-
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}}
19-
%0 = "quant.const_fake_quant"(%arg0) {
20-
min = 1.1 : f32, max = 1.0 : f32, num_bits = 8
21-
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
22-
return %0 : tensor<8x4x3xf32>
23-
}
24-
253
// -----
264
// Unsupported quantizable type (i1 is currently not a supported element type).
275
func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {

test/Dialect/QuantOps/convert-fakequant.mlir

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
4747

4848
// -----
4949
// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
50-
// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
50+
// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
5151
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
5252
^bb0(%arg0: tensor<8x4x3xf32>):
5353
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@@ -62,7 +62,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
6262

6363
// -----
6464
// Verifies a quint8 symmetric range of -1..127/128.
65-
// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
65+
// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
6666
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
6767
^bb0(%arg0: tensor<8x4x3xf32>):
6868
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@@ -122,7 +122,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
122122

123123
// -----
124124
// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
125-
// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange
125+
// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
126126
func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
127127
^bb0(%arg0: tensor<8x4x3xf32>):
128128
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@@ -137,7 +137,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
137137

138138
// -----
139139
// Verifies a qint8 symmetric range of -1..127/128.
140-
// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange
140+
// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
141141
func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
142142
^bb0(%arg0: tensor<8x4x3xf32>):
143143
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
@@ -181,9 +181,41 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
181181
return %0 : tensor<f32>
182182
}
183183

184+
// -----
185+
// CHECK-LABEL: fakeQuantArgs_all_positive
186+
func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
187+
^bb0(%arg0: tensor<8x4x3xf32>):
188+
189+
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
190+
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
191+
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
192+
// CHECK-SAME: -> tensor<8x4x3xf32>
193+
194+
%0 = "quant.const_fake_quant"(%arg0) {
195+
min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
196+
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
197+
return %0 : tensor<8x4x3xf32>
198+
}
199+
200+
// -----
201+
// CHECK-LABEL: fakeQuantArgs_all_negative
202+
func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
203+
^bb0(%arg0: tensor<8x4x3xf32>):
204+
205+
// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
206+
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>
207+
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>)
208+
// CHECK-SAME: -> tensor<8x4x3xf32>
209+
210+
%0 = "quant.const_fake_quant"(%arg0) {
211+
min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
212+
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
213+
return %0 : tensor<8x4x3xf32>
214+
}
215+
184216
// -----
185217
// Verifies a qint8 per axis
186-
// CHECK_LABEL: fakeQuantPerAxis
218+
// CHECK-LABEL: fakeQuantPerAxis
187219
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
188220
^bb0(%arg0: tensor<8x4x3xf32>):
189221

0 commit comments

Comments
 (0)