@@ -90,9 +90,9 @@ def __init__(self):
9090
9191 self ._config_quantizer_map = {
9292 'Default8BitQuantizeConfig' :
93- PrunePerserveDefault8BitWeightsQuantizer (),
93+ PrunePreserveDefault8BitWeightsQuantizer (),
9494 'Default8BitConvQuantizeConfig' :
95- PrunePerserveDefault8BitConvWeightsQuantizer (),
95+ PrunePreserveDefault8BitConvWeightsQuantizer (),
9696 }
9797
9898 @classmethod
@@ -224,7 +224,7 @@ def get_quantize_config(self, layer):
224224 Returns:
225225 Returns the quantization config with sparsity preserve weight_quantizer.
226226 """
227- quantize_config = default_8bit_quantize_registry .QuantizeRegistry (
227+ quantize_config = default_8bit_quantize_registry .Default8BitQuantizeRegistry (
228228 ).get_quantize_config (layer )
229229 prune_aware_quantize_config = super (
230230 Default8bitPrunePreserveQuantizeRegistry ,
@@ -233,10 +233,10 @@ def get_quantize_config(self, layer):
233233 return prune_aware_quantize_config
234234
235235
236- class PrunePerserveDefaultWeightsQuantizer (quantizers .LastValueQuantizer ):
236+ class PrunePreserveDefaultWeightsQuantizer (quantizers .LastValueQuantizer ):
237237 """Quantize weights while preserve sparsity."""
238238 def __init__ (self , num_bits , per_axis , symmetric , narrow_range ):
239- """PrunePerserveDefaultWeightsQuantizer
239+ """PrunePreserveDefaultWeightsQuantizer
240240
241241 Args:
242242 num_bits: Number of bits for quantization
@@ -249,7 +249,7 @@ def __init__(self, num_bits, per_axis, symmetric, narrow_range):
249249 range has 0 as the centre.
250250 """
251251
252- super (PrunePerserveDefaultWeightsQuantizer , self ).__init__ (
252+ super (PrunePreserveDefaultWeightsQuantizer , self ).__init__ (
253253 num_bits = num_bits ,
254254 per_axis = per_axis ,
255255 symmetric = symmetric ,
@@ -276,7 +276,7 @@ def build(self, tensor_shape, name, layer):
276276 """
277277 result = self ._build_sparsity_mask (name , layer )
278278 result .update (
279- super (PrunePerserveDefaultWeightsQuantizer ,
279+ super (PrunePreserveDefaultWeightsQuantizer ,
280280 self ).build (tensor_shape , name , layer ))
281281 return result
282282
@@ -308,27 +308,27 @@ def __call__(self, inputs, training, weights, **kwargs):
308308 )
309309
310310
311- class PrunePerserveDefault8BitWeightsQuantizer (
312- PrunePerserveDefaultWeightsQuantizer ):
313- """PrunePerserveWeightsQuantizer for default 8bit weights"""
311+ class PrunePreserveDefault8BitWeightsQuantizer (
312+ PrunePreserveDefaultWeightsQuantizer ):
313+ """PrunePreserveWeightsQuantizer for default 8bit weights"""
314314 def __init__ (self ):
315- super (PrunePerserveDefault8BitWeightsQuantizer ,
315+ super (PrunePreserveDefault8BitWeightsQuantizer ,
316316 self ).__init__ (num_bits = 8 ,
317317 per_axis = False ,
318318 symmetric = True ,
319319 narrow_range = True )
320320
321321
322- class PrunePerserveDefault8BitConvWeightsQuantizer (
323- PrunePerserveDefaultWeightsQuantizer ,
322+ class PrunePreserveDefault8BitConvWeightsQuantizer (
323+ PrunePreserveDefaultWeightsQuantizer ,
324324 default_8bit_quantizers .Default8BitConvWeightsQuantizer ,
325325):
326- """PrunePerserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights"""
326+ """PrunePreserveWeightsQuantizer for default 8bit Conv2D/DepthwiseConv2D weights"""
327327 def __init__ (self ):
328328 default_8bit_quantizers .Default8BitConvWeightsQuantizer .__init__ (self )
329329
330330 def build (self , tensor_shape , name , layer ):
331- result = PrunePerserveDefaultWeightsQuantizer ._build_sparsity_mask (
331+ result = PrunePreserveDefaultWeightsQuantizer ._build_sparsity_mask (
332332 self , name , layer )
333333 result .update (
334334 default_8bit_quantizers .Default8BitConvWeightsQuantizer .build (
0 commit comments