@@ -401,7 +401,6 @@ def _register_quantized_linear_lowering(
401401 pattern ,
402402 pass_number ,
403403 computation_op ,
404- unary_attr ,
405404):
406405 @register_lowering_pattern (
407406 pattern ,
@@ -427,11 +426,13 @@ def qlinear(match: Match, *args, **kwargs):
427426 b = kwargs ["b" ] if "b" in kwargs else None
428427
429428 # Output QParams
430- o_inv_scale = kwargs ["o_inv_scale" ] if output_dtype == torch .uint8 else 1.0
431- o_zero_point = kwargs ["o_zp" ] if output_dtype == torch .uint8 else 0
432- assert (
433- kwargs ["postop_name" ] == "none"
434- ) # Expected no post op fused in weight prepack phase
429+ o_inv_scale = kwargs ["output_scale" ]
430+ o_zero_point = kwargs ["output_zero_point" ]
431+
432+ # post op
433+ postop_name = kwargs ["postop_name" ]
434+ postop_args = kwargs ["postop_args" ]
435+ postop_algorithm = kwargs ["postop_algorithm" ]
435436
436437 computation_args = (
437438 x ,
@@ -444,12 +445,12 @@ def qlinear(match: Match, *args, **kwargs):
444445 o_inv_scale ,
445446 o_zero_point ,
446447 output_dtype ,
447- unary_attr . op_name ,
448- unary_attr . scalars_attr ,
449- unary_attr . algorithm_attr ,
448+ postop_name ,
449+ postop_args ,
450+ postop_algorithm ,
450451 )
451- counters ["inductor" ]["qlinear_unary_matcher_count " ] += 1
452- counters ["inductor" ]["qlinear_unary_matcher_nodes " ] += len (match .nodes )
452+ counters ["inductor" ]["qlinear_unary_lower_count " ] += 1
453+ counters ["inductor" ]["qlinear_unary_lower_nodes " ] += len (match .nodes )
453454 return L [computation_op ](* computation_args )
454455
455456 return qlinear
@@ -704,13 +705,7 @@ def qconv_binary(match: Match, *args, **kwargs):
704705
705706
706707def _register_quantization_unary_fusion ():
707- from .mkldnn_fusion import (
708- _gelu_fusion_1 as _gelu_fusion_erf ,
709- _gelu_fusion_2 as _gelu_fusion_tanh ,
710- _hardswish_fusion ,
711- _hardtanh_fusion ,
712- _silu_fusion ,
713- )
708+ from .mkldnn_fusion import _hardswish_fusion , _hardtanh_fusion , _silu_fusion
714709
715710 class UnaryAttr :
716711 def __init__ (
@@ -720,8 +715,8 @@ def __init__(
720715 self .scalars_attr = scalars_attr if scalars_attr else []
721716 self .algorithm_attr = algorithm_attr if algorithm_attr else ""
722717
718+ # QConv2d
723719 for original_pattern_output_dtype in [torch .float32 , torch .bfloat16 ]:
724- # QConv2d
725720 # Priority 1 to match: QConv2d Unary pattern with int8 output
726721 # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
727722 # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
@@ -819,87 +814,19 @@ def __init__(
819814 unary_attr , # unary_attr
820815 )
821816
822- # QLinear
823- for x_scale_zp_are_tensors in (False , True ):
824- qlinear_pattern = get_qlinear_pt2e_pattern (x_scale_zp_are_tensors )
825- # Priority 1 to match: QLinear Unary pattern with int8 output
826- linear_unary_replace_patterns = {
827- UnaryAttr ("none" , [], "" ): generate_pattern_with_output_quant (
828- qlinear_pattern ,
829- ),
830- UnaryAttr ("relu" , [], "" ): generate_pattern_with_output_quant (
831- generate_pattern_with_unary (qlinear_pattern , aten .relu .default ),
832- ),
833- UnaryAttr ("gelu" , [], "none" ): generate_pattern_with_output_quant (
834- _unary_fusion_pattern (
835- _gelu_fusion_erf ,
836- get_qlinear_pt2e_pattern (
837- x_scale_zp_are_tensors , 1 if is_bf16 else 2
838- ),
839- 2 ,
840- is_bf16 ,
841- ),
842- with_dtype_convert = is_bf16 ,
843- ),
844- UnaryAttr ("gelu" , [], "tanh" ): generate_pattern_with_output_quant (
845- _unary_fusion_pattern (
846- _gelu_fusion_tanh ,
847- get_qlinear_pt2e_pattern (
848- x_scale_zp_are_tensors , 1 if is_bf16 else 4
849- ),
850- 4 ,
851- is_bf16 ,
852- ),
853- with_dtype_convert = is_bf16 ,
854- ),
855- }
856-
857- for unary_attr , patterns in linear_unary_replace_patterns .items ():
858- _register_quantized_linear_lowering (
859- patterns ,
860- 1 , # pass_number
861- torch .ops .onednn .qlinear_pointwise , # computation_op
862- unary_attr , # unary_attr
863- )
864-
865- # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
866- linear_unary_replace_float_out_patterns = {
867- UnaryAttr ("relu" , [], "" ): generate_pattern_with_unary (
868- qlinear_pattern , aten .relu .default
869- ),
870- UnaryAttr ("gelu" , [], "none" ): _may_generate_pattern_with_dtype_convert (
871- _unary_fusion_pattern (
872- _gelu_fusion_erf ,
873- get_qlinear_pt2e_pattern (
874- x_scale_zp_are_tensors , 1 if is_bf16 else 2
875- ),
876- 2 ,
877- is_bf16 ,
878- ),
879- Arg (),
880- is_bf16 ,
881- ),
882- UnaryAttr ("gelu" , [], "tanh" ): _may_generate_pattern_with_dtype_convert (
883- _unary_fusion_pattern (
884- _gelu_fusion_tanh ,
885- get_qlinear_pt2e_pattern (
886- x_scale_zp_are_tensors , 1 if is_bf16 else 4
887- ),
888- 4 ,
889- is_bf16 ,
890- ),
891- Arg (),
892- is_bf16 ,
893- ),
894- }
895-
896- for unary_attr , patterns in linear_unary_replace_float_out_patterns .items ():
897- _register_quantized_linear_lowering (
898- patterns ,
899- 2 , # pass_number
900- torch .ops .onednn .qlinear_pointwise , # computation_op
901- unary_attr , # unary_attr
902- )
817+ # QLinear
818+ for x_scale_zp_are_tensors in (False , True ):
819+ qlinear_pattern = get_qlinear_pt2e_pattern (x_scale_zp_are_tensors )
820+ computation_op = (
821+ torch .ops .onednn .qlinear_pointwise .tensor
822+ if x_scale_zp_are_tensors
823+ else torch .ops .onednn .qlinear_pointwise .default
824+ )
825+ _register_quantized_linear_lowering (
826+ qlinear_pattern ,
827+ 2 , # pass_number
828+ computation_op ,
829+ )
903830
904831
905832def _register_quantization_binary_fusion ():
@@ -3059,6 +2986,177 @@ def _int_mm_weight_prepack(match: Match, *args, **kwargs):
30592986 )
30602987
30612988
2989+ def _register_qlinear_post_op_fusion_pass (
2990+ pattern ,
2991+ pass_number ,
2992+ computation_op ,
2993+ unary_attr ,
2994+ ):
2995+ @register_freezing_graph_pattern (
2996+ pattern ,
2997+ extra_check = _is_valid_quantized_linear_optimization_pattern (),
2998+ pass_number = pass_number ,
2999+ )
3000+ def qlinear_post_op_fusion (match : Match , * args , ** kwargs ):
3001+ """
3002+ Match the pattern:
3003+ qlinear - post op
3004+ """
3005+ output_dtype = _get_pattern_output_dtype (match )
3006+ # Activation QParams
3007+ x , x_scale , x_zp = (
3008+ kwargs ["x" ],
3009+ kwargs ["x_scale" ],
3010+ kwargs ["x_zp" ],
3011+ )
3012+ # Weight QParams
3013+ packed_weight , w_scale , w_zp = (
3014+ kwargs ["packed_weight" ],
3015+ kwargs ["w_scale" ],
3016+ kwargs ["w_zp" ],
3017+ )
3018+
3019+ # bias
3020+ b = kwargs ["b" ] if "b" in kwargs else None
3021+
3022+ # Output QParams
3023+ o_inv_scale = kwargs ["o_inv_scale" ] if output_dtype == torch .uint8 else 1.0
3024+ o_zero_point = kwargs ["o_zp" ] if output_dtype == torch .uint8 else 0
3025+ assert (
3026+ kwargs ["postop_name" ] == "none"
3027+ ) # Expected no post op fused in weight prepack phase
3028+
3029+ out_node = match .output_node ()
3030+ with match .graph .inserting_before (out_node ):
3031+ computation_args = (
3032+ x ,
3033+ x_scale ,
3034+ x_zp ,
3035+ packed_weight ,
3036+ w_scale ,
3037+ w_zp ,
3038+ b ,
3039+ o_inv_scale ,
3040+ o_zero_point ,
3041+ output_dtype ,
3042+ unary_attr .op_name ,
3043+ unary_attr .scalars_attr ,
3044+ unary_attr .algorithm_attr ,
3045+ )
3046+ new_linear_node = match .graph .call_function (
3047+ computation_op , args = computation_args
3048+ )
3049+ out_node .replace_all_uses_with (new_linear_node )
3050+ new_linear_node .meta .update (out_node .meta )
3051+ for node in reversed (match .nodes ):
3052+ match .graph .erase_node (node )
3053+ counters ["inductor" ]["qlinear_unary_matcher_count" ] += 1
3054+ counters ["inductor" ]["qlinear_unary_matcher_nodes" ] += len (match .nodes )
3055+
3056+
3057+ def _register_qlinear_unary_fusion ():
3058+ from .mkldnn_fusion import (
3059+ _gelu_fusion_1 as _gelu_fusion_erf ,
3060+ _gelu_fusion_2 as _gelu_fusion_tanh ,
3061+ )
3062+
3063+ class UnaryAttr :
3064+ def __init__ (
3065+ self , op_name : str , scalars_attr = None , algorithm_attr = None
3066+ ) -> None :
3067+ self .op_name = op_name
3068+ self .scalars_attr = scalars_attr if scalars_attr else []
3069+ self .algorithm_attr = algorithm_attr if algorithm_attr else ""
3070+
3071+ for original_pattern_output_dtype in [torch .float32 , torch .bfloat16 ]:
3072+ is_bf16 = original_pattern_output_dtype == torch .bfloat16
3073+ for x_scale_zp_are_tensors in (False , True ):
3074+ qlinear_pattern = get_qlinear_pt2e_pattern (x_scale_zp_are_tensors )
3075+ computation_op = (
3076+ torch .ops .onednn .qlinear_pointwise .tensor
3077+ if x_scale_zp_are_tensors
3078+ else torch .ops .onednn .qlinear_pointwise .default
3079+ )
3080+ # Priority 1 to match: QLinear Unary pattern with int8 output
3081+ linear_unary_replace_patterns = {
3082+ UnaryAttr ("none" , [], "" ): generate_pattern_with_output_quant (
3083+ qlinear_pattern ,
3084+ ),
3085+ UnaryAttr ("relu" , [], "" ): generate_pattern_with_output_quant (
3086+ generate_pattern_with_unary (qlinear_pattern , aten .relu .default ),
3087+ ),
3088+ UnaryAttr ("gelu" , [], "none" ): generate_pattern_with_output_quant (
3089+ _unary_fusion_pattern (
3090+ _gelu_fusion_erf ,
3091+ get_qlinear_pt2e_pattern (
3092+ x_scale_zp_are_tensors , 1 if is_bf16 else 2
3093+ ),
3094+ 2 ,
3095+ is_bf16 ,
3096+ ),
3097+ with_dtype_convert = is_bf16 ,
3098+ ),
3099+ UnaryAttr ("gelu" , [], "tanh" ): generate_pattern_with_output_quant (
3100+ _unary_fusion_pattern (
3101+ _gelu_fusion_tanh ,
3102+ get_qlinear_pt2e_pattern (
3103+ x_scale_zp_are_tensors , 1 if is_bf16 else 4
3104+ ),
3105+ 4 ,
3106+ is_bf16 ,
3107+ ),
3108+ with_dtype_convert = is_bf16 ,
3109+ ),
3110+ }
3111+
3112+ for unary_attr , patterns in linear_unary_replace_patterns .items ():
3113+ _register_qlinear_post_op_fusion_pass (
3114+ patterns ,
3115+ 3 , # pass_number
3116+ computation_op ,
3117+ unary_attr , # unary_attr
3118+ )
3119+
3120+ # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
3121+ linear_unary_replace_float_out_patterns = {
3122+ UnaryAttr ("relu" , [], "" ): generate_pattern_with_unary (
3123+ qlinear_pattern , aten .relu .default
3124+ ),
3125+ UnaryAttr ("gelu" , [], "none" ): _may_generate_pattern_with_dtype_convert (
3126+ _unary_fusion_pattern (
3127+ _gelu_fusion_erf ,
3128+ get_qlinear_pt2e_pattern (
3129+ x_scale_zp_are_tensors , 1 if is_bf16 else 2
3130+ ),
3131+ 2 ,
3132+ is_bf16 ,
3133+ ),
3134+ Arg (),
3135+ is_bf16 ,
3136+ ),
3137+ UnaryAttr ("gelu" , [], "tanh" ): _may_generate_pattern_with_dtype_convert (
3138+ _unary_fusion_pattern (
3139+ _gelu_fusion_tanh ,
3140+ get_qlinear_pt2e_pattern (
3141+ x_scale_zp_are_tensors , 1 if is_bf16 else 4
3142+ ),
3143+ 4 ,
3144+ is_bf16 ,
3145+ ),
3146+ Arg (),
3147+ is_bf16 ,
3148+ ),
3149+ }
3150+
3151+ for unary_attr , patterns in linear_unary_replace_float_out_patterns .items ():
3152+ _register_qlinear_post_op_fusion_pass (
3153+ patterns ,
3154+ 4 , # pass_number
3155+ computation_op ,
3156+ unary_attr , # unary_attr
3157+ )
3158+
3159+
30623160@functools .lru_cache (None )
30633161def _register_quantization_weight_pack_pass ():
30643162 # Step 1: Dequant promotion for int8-mixed-fp32/bf16
@@ -3074,6 +3172,9 @@ def _register_quantization_weight_pack_pass():
30743172 # Step 4: weight prepack for SmoothQuant from Torchao
30753173 _register_smooth_quant_int_mm_pattern ()
30763174
3175+ # Step 5: QLinear post op Fusion
3176+ _register_qlinear_unary_fusion ()
3177+
30773178
30783179def quant_lift_up (graph_module : torch .fx .GraphModule ):
30793180 """
0 commit comments