|
14 | 14 | QuantizationConfig, |
15 | 15 | ) |
16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops |
17 | | -from torch.ao.quantization.observer import MinMaxObserver |
| 17 | +from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver |
18 | 18 | from torch.ao.quantization.quantizer import ( |
19 | 19 | QuantizationAnnotation, |
| 20 | + QuantizationSpec, |
20 | 21 | SharedQuantizationSpec, |
21 | 22 | ) |
22 | 23 | from torch.fx import Node |
23 | 24 |
|
24 | 25 |
|
25 | | -def annotate_matmul_16a8w( # noqa: C901 |
26 | | - gm: torch.fx.GraphModule, traverse_input1=True |
27 | | -) -> None: |
| 26 | +def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: |
| 27 | + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| 28 | + input_qspec_map = {} |
| 29 | + input_act = node.args[0] |
| 30 | + input_spec = quantization_config.input_activation |
| 31 | + input_qspec_map[input_act] = input_spec |
| 32 | + |
| 33 | + weight = node.args[1] |
| 34 | + input_qspec_map[weight] = quantization_config.weight |
| 35 | + |
| 36 | + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| 37 | + input_qspec_map=input_qspec_map, |
| 38 | + output_qspec=quantization_config.output_activation, |
| 39 | + _annotated=True, |
| 40 | + ) |
| 41 | + |
| 42 | + quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( |
| 43 | + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver |
| 44 | + ) |
| 45 | + for node in gm.graph.nodes: |
| 46 | + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: |
| 47 | + if "nn_module_stack" in node.meta: |
| 48 | + module_values_list = list(node.meta["nn_module_stack"].values()) |
| 49 | + full_qualified_name = module_values_list[-1][0] |
| 50 | + if full_qualified_name == "output.conv": |
| 51 | + annotate_conv2d( |
| 52 | + node, quantization_config=quantization_config_16a8w_per_channel |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | +def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): |
| 57 | + for node in gm.graph.nodes: |
| 58 | + if node.op == "output": |
| 59 | + for index, prefill_output in enumerate(node.args[0]): |
| 60 | + kv_quant_attr = kv_quant_attrs[index] |
| 61 | + fixed_observer = FixedQParamsObserver.with_args( |
| 62 | + scale=kv_quant_attr[0], |
| 63 | + zero_point=kv_quant_attr[1], |
| 64 | + quant_min=kv_quant_attr[2], |
| 65 | + quant_max=kv_quant_attr[3], |
| 66 | + dtype=kv_quant_attr[4], |
| 67 | + qscheme=torch.torch.per_tensor_affine, |
| 68 | + ) |
| 69 | + |
| 70 | + fixed_output_spec = QuantizationSpec( |
| 71 | + quant_min=kv_quant_attr[2], |
| 72 | + quant_max=kv_quant_attr[3], |
| 73 | + dtype=kv_quant_attr[4], |
| 74 | + ch_axis=0, |
| 75 | + observer_or_fake_quant_ctr=fixed_observer, |
| 76 | + ) |
| 77 | + |
| 78 | + input_qspec_map = {} |
| 79 | + for input in prefill_output.args: |
| 80 | + if isinstance(input, Node): |
| 81 | + input_qspec_map[input] = fixed_output_spec |
| 82 | + |
| 83 | + prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( |
| 84 | + input_qspec_map=input_qspec_map, |
| 85 | + output_qspec=fixed_output_spec, |
| 86 | + _annotated=True, |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 |
28 | 91 | """ |
29 | 92 | This function is specific for matmul op 16a8w. |
30 | 93 | For k, we will tag such as the below, and |
@@ -142,8 +205,7 @@ def annotate_matmul_input1(node: Node): |
142 | 205 | for node in gm.graph.nodes: |
143 | 206 | if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: |
144 | 207 | annotate_matmul(node, quantization_config_16a8w) |
145 | | - if traverse_input1: |
146 | | - annotate_matmul_input1(node.args[1]) |
| 208 | + annotate_matmul_input1(node.args[1]) |
147 | 209 |
|
148 | 210 |
|
149 | 211 | def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901 |
|
0 commit comments