Skip to content

Commit a05a2cc

Browse files
Xharktensorflower-gardener
authored andcommitted
Move _prevent_constant_folding applied location from after algorithm function to before.
This change makes ReducesTFLiteModelSize tests passed. PiperOrigin-RevId: 338386984
1 parent 2e22094 commit a05a2cc

File tree

4 files changed

+17
-28
lines changed

4 files changed

+17
-28
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ py_library(
1414

1515
py_test(
1616
name = "same_training_and_inference_test",
17+
timeout = "long",
1718
srcs = ["same_training_and_inference_test.py"],
1819
python_version = "PY3",
1920
deps = [
@@ -35,6 +36,7 @@ py_library(
3536

3637
py_test(
3738
name = "different_training_and_inference_test",
39+
timeout = "long",
3840
srcs = ["different_training_and_inference_test.py"],
3941
python_version = "PY3",
4042
deps = [

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,7 @@ def testSVD_HasReasonableAccuracy_TF(self):
170170

171171
self.assertGreater(results[1], 0.60)
172172

173-
# TODO(tfmot): currently fails - didn't hook up constant
174-
# folding prevention correctly.
175-
def testSVD_ReducesTFLiteModelSize_Fails(self):
176-
return
177-
178-
# pylint: disable=unreachable
173+
def testSVD_ReducesTFLiteModelSize(self):
179174
model = _build_model()
180175

181176
original_saved_model_dir = _save_as_saved_model(model)
@@ -192,7 +187,6 @@ def testSVD_ReducesTFLiteModelSize_Fails(self):
192187
compressed_size = os.path.getsize(compressed_tflite_file)
193188

194189
self.assertLess(compressed_size, original_size / 6)
195-
# pylint: enable=unreachable
196190

197191
def testSVD_HasReasonableAccuracy_TFLite(self):
198192
model = _build_model()

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,7 @@ def testSVD_HasReasonableAccuracy_TF(self):
170170

171171
self.assertGreater(results[1], 0.60)
172172

173-
# TODO(tfmot): currently fails - didn't hook up constant
174-
# folding prevention correctly.
175-
def testSVD_ReducesTFLiteModelSize_Fails(self):
176-
return
177-
178-
# pylint: disable=unreachable
173+
def testSVD_ReducesTFLiteModelSize(self):
179174
model = _build_model()
180175

181176
original_saved_model_dir = _save_as_saved_model(model)
@@ -191,7 +186,6 @@ def testSVD_ReducesTFLiteModelSize_Fails(self):
191186
compressed_size = os.path.getsize(compressed_tflite_file)
192187

193188
self.assertLess(compressed_size, original_size / 6)
194-
# pylint: enable=unreachable
195189

196190
def testSVD_HasReasonableAccuracy_TFLite(self):
197191
model = _build_model()

tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ def build(self, input_shape):
112112

113113
def call(self, inputs):
114114
for attr_name in self.compressible_weights:
115-
training_weight_tensors = [
116-
v.read_value() for v in self.training_weights[attr_name]
117-
]
115+
# TODO(tfmot): move constant folding prevention to the inference graph
116+
# only, since constant folding won't happen during training.
117+
training_weight_tensors = []
118+
for v in self.training_weights[attr_name]:
119+
training_weight_tensors.append(
120+
_prevent_constant_folding(v.read_value(), inputs))
121+
118122
weight_tensor = self.algorithm.training(training_weight_tensors)
119-
# TODO(tfmot): move this to the inference graph only, since
120-
# constant folding won't happen during training.
121-
non_const_foldable_weight_tensor = _prevent_constant_folding(
122-
weight_tensor, inputs)
123-
setattr(self.layer, attr_name, non_const_foldable_weight_tensor)
123+
setattr(self.layer, attr_name, weight_tensor)
124124

125125
# This assumes that all changes to the forward pass happen "prior" to
126126
# the nested layer's portion of the forward pass. This suffices since
@@ -198,13 +198,12 @@ def call(self, inputs, training=None):
198198
for attr_name in self.training_tensors:
199199
# TODO(tfmot): understand how read_value() is converted to
200200
# inference in TensorFlow Lite.
201-
compressed_weight_tensors = [
202-
v.read_value() for v in self.compressed_weights[attr_name]
203-
]
201+
compressed_weight_tensors = []
202+
for v in self.compressed_weights[attr_name]:
203+
compressed_weight_tensors.append(
204+
_prevent_constant_folding(v.read_value(), inputs))
204205
weight_tensor = self.algorithm.decompress(*compressed_weight_tensors)
205-
non_const_foldable_weight_tensor = _prevent_constant_folding(
206-
weight_tensor, inputs)
207-
setattr(self.layer, attr_name, non_const_foldable_weight_tensor)
206+
setattr(self.layer, attr_name, weight_tensor)
208207

209208
# TODO(tfmot): handle training arg if needed given this is inference only.
210209
return self.layer.call(inputs)

0 commit comments

Comments
 (0)