@@ -187,6 +187,65 @@ def get_16a8w_qnn_ptq_config(
187187 return quantization_config
188188
189189
190+ def get_16a8w_qnn_qat_config (
191+ act_observer = MovingAverageMinMaxObserver ,
192+ ) -> QuantizationConfig :
193+ extra_args : Dict [str , Any ] = {"eps" : 2 ** - 20 }
194+ act_fake_quant_ctr = FakeQuantize .with_args (
195+ dtype = torch .int32 ,
196+ quant_min = torch .iinfo (torch .uint16 ).min ,
197+ quant_max = torch .iinfo (torch .uint16 ).max ,
198+ qscheme = torch .per_tensor_affine ,
199+ reduce_range = True ,
200+ observer = act_observer .with_args (** extra_args ),
201+ )
202+ act_quantization_spec = QuantizationSpec (
203+ dtype = torch .int32 ,
204+ quant_min = torch .iinfo (torch .uint16 ).min ,
205+ quant_max = torch .iinfo (torch .uint16 ).max ,
206+ qscheme = torch .per_tensor_affine ,
207+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
208+ )
209+ weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
210+ dtype = torch .int8 ,
211+ quant_min = torch .iinfo (torch .int8 ).min + 1 ,
212+ quant_max = torch .iinfo (torch .int8 ).max ,
213+ qscheme = torch .per_tensor_symmetric ,
214+ reduce_range = True ,
215+ observer = MovingAverageMinMaxObserver ,
216+ )
217+ weight_quantization_spec = QuantizationSpec (
218+ dtype = torch .int8 ,
219+ quant_min = torch .iinfo (torch .int8 ).min + 1 ,
220+ quant_max = torch .iinfo (torch .int8 ).max ,
221+ qscheme = torch .per_tensor_symmetric ,
222+ ch_axis = 0 ,
223+ observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
224+ )
225+ bias_fake_quant_ctr = FakeQuantize .with_args (
226+ dtype = torch .int32 ,
227+ quant_min = torch .iinfo (torch .int32 ).min ,
228+ quant_max = torch .iinfo (torch .int32 ).max ,
229+ qscheme = torch .per_tensor_symmetric ,
230+ observer = MovingAverageMinMaxObserver ,
231+ )
232+ bias_quantization_spec = QuantizationSpec (
233+ dtype = torch .int32 ,
234+ quant_min = torch .iinfo (torch .int32 ).min ,
235+ quant_max = torch .iinfo (torch .int32 ).max ,
236+ qscheme = torch .per_tensor_symmetric ,
237+ observer_or_fake_quant_ctr = bias_fake_quant_ctr ,
238+ )
239+ quantization_config = QuantizationConfig (
240+ input_activation = act_quantization_spec ,
241+ output_activation = act_quantization_spec ,
242+ weight = weight_quantization_spec ,
243+ bias = bias_quantization_spec ,
244+ )
245+
246+ return quantization_config
247+
248+
190249def get_16a16w_qnn_ptq_config (
191250 act_observer = MovingAverageMinMaxObserver ,
192251) -> QuantizationConfig :
@@ -459,6 +518,7 @@ def get_qat_per_channel_quant_config(
459518 act_dtype = torch .uint8 ,
460519 weight_dtype = torch .int8 ,
461520 act_observer = MovingAverageMinMaxObserver ,
521+ act_symmetric = False ,
462522) -> QuantizationConfig :
463523 supported_act_types = {
464524 torch .uint8 ,
@@ -476,21 +536,38 @@ def get_qat_per_channel_quant_config(
476536 ), f"weight_dtype, { weight_dtype } is not one of supported types, { supported_weight_dtypes } "
477537
478538 # torch does not support uint16 quantization, use int32 to bypass
479- act_fake_quant_ctr = FakeQuantize .with_args (
480- dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
481- quant_min = torch .iinfo (act_dtype ).min ,
482- quant_max = torch .iinfo (act_dtype ).max ,
483- qscheme = torch .per_tensor_affine ,
484- reduce_range = True ,
485- observer = act_observer ,
486- )
487- act_quantization_spec = QuantizationSpec (
488- dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
489- quant_min = torch .iinfo (act_dtype ).min ,
490- quant_max = torch .iinfo (act_dtype ).max ,
491- qscheme = torch .per_tensor_affine ,
492- observer_or_fake_quant_ctr = act_fake_quant_ctr ,
493- )
539+ if act_symmetric :
540+ # If zero_point is 128, htp can do optimizations.
541+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
542+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
543+ act_fake_quant_ctr = FakeQuantize .with_args (
544+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
545+ qscheme = torch .per_tensor_symmetric ,
546+ reduce_range = True ,
547+ observer = act_observer ,
548+ )
549+ act_quantization_spec = QuantizationSpec (
550+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
551+ qscheme = torch .per_tensor_symmetric ,
552+ ch_axis = 0 ,
553+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
554+ )
555+ else :
556+ act_fake_quant_ctr = FakeQuantize .with_args (
557+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
558+ quant_min = torch .iinfo (act_dtype ).min ,
559+ quant_max = torch .iinfo (act_dtype ).max ,
560+ qscheme = torch .per_tensor_affine ,
561+ reduce_range = True ,
562+ observer = act_observer ,
563+ )
564+ act_quantization_spec = QuantizationSpec (
565+ dtype = torch .int32 if act_dtype == torch .uint16 else act_dtype ,
566+ quant_min = torch .iinfo (act_dtype ).min ,
567+ quant_max = torch .iinfo (act_dtype ).max ,
568+ qscheme = torch .per_tensor_affine ,
569+ observer_or_fake_quant_ctr = act_fake_quant_ctr ,
570+ )
494571
495572 weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
496573 dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
@@ -513,7 +590,21 @@ def get_qat_per_channel_quant_config(
513590 observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
514591 )
515592
516- bias_quantization_spec = _derived_bias_quant_spec
593+ bias_fake_quant_ctr = FakeQuantize .with_args (
594+ dtype = torch .int32 ,
595+ quant_min = torch .iinfo (torch .int32 ).min ,
596+ quant_max = torch .iinfo (torch .int32 ).max ,
597+ qscheme = torch .per_tensor_symmetric ,
598+ reduce_range = True ,
599+ observer = MovingAverageMinMaxObserver ,
600+ )
601+ bias_quantization_spec = QuantizationSpec (
602+ dtype = torch .int32 ,
603+ quant_min = torch .iinfo (torch .int32 ).min ,
604+ quant_max = torch .iinfo (torch .int32 ).max ,
605+ qscheme = torch .per_tensor_symmetric ,
606+ observer_or_fake_quant_ctr = bias_fake_quant_ctr ,
607+ )
517608
518609 quantization_config = QuantizationConfig (
519610 input_activation = act_quantization_spec ,
0 commit comments