File tree Expand file tree Collapse file tree 6 files changed +8
-15
lines changed
tensorflow_model_optimization/python/core/quantization/keras Expand file tree Collapse file tree 6 files changed +8
-15
lines changed Original file line number Diff line number Diff line change 18
18
from __future__ import division
19
19
from __future__ import print_function
20
20
21
- from absl .testing import parameterized
22
21
import tensorflow as tf
23
22
24
23
# TODO(b/139939526): move to public API.
25
- from tensorflow .python .keras import keras_parameterized
26
24
from tensorflow_model_optimization .python .core .keras import compat
27
25
from tensorflow_model_optimization .python .core .quantization .keras import quant_ops
28
26
29
27
_SYMMETRIC_RANGE_RATIO = 0.9921875 # 127 / 128
30
28
31
29
32
- @keras_parameterized .run_all_keras_modes
33
- class QuantOpsTest (tf .test .TestCase , parameterized .TestCase ):
30
+ class QuantOpsTest (tf .test .TestCase ):
34
31
35
32
def testAllValuesQuantiize_TrainingAssign (self ):
36
33
min_value , max_value = self ._GetMinMaxValues (
Original file line number Diff line number Diff line change 23
23
import numpy as np
24
24
import tensorflow as tf
25
25
26
- from tensorflow .python .keras import keras_parameterized
27
26
from tensorflow_model_optimization .python .core .quantization .keras import quantize_aware_activation
28
27
from tensorflow_model_optimization .python .core .quantization .keras import quantizers
29
28
37
36
MovingAverageQuantizer = quantizers .MovingAverageQuantizer
38
37
39
38
40
- @keras_parameterized .run_all_keras_modes
39
+ @tf .__internal__ .distribute .combinations .generate (
40
+ tf .__internal__ .test .combinations .combine (mode = ['graph' , 'eager' ]))
41
41
class QuantizeAwareQuantizationTest (tf .test .TestCase , parameterized .TestCase ):
42
42
43
43
def setUp (self ):
Original file line number Diff line number Diff line change 25
25
import tensorflow as tf
26
26
27
27
# TODO(b/139939526): move to public API.
28
- from tensorflow .python .keras import keras_parameterized
29
28
from tensorflow_model_optimization .python .core .keras import compat
30
29
from tensorflow_model_optimization .python .core .keras .testing import test_utils_mnist
31
30
from tensorflow_model_optimization .python .core .quantization .keras import quantize
34
33
layers = tf .keras .layers
35
34
36
35
37
- @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
38
- class QuantizeFunctionalTest (tf .test .TestCase , parameterized .TestCase ):
36
+ @tf .__internal__ .distribute .combinations .generate (
37
+ tf .__internal__ .test .combinations .combine (mode = ['graph' , 'eager' ]))
38
+ class QuantizeFunctionalTest (tf .test .TestCase ):
39
39
40
40
# TODO(pulkitb): Parameterize test and include functional mnist, and
41
41
# other RNN models.
Original file line number Diff line number Diff line change 26
26
import tensorflow as tf
27
27
28
28
# TODO(b/139939526): move to public API.
29
- from tensorflow .python .keras import keras_parameterized
30
29
31
30
from tensorflow_model_optimization .python .core .keras import compat
32
31
from tensorflow_model_optimization .python .core .keras import test_utils
44
43
# TODO(tfmot): enable for v1. Currently fails because the decorator
45
44
# on graph mode wraps everything in a graph, which is not compatible
46
45
# with the TFLite converter's call to clear_session().
47
- @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
46
+ @tf .__internal__ .distribute .combinations .generate (
47
+ tf .__internal__ .test .combinations .combine (mode = ['graph' , 'eager' ]))
48
48
class QuantizeIntegrationTest (tf .test .TestCase , parameterized .TestCase ):
49
49
50
50
def _batch (self , dims , batch_size ):
Original file line number Diff line number Diff line change 26
26
import numpy as np
27
27
import tensorflow as tf
28
28
29
- from tensorflow .python .keras import keras_parameterized
30
29
from tensorflow_model_optimization .python .core .quantization .keras import quantize
31
30
from tensorflow_model_optimization .python .core .quantization .keras import utils
32
31
33
32
34
- @keras_parameterized .run_all_keras_modes (always_skip_v1 = True )
35
33
class QuantizeModelsTest (tf .test .TestCase , parameterized .TestCase ):
36
34
37
35
# Derived using
Original file line number Diff line number Diff line change 23
23
import numpy as np
24
24
import tensorflow as tf
25
25
26
- from tensorflow .python .keras import keras_parameterized
27
26
from tensorflow_model_optimization .python .core .keras import compat
28
27
from tensorflow_model_optimization .python .core .quantization .keras import quantizers
29
28
30
29
deserialize_keras_object = tf .keras .utils .deserialize_keras_object
31
30
serialize_keras_object = tf .keras .utils .serialize_keras_object
32
31
33
32
34
- @keras_parameterized .run_all_keras_modes
35
33
@parameterized .parameters (
36
34
quantizers .LastValueQuantizer ,
37
35
quantizers .MovingAverageQuantizer ,
You can’t perform that action at this time.
0 commit comments