2525)
2626from torchao .quantization .pt2e .quantizer .quantizer import Q_ANNOTATION_KEY
2727
28+ from .observers .concat_observer import ConcatObserver
29+
2830from .qconfig import (
2931 get_16a16w_qnn_ptq_config ,
3032 get_16a4w_qnn_qat_config ,
@@ -691,7 +693,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:
691693
692694@register_annotator ([torch .ops .aten .slice .Tensor ])
693695def annotate_slice (node : Node , quantization_config : QuantizationConfig ) -> None :
694- annotate_single_in_single_out (node , quantization_config )
696+ annotate_single_in_share_out (node , quantization_config )
695697
696698
697699@register_annotator ([torch .ops .aten .slice_scatter .default ])
@@ -1277,31 +1279,40 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
12771279
12781280@register_annotator ([torch .ops .aten .cat .default , torch .ops .aten .concat .default ])
12791281def annotate_cat (node : Node , quantization_config : QuantizationConfig ) -> None :
1280- input_nodes = node .args [0 ]
12811282 if _is_annotated ([node ]) or not _is_float_tensor (node ):
12821283 return
12831284
1284- assert isinstance (input_nodes , Sequence )
1285-
1286- first_input_node = input_nodes [0 ]
1287- input_qspec_map = {}
1288- assert isinstance (first_input_node , Node )
1289- assert isinstance (node , Node )
1290- if _is_float_tensor (first_input_node ):
1291- input_qspec_map [first_input_node ] = quantization_config .input_activation
1292- share_qparams_with_input_act0_qspec = SharedQuantizationSpec (
1293- (first_input_node , node )
1294- )
1295-
1296- for input_node in input_nodes [1 :]:
1297- if input_node not in input_qspec_map :
1298- assert isinstance (input_node , Node )
1299- if _is_float_tensor (input_node ):
1300- input_qspec_map [input_node ] = share_qparams_with_input_act0_qspec
1301-
1285+ input_qspec_map , input_nodes = {}, node .args [0 ]
1286+ for input in input_nodes :
1287+ input_qspec = input .meta .get (Q_ANNOTATION_KEY , None )
1288+ if (
1289+ # placeholder
1290+ input_qspec is None
1291+ or
1292+ # keep shared qspec here for propagation the data range
1293+ # without introducing extra requantizations
1294+ not isinstance (input_qspec .output_qspec , SharedQuantizationSpec )
1295+ ):
1296+ input_qspec_map [input ] = quantization_config .input_activation
1297+
1298+ output_qspec = QuantizationSpec (
1299+ dtype = quantization_config .output_activation .dtype ,
1300+ qscheme = quantization_config .output_activation .qscheme ,
1301+ quant_max = quantization_config .output_activation .quant_max ,
1302+ quant_min = quantization_config .output_activation .quant_min ,
1303+ observer_or_fake_quant_ctr = ConcatObserver .with_args (
1304+ # we need to know the concat node in order to hack all the input observers' data range
1305+ # since deep copy of fake tensor (node.meta["val"]) is inhibited
1306+ # we could only ship grap & node name and perform postprocess inside observer currently
1307+ ** {
1308+ "node_name" : node .name ,
1309+ "graph" : node .graph ,
1310+ }
1311+ ),
1312+ )
13021313 node .meta [Q_ANNOTATION_KEY ] = QuantizationAnnotation (
13031314 input_qspec_map = input_qspec_map ,
1304- output_qspec = share_qparams_with_input_act0_qspec ,
1315+ output_qspec = output_qspec ,
13051316 _annotated = True ,
13061317 )
13071318
@@ -1345,6 +1356,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13451356 input_act = node .args [0 ]
13461357 assert isinstance (input_act , Node )
13471358 input_qspec_map [input_act ] = quantization_config .input_activation
1359+ share_qparams_with_input_node_qspec = SharedQuantizationSpec ((input_act , node ))
13481360
13491361 node .meta [Q_ANNOTATION_KEY ] = QuantizationAnnotation (
13501362 input_qspec_map = input_qspec_map ,
@@ -1353,7 +1365,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13531365
13541366 for user in node .users :
13551367 user .meta [Q_ANNOTATION_KEY ] = QuantizationAnnotation (
1356- output_qspec = quantization_config . output_activation ,
1368+ output_qspec = share_qparams_with_input_node_qspec ,
13571369 _annotated = True ,
13581370 )
13591371
0 commit comments