|  | 
| 14 | 14 |     QuantizationConfig, | 
| 15 | 15 | ) | 
| 16 | 16 | from executorch.exir.dialects._ops import ops as exir_ops | 
| 17 |  | -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver | 
|  | 17 | +from torch.ao.quantization.observer import MinMaxObserver | 
| 18 | 18 | from torch.ao.quantization.quantizer import ( | 
| 19 | 19 |     QuantizationAnnotation, | 
| 20 |  | -    QuantizationSpec, | 
| 21 | 20 |     SharedQuantizationSpec, | 
| 22 | 21 | ) | 
| 23 | 22 | from torch.fx import Node | 
| 24 | 23 | 
 | 
| 25 | 24 | 
 | 
| 26 |  | -def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: | 
| 27 |  | -    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: | 
| 28 |  | -        input_qspec_map = {} | 
| 29 |  | -        input_act = node.args[0] | 
| 30 |  | -        input_spec = quantization_config.input_activation | 
| 31 |  | -        input_qspec_map[input_act] = input_spec | 
| 32 |  | - | 
| 33 |  | -        weight = node.args[1] | 
| 34 |  | -        input_qspec_map[weight] = quantization_config.weight | 
| 35 |  | - | 
| 36 |  | -        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( | 
| 37 |  | -            input_qspec_map=input_qspec_map, | 
| 38 |  | -            output_qspec=quantization_config.output_activation, | 
| 39 |  | -            _annotated=True, | 
| 40 |  | -        ) | 
| 41 |  | - | 
| 42 |  | -    quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( | 
| 43 |  | -        torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver | 
| 44 |  | -    ) | 
| 45 |  | -    for node in gm.graph.nodes: | 
| 46 |  | -        if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: | 
| 47 |  | -            if "nn_module_stack" in node.meta: | 
| 48 |  | -                module_values_list = list(node.meta["nn_module_stack"].values()) | 
| 49 |  | -                full_qualified_name = module_values_list[-1][0] | 
| 50 |  | -                if full_qualified_name == "output.conv": | 
| 51 |  | -                    annotate_conv2d( | 
| 52 |  | -                        node, quantization_config=quantization_config_16a8w_per_channel | 
| 53 |  | -                    ) | 
| 54 |  | - | 
| 55 |  | - | 
| 56 |  | -def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): | 
| 57 |  | -    for node in gm.graph.nodes: | 
| 58 |  | -        if node.op == "output": | 
| 59 |  | -            for index, prefill_output in enumerate(node.args[0]): | 
| 60 |  | -                kv_quant_attr = kv_quant_attrs[index] | 
| 61 |  | -                fixed_observer = FixedQParamsObserver.with_args( | 
| 62 |  | -                    scale=kv_quant_attr[0], | 
| 63 |  | -                    zero_point=kv_quant_attr[1], | 
| 64 |  | -                    quant_min=kv_quant_attr[2], | 
| 65 |  | -                    quant_max=kv_quant_attr[3], | 
| 66 |  | -                    dtype=kv_quant_attr[4], | 
| 67 |  | -                    qscheme=torch.torch.per_tensor_affine, | 
| 68 |  | -                ) | 
| 69 |  | - | 
| 70 |  | -                fixed_output_spec = QuantizationSpec( | 
| 71 |  | -                    quant_min=kv_quant_attr[2], | 
| 72 |  | -                    quant_max=kv_quant_attr[3], | 
| 73 |  | -                    dtype=kv_quant_attr[4], | 
| 74 |  | -                    ch_axis=0, | 
| 75 |  | -                    observer_or_fake_quant_ctr=fixed_observer, | 
| 76 |  | -                ) | 
| 77 |  | - | 
| 78 |  | -                input_qspec_map = {} | 
| 79 |  | -                for input in prefill_output.args: | 
| 80 |  | -                    if isinstance(input, Node): | 
| 81 |  | -                        input_qspec_map[input] = fixed_output_spec | 
| 82 |  | - | 
| 83 |  | -                prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( | 
| 84 |  | -                    input_qspec_map=input_qspec_map, | 
| 85 |  | -                    output_qspec=fixed_output_spec, | 
| 86 |  | -                    _annotated=True, | 
| 87 |  | -                ) | 
| 88 |  | - | 
| 89 |  | - | 
| 90 |  | -def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:  # noqa: C901 | 
|  | 25 | +def annotate_matmul_16a8w(  # noqa: C901 | 
|  | 26 | +    gm: torch.fx.GraphModule, traverse_input1=True | 
|  | 27 | +) -> None: | 
| 91 | 28 |     """ | 
| 92 | 29 |     This function is specific for matmul op 16a8w. | 
| 93 | 30 |     For k, we will tag such as the below, and | 
| @@ -205,7 +142,8 @@ def annotate_matmul_input1(node: Node): | 
| 205 | 142 |     for node in gm.graph.nodes: | 
| 206 | 143 |         if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: | 
| 207 | 144 |             annotate_matmul(node, quantization_config_16a8w) | 
| 208 |  | -            annotate_matmul_input1(node.args[1]) | 
|  | 145 | +            if traverse_input1: | 
|  | 146 | +                annotate_matmul_input1(node.args[1]) | 
| 209 | 147 | 
 | 
| 210 | 148 | 
 | 
| 211 | 149 | def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None:  # noqa: C901 | 
|  | 
0 commit comments