Skip to content

Commit 20b2b38

Browse files
jaeyootensorflower-gardener
authored andcommitted
Add scheduler into tfmot compression API
PiperOrigin-RevId: 341812516
1 parent c05ce9e commit 20b2b38

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,26 @@ py_library(
1212
# tensorflow dep1,
1313
],
1414
)
15+
16+
py_library(
17+
name = "schedules",
18+
srcs = ["schedules.py"],
19+
srcs_version = "PY3ONLY",
20+
deps = [
21+
# tensorflow dep1,
22+
],
23+
)
24+
25+
py_test(
26+
name = "schedules_test",
27+
srcs = [
28+
"schedules_test.py",
29+
],
30+
python_version = "PY3",
31+
deps = [
32+
":schedules",
33+
# absl/testing:parameterized dep1,
34+
# numpy dep1,
35+
# tensorflow dep1,
36+
],
37+
)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Compression Scheduler for tfmot compression."""
16+
import abc
17+
from typing import Union, Optional
18+
19+
import tensorflow as tf
20+
21+
22+
class Scheduler(metaclass=abc.ABCMeta):
23+
"""Abstract Scheduler."""
24+
25+
@abc.abstractmethod
26+
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
27+
"""Scheduler function given tf.Tensor step number.
28+
29+
Args:
30+
step: tf.Tensor with tf.int32 or tf.int64 representing the current step
31+
number of training loops.
32+
33+
Returns:
34+
Any tf.Tensor Scheduled value of given `step`
35+
"""
36+
raise NotImplementedError()
37+
38+
39+
class PolynomialDecay(Scheduler):
40+
"""Scheduling based on polynomial equation.
41+
42+
s(t) = start_value for t < begin_step
43+
44+
= end_value + [(start_value - end_value) * (1 - decay_term) ** exponent]
45+
46+
where decay_term = (t - begin_step) / decay_steps
47+
48+
for 0 <= 1 - decay_term <= 1
49+
<-> 0 <= decay_term <= 1
50+
<-> 0 <= (t - begin_step) / decay_steps <= 1
51+
<-> 0 <= (t - begin_step) <= decay_steps
52+
<-> begin_step <= t <= begin_step + decay_steps (=end_step)
53+
54+
= end_value for t > begin_step + decay_steps (=end_step)
55+
"""
56+
57+
def __init__(self,
58+
start_value: Union[int, float],
59+
decay_steps: int,
60+
end_value: Union[int, float],
61+
begin_step: Optional[int] = 0,
62+
exponent: Optional[float] = 1.0,
63+
dtype: Optional[tf.dtypes.DType] = tf.float32,
64+
name: Optional[str] = None):
65+
"""Initialize PolynomialDecayScheduler.
66+
67+
Args:
68+
start_value: the initial value of decaying. It is also the default value
69+
of this scheduler for step <= begin_step.
70+
decay_steps: A Python positive int value for duration of decaying.
71+
end_value: the final value of decaying. It is also the default value of
72+
this scheduler for step >= end_step = begin_step + decay_steps
73+
begin_step: The step value that this scheduler starts decaying.
74+
Defaults to 0, which means it decays right after training starts.
75+
exponent: The exponent of the polynomial decaying.
76+
Defaults to 1.0, a linear function.
77+
dtype: `tf.dtypes.DType`, dtype of returned tensor.
78+
Defaults to tf.float32.
79+
name: A Python `str` for the name scope of this scheduler.
80+
81+
Returns:
82+
A `tf.Tensor` of the scheduled output value calculated from the polynomial
83+
equation as given above.
84+
"""
85+
self.name = name
86+
self.start_value = start_value
87+
self.begin_step = begin_step
88+
self.end_value = end_value
89+
self.decay_steps = decay_steps
90+
self.end_step = self.begin_step + self.decay_steps
91+
self.exponent = exponent
92+
self.dtype = dtype
93+
94+
def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
95+
96+
with tf.name_scope(self.name or "PolynomialDecay"):
97+
val = tf.cond(tf.math.less(step, self.begin_step),
98+
lambda: tf.cast(self.start_value, dtype=self.dtype),
99+
lambda: self._after_begin_step(step), name="start")
100+
return val
101+
102+
def _after_begin_step(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
103+
104+
with tf.name_scope(self.name or "PolynomialDecay"):
105+
val = tf.cond(tf.math.greater(step, self.end_step),
106+
lambda: tf.cast(self.end_value, dtype=self.dtype),
107+
lambda: self._during_decay(step), name="end")
108+
return val
109+
110+
def _during_decay(self, step: Union[int, tf.Tensor]) -> tf.Tensor:
111+
"""Return decayed scheduled value."""
112+
113+
with tf.name_scope(self.name or "PolynomialDecay"):
114+
local_steps = tf.cast(step - self.begin_step, dtype=tf.float32)
115+
decay_term = tf.math.divide(local_steps,
116+
tf.cast(self.decay_steps, dtype=tf.float32))
117+
total_delta = tf.cast(self.start_value - self.end_value, dtype=tf.float32)
118+
target = tf.math.add(self.end_value, tf.cast(
119+
tf.math.multiply(total_delta, tf.pow(1 - decay_term, self.exponent)),
120+
dtype=self.dtype))
121+
val = tf.stop_gradient(target)
122+
return val
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for schedulers."""
16+
17+
import tensorflow as tf
18+
19+
from tensorflow_model_optimization.python.core.common.keras.compression import schedules
20+
21+
22+
class SimpleScheduler(schedules.Scheduler):
23+
24+
def __call__(self, step: int) -> float:
25+
return 0.1 if step >= 1000 else 0.6
26+
27+
28+
class SimpleSchedulerTest(tf.test.TestCase):
29+
30+
def testSimpleScheduler(self):
31+
scheduler = SimpleScheduler()
32+
expected = [0.6, 0.6, 0.1, 0.1]
33+
output = [scheduler(i) for i in [0, 100, 1000, 2000]]
34+
self.assertAllEqual(output, expected)
35+
36+
37+
class CubicPolynomialDecayTest(tf.test.TestCase):
38+
39+
def testBeforeDecaying(self):
40+
init_value = 0.1
41+
final_value = 1.0
42+
begin_step = 10
43+
decaying_step = 10
44+
total_training_step = begin_step
45+
scheduler = schedules.PolynomialDecay(init_value, decaying_step,
46+
final_value, begin_step=begin_step,
47+
exponent=3)
48+
output = [scheduler(i) for i in range(total_training_step)]
49+
expected = [init_value] * begin_step
50+
self.assertAllClose(output, expected)
51+
52+
def testDecaying(self):
53+
init_value = 0.1
54+
final_value = 1.0
55+
begin_step = 10
56+
decaying_step = 10
57+
exponent = 3
58+
scheduler = schedules.PolynomialDecay(init_value, decaying_step,
59+
final_value, begin_step=begin_step,
60+
exponent=exponent)
61+
expected = [final_value + (init_value - final_value) * \
62+
(1-float(i)/decaying_step) ** exponent
63+
for i in range(decaying_step)]
64+
output = [scheduler(begin_step + i) for i in range(decaying_step)]
65+
self.assertAllClose(output, expected)
66+
67+
def testBeyondEnd(self):
68+
init_value = 0.1
69+
final_value = 1.0
70+
begin_step = 10
71+
decaying_step = 10
72+
total_steps = 30
73+
beyond_end_steps = total_steps - decaying_step - begin_step
74+
scheduler = schedules.PolynomialDecay(init_value, decaying_step,
75+
final_value, begin_step=begin_step,
76+
exponent=3)
77+
expected = [final_value] * beyond_end_steps
78+
output = [scheduler(begin_step + decaying_step + i)
79+
for i in range(beyond_end_steps)]
80+
self.assertAllClose(output, expected)
81+
82+
83+
class LinearPolynomialDecayTest(tf.test.TestCase):
84+
85+
def testHalfWay(self):
86+
step = 5
87+
lr = 0.05
88+
end_lr = 0.0
89+
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
90+
expected = lr * 0.5
91+
self.assertAllClose(decayed_lr(step), expected, 1e-6)
92+
93+
def testEnd(self):
94+
step = 10
95+
lr = 0.05
96+
end_lr = 0.001
97+
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
98+
expected = end_lr
99+
self.assertAllClose(decayed_lr(step), expected, 1e-6)
100+
101+
def testHalfWayWithEnd(self):
102+
step = 5
103+
lr = 0.05
104+
end_lr = 0.001
105+
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
106+
expected = (lr + end_lr) * 0.5
107+
self.assertAllClose(decayed_lr(step), expected, 1e-6)
108+
109+
def testBeyondEnd(self):
110+
step = 15
111+
lr = 0.05
112+
end_lr = 0.001
113+
decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)
114+
expected = end_lr
115+
self.assertAllClose(decayed_lr(step), expected, 1e-6)
116+
117+
if __name__ == '__main__':
118+
tf.test.main()

0 commit comments

Comments
 (0)