Skip to content

Commit 7123df2

Browse files
fredrectensorflower-gardener
authored andcommitted
Reuse test_utils in compression tests.
PiperOrigin-RevId: 370829522
1 parent 9381edc commit 7123df2

File tree

6 files changed

+18
-145
lines changed

6 files changed

+18
-145
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ py_strict_test(
2222
python_version = "PY3",
2323
deps = [
2424
":same_training_and_inference",
25-
# numpy dep1,
2625
# tensorflow dep1,
26+
"//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
2727
],
2828
)
2929

@@ -44,8 +44,8 @@ py_strict_test(
4444
python_version = "PY3",
4545
deps = [
4646
":different_training_and_inference",
47-
# numpy dep1,
4847
# tensorflow dep1,
48+
"//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
4949
],
5050
)
5151

@@ -66,8 +66,8 @@ py_strict_test(
6666
python_version = "PY3",
6767
deps = [
6868
":bias_only",
69-
# numpy dep1,
7069
# tensorflow dep1,
70+
"//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
7171
],
7272
)
7373

@@ -112,7 +112,7 @@ py_strict_test(
112112
python_version = "PY3",
113113
deps = [
114114
":periodical_update_and_scheduling",
115-
# numpy dep1,
116115
# tensorflow dep1,
116+
"//tensorflow_model_optimization/python/core/keras/testing:test_utils_mnist",
117117
],
118118
)

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/bias_only_test.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import os
1818
import tempfile
1919

20-
import numpy as np
2120
import tensorflow as tf
2221

2322
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import bias_only
23+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
2424

2525

2626
def _build_model():
@@ -81,38 +81,6 @@ def _convert_to_tflite(saved_model_dir):
8181
return tflite_file
8282

8383

84-
# TODO(tfmot): reuse test_utils_mnist.py.
85-
def _test_tflite(tflite_file):
86-
interpreter = tf.lite.Interpreter(model_path=tflite_file)
87-
interpreter.allocate_tensors()
88-
89-
input_index = interpreter.get_input_details()[0]['index']
90-
output_index = interpreter.get_output_details()[0]['index']
91-
92-
(_, _), (x_test, y_test) = _get_dataset()
93-
94-
# Testing the entire dataset is too slow. Verifying only 300 of 10k samples.
95-
x_test = x_test[0:300, :]
96-
y_test = y_test[0:300]
97-
98-
total_seen = 0
99-
num_correct = 0
100-
101-
for img, label in zip(x_test, y_test):
102-
batch_input_shape = (1, 28, 28)
103-
inp = img.reshape(batch_input_shape)
104-
inp = inp.astype(np.float32)
105-
total_seen += 1
106-
interpreter.set_tensor(input_index, inp)
107-
interpreter.invoke()
108-
predictions = interpreter.get_tensor(output_index)
109-
110-
if np.argmax(predictions) == label:
111-
num_correct += 1
112-
113-
return float(num_correct) / float(total_seen)
114-
115-
11684
def _get_directory_size_in_bytes(directory):
11785
total = 0
11886
try:
@@ -169,7 +137,7 @@ def testBiasOnly_HasReasonableAccuracy_TFLite(self):
169137
saved_model_dir = _save_as_saved_model(compressed_model)
170138
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
171139

172-
accuracy = _test_tflite(compressed_tflite_file)
140+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
173141

174142
self.assertGreater(accuracy, 0.60)
175143

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/different_training_and_inference_test.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import os
1717
import tempfile
1818

19-
import numpy as np
2019
import tensorflow as tf
2120

2221
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import different_training_and_inference as svd
22+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
2323

2424

2525
# TODO(tfmot): dedup.
@@ -81,38 +81,6 @@ def _convert_to_tflite(saved_model_dir):
8181
return tflite_file
8282

8383

84-
# TODO(tfmot): reuse test_utils_mnist.py.
85-
def _test_tflite(tflite_file):
86-
interpreter = tf.lite.Interpreter(model_path=tflite_file)
87-
interpreter.allocate_tensors()
88-
89-
input_index = interpreter.get_input_details()[0]['index']
90-
output_index = interpreter.get_output_details()[0]['index']
91-
92-
(_, _), (x_test, y_test) = _get_dataset()
93-
94-
# Testing the entire dataset is too slow. Verifying only 300 of 10k samples.
95-
x_test = x_test[0:300, :]
96-
y_test = y_test[0:300]
97-
98-
total_seen = 0
99-
num_correct = 0
100-
101-
for img, label in zip(x_test, y_test):
102-
batch_input_shape = (1, 28, 28)
103-
inp = img.reshape(batch_input_shape)
104-
inp = inp.astype(np.float32)
105-
total_seen += 1
106-
interpreter.set_tensor(input_index, inp)
107-
interpreter.invoke()
108-
predictions = interpreter.get_tensor(output_index)
109-
110-
if np.argmax(predictions) == label:
111-
num_correct += 1
112-
113-
return float(num_correct) / float(total_seen)
114-
115-
11684
def _get_directory_size_in_bytes(directory):
11785
total = 0
11886
try:
@@ -192,7 +160,7 @@ def testSVD_HasReasonableAccuracy_TFLite(self):
192160
saved_model_dir = _save_as_saved_model(model_for_inference)
193161
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
194162

195-
accuracy = _test_tflite(compressed_tflite_file)
163+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
196164

197165
self.assertGreater(accuracy, 0.60)
198166

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/periodical_update_and_scheduling_test.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import os
1818
import tempfile
1919

20-
import numpy as np
2120
import tensorflow as tf
2221

2322
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import periodical_update_and_scheduling as svd
23+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
2424

2525

2626
def _build_model():
@@ -81,38 +81,6 @@ def _convert_to_tflite(saved_model_dir):
8181
return tflite_file
8282

8383

84-
# TODO(tfmot): reuse test_utils_mnist.py.
85-
def _test_tflite(tflite_file):
86-
interpreter = tf.lite.Interpreter(model_path=tflite_file)
87-
interpreter.allocate_tensors()
88-
89-
input_index = interpreter.get_input_details()[0]['index']
90-
output_index = interpreter.get_output_details()[0]['index']
91-
92-
(_, _), (x_test, y_test) = _get_dataset()
93-
94-
# Testing the entire dataset is too slow. Verifying only 300 of 10k samples.
95-
x_test = x_test[0:300, :]
96-
y_test = y_test[0:300]
97-
98-
total_seen = 0
99-
num_correct = 0
100-
101-
for img, label in zip(x_test, y_test):
102-
batch_input_shape = (1, 28, 28)
103-
inp = img.reshape(batch_input_shape)
104-
inp = inp.astype(np.float32)
105-
total_seen += 1
106-
interpreter.set_tensor(input_index, inp)
107-
interpreter.invoke()
108-
predictions = interpreter.get_tensor(output_index)
109-
110-
if np.argmax(predictions) == label:
111-
num_correct += 1
112-
113-
return float(num_correct) / float(total_seen)
114-
115-
11684
def _get_directory_size_in_bytes(directory):
11785
total = 0
11886
try:
@@ -204,7 +172,7 @@ def testSVD_HasReasonableAccuracy_TFLite(self):
204172
saved_model_dir = _save_as_saved_model(compressed_model)
205173
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
206174

207-
accuracy = _test_tflite(compressed_tflite_file)
175+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
208176

209177
self.assertGreater(accuracy, 0.60)
210178

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/same_training_and_inference_test.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import os
1818
import tempfile
1919

20-
import numpy as np
2120
import tensorflow as tf
2221

2322
from tensorflow_model_optimization.python.core.common.keras.compression.algorithms import same_training_and_inference as svd
23+
from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist
2424

2525

2626
def _build_model():
@@ -81,38 +81,6 @@ def _convert_to_tflite(saved_model_dir):
8181
return tflite_file
8282

8383

84-
# TODO(tfmot): reuse test_utils_mnist.py.
85-
def _test_tflite(tflite_file):
86-
interpreter = tf.lite.Interpreter(model_path=tflite_file)
87-
interpreter.allocate_tensors()
88-
89-
input_index = interpreter.get_input_details()[0]['index']
90-
output_index = interpreter.get_output_details()[0]['index']
91-
92-
(_, _), (x_test, y_test) = _get_dataset()
93-
94-
# Testing the entire dataset is too slow. Verifying only 300 of 10k samples.
95-
x_test = x_test[0:300, :]
96-
y_test = y_test[0:300]
97-
98-
total_seen = 0
99-
num_correct = 0
100-
101-
for img, label in zip(x_test, y_test):
102-
batch_input_shape = (1, 28, 28)
103-
inp = img.reshape(batch_input_shape)
104-
inp = inp.astype(np.float32)
105-
total_seen += 1
106-
interpreter.set_tensor(input_index, inp)
107-
interpreter.invoke()
108-
predictions = interpreter.get_tensor(output_index)
109-
110-
if np.argmax(predictions) == label:
111-
num_correct += 1
112-
113-
return float(num_correct) / float(total_seen)
114-
115-
11684
def _get_directory_size_in_bytes(directory):
11785
total = 0
11886
try:
@@ -194,7 +162,7 @@ def testSVD_HasReasonableAccuracy_TFLite(self):
194162
saved_model_dir = _save_as_saved_model(compressed_model)
195163
compressed_tflite_file = _convert_to_tflite(saved_model_dir)
196164

197-
accuracy = _test_tflite(compressed_tflite_file)
165+
accuracy = test_utils_mnist.eval_tflite(compressed_tflite_file)
198166

199167
self.assertGreater(accuracy, 0.60)
200168

tensorflow_model_optimization/python/core/keras/testing/test_utils_mnist.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def layers_list():
2525
return [
2626
l.Conv2D(32, 5, padding='same', activation='relu',
27-
input_shape=input_shape()),
27+
input_shape=image_input_shape()),
2828
l.MaxPooling2D((2, 2), (2, 2), padding='same'),
2929
# TODO(pulkitb): Add BatchNorm when transformations are ready.
3030
# l.BatchNormalization(),
@@ -43,7 +43,7 @@ def sequential_model():
4343

4444
def functional_model():
4545
"""Builds an MNIST functional model."""
46-
inp = keras.Input(input_shape())
46+
inp = keras.Input(image_input_shape())
4747
x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
4848
x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
4949
# TODO(pulkitb): Add BatchNorm when transformations are ready.
@@ -58,7 +58,7 @@ def functional_model():
5858
return keras.models.Model([inp], [out])
5959

6060

61-
def input_shape(img_rows=28, img_cols=28):
61+
def image_input_shape(img_rows=28, img_cols=28):
6262
if tf.keras.backend.image_data_format() == 'channels_first':
6363
return 1, img_rows, img_cols
6464
else:
@@ -92,9 +92,11 @@ def preprocessed_data(img_rows=28,
9292

9393
def eval_tflite(model_path):
9494
"""Evaluate mnist in TFLite for accuracy."""
95+
9596
interpreter = tf.lite.Interpreter(model_path=model_path)
9697
interpreter.allocate_tensors()
9798
input_index = interpreter.get_input_details()[0]['index']
99+
input_shape = interpreter.get_input_details()[0]['shape']
98100
output_index = interpreter.get_output_details()[0]['index']
99101

100102
_, _, x_test, y_test = preprocessed_data()
@@ -106,8 +108,7 @@ def eval_tflite(model_path):
106108
num_correct = 0
107109

108110
for img, label in zip(x_test, y_test):
109-
batch_input_shape = (1,) + input_shape()
110-
inp = img.reshape(batch_input_shape)
111+
inp = img.reshape(input_shape)
111112
total_seen += 1
112113
interpreter.set_tensor(input_index, inp)
113114
interpreter.invoke()

0 commit comments

Comments
 (0)