Skip to content

Commit 6bfa4d7

Browse files
committed
refine
1 parent 19a568a commit 6bfa4d7

File tree

1 file changed

+70
-11
lines changed

1 file changed

+70
-11
lines changed

prototype_source/pt2e_quant_xpu_inductor.rst

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PyTorch 2 Export Quantization with Intel GPU Backend through Inductor
22
==================================================================
33

4-
** Author**: `Yan, Zhiwei`, `Wang, Eikan`, `Liu River`, `Cui, Yifeng`
4+
**Author**: `Yan Zhiwei <https://github.com/ZhiweiYan-96>`, `Wang Eikan <https://github.com/EikanWang>`, `Liu River <https://github.com/riverliuintel>`, `Cui Yifeng <https://github.com/CuiYifeng>`
55

66

77
Prerequisites
@@ -19,7 +19,7 @@ utilze PyTorch 2 Export Quantization flow and lower the quantized model into the
1919

2020
The pytorch 2 export quantization flow uses the torch.export to capture the model into a graph and perform quantization transformations on top of the ATen graph.
2121
This approach is expected to have significantly higher model coverage, better programmability, and a simplified UX.
22-
TorchInductor is the new compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels.
22+
TorchInductor is the compiler backend that compiles the FX Graphs generated by TorchDynamo into optimized C++/Triton kernels.
2323

2424
The quantization flow mainly includes three steps:
2525

@@ -28,9 +28,9 @@ The quantization flow mainly includes three steps:
2828
performing the prepared model's calibration or quantization-aware training, and converting the prepared model into the quantized model.
2929
- Step 3: Lower the quantized model into inductor with the API ``torch.compile``.
3030

31-
During Step3, the inductor would decide which kernels are dispatched into. There are two kinds of kernels the Intel GPU would obtain benefits, oneDNN kernels and triton fusion. oneDNN libray contains
32-
highly-optimized kernels for quantized Conv/GEMM. Furthermore, oneDNN supports extra operator fusion on these operators, like quantized linear with eltwise activation function(ReLU) and binary operation(add, inplace sum).
33-
For other operators that does not call oneDNN or fallback to ATen implementation, triton would be responsible to generate kernels on our GPUs, like operators `quantize` and `dequantize`.
31+
During Step3, the inductor would decide which kernels are dispatched into. There are two kinds of kernels the Intel GPU would obtain benefits, oneDNN kernels and triton kernels. `Intel oneAPI Deep Neural Network Library (oneDNN) <https://github.com/uxlfoundation/oneDNN>` contains
32+
highly-optimized quantized Cong/GEMM kernels for bot CPU and GPU. Furthermore, oneDNN supports extra operator fusion on these operators, like quantized linear with eltwise activation function(ReLU) and binary operation(add, inplace sum).
33+
Besides oneDNN kernels, triton would be responsible to generate kernels on our GPUs, like operators `quantize` and `dequantize`. The triton kernels are optimized by `Intel XPU Backend for Triton <https://github.com/intel/intel-xpu-backend-for-triton>`
3434

3535

3636
The high-level architecture of this flow could look like this:
@@ -64,7 +64,7 @@ The high-level architecture of this flow could look like this:
6464
Inductor
6565
|
6666
—--------------------------------------------------------
67-
| oneDNN Kernels Triton Kernels |
67+
| oneDNN Kernels ATen Ops Triton Kernels |
6868
—--------------------------------------------------------
6969

7070

@@ -75,7 +75,10 @@ Post Training Quantization
7575
Static quantization is the only method we support currently. QAT and dynami quantization will be avaliable in later versions.
7676

7777
Please install dependencies package through Intel GPU channels as follows
78-
`pip install torchvision pytorch-triton-xpu --index-url https://download.pytorch.org/whl/nightly/xpu`
78+
79+
::
80+
81+
pip install torchvision pytorch-triton-xpu --index-url https://download.pytorch.org/whl/nightly/xpu
7982

8083

8184
1. Capture FX Graph
@@ -117,7 +120,7 @@ Next, we will have the FX Module to be quantized.
117120
2. Apply Quantization
118121
^^^^^^^^^^^^^^^^^^^^^^^
119122

120-
After we capture the FX Module to be quantized, we will import the Backend Quantizer for X86 CPU and configure how to
123+
After we capture the FX Module to be quantized, we will import the Backend Quantizer for Intel GPU and configure how to
121124
quantize the model.
122125

123126
::
@@ -127,11 +130,66 @@ quantize the model.
127130

128131
.. note::
129132

130-
The default quantization configuration in ``XPUInductorQuantizer`` uses signed 8-bits for both activations and weights. The tensor is per-tensor quantized, while weight is per-channel quantized.
133+
The default quantization configuration in ``XPUInductorQuantizer`` uses signed 8-bits for both activations and weights. The tensor is per-tensor quantized, while weight is signed 8-bit per-channel quantized.
131134

135+
Besides the default quant configuration, we also support signed 8-bits symmetric quantized activation, which has the potential
136+
to provide better performance.
132137

133-
After we import the backend-specific Quantizer, we will prepare the model for post-training quantization.
134-
``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model.
138+
::
139+
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
140+
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
141+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
142+
from typing import Any, Optional, TYPE_CHECKING
143+
if TYPE_CHECKING:
144+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
145+
def get_xpu_inductor_symm_quantization_config():
146+
extra_args: dict[str, Any] = {"eps": 2**-12}
147+
act_observer_or_fake_quant_ctr = HistogramObserver
148+
act_quantization_spec = QuantizationSpec(
149+
dtype=torch.int8,
150+
quant_min=-128,
151+
quant_max=127,
152+
qscheme=torch.per_tensor_symmetric,
153+
is_dynamic=False,
154+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
155+
**extra_args
156+
),
157+
)
158+
159+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
160+
PerChannelMinMaxObserver
161+
)
162+
163+
weight_quantization_spec = QuantizationSpec(
164+
dtype=torch.int8,
165+
quant_min=-128,
166+
quant_max=127,
167+
qscheme=torch.per_channel_symmetric,
168+
ch_axis=0, # 0 corresponding to weight shape = (oc, ic, kh, kw) of conv
169+
is_dynamic=False,
170+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
171+
**extra_args
172+
),
173+
)
174+
175+
bias_quantization_spec = None # will use placeholder observer by default
176+
quantization_config = QuantizationConfig(
177+
act_quantization_spec,
178+
act_quantization_spec,
179+
weight_quantization_spec,
180+
bias_quantization_spec,
181+
False,
182+
)
183+
return quantization_config
184+
185+
Then, the user can set the quantization configuration to the quantizer.
186+
187+
::
188+
quantizer = XPUInductorQuantizer()
189+
quantizer.set_global(get_xpu_inductor_symm_quantization_config())
190+
191+
After we import the backend-specific Quantizer, we will prepare the model for post-training quantization.
192+
``prepare_pt2e`` folds BatchNorm operators into preceding Conv2d operators, and inserts observers in appropriate places in the model.
135193

136194
::
137195

@@ -200,6 +258,7 @@ script within the BFloat16 Autocast context.
200258
# Running some benchmark
201259
optimized_model(*example_inputs)
202260

261+
203262
Put all these codes together, we will have the toy example code.
204263
Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``.
205264

0 commit comments

Comments
 (0)