Skip to content

Commit f8fcb9e

Browse files
Xia-Weiwenpytorchmergebot
authored andcommitted
[Quant][Inductor][X86] Separate unary post op fusion and lowering for qlinear (pytorch#143903)
**Summary** The current implementation fuses quantized ops and their post ops and lowers the fused the op to cpp backend in the same pass. It is better to separate post op fusion and lowering because - it looks better in terms of design - we need the post op fusion pass for PT2E quantization eager mode This PR is the first of a series of PRs which separate post op fusion and lowering for quantized linear and convolution. It moves unary post op fusion of qlinear out of the lowering pass. This PR moves the fusion pass from the lowering pass to after the weight-prepack pass. The workflow is 1. Weight prepack for qlinear so that `dq - linear` patterns are replaced by `onednn.qlinear_pointwise` 2. Fuse `onednn.qlinear_pointwise` and post ops 3. Lower to cpp backend This PR adds additional `PatternMatcherPass`'s to handle the post op fusion. Pattern matchers used for fusion are reused. **Test plan** It is covered by existing UTs in `test_mkldnn_pattern_matcher.py` for post op fusion. Pull Request resolved: pytorch#143903 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/jerryzh168
1 parent 094ca31 commit f8fcb9e

File tree

4 files changed

+210
-103
lines changed

4 files changed

+210
-103
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,6 +2167,10 @@ def matcher_check_fn():
21672167
# 2. QLinear Unary fusion in post-grad fusion pass
21682168
self.assertEqual(
21692169
counters["inductor"]["qlinear_unary_matcher_count"],
2170+
2,
2171+
)
2172+
self.assertEqual(
2173+
counters["inductor"]["qlinear_unary_lower_count"],
21702174
0 if TEST_ACL else 2,
21712175
)
21722176

@@ -2443,7 +2447,7 @@ def default_matcher_check_fn():
24432447
# 3. QLinear Unary fusion in post-grad fusion pass * 1
24442448
self.assertEqual(
24452449
counters["inductor"]["qlinear_unary_matcher_count"],
2446-
0 if TEST_ACL else 1,
2450+
1,
24472451
)
24482452

24492453
self._test_common(
@@ -3706,7 +3710,7 @@ def matcher_check_fn():
37063710
)
37073711
self.assertEqual(
37083712
counters["inductor"]["qlinear_unary_matcher_count"],
3709-
3 if annotate_matmul and not TEST_ACL else 0,
3713+
3 if annotate_matmul else 0,
37103714
)
37113715

37123716
quantizer = X86InductorQuantizer()

torch/_inductor/fx_passes/freezing_patterns.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def lazy_init():
102102

103103

104104
def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0):
105+
while pass_number > len(pass_patterns) - 1:
106+
pass_patterns.append(PatternMatcherPass())
105107
return register_graph_pattern(
106108
pattern,
107109
extra_check=extra_check,

torch/_inductor/fx_passes/quantization.py

Lines changed: 201 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,6 @@ def _register_quantized_linear_lowering(
401401
pattern,
402402
pass_number,
403403
computation_op,
404-
unary_attr,
405404
):
406405
@register_lowering_pattern(
407406
pattern,
@@ -427,11 +426,13 @@ def qlinear(match: Match, *args, **kwargs):
427426
b = kwargs["b"] if "b" in kwargs else None
428427

429428
# Output QParams
430-
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
431-
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
432-
assert (
433-
kwargs["postop_name"] == "none"
434-
) # Expected no post op fused in weight prepack phase
429+
o_inv_scale = kwargs["output_scale"]
430+
o_zero_point = kwargs["output_zero_point"]
431+
432+
# post op
433+
postop_name = kwargs["postop_name"]
434+
postop_args = kwargs["postop_args"]
435+
postop_algorithm = kwargs["postop_algorithm"]
435436

436437
computation_args = (
437438
x,
@@ -444,12 +445,12 @@ def qlinear(match: Match, *args, **kwargs):
444445
o_inv_scale,
445446
o_zero_point,
446447
output_dtype,
447-
unary_attr.op_name,
448-
unary_attr.scalars_attr,
449-
unary_attr.algorithm_attr,
448+
postop_name,
449+
postop_args,
450+
postop_algorithm,
450451
)
451-
counters["inductor"]["qlinear_unary_matcher_count"] += 1
452-
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
452+
counters["inductor"]["qlinear_unary_lower_count"] += 1
453+
counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes)
453454
return L[computation_op](*computation_args)
454455

455456
return qlinear
@@ -704,13 +705,7 @@ def qconv_binary(match: Match, *args, **kwargs):
704705

705706

706707
def _register_quantization_unary_fusion():
707-
from .mkldnn_fusion import (
708-
_gelu_fusion_1 as _gelu_fusion_erf,
709-
_gelu_fusion_2 as _gelu_fusion_tanh,
710-
_hardswish_fusion,
711-
_hardtanh_fusion,
712-
_silu_fusion,
713-
)
708+
from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion
714709

715710
class UnaryAttr:
716711
def __init__(
@@ -720,8 +715,8 @@ def __init__(
720715
self.scalars_attr = scalars_attr if scalars_attr else []
721716
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
722717

718+
# QConv2d
723719
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
724-
# QConv2d
725720
# Priority 1 to match: QConv2d Unary pattern with int8 output
726721
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
727722
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
@@ -819,87 +814,19 @@ def __init__(
819814
unary_attr, # unary_attr
820815
)
821816

822-
# QLinear
823-
for x_scale_zp_are_tensors in (False, True):
824-
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
825-
# Priority 1 to match: QLinear Unary pattern with int8 output
826-
linear_unary_replace_patterns = {
827-
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
828-
qlinear_pattern,
829-
),
830-
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
831-
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
832-
),
833-
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
834-
_unary_fusion_pattern(
835-
_gelu_fusion_erf,
836-
get_qlinear_pt2e_pattern(
837-
x_scale_zp_are_tensors, 1 if is_bf16 else 2
838-
),
839-
2,
840-
is_bf16,
841-
),
842-
with_dtype_convert=is_bf16,
843-
),
844-
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
845-
_unary_fusion_pattern(
846-
_gelu_fusion_tanh,
847-
get_qlinear_pt2e_pattern(
848-
x_scale_zp_are_tensors, 1 if is_bf16 else 4
849-
),
850-
4,
851-
is_bf16,
852-
),
853-
with_dtype_convert=is_bf16,
854-
),
855-
}
856-
857-
for unary_attr, patterns in linear_unary_replace_patterns.items():
858-
_register_quantized_linear_lowering(
859-
patterns,
860-
1, # pass_number
861-
torch.ops.onednn.qlinear_pointwise, # computation_op
862-
unary_attr, # unary_attr
863-
)
864-
865-
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
866-
linear_unary_replace_float_out_patterns = {
867-
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
868-
qlinear_pattern, aten.relu.default
869-
),
870-
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
871-
_unary_fusion_pattern(
872-
_gelu_fusion_erf,
873-
get_qlinear_pt2e_pattern(
874-
x_scale_zp_are_tensors, 1 if is_bf16 else 2
875-
),
876-
2,
877-
is_bf16,
878-
),
879-
Arg(),
880-
is_bf16,
881-
),
882-
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
883-
_unary_fusion_pattern(
884-
_gelu_fusion_tanh,
885-
get_qlinear_pt2e_pattern(
886-
x_scale_zp_are_tensors, 1 if is_bf16 else 4
887-
),
888-
4,
889-
is_bf16,
890-
),
891-
Arg(),
892-
is_bf16,
893-
),
894-
}
895-
896-
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
897-
_register_quantized_linear_lowering(
898-
patterns,
899-
2, # pass_number
900-
torch.ops.onednn.qlinear_pointwise, # computation_op
901-
unary_attr, # unary_attr
902-
)
817+
# QLinear
818+
for x_scale_zp_are_tensors in (False, True):
819+
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
820+
computation_op = (
821+
torch.ops.onednn.qlinear_pointwise.tensor
822+
if x_scale_zp_are_tensors
823+
else torch.ops.onednn.qlinear_pointwise.default
824+
)
825+
_register_quantized_linear_lowering(
826+
qlinear_pattern,
827+
2, # pass_number
828+
computation_op,
829+
)
903830

904831

905832
def _register_quantization_binary_fusion():
@@ -3059,6 +2986,177 @@ def _int_mm_weight_prepack(match: Match, *args, **kwargs):
30592986
)
30602987

30612988

2989+
def _register_qlinear_post_op_fusion_pass(
2990+
pattern,
2991+
pass_number,
2992+
computation_op,
2993+
unary_attr,
2994+
):
2995+
@register_freezing_graph_pattern(
2996+
pattern,
2997+
extra_check=_is_valid_quantized_linear_optimization_pattern(),
2998+
pass_number=pass_number,
2999+
)
3000+
def qlinear_post_op_fusion(match: Match, *args, **kwargs):
3001+
"""
3002+
Match the pattern:
3003+
qlinear - post op
3004+
"""
3005+
output_dtype = _get_pattern_output_dtype(match)
3006+
# Activation QParams
3007+
x, x_scale, x_zp = (
3008+
kwargs["x"],
3009+
kwargs["x_scale"],
3010+
kwargs["x_zp"],
3011+
)
3012+
# Weight QParams
3013+
packed_weight, w_scale, w_zp = (
3014+
kwargs["packed_weight"],
3015+
kwargs["w_scale"],
3016+
kwargs["w_zp"],
3017+
)
3018+
3019+
# bias
3020+
b = kwargs["b"] if "b" in kwargs else None
3021+
3022+
# Output QParams
3023+
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
3024+
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
3025+
assert (
3026+
kwargs["postop_name"] == "none"
3027+
) # Expected no post op fused in weight prepack phase
3028+
3029+
out_node = match.output_node()
3030+
with match.graph.inserting_before(out_node):
3031+
computation_args = (
3032+
x,
3033+
x_scale,
3034+
x_zp,
3035+
packed_weight,
3036+
w_scale,
3037+
w_zp,
3038+
b,
3039+
o_inv_scale,
3040+
o_zero_point,
3041+
output_dtype,
3042+
unary_attr.op_name,
3043+
unary_attr.scalars_attr,
3044+
unary_attr.algorithm_attr,
3045+
)
3046+
new_linear_node = match.graph.call_function(
3047+
computation_op, args=computation_args
3048+
)
3049+
out_node.replace_all_uses_with(new_linear_node)
3050+
new_linear_node.meta.update(out_node.meta)
3051+
for node in reversed(match.nodes):
3052+
match.graph.erase_node(node)
3053+
counters["inductor"]["qlinear_unary_matcher_count"] += 1
3054+
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
3055+
3056+
3057+
def _register_qlinear_unary_fusion():
3058+
from .mkldnn_fusion import (
3059+
_gelu_fusion_1 as _gelu_fusion_erf,
3060+
_gelu_fusion_2 as _gelu_fusion_tanh,
3061+
)
3062+
3063+
class UnaryAttr:
3064+
def __init__(
3065+
self, op_name: str, scalars_attr=None, algorithm_attr=None
3066+
) -> None:
3067+
self.op_name = op_name
3068+
self.scalars_attr = scalars_attr if scalars_attr else []
3069+
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
3070+
3071+
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
3072+
is_bf16 = original_pattern_output_dtype == torch.bfloat16
3073+
for x_scale_zp_are_tensors in (False, True):
3074+
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
3075+
computation_op = (
3076+
torch.ops.onednn.qlinear_pointwise.tensor
3077+
if x_scale_zp_are_tensors
3078+
else torch.ops.onednn.qlinear_pointwise.default
3079+
)
3080+
# Priority 1 to match: QLinear Unary pattern with int8 output
3081+
linear_unary_replace_patterns = {
3082+
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
3083+
qlinear_pattern,
3084+
),
3085+
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
3086+
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
3087+
),
3088+
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
3089+
_unary_fusion_pattern(
3090+
_gelu_fusion_erf,
3091+
get_qlinear_pt2e_pattern(
3092+
x_scale_zp_are_tensors, 1 if is_bf16 else 2
3093+
),
3094+
2,
3095+
is_bf16,
3096+
),
3097+
with_dtype_convert=is_bf16,
3098+
),
3099+
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
3100+
_unary_fusion_pattern(
3101+
_gelu_fusion_tanh,
3102+
get_qlinear_pt2e_pattern(
3103+
x_scale_zp_are_tensors, 1 if is_bf16 else 4
3104+
),
3105+
4,
3106+
is_bf16,
3107+
),
3108+
with_dtype_convert=is_bf16,
3109+
),
3110+
}
3111+
3112+
for unary_attr, patterns in linear_unary_replace_patterns.items():
3113+
_register_qlinear_post_op_fusion_pass(
3114+
patterns,
3115+
3, # pass_number
3116+
computation_op,
3117+
unary_attr, # unary_attr
3118+
)
3119+
3120+
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
3121+
linear_unary_replace_float_out_patterns = {
3122+
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
3123+
qlinear_pattern, aten.relu.default
3124+
),
3125+
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
3126+
_unary_fusion_pattern(
3127+
_gelu_fusion_erf,
3128+
get_qlinear_pt2e_pattern(
3129+
x_scale_zp_are_tensors, 1 if is_bf16 else 2
3130+
),
3131+
2,
3132+
is_bf16,
3133+
),
3134+
Arg(),
3135+
is_bf16,
3136+
),
3137+
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
3138+
_unary_fusion_pattern(
3139+
_gelu_fusion_tanh,
3140+
get_qlinear_pt2e_pattern(
3141+
x_scale_zp_are_tensors, 1 if is_bf16 else 4
3142+
),
3143+
4,
3144+
is_bf16,
3145+
),
3146+
Arg(),
3147+
is_bf16,
3148+
),
3149+
}
3150+
3151+
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
3152+
_register_qlinear_post_op_fusion_pass(
3153+
patterns,
3154+
4, # pass_number
3155+
computation_op,
3156+
unary_attr, # unary_attr
3157+
)
3158+
3159+
30623160
@functools.lru_cache(None)
30633161
def _register_quantization_weight_pack_pass():
30643162
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
@@ -3074,6 +3172,9 @@ def _register_quantization_weight_pack_pass():
30743172
# Step 4: weight prepack for SmoothQuant from Torchao
30753173
_register_smooth_quant_int_mm_pattern()
30763174

3175+
# Step 5: QLinear post op Fusion
3176+
_register_qlinear_unary_fusion()
3177+
30773178

30783179
def quant_lift_up(graph_module: torch.fx.GraphModule):
30793180
"""

0 commit comments

Comments
 (0)