Skip to content

Commit 952fafd

Browse files
nutsiepullytensorflower-gardener
authored andcommitted
Add quantize wrapper test.
Verifies quantization is applied before layer is executed. PiperOrigin-RevId: 255689225
1 parent 56e9194 commit 952fafd

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

tensorflow_model_optimization/python/core/quantization/keras/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ py_library(
8383
],
8484
)
8585

86+
py_test(
87+
name = "quantize_emulate_wrapper_test",
88+
srcs = [
89+
"quantize_emulate_wrapper_test.py",
90+
],
91+
python_version = "PY3",
92+
srcs_version = "PY2AND3",
93+
visibility = ["//visibility:public"],
94+
deps = [
95+
":quantize_emulate_wrapper",
96+
# numpy dep1,
97+
# tensorflow dep1,
98+
# python/keras tensorflow dep2,
99+
],
100+
)
101+
86102
py_library(
87103
name = "quantize_emulate",
88104
srcs = [
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for QuantizeEmulateWrapper."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import numpy as np
22+
23+
from tensorflow.python import keras
24+
from tensorflow.python.platform import test
25+
from tensorflow_model_optimization.python.core.quantization.keras import quantize_emulate_wrapper
26+
27+
QuantizeEmulateWrapper = quantize_emulate_wrapper.QuantizeEmulateWrapper
28+
29+
30+
class QuantizeEmulateWrapperTest(test.TestCase):
31+
32+
def setUp(self):
33+
self.quant_params = {
34+
'num_bits': 8,
35+
'narrow_range': False,
36+
'symmetric': True
37+
}
38+
39+
def testQuantizesWeightsInLayer(self):
40+
weights = lambda shape, dtype: np.array([[-1.0, 0.0], [0.0, 1.0]])
41+
model = keras.Sequential([
42+
QuantizeEmulateWrapper(
43+
keras.layers.Dense(2, kernel_initializer=weights),
44+
input_shape=(2,),
45+
**self.quant_params)
46+
])
47+
48+
# FakeQuant([-1.0, 1.0]) = [-0.9882355, 0.9882355]
49+
# Obtained from tf.fake_quant_with_min_max_vars
50+
self.assertAllClose(
51+
np.array([[-0.9882355, 0.9882355]]),
52+
# Inputs are all ones, so result comes directly from weights.
53+
model.predict(np.ones((1, 2))))
54+
55+
56+
if __name__ == '__main__':
57+
test.main()

0 commit comments

Comments
 (0)