Skip to content

Commit 66800eb

Browse files
committed
add max-autotune tutorial
1 parent be7f1b3 commit 66800eb

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
Max-autotune Support on CPU with GEMM Template Tutorial
2+
==============================================================
3+
4+
**Author**: `Jiong Gong <https://github.com/jgong5>`__, `Leslie Fang <https://github.com/leslie-fang-intel>`__, `Chunyuan Wu <https://github.com/chunyuan-w>`__
5+
6+
Prerequisites:
7+
----------------
8+
- `torch.compile and TorchInductor concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
9+
10+
Introduction
11+
------------
12+
``max-autotune`` mode for the Inductor CPU backend in ``torch.compile`` profiles multiple implementations of operations at compile time and selects the best-performing one,
13+
trading longer compilation times for improved runtime performance. This enhancement is particularly beneficial for GEMM-related operations.
14+
In the Inductor CPU backend, we’ve introduced a C++ template-based GEMM implementation as an alternative to the ATen-based approach that relies on oneDNN and MKL libraries.
15+
This is similar to the max-autotune mode on CUDA, where implementations from ATen, Triton, and CUTLASS are considered.
16+
17+
We have covered most popular data types, including FP32, BF16, FP16, and INT8, with epilogue fusions for x86 CPUs.
18+
19+
How to activate ``max-autotune`` mode
20+
------------
21+
To activate the ``max-autotune`` mode in PyTorch, set the ``mode`` argument to ``max-autotune`` when compiling your model using ``torch.compile``.
22+
If you prefer to bypass the tuning process and always use the CPP template implementations, you can configure this via an environment variable:
23+
``export TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=CPP``.
24+
25+
26+
Example code
27+
------------
28+
The below code is an example of using the ``max-autotune`` mode on a simple neural network with a linear layer followed by a ReLU activation.
29+
You could run the example code by setting this environment variable ``export TORCHINDUCTOR_FREEZING=1``.
30+
31+
32+
.. code:: python
33+
34+
import torch
35+
from torch._inductor import config
36+
config.trace.log_autotuning_results = True # enable the log of autotuning results
37+
38+
class M(torch.nn.Module):
39+
def __init__(
40+
self,
41+
in_features,
42+
out_features,
43+
bias,
44+
**kwargs,
45+
):
46+
super().__init__()
47+
self.linear = torch.nn.Linear(
48+
in_features,
49+
out_features,
50+
bias,
51+
**kwargs,
52+
)
53+
self.relu = torch.nn.ReLU()
54+
55+
def forward(self, x):
56+
x = self.linear(x)
57+
x = self.relu(x)
58+
return x
59+
60+
amp_enabled = True
61+
batch_size = 64
62+
in_features = 16
63+
out_features = 32
64+
bias = True
65+
66+
x = torch.randn(batch_size, in_features)
67+
model = M(in_features, out_features, bias)
68+
69+
with torch.no_grad(), torch.cpu.amp.autocast(enabled=amp_enabled):
70+
compiled = torch.compile(model, mode="max-autotune") # turn on "max-autotune" mode
71+
y = compiled(x)
72+
73+
74+
When running the above code snippet, you will see the autotuning result (the performance numbers are for demonstration purposes).
75+
In this case, CPP template outperforms ATen kernel so that it will be selected.
76+
77+
.. code:: shell
78+
79+
AUTOTUNE linear_unary(64x16, 32x16, 32)
80+
cpp_packed_gemm_0 0.2142 ms 100.0%
81+
_linear_pointwise 0.2441 ms 87.7%
82+
83+
84+
We could check the generated output code by setting ``export TORCH_LOGS="+output_code"``.
85+
When CPP template is selected, we won't have ``torch.ops.mkldnn._linear_pointwise.default`` (for bfloat16) or ``torch.ops.mkl._mkl_linear.default`` (for float32)
86+
in the generated code anymore, instead, we'll find kernel based on CPP GEMM template ``cpp_fused__to_copy_relu_1``
87+
(only part of the code is demonstrated below for simplicity) with the bias and relu epilogues fused inside the CPP GEMM template kernel.
88+
89+
.. code:: python
90+
91+
cpp_fused__to_copy_relu_1 = async_compile.cpp_pybinding(['const bfloat16*', 'const bfloat16*', 'const bfloat16*', 'bfloat16*'], '''
92+
93+
...
94+
95+
template <bool accum>
96+
inline void kernel_micro_gemm_amx_kernel_32_2(
97+
AMXState& amx_state,
98+
const bfloat16* __restrict__ A,
99+
const bfloat16* __restrict__ B,
100+
float* __restrict__ C,
101+
int64_t K,
102+
int64_t lda,
103+
int64_t ldb,
104+
int64_t ldc,
105+
uint8_t tilecfg_rows
106+
) {
107+
...
108+
}
109+
110+
...
111+
112+
template <bool accum>
113+
inline void kernel_micro_gemm(
114+
AMXState& amx_state,
115+
const bfloat16* __restrict__ A,
116+
const bfloat16* __restrict__ B,
117+
float* __restrict__ C,
118+
int64_t M,
119+
int64_t N,
120+
int64_t K,
121+
int64_t lda,
122+
int64_t ldb,
123+
int64_t ldc
124+
) {
125+
...
126+
}
127+
128+
extern "C"
129+
void kernel(const bfloat16* X, const bfloat16* W, const bfloat16* inp, bfloat16* Y)
130+
{
131+
constexpr int64_t num_threads = 40;
132+
constexpr int64_t N = 32;
133+
constexpr int64_t K = 16;
134+
constexpr int64_t M = static_cast<int64_t>(64L);
135+
...
136+
#pragma omp parallel num_threads(40)
137+
{
138+
const int tid = omp_get_thread_num();
139+
...
140+
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
141+
...
142+
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
143+
...
144+
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
145+
...
146+
for (int64_t nci = nc; nci < nc_block_end; nci++) {
147+
if (kc == k_block_start) {
148+
kernel_micro_gemm<static_cast<bool>(false)>(
149+
...
150+
);
151+
152+
} else {
153+
kernel_micro_gemm<static_cast<bool>(true)>(
154+
...
155+
);
156+
157+
}
158+
}
159+
}
160+
{
161+
{
162+
// Epilogue fusion here for bias and relu
163+
#pragma GCC ivdep
164+
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
165+
{
166+
for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
167+
{
168+
auto tmp0 = at::vec::Vectorized<bfloat16>::loadu(inp + static_cast<int64_t>(n_start + x1), static_cast<int64_t>(16));
169+
auto tmp2 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
170+
auto tmp1 = at::vec::convert<float>(tmp0);
171+
auto tmp3 = tmp1 + tmp2;
172+
auto tmp4 = at::vec::convert<bfloat16>(tmp3);
173+
auto tmp5 = static_cast<float>(0.0);
174+
auto tmp6 = at::vec::Vectorized<float>(tmp5);
175+
auto tmp7 = at::vec::maximum(tmp3, tmp6);
176+
auto tmp8 = at::vec::convert<bfloat16>(tmp7);
177+
tmp8.store(Y + static_cast<int64_t>(n_start + x1 + (32L*m_start) + (32L*x0)), static_cast<int64_t>(16));
178+
}
179+
180+
...
181+
182+
}
183+
}
184+
185+
}
186+
}
187+
}
188+
...
189+
}
190+
}
191+
''')
192+
193+
Conclusion
194+
------------
195+
In this tutorial, we introduced max-autotune support on CPU with GEMM template. We explained the API to activate this feature and demonstrated
196+
the generated code of GEMM template.
197+
198+
This feature is in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_.

prototype_source/prototype_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ Prototype features are not available as part of binary distributions like PyPI o
217217
:link: ../prototype/inductor_cpp_wrapper_tutorial.html
218218
:tags: Model-Optimization
219219

220+
.. customcarditem::
221+
:header: Max-autotune Support on CPU with GEMM Template Tutorial
222+
:card_description: Tutorial for max-autotune mode support for torch.compile with GEMM template
223+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
224+
:link: ../prototype/max_autotune_CPU_with_gemm_template_tutorial.html
225+
:tags: Model-Optimization
226+
220227
.. Distributed
221228
.. customcarditem::
222229
:header: Flight Recorder Tutorial
@@ -265,3 +272,4 @@ Prototype features are not available as part of binary distributions like PyPI o
265272
prototype/maskedtensor_sparsity.html
266273
prototype/maskedtensor_advanced_semantics.html
267274
prototype/maskedtensor_adagrad.html
275+
prototype/max_autotune_CPU_with_gemm_template_tutorial.html

0 commit comments

Comments
 (0)