@@ -90,9 +90,9 @@ def __init__(self):
90
90
91
91
self ._config_quantizer_map = {
92
92
'Default8BitQuantizeConfig' :
93
- PrunePerserveDefault8BitWeightsQuantizer (),
93
+ PrunePreserveDefault8BitWeightsQuantizer (),
94
94
'Default8BitConvQuantizeConfig' :
95
- PrunePerserveDefault8BitConvWeightsQuantizer (),
95
+ PrunePreserveDefault8BitConvWeightsQuantizer (),
96
96
}
97
97
98
98
@classmethod
@@ -224,7 +224,7 @@ def get_quantize_config(self, layer):
224
224
Returns:
225
225
Returns the quantization config with sparsity preserve weight_quantizer.
226
226
"""
227
- quantize_config = default_8bit_quantize_registry .QuantizeRegistry (
227
+ quantize_config = default_8bit_quantize_registry .Default8BitQuantizeRegistry (
228
228
).get_quantize_config (layer )
229
229
prune_aware_quantize_config = super (
230
230
Default8bitPrunePreserveQuantizeRegistry ,
@@ -233,10 +233,10 @@ def get_quantize_config(self, layer):
233
233
return prune_aware_quantize_config
234
234
235
235
236
- class PrunePerserveDefaultWeightsQuantizer (quantizers .LastValueQuantizer ):
236
+ class PrunePreserveDefaultWeightsQuantizer (quantizers .LastValueQuantizer ):
237
237
"""Quantize weights while preserve sparsity."""
238
238
def __init__ (self , num_bits , per_axis , symmetric , narrow_range ):
239
- """PrunePerserveDefaultWeightsQuantizer
239
+ """PrunePreserveDefaultWeightsQuantizer
240
240
241
241
Args:
242
242
num_bits: Number of bits for quantization
@@ -249,7 +249,7 @@ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
249
249
range has 0 as the centre.
250
250
"""
251
251
252
- super (PrunePerserveDefaultWeightsQuantizer , self ).__init__ (
252
+ super (PrunePreserveDefaultWeightsQuantizer , self ).__init__ (
253
253
num_bits = num_bits ,
254
254
per_axis = per_axis ,
255
255
symmetric = symmetric ,
@@ -276,7 +276,7 @@ def build(self, tensor_shape, name, layer):
276
276
"""
277
277
result = self ._build_sparsity_mask (name , layer )
278
278
result .update (
279
- super (PrunePerserveDefaultWeightsQuantizer ,
279
+ super (PrunePreserveDefaultWeightsQuantizer ,
280
280
self ).build (tensor_shape , name , layer ))
281
281
return result
282
282
@@ -308,27 +308,27 @@ def __call__(self, inputs, training, weights, **kwargs):
308
308
)
309
309
310
310
311
- class PrunePerserveDefault8BitWeightsQuantizer (
312
- PrunePerserveDefaultWeightsQuantizer ):
313
- """PrunePerserveWeightsQuantizer for default 8bit weights"""
311
+ class PrunePreserveDefault8BitWeightsQuantizer (
312
+ PrunePreserveDefaultWeightsQuantizer ):
313
+ """PrunePreserveWeightsQuantizer for default 8bit weights"""
314
314
def __init__ (self ):
315
- super (PrunePerserveDefault8BitWeightsQuantizer ,
315
+ super (PrunePreserveDefault8BitWeightsQuantizer ,
316
316
self ).__init__ (num_bits = 8 ,
317
317
per_axis = False ,
318
318
symmetric = True ,
319
319
narrow_range = True )
320
320
321
321
322
- class PrunePerserveDefault8BitConvWeightsQuantizer (
323
- PrunePerserveDefaultWeightsQuantizer ,
322
+ class PrunePreserveDefault8BitConvWeightsQuantizer (
323
+ PrunePreserveDefaultWeightsQuantizer ,
324
324
default_8bit_quantizers .Default8BitConvWeightsQuantizer ,
325
325
):
326
- """PrunePerserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights"""
326
+ """PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights"""
327
327
def __init__ (self ):
328
328
default_8bit_quantizers .Default8BitConvWeightsQuantizer .__init__ (self )
329
329
330
330
def build (self , tensor_shape , name , layer ):
331
- result = PrunePerserveDefaultWeightsQuantizer ._build_sparsity_mask (
331
+ result = PrunePreserveDefaultWeightsQuantizer ._build_sparsity_mask (
332
332
self , name , layer )
333
333
result .update (
334
334
default_8bit_quantizers .Default8BitConvWeightsQuantizer .build (
0 commit comments