@@ -241,8 +241,7 @@ def get_ptq_per_channel_quant_config(
241241 torch .int8 ,
242242 torch .int16 ,
243243 }
244- # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
245- supported_weight_dtypes = {"int4" , torch .int8 , torch .int16 }
244+ supported_weight_dtypes = {torch .int4 , torch .int8 , torch .int16 }
246245 assert (
247246 act_dtype in supported_act_types
248247 ), f"act_dtype, { act_dtype } is not one of supported types, { supported_act_types } "
@@ -276,9 +275,11 @@ def get_ptq_per_channel_quant_config(
276275 )
277276
278277 weight_quantization_spec = QuantizationSpec (
279- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
280- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
281- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
278+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
279+ quant_min = (
280+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
281+ ),
282+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
282283 qscheme = torch .per_channel_symmetric ,
283284 ch_axis = 0 ,
284285 observer_or_fake_quant_ctr = PerChannelMinMaxObserver .with_args (** extra_args ),
@@ -310,9 +311,11 @@ def get_ptq_per_block_quant_config(
310311 act_symmetric = act_symmetric ,
311312 )
312313 weight_quantization_spec = QuantizationSpec (
313- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
314- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
315- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
314+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
315+ quant_min = (
316+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
317+ ),
318+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
316319 qscheme = torch .per_channel_symmetric ,
317320 ch_axis = 0 ,
318321 observer_or_fake_quant_ctr = PerBlockParamObserver .with_args (** extra_args ),
@@ -463,8 +466,7 @@ def get_qat_per_channel_quant_config(
463466 torch .int8 ,
464467 torch .int16 ,
465468 }
466- # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
467- supported_weight_dtypes = {"int4" , torch .int8 , torch .int16 }
469+ supported_weight_dtypes = {torch .int4 , torch .int8 , torch .int16 }
468470 assert (
469471 act_dtype in supported_act_types
470472 ), f"act_dtype, { act_dtype } is not one of supported types, { supported_act_types } "
@@ -491,17 +493,21 @@ def get_qat_per_channel_quant_config(
491493 )
492494
493495 weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize .with_args (
494- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
495- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
496- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
496+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
497+ quant_min = (
498+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
499+ ),
500+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
497501 qscheme = torch .per_channel_symmetric ,
498502 ch_axis = 0 ,
499503 observer = MovingAveragePerChannelMinMaxObserver ,
500504 )
501505 weight_quantization_spec = QuantizationSpec (
502- dtype = torch .int8 if weight_dtype == "int4" else weight_dtype ,
503- quant_min = - 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).min + 1 ,
504- quant_max = 7 if weight_dtype == "int4" else torch .iinfo (weight_dtype ).max ,
506+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
507+ quant_min = (
508+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
509+ ),
510+ quant_max = 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ,
505511 qscheme = torch .per_channel_symmetric ,
506512 ch_axis = 0 ,
507513 observer_or_fake_quant_ctr = weight_fake_quant_ctr ,
0 commit comments