@@ -72,6 +72,8 @@ def _get_rnn_cells(self, rnn_layer):
72
72
class QuantizeRegistry (quantize_registry .QuantizeRegistry , _RNNHelper ):
73
73
"""QuantizationRegistry for built-in Keras classes for default 8-bit scheme."""
74
74
75
+ # TODO(tfmot): expand layers test in quantize_functional_test.py
76
+ # to add more layers to whitelist.
75
77
_LAYER_QUANTIZE_INFO = [
76
78
77
79
# Activation Layers
@@ -84,16 +86,18 @@ class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
84
86
# layers.ThresholdedReLU,
85
87
86
88
# Convolution Layers
87
- _QuantizeInfo (layers .Conv1D , ['kernel' ], ['activation' ]),
88
- _QuantizeInfo (layers .Conv3D , ['kernel' ], ['activation' ]),
89
- # TODO(pulkitb): Verify Transpose layers.
90
- _QuantizeInfo (layers .Conv2DTranspose , ['kernel' ], ['activation' ]),
91
- _QuantizeInfo (layers .Conv3DTranspose , ['kernel' ], ['activation' ]),
89
+ # _QuantizeInfo(layers.Conv1D, ['kernel'], ['activation']),
90
+
91
+ # layers.Conv2D is supported and handled in code below.
92
+
93
+ # _QuantizeInfo(layers.Conv3D, ['kernel'], ['activation']),
94
+ # _QuantizeInfo(layers.Conv2DTranspose, ['kernel'], ['activation']),
95
+ # _QuantizeInfo(layers.Conv3DTranspose, ['kernel'], ['activation']),
92
96
_no_quantize (layers .Cropping1D ),
93
97
_no_quantize (layers .Cropping2D ),
94
98
_no_quantize (layers .Cropping3D ),
95
99
_no_quantize (layers .UpSampling1D ),
96
- _no_quantize (layers .UpSampling2D ),
100
+ # _no_quantize(layers.UpSampling2D),
97
101
_no_quantize (layers .UpSampling3D ),
98
102
_no_quantize (layers .ZeroPadding1D ),
99
103
_no_quantize (layers .ZeroPadding2D ),
@@ -107,7 +111,7 @@ class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
107
111
_QuantizeInfo (layers .Dense , ['kernel' ], ['activation' ]),
108
112
_no_quantize (layers .Dropout ),
109
113
_no_quantize (layers .Flatten ),
110
- _no_quantize (layers .Masking ),
114
+ # _no_quantize(layers.Masking),
111
115
_no_quantize (layers .Permute ),
112
116
_no_quantize (layers .RepeatVector ),
113
117
_no_quantize (layers .Reshape ),
@@ -119,7 +123,7 @@ class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
119
123
# Pooling Layers
120
124
_QuantizeInfo (layers .AveragePooling1D , [], [], True ),
121
125
_QuantizeInfo (layers .AveragePooling2D , [], [], True ),
122
- _QuantizeInfo (layers .AveragePooling3D , [], [], True ),
126
+ # _QuantizeInfo(layers.AveragePooling3D, [], [], True),
123
127
_QuantizeInfo (layers .GlobalAveragePooling1D , [], [], True ),
124
128
_QuantizeInfo (layers .GlobalAveragePooling2D , [], [], True ),
125
129
_QuantizeInfo (layers .GlobalAveragePooling3D , [], [], True ),
@@ -128,11 +132,10 @@ class QuantizeRegistry(quantize_registry.QuantizeRegistry, _RNNHelper):
128
132
_no_quantize (layers .GlobalMaxPooling3D ),
129
133
_no_quantize (layers .MaxPooling1D ),
130
134
_no_quantize (layers .MaxPooling2D ),
131
- _no_quantize (layers .MaxPooling3D ),
135
+ # _no_quantize(layers.MaxPooling3D),
132
136
133
- # TODO(pulkitb): Verify Locally Connected layers.
134
- _QuantizeInfo (layers .LocallyConnected1D , ['kernel' ], ['activation' ]),
135
- _QuantizeInfo (layers .LocallyConnected2D , ['kernel' ], ['activation' ]),
137
+ # _QuantizeInfo(layers.LocallyConnected1D, ['kernel'], ['activation']),
138
+ # _QuantizeInfo(layers.LocallyConnected2D, ['kernel'], ['activation']),
136
139
_QuantizeInfo (layers .Add , [], [], True ),
137
140
138
141
# Enable once verified with TFLite behavior.
0 commit comments