Skip to content

Commit 110b6c9

Browse files
dengyinlintensorflower-gardener
authored andcommitted
Implements split_by_small_values and take_the_difference encoding stages.
PiperOrigin-RevId: 264075888
1 parent 4f2545a commit 110b6c9

File tree

4 files changed

+379
-0
lines changed

4 files changed

+379
-0
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ py_library(
1111
deps = [
1212
":clipping",
1313
":kashin",
14+
":misc",
1415
":quantization",
1516
],
1617
)
@@ -83,3 +84,25 @@ py_test(
8384
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/testing:test_utils",
8485
],
8586
)
87+
88+
py_library(
89+
name = "misc",
90+
srcs = ["misc.py"],
91+
deps = [
92+
# tensorflow dep1,
93+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/core:encoding_stage",
94+
],
95+
)
96+
97+
py_test(
98+
name = "misc_test",
99+
size = "small",
100+
srcs = ["misc_test.py"],
101+
deps = [
102+
":misc",
103+
# absl/testing:parameterized dep1,
104+
# numpy dep1,
105+
# tensorflow dep1,
106+
"//tensorflow_model_optimization/python/core/internal/tensor_encoding/testing:test_utils",
107+
],
108+
)

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.clipping import ClipByNormEncodingStage
2525
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.clipping import ClipByValueEncodingStage
2626
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.kashin import KashinHadamardEncodingStage
27+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.misc import DifferenceBetweenIntegersEncodingStage
28+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.misc import SplitBySmallValueEncodingStage
2729
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.quantization import PerChannelPRNGUniformQuantizationEncodingStage
2830
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.quantization import PerChannelUniformQuantizationEncodingStage
2931
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research.quantization import PRNGUniformQuantizationEncodingStage
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Misc."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.core import encoding_stage
23+
24+
25+
@encoding_stage.tf_style_encoding_stage
26+
class SplitBySmallValueEncodingStage(encoding_stage.EncodingStageInterface):
27+
"""Encoding stage splitting the input by small values.
28+
29+
This encoding stage will split the input into two outputs: the value and the
30+
indices of the elements whose absolute value is larger than a certain
31+
threshold. The elements smaller than the threshold is then decoded to zero.
32+
"""
33+
34+
ENCODED_INDICES_KEY = 'indices'
35+
ENCODED_VALUES_KEY = 'non_zero_floats'
36+
THRESHOLD_PARAMS_KEY = 'threshold'
37+
38+
def __init__(self, threshold=1e-8):
39+
"""Initializer for the SplitBySmallValueEncodingStage.
40+
41+
Args:
42+
threshold: The threshold of the small weights to be set to zero.
43+
"""
44+
self._threshold = threshold
45+
46+
@property
47+
def name(self):
48+
"""See base class."""
49+
return 'split_by_small_value'
50+
51+
@property
52+
def compressible_tensors_keys(self):
53+
"""See base class."""
54+
return [
55+
self.ENCODED_VALUES_KEY,
56+
self.ENCODED_INDICES_KEY,
57+
]
58+
59+
@property
60+
def commutes_with_sum(self):
61+
"""See base class."""
62+
return False
63+
64+
@property
65+
def decode_needs_input_shape(self):
66+
"""See base class."""
67+
return True
68+
69+
def get_params(self):
70+
"""See base class."""
71+
return {self.THRESHOLD_PARAMS_KEY: self._threshold}, {}
72+
73+
def encode(self, x, encode_params):
74+
"""See base class."""
75+
76+
threshold = tf.cast(encode_params[self.THRESHOLD_PARAMS_KEY], x.dtype)
77+
indices = tf.cast(tf.compat.v2.where(tf.abs(x) > threshold), tf.int32)
78+
non_zero_x = tf.gather_nd(x, indices)
79+
80+
return {
81+
self.ENCODED_INDICES_KEY: indices,
82+
self.ENCODED_VALUES_KEY: non_zero_x,
83+
}
84+
85+
def decode(self,
86+
encoded_tensors,
87+
decode_params,
88+
num_summands=None,
89+
shape=None):
90+
"""See base class."""
91+
del decode_params, num_summands # Unused.
92+
93+
indices = encoded_tensors[self.ENCODED_INDICES_KEY]
94+
non_zero_x = encoded_tensors[self.ENCODED_VALUES_KEY]
95+
96+
shape = tf.cast(shape, indices.dtype)
97+
decoded_x = tf.scatter_nd(indices=indices, updates=non_zero_x, shape=shape)
98+
99+
return decoded_x
100+
101+
102+
@encoding_stage.tf_style_encoding_stage
103+
class DifferenceBetweenIntegersEncodingStage(
104+
encoding_stage.EncodingStageInterface):
105+
"""Encoding stage taking the difference between a sequence of integers.
106+
107+
This encoding stage can be useful when the original integers can be large, but
108+
the difference of the integers are much smaller values and have a more compact
109+
representation. For example, it can be combined with the
110+
`SplitBySmallValueEncodingStage` to further compress the increasing sequence
111+
of indices.
112+
113+
The encode method expects a tensor with 1 dimension and with integer dtype.
114+
"""
115+
116+
ENCODED_VALUES_KEY = 'difference_between_integers'
117+
118+
@property
119+
def name(self):
120+
"""See base class."""
121+
return 'difference_between_integers'
122+
123+
@property
124+
def compressible_tensors_keys(self):
125+
"""See base class."""
126+
return [
127+
self.ENCODED_VALUES_KEY,
128+
]
129+
130+
@property
131+
def commutes_with_sum(self):
132+
"""See base class."""
133+
return False
134+
135+
@property
136+
def decode_needs_input_shape(self):
137+
"""See base class."""
138+
return False
139+
140+
def get_params(self):
141+
"""See base class."""
142+
return {}, {}
143+
144+
def encode(self, x, encode_params):
145+
"""See base class."""
146+
del encode_params # Unused.
147+
if x.shape.ndims != 1:
148+
raise ValueError('Number of dimensions must be 1. Shape of x: %s' %
149+
x.shape)
150+
if not x.dtype.is_integer:
151+
raise TypeError(
152+
'Unsupported input type: %s. Support only integer types.' % x.dtype)
153+
154+
diff_x = x - tf.concat([[0], x[:-1]], 0)
155+
return {
156+
self.ENCODED_VALUES_KEY: diff_x,
157+
}
158+
159+
def decode(self,
160+
encoded_tensors,
161+
decode_params,
162+
num_summands=None,
163+
shape=None):
164+
"""See base class."""
165+
del decode_params, num_summands, shape # Unused
166+
return tf.cumsum(encoded_tensors[self.ENCODED_VALUES_KEY])
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import itertools
20+
21+
from absl.testing import parameterized
22+
import numpy as np
23+
import tensorflow as tf
24+
25+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc
26+
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils
27+
28+
29+
class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest):
30+
31+
def default_encoding_stage(self):
32+
"""See base class."""
33+
return misc.SplitBySmallValueEncodingStage()
34+
35+
def default_input(self):
36+
"""See base class."""
37+
return tf.random.uniform([50], minval=-1.0, maxval=1.0)
38+
39+
@property
40+
def is_lossless(self):
41+
"""See base class."""
42+
return False
43+
44+
def common_asserts_for_test_data(self, data):
45+
"""See base class."""
46+
self._assert_is_integer(
47+
data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
48+
49+
def _assert_is_integer(self, indices):
50+
"""Asserts that indices values are integers."""
51+
assert indices.dtype == np.int32
52+
53+
@parameterized.parameters([tf.float32, tf.float64])
54+
def test_input_types(self, x_dtype):
55+
# Tests different input dtypes.
56+
x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
57+
threshold = 0.05
58+
stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
59+
encode_params, decode_params = stage.get_params()
60+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
61+
decode_params)
62+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
63+
test_data = self.evaluate_test_data(test_data)
64+
65+
self._assert_is_integer(test_data.encoded_x[
66+
misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])
67+
68+
# The numpy arrays must have the same dtype as the arrays from test_data.
69+
expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype)
70+
expected_encoded_indices = np.array([[0], [1]], dtype=np.int32)
71+
expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
72+
dtype=x_dtype.as_numpy_dtype)
73+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
74+
expected_encoded_values)
75+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
76+
expected_encoded_indices)
77+
self.assertAllEqual(test_data.decoded_x, expected_decoded_x)
78+
79+
def test_all_zero_input_works(self):
80+
# Tests that encoding does not blow up with all-zero input. With all-zero
81+
# input, both of the encoded values will be empty arrays.
82+
stage = misc.SplitBySmallValueEncodingStage()
83+
test_data = self.run_one_to_many_encode_decode(stage,
84+
lambda: tf.zeros([50]))
85+
86+
self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)
87+
88+
def test_all_below_threshold_works(self):
89+
# Tests that encoding does not blow up with all-below-threshold input. In
90+
# this case, both of the encoded values will be empty arrays.
91+
stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
92+
x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
93+
encode_params, decode_params = stage.get_params()
94+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
95+
decode_params)
96+
test_data = test_utils.TestData(x, encoded_x, decoded_x)
97+
test_data = self.evaluate_test_data(test_data)
98+
99+
expected_encoded_indices = np.array([], dtype=np.int32).reshape([0, 1])
100+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
101+
self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
102+
expected_encoded_indices)
103+
self.assertAllEqual(test_data.decoded_x,
104+
np.zeros([50], dtype=x.dtype.as_numpy_dtype))
105+
106+
107+
class DifferenceBetweenIntegersEncodingStageTest(
108+
test_utils.BaseEncodingStageTest):
109+
110+
def default_encoding_stage(self):
111+
"""See base class."""
112+
return misc.DifferenceBetweenIntegersEncodingStage()
113+
114+
def default_input(self):
115+
"""See base class."""
116+
return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64)
117+
118+
@property
119+
def is_lossless(self):
120+
"""See base class."""
121+
return True
122+
123+
def common_asserts_for_test_data(self, data):
124+
"""See base class."""
125+
self.assertAllEqual(data.x, data.decoded_x)
126+
127+
@parameterized.parameters(
128+
itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64]))
129+
def test_with_multiple_input_shapes(self, input_dims, dtype):
130+
131+
def x_fn():
132+
return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype)
133+
134+
test_data = self.run_one_to_many_encode_decode(
135+
self.default_encoding_stage(), x_fn)
136+
self.common_asserts_for_test_data(test_data)
137+
138+
def test_empty_input_static(self):
139+
# Tests that the encoding works when the input shape is [0].
140+
x = []
141+
x = tf.convert_to_tensor(x, dtype=tf.int32)
142+
assert x.shape.as_list() == [0]
143+
144+
stage = self.default_encoding_stage()
145+
encode_params, decode_params = stage.get_params()
146+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
147+
decode_params)
148+
149+
test_data = self.evaluate_test_data(
150+
test_utils.TestData(x, encoded_x, decoded_x))
151+
self.common_asserts_for_test_data(test_data)
152+
153+
def test_empty_input_dynamic(self):
154+
# Tests that the encoding works when the input shape is [0], but not
155+
# statically known.
156+
y = tf.zeros((10,))
157+
indices = tf.where_v2(tf.abs(y) > 1e-8)
158+
x = tf.gather_nd(y, indices)
159+
x = tf.cast(x, tf.int32) # Empty tensor.
160+
assert x.shape.as_list() == [None]
161+
stage = self.default_encoding_stage()
162+
encode_params, decode_params = stage.get_params()
163+
encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
164+
decode_params)
165+
166+
test_data = self.evaluate_test_data(
167+
test_utils.TestData(x, encoded_x, decoded_x))
168+
assert test_data.x.shape == (0,)
169+
assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,)
170+
assert test_data.decoded_x.shape == (0,)
171+
172+
@parameterized.parameters([tf.bool, tf.float32])
173+
def test_encode_unsupported_type_raises(self, dtype):
174+
stage = self.default_encoding_stage()
175+
with self.assertRaisesRegexp(TypeError, 'Unsupported input type'):
176+
self.run_one_to_many_encode_decode(
177+
stage, lambda: tf.cast(self.default_input(), dtype))
178+
179+
def test_encode_unsupported_input_shape_raises(self):
180+
x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32)
181+
stage = self.default_encoding_stage()
182+
params, _ = stage.get_params()
183+
with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'):
184+
stage.encode(x, params)
185+
186+
187+
if __name__ == '__main__':
188+
tf.test.main()

0 commit comments

Comments
 (0)